diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index d3cfcc4b09..fe39bb12ba 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -1,20 +1,203 @@ -project(sulvim_core) +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- +cmake_minimum_required( VERSION 3.14 ) + +add_definitions(-DELPP_THREAD_SAFE) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +message( STATUS "Building using CMake version: ${CMAKE_VERSION}" ) + +set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake" ) +include( Utils ) + +# **************************** Build time, type and code version **************************** +get_current_time( BUILD_TIME ) +message( STATUS "Build time = ${BUILD_TIME}" ) + +get_build_type( TARGET BUILD_TYPE + DEFAULT "Release" ) +message( STATUS "Build type = ${BUILD_TYPE}" ) + +get_milvus_version( TARGET MILVUS_VERSION + DEFAULT "0.10.0" ) +message( STATUS "Build version = ${MILVUS_VERSION}" ) + +get_last_commit_id( LAST_COMMIT_ID ) +message( STATUS "LAST_COMMIT_ID = ${LAST_COMMIT_ID}" ) + +#configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h.in +# ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h @ONLY ) + +# unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE) +set( CMAKE_EXPORT_COMPILE_COMMANDS ON ) + +# **************************** Project **************************** +project( milvus VERSION "${MILVUS_VERSION}" ) -cmake_minimum_required(VERSION 3.16) set( CMAKE_CXX_STANDARD 17 ) set( CMAKE_CXX_STANDARD_REQUIRED on ) -set (CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_SOURCE_DIR}/cmake") -include_directories(src) -add_subdirectory(src) -add_subdirectory(unittest) -install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/ - DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include - FILES_MATCHING PATTERN "*_c.h" -) +set( MILVUS_SOURCE_DIR ${PROJECT_SOURCE_DIR} ) +set( MILVUS_BINARY_DIR ${PROJECT_BINARY_DIR} ) +set( MILVUS_ENGINE_SRC ${PROJECT_SOURCE_DIR}/src ) +set( MILVUS_THIRDPARTY_SRC ${PROJECT_SOURCE_DIR}/thirdparty ) -install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so - DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib) +# This will set RPATH to all excutable TARGET +# self-installed dynamic libraries will be correctly linked by excutable +set( CMAKE_INSTALL_RPATH "/usr/lib" "${CMAKE_INSTALL_PREFIX}/lib" ) +set( CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE ) +# **************************** Dependencies **************************** +include( CTest ) +include( BuildUtils ) +include( DefineOptions ) + +include( ExternalProject ) +include( FetchContent ) +include_directories(thirdparty) +set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download ) +set(FETCHCONTENT_QUIET OFF) +include( ThirdPartyPackages ) +find_package(OpenMP REQUIRED) +# **************************** Compiler arguments **************************** +message( STATUS "Building Milvus CPU version" ) + + +#append_flags( CMAKE_CXX_FLAGS +# FLAGS +# "-fPIC" +# "-DELPP_THREAD_SAFE" +# "-fopenmp" +# "-Werror" +# ) + +# **************************** Coding style check tools **************************** +find_package( ClangTools ) +set( BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support" ) +message(STATUS "CMAKE_SOURCE_DIR is at ${CMAKE_SOURCE_DIR}" ) + +if("$ENV{CMAKE_EXPORT_COMPILE_COMMANDS}" STREQUAL "1" OR CLANG_TIDY_FOUND) + # Generate a Clang compile_commands.json "compilation database" file for use + # with various development tools, such as Vim's YouCompleteMe plugin. + # See http://clang.llvm.org/docs/JSONCompilationDatabase.html + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +endif() + +# +# "make lint" target +# +if ( NOT MILVUS_VERBOSE_LINT ) + set( MILVUS_LINT_QUIET "--quiet" ) +endif () + +if ( NOT LINT_EXCLUSIONS_FILE ) + # source files matching a glob from a line in this file + # will be excluded from linting (cpplint, clang-tidy, clang-format) + set( LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt ) +endif () + +find_program( CPPLINT_BIN NAMES cpplint cpplint.py HINTS ${BUILD_SUPPORT_DIR} ) +message( STATUS "Found cpplint executable at ${CPPLINT_BIN}" ) + +# +# "make lint" targets +# +add_custom_target( lint + ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_cpplint.py + --cpplint_binary ${CPPLINT_BIN} + --exclude_globs ${LINT_EXCLUSIONS_FILE} + --source_dir ${CMAKE_CURRENT_SOURCE_DIR} + ${MILVUS_LINT_QUIET} + ) + +# +# "make clang-format" and "make check-clang-format" targets +# +if ( ${CLANG_FORMAT_FOUND} ) + # runs clang format and updates files in place. + add_custom_target( clang-format + ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_format.py + --clang_format_binary ${CLANG_FORMAT_BIN} + --exclude_globs ${LINT_EXCLUSIONS_FILE} + --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src + --fix + ${MILVUS_LINT_QUIET} ) + + # runs clang format and exits with a non-zero exit code if any files need to be reformatted + add_custom_target( check-clang-format + ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_format.py + --clang_format_binary ${CLANG_FORMAT_BIN} + --exclude_globs ${LINT_EXCLUSIONS_FILE} + --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src + ${MILVUS_LINT_QUIET} ) +endif () + +# +# "make clang-tidy" and "make check-clang-tidy" targets +# +if ( ${CLANG_TIDY_FOUND} ) + # runs clang-tidy and attempts to fix any warning automatically + add_custom_target( clang-tidy + ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_tidy.py + --clang_tidy_binary ${CLANG_TIDY_BIN} + --exclude_globs ${LINT_EXCLUSIONS_FILE} + --compile_commands ${CMAKE_BINARY_DIR}/compile_commands.json + --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src + --fix + ${MILVUS_LINT_QUIET} ) + + # runs clang-tidy and exits with a non-zero exit code if any errors are found. + add_custom_target( check-clang-tidy + ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_clang_tidy.py + --clang_tidy_binary ${CLANG_TIDY_BIN} + --exclude_globs ${LINT_EXCLUSIONS_FILE} + --compile_commands ${CMAKE_BINARY_DIR}/compile_commands.json + --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src + ${MILVUS_LINT_QUIET} ) +endif () + +# +# Validate and print out Milvus configuration options +# + +config_summary() + +# **************************** Source files **************************** + +add_subdirectory( thirdparty ) +add_subdirectory( src ) + +# Unittest lib +if ( BUILD_UNIT_TEST STREQUAL "ON" ) + if ( BUILD_COVERAGE STREQUAL "ON" ) + append_flags( CMAKE_CXX_FLAGS + FLAGS + "-fprofile-arcs" + "-ftest-coverage" + ) + endif () + append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS") + + add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/unittest ) +endif () + + + +add_custom_target( Clean-All COMMAND ${CMAKE_BUILD_TOOL} clean ) + +# **************************** Install **************************** + +if ( NOT MILVUS_DB_PATH ) + set( MILVUS_DB_PATH "${CMAKE_INSTALL_PREFIX}" ) +endif () + +set( GPU_ENABLE "false" ) diff --git a/core/CMakeLists_old.txt b/core/CMakeLists_old.txt new file mode 100644 index 0000000000..d3cfcc4b09 --- /dev/null +++ b/core/CMakeLists_old.txt @@ -0,0 +1,20 @@ +project(sulvim_core) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +cmake_minimum_required(VERSION 3.16) +set( CMAKE_CXX_STANDARD 17 ) +set( CMAKE_CXX_STANDARD_REQUIRED on ) +set (CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_SOURCE_DIR}/cmake") +include_directories(src) +add_subdirectory(src) +add_subdirectory(unittest) + +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/dog_segment/ + DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/include + FILES_MATCHING PATTERN "*_c.h" +) + +install(FILES ${CMAKE_BINARY_DIR}/src/dog_segment/libmilvus_dog_segment.so + DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib) diff --git a/core/build.sh b/core/build.sh index c731f07a03..4417ca5f22 100755 --- a/core/build.sh +++ b/core/build.sh @@ -1,8 +1,156 @@ #!/bin/bash -if [[ -d "./build" ]]; then - rm -rf build + +# Compile jobs variable; Usage: $ jobs=12 ./build.sh ... +if [[ ! ${jobs+1} ]]; then + jobs=$(nproc) fi -mkdir build && cd build -cmake .. -make -j8 && make install +BUILD_OUTPUT_DIR="cmake_build" +BUILD_TYPE="Debug" +BUILD_UNITTEST="OFF" +INSTALL_PREFIX=$(pwd)/milvus +MAKE_CLEAN="OFF" +BUILD_COVERAGE="OFF" +DB_PATH="/tmp/milvus" +PROFILING="OFF" +RUN_CPPLINT="OFF" +CUDA_COMPILER=/usr/local/cuda/bin/nvcc +GPU_VERSION="OFF" #defaults to CPU version +WITH_PROMETHEUS="ON" +CUDA_ARCH="DEFAULT" +CUSTOM_THIRDPARTY_PATH="" + +while getopts "p:d:t:s:f:ulrcghzme" arg; do + case $arg in + f) + CUSTOM_THIRDPARTY_PATH=$OPTARG + ;; + p) + INSTALL_PREFIX=$OPTARG + ;; + d) + DB_PATH=$OPTARG + ;; + t) + BUILD_TYPE=$OPTARG # BUILD_TYPE + ;; + u) + echo "Build and run unittest cases" + BUILD_UNITTEST="ON" + ;; + l) + RUN_CPPLINT="ON" + ;; + r) + if [[ -d ${BUILD_OUTPUT_DIR} ]]; then + MAKE_CLEAN="ON" + fi + ;; + c) + BUILD_COVERAGE="ON" + ;; + z) + PROFILING="ON" + ;; + g) + GPU_VERSION="ON" + ;; + e) + WITH_PROMETHEUS="OFF" + ;; + s) + CUDA_ARCH=$OPTARG + ;; + h) # help + echo " + +parameter: +-f: custom paths of thirdparty downloaded files(default: NULL) +-p: install prefix(default: $(pwd)/milvus) +-d: db data path(default: /tmp/milvus) +-t: build type(default: Debug) +-u: building unit test options(default: OFF) +-l: run cpplint, clang-format and clang-tidy(default: OFF) +-r: remove previous build directory(default: OFF) +-c: code coverage(default: OFF) +-z: profiling(default: OFF) +-g: build GPU version(default: OFF) +-e: build without prometheus(default: OFF) +-s: build with CUDA arch(default:DEFAULT), for example '-gencode=compute_61,code=sm_61;-gencode=compute_75,code=sm_75' +-h: help + +usage: +./build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -s \${CUDA_ARCH} -f\${CUSTOM_THIRDPARTY_PATH} [-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h] + " + exit 0 + ;; + ?) + echo "ERROR! unknown argument" + exit 1 + ;; + esac +done + +if [[ ! -d ${BUILD_OUTPUT_DIR} ]]; then + mkdir ${BUILD_OUTPUT_DIR} +fi + +cd ${BUILD_OUTPUT_DIR} + +# remove make cache since build.sh -l use default variables +# force update the variables each time +make rebuild_cache >/dev/null 2>&1 + + +if [[ ${MAKE_CLEAN} == "ON" ]]; then + echo "Runing make clean in ${BUILD_OUTPUT_DIR} ..." + make clean + exit 0 +fi + +CMAKE_CMD="cmake \ +-DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ +-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ +-DOpenBLAS_SOURCE=AUTO \ +-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \ +-DBUILD_COVERAGE=${BUILD_COVERAGE} \ +-DMILVUS_DB_PATH=${DB_PATH} \ +-DENABLE_CPU_PROFILING=${PROFILING} \ +-DMILVUS_GPU_VERSION=${GPU_VERSION} \ +-DMILVUS_WITH_PROMETHEUS=${WITH_PROMETHEUS} \ +-DMILVUS_CUDA_ARCH=${CUDA_ARCH} \ +-DCUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_PATH} \ +../" +echo ${CMAKE_CMD} +${CMAKE_CMD} + + +if [[ ${RUN_CPPLINT} == "ON" ]]; then + # cpplint check + make lint + if [ $? -ne 0 ]; then + echo "ERROR! cpplint check failed" + exit 1 + fi + echo "cpplint check passed!" + + # clang-format check + make check-clang-format + if [ $? -ne 0 ]; then + echo "ERROR! clang-format check failed" + exit 1 + fi + echo "clang-format check passed!" + + # clang-tidy check + make check-clang-tidy + if [ $? -ne 0 ]; then + echo "ERROR! clang-tidy check failed" + exit 1 + fi + echo "clang-tidy check passed!" +else + # compile and build + make -j ${jobs} install || exit 1 +fi diff --git a/core/cmake/BuildUtils.cmake b/core/cmake/BuildUtils.cmake new file mode 100644 index 0000000000..a739ce243d --- /dev/null +++ b/core/cmake/BuildUtils.cmake @@ -0,0 +1,231 @@ +# Define a function that check last file modification +function(Check_Last_Modify cache_check_lists_file_path working_dir last_modified_commit_id) + if(EXISTS "${working_dir}") + if(EXISTS "${cache_check_lists_file_path}") + set(GIT_LOG_SKIP_NUM 0) + set(_MATCH_ALL ON CACHE BOOL "Match all") + set(_LOOP_STATUS ON CACHE BOOL "Whether out of loop") + file(STRINGS ${cache_check_lists_file_path} CACHE_IGNORE_TXT) + while(_LOOP_STATUS) + foreach(_IGNORE_ENTRY ${CACHE_IGNORE_TXT}) + if(NOT _IGNORE_ENTRY MATCHES "^[^#]+") + continue() + endif() + + set(_MATCH_ALL OFF) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --name-status --pretty= WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE CHANGE_FILES) + if(NOT CHANGE_FILES STREQUAL "") + string(REPLACE "\n" ";" _CHANGE_FILES ${CHANGE_FILES}) + foreach(_FILE_ENTRY ${_CHANGE_FILES}) + string(REGEX MATCH "[^ \t]+$" _FILE_NAME ${_FILE_ENTRY}) + execute_process(COMMAND sh -c "echo ${_FILE_NAME} | grep ${_IGNORE_ENTRY}" RESULT_VARIABLE return_code) + if (return_code EQUAL 0) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + set(_LOOP_STATUS OFF) + endif() + endforeach() + else() + set(_LOOP_STATUS OFF) + endif() + endforeach() + + if(_MATCH_ALL) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + set(_LOOP_STATUS OFF) + endif() + + math(EXPR GIT_LOG_SKIP_NUM "${GIT_LOG_SKIP_NUM} + 1") + endwhile(_LOOP_STATUS) + else() + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set (${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + endif() + else() + message(FATAL_ERROR "The directory ${working_dir} does not exist") + endif() +endfunction() + +# Define a function that extracts a cached package +function(ExternalProject_Use_Cache project_name package_file install_path) + message(STATUS "Will use cached package file: ${package_file}") + + ExternalProject_Add(${project_name} + DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E echo + "No download step needed (using cached package)" + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E echo + "No configure step needed (using cached package)" + BUILD_COMMAND ${CMAKE_COMMAND} -E echo + "No build step needed (using cached package)" + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo + "No install step needed (using cached package)" + ) + + # We want our tar files to contain the Install/ prefix (not for any + # very special reason, only for consistency and so that we can identify them + # in the extraction logs) which means that we must extract them in the + # binary (top-level build) directory to have them installed in the right + # place for subsequent ExternalProjects to pick them up. It seems that the + # only way to control the working directory is with Add_Step! + ExternalProject_Add_Step(${project_name} extract + ALWAYS 1 + COMMAND + ${CMAKE_COMMAND} -E echo + "Extracting ${package_file} to ${install_path}" + COMMAND + ${CMAKE_COMMAND} -E tar xzf ${package_file} ${install_path} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + ) + + ExternalProject_Add_StepTargets(${project_name} extract) +endfunction() + +# Define a function that to create a new cached package +function(ExternalProject_Create_Cache project_name package_file install_path cache_username cache_password cache_path) + if(EXISTS ${package_file}) + message(STATUS "Removing existing package file: ${package_file}") + file(REMOVE ${package_file}) + endif() + + string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file}) + if(NOT EXISTS ${package_dir}) + file(MAKE_DIRECTORY ${package_dir}) + endif() + + message(STATUS "Will create cached package file: ${package_file}") + + ExternalProject_Add_Step(${project_name} package + DEPENDEES install + BYPRODUCTS ${package_file} + COMMAND ${CMAKE_COMMAND} -E echo "Updating cached package file: ${package_file}" + COMMAND ${CMAKE_COMMAND} -E tar czvf ${package_file} ${install_path} + COMMAND ${CMAKE_COMMAND} -E echo "Uploading package file ${package_file} to ${cache_path}" + COMMAND curl -u${cache_username}:${cache_password} -T ${package_file} ${cache_path} + ) + + ExternalProject_Add_StepTargets(${project_name} package) +endfunction() + +function(ADD_THIRDPARTY_LIB LIB_NAME) + set(options) + set(one_value_args SHARED_LIB STATIC_LIB) + set(multi_value_args DEPS INCLUDE_DIRECTORIES) + cmake_parse_arguments(ARG + "${options}" + "${one_value_args}" + "${multi_value_args}" + ${ARGN}) + if(ARG_UNPARSED_ARGUMENTS) + message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}") + endif() + + if(ARG_STATIC_LIB AND ARG_SHARED_LIB) + if(NOT ARG_STATIC_LIB) + message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}") + endif() + + set(AUG_LIB_NAME "${LIB_NAME}_static") + add_library(${AUG_LIB_NAME} STATIC IMPORTED) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}") + if(ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif() + message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}") + if(ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif() + + set(AUG_LIB_NAME "${LIB_NAME}_shared") + add_library(${AUG_LIB_NAME} SHARED IMPORTED) + + if(WIN32) + # Mark the ".lib" location as part of a Windows DLL + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}") + else() + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}") + endif() + if(ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif() + message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}") + if(ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif() + elseif(ARG_STATIC_LIB) + set(AUG_LIB_NAME "${LIB_NAME}_static") + add_library(${AUG_LIB_NAME} STATIC IMPORTED) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}") + if(ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif() + message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}") + if(ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif() + elseif(ARG_SHARED_LIB) + set(AUG_LIB_NAME "${LIB_NAME}_shared") + add_library(${AUG_LIB_NAME} SHARED IMPORTED) + + if(WIN32) + # Mark the ".lib" location as part of a Windows DLL + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}") + else() + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}") + endif() + message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}") + if(ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif() + if(ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif() + else() + message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}") + endif() +endfunction() + +MACRO (import_mysql_inc) + find_path (MYSQL_INCLUDE_DIR + NAMES "mysql.h" + PATH_SUFFIXES "mysql") + + if (${MYSQL_INCLUDE_DIR} STREQUAL "MYSQL_INCLUDE_DIR-NOTFOUND") + message(FATAL_ERROR "Could not found MySQL include directory") + else () + include_directories(${MYSQL_INCLUDE_DIR}) + endif () +ENDMACRO (import_mysql_inc) + +MACRO(using_ccache_if_defined MILVUS_USE_CCACHE) + if (MILVUS_USE_CCACHE) + find_program(CCACHE_FOUND ccache) + if (CCACHE_FOUND) + message(STATUS "Using ccache: ${CCACHE_FOUND}") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND}) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND}) + # let ccache preserve C++ comments, because some of them may be + # meaningful to the compiler + set(ENV{CCACHE_COMMENTS} "1") + endif (CCACHE_FOUND) + endif () +ENDMACRO(using_ccache_if_defined) + diff --git a/core/cmake/DefineOptions.cmake b/core/cmake/DefineOptions.cmake new file mode 100644 index 0000000000..1533d1082a --- /dev/null +++ b/core/cmake/DefineOptions.cmake @@ -0,0 +1,156 @@ + +macro(set_option_category name) + set(MILVUS_OPTION_CATEGORY ${name}) + list(APPEND "MILVUS_OPTION_CATEGORIES" ${name}) +endmacro() + +macro(define_option name description default) + option(${name} ${description} ${default}) + list(APPEND "MILVUS_${MILVUS_OPTION_CATEGORY}_OPTION_NAMES" ${name}) + set("${name}_OPTION_DESCRIPTION" ${description}) + set("${name}_OPTION_DEFAULT" ${default}) + set("${name}_OPTION_TYPE" "bool") +endmacro() + +function(list_join lst glue out) + if ("${${lst}}" STREQUAL "") + set(${out} "" PARENT_SCOPE) + return() + endif () + + list(GET ${lst} 0 joined) + list(REMOVE_AT ${lst} 0) + foreach (item ${${lst}}) + set(joined "${joined}${glue}${item}") + endforeach () + set(${out} ${joined} PARENT_SCOPE) +endfunction() + +macro(define_option_string name description default) + set(${name} ${default} CACHE STRING ${description}) + list(APPEND "MILVUS_${MILVUS_OPTION_CATEGORY}_OPTION_NAMES" ${name}) + set("${name}_OPTION_DESCRIPTION" ${description}) + set("${name}_OPTION_DEFAULT" "\"${default}\"") + set("${name}_OPTION_TYPE" "string") + + set("${name}_OPTION_ENUM" ${ARGN}) + list_join("${name}_OPTION_ENUM" "|" "${name}_OPTION_ENUM") + if (NOT ("${${name}_OPTION_ENUM}" STREQUAL "")) + set_property(CACHE ${name} PROPERTY STRINGS ${ARGN}) + endif () +endmacro() + +#---------------------------------------------------------------------- +set_option_category("Milvus Build Option") + +define_option(MILVUS_GPU_VERSION "Build GPU version" OFF) + +#---------------------------------------------------------------------- +set_option_category("Thirdparty") + +set(MILVUS_DEPENDENCY_SOURCE_DEFAULT "BUNDLED") + +define_option_string(MILVUS_DEPENDENCY_SOURCE + "Method to use for acquiring MILVUS's build dependencies" + "${MILVUS_DEPENDENCY_SOURCE_DEFAULT}" + "AUTO" + "BUNDLED" + "SYSTEM") + +define_option(MILVUS_USE_CCACHE "Use ccache when compiling (if available)" ON) + +define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD + "Show output from ExternalProjects rather than just logging to files" ON) + +define_option(MILVUS_WITH_EASYLOGGINGPP "Build with Easylogging++ library" ON) + +define_option(MILVUS_WITH_GRPC "Build with GRPC" OFF) + +define_option(MILVUS_WITH_ZLIB "Build with zlib compression" ON) + +define_option(MILVUS_WITH_OPENTRACING "Build with Opentracing" ON) + +define_option(MILVUS_WITH_YAMLCPP "Build with yaml-cpp library" ON) + +define_option(MILVUS_WITH_PULSAR "Build with pulsar-client" ON) + +#---------------------------------------------------------------------- +set_option_category("Test and benchmark") + +unset(MILVUS_BUILD_TESTS CACHE) +if (BUILD_UNIT_TEST) + define_option(MILVUS_BUILD_TESTS "Build the MILVUS googletest unit tests" ON) +else () + define_option(MILVUS_BUILD_TESTS "Build the MILVUS googletest unit tests" OFF) +endif (BUILD_UNIT_TEST) + +#---------------------------------------------------------------------- +macro(config_summary) + message(STATUS "---------------------------------------------------------------------") + message(STATUS "MILVUS version: ${MILVUS_VERSION}") + message(STATUS) + message(STATUS "Build configuration summary:") + + message(STATUS " Generator: ${CMAKE_GENERATOR}") + message(STATUS " Build type: ${CMAKE_BUILD_TYPE}") + message(STATUS " Source directory: ${CMAKE_CURRENT_SOURCE_DIR}") + if (${CMAKE_EXPORT_COMPILE_COMMANDS}) + message( + STATUS " Compile commands: ${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json") + endif () + + foreach (category ${MILVUS_OPTION_CATEGORIES}) + + message(STATUS) + message(STATUS "${category} options:") + + set(option_names ${MILVUS_${category}_OPTION_NAMES}) + + set(max_value_length 0) + foreach (name ${option_names}) + string(LENGTH "\"${${name}}\"" value_length) + if (${max_value_length} LESS ${value_length}) + set(max_value_length ${value_length}) + endif () + endforeach () + + foreach (name ${option_names}) + if ("${${name}_OPTION_TYPE}" STREQUAL "string") + set(value "\"${${name}}\"") + else () + set(value "${${name}}") + endif () + + set(default ${${name}_OPTION_DEFAULT}) + set(description ${${name}_OPTION_DESCRIPTION}) + string(LENGTH ${description} description_length) + if (${description_length} LESS 70) + string( + SUBSTRING + " " + ${description_length} -1 description_padding) + else () + set(description_padding " + ") + endif () + + set(comment "[${name}]") + + if ("${value}" STREQUAL "${default}") + set(comment "[default] ${comment}") + endif () + + if (NOT ("${${name}_OPTION_ENUM}" STREQUAL "")) + set(comment "${comment} [${${name}_OPTION_ENUM}]") + endif () + + string( + SUBSTRING "${value} " + 0 ${max_value_length} value) + + message(STATUS " ${description} ${description_padding} ${value} ${comment}") + endforeach () + + endforeach () + +endmacro() diff --git a/core/cmake/FindClangTools.cmake b/core/cmake/FindClangTools.cmake new file mode 100644 index 0000000000..6b47327a06 --- /dev/null +++ b/core/cmake/FindClangTools.cmake @@ -0,0 +1,111 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tries to find the clang-tidy and clang-format modules +# +# Usage of this module as follows: +# +# find_package(ClangTools) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# ClangToolsBin_HOME - +# When set, this path is inspected instead of standard library binary locations +# to find clang-tidy and clang-format +# +# This module defines +# CLANG_TIDY_BIN, The path to the clang tidy binary +# CLANG_TIDY_FOUND, Whether clang tidy was found +# CLANG_FORMAT_BIN, The path to the clang format binary +# CLANG_TIDY_FOUND, Whether clang format was found + +find_program(CLANG_TIDY_BIN + NAMES + clang-tidy-7.0 + clang-tidy-7 + clang-tidy-6.0 + clang-tidy-5.0 + clang-tidy-4.0 + clang-tidy-3.9 + clang-tidy-3.8 + clang-tidy-3.7 + clang-tidy-3.6 + clang-tidy + PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin + NO_DEFAULT_PATH +) + +if ( "${CLANG_TIDY_BIN}" STREQUAL "CLANG_TIDY_BIN-NOTFOUND" ) + set(CLANG_TIDY_FOUND 0) + message("clang-tidy not found") +else() + set(CLANG_TIDY_FOUND 1) + message("clang-tidy found at ${CLANG_TIDY_BIN}") +endif() + +if (CLANG_FORMAT_VERSION) + find_program(CLANG_FORMAT_BIN + NAMES clang-format-${CLANG_FORMAT_VERSION} + PATHS + ${ClangTools_PATH} + $ENV{CLANG_TOOLS_PATH} + /usr/local/bin /usr/bin + NO_DEFAULT_PATH + ) + + # If not found yet, search alternative locations + if (("${CLANG_FORMAT_BIN}" STREQUAL "CLANG_FORMAT_BIN-NOTFOUND") AND APPLE) + # Homebrew ships older LLVM versions in /usr/local/opt/llvm@version/ + STRING(REGEX REPLACE "^([0-9]+)\\.[0-9]+" "\\1" CLANG_FORMAT_MAJOR_VERSION "${CLANG_FORMAT_VERSION}") + STRING(REGEX REPLACE "^[0-9]+\\.([0-9]+)" "\\1" CLANG_FORMAT_MINOR_VERSION "${CLANG_FORMAT_VERSION}") + if ("${CLANG_FORMAT_MINOR_VERSION}" STREQUAL "0") + find_program(CLANG_FORMAT_BIN + NAMES clang-format + PATHS /usr/local/opt/llvm@${CLANG_FORMAT_MAJOR_VERSION}/bin + NO_DEFAULT_PATH + ) + else() + find_program(CLANG_FORMAT_BIN + NAMES clang-format + PATHS /usr/local/opt/llvm@${CLANG_FORMAT_VERSION}/bin + NO_DEFAULT_PATH + ) + endif() + endif() +else() + find_program(CLANG_FORMAT_BIN + NAMES + clang-format-7.0 + clang-format-7 + clang-format-6.0 + clang-format-5.0 + clang-format-4.0 + clang-format-3.9 + clang-format-3.8 + clang-format-3.7 + clang-format-3.6 + clang-format + PATHS ${ClangTools_PATH} $ENV{CLANG_TOOLS_PATH} /usr/local/bin /usr/bin + NO_DEFAULT_PATH + ) +endif() + +if ( "${CLANG_FORMAT_BIN}" STREQUAL "CLANG_FORMAT_BIN-NOTFOUND" ) + set(CLANG_FORMAT_FOUND 0) + message("clang-format not found") +else() + set(CLANG_FORMAT_FOUND 1) + message("clang-format found at ${CLANG_FORMAT_BIN}") +endif() + diff --git a/core/cmake/ThirdPartyPackages.cmake b/core/cmake/ThirdPartyPackages.cmake new file mode 100644 index 0000000000..c129b46d63 --- /dev/null +++ b/core/cmake/ThirdPartyPackages.cmake @@ -0,0 +1,172 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. + + +message(STATUS "Using ${MILVUS_DEPENDENCY_SOURCE} approach to find dependencies") + +# For each dependency, set dependency source to global default, if unset +foreach (DEPENDENCY ${MILVUS_THIRDPARTY_DEPENDENCIES}) + if ("${${DEPENDENCY}_SOURCE}" STREQUAL "") + set(${DEPENDENCY}_SOURCE ${MILVUS_DEPENDENCY_SOURCE}) + endif () +endforeach () + +# ---------------------------------------------------------------------- +# Identify OS +if (UNIX) + if (APPLE) + set(CMAKE_OS_NAME "osx" CACHE STRING "Operating system name" FORCE) + else (APPLE) + ## Check for Debian GNU/Linux ________________ + find_file(DEBIAN_FOUND debian_version debconf.conf + PATHS /etc + ) + if (DEBIAN_FOUND) + set(CMAKE_OS_NAME "debian" CACHE STRING "Operating system name" FORCE) + endif (DEBIAN_FOUND) + ## Check for Fedora _________________________ + find_file(FEDORA_FOUND fedora-release + PATHS /etc + ) + if (FEDORA_FOUND) + set(CMAKE_OS_NAME "fedora" CACHE STRING "Operating system name" FORCE) + endif (FEDORA_FOUND) + ## Check for RedHat _________________________ + find_file(REDHAT_FOUND redhat-release inittab.RH + PATHS /etc + ) + if (REDHAT_FOUND) + set(CMAKE_OS_NAME "redhat" CACHE STRING "Operating system name" FORCE) + endif (REDHAT_FOUND) + ## Extra check for Ubuntu ____________________ + if (DEBIAN_FOUND) + ## At its core Ubuntu is a Debian system, with + ## a slightly altered configuration; hence from + ## a first superficial inspection a system will + ## be considered as Debian, which signifies an + ## extra check is required. + find_file(UBUNTU_EXTRA legal issue + PATHS /etc + ) + if (UBUNTU_EXTRA) + ## Scan contents of file + file(STRINGS ${UBUNTU_EXTRA} UBUNTU_FOUND + REGEX Ubuntu + ) + ## Check result of string search + if (UBUNTU_FOUND) + set(CMAKE_OS_NAME "ubuntu" CACHE STRING "Operating system name" FORCE) + set(DEBIAN_FOUND FALSE) + + find_program(LSB_RELEASE_EXEC lsb_release) + execute_process(COMMAND ${LSB_RELEASE_EXEC} -rs + OUTPUT_VARIABLE LSB_RELEASE_ID_SHORT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + STRING(REGEX REPLACE "\\." "_" UBUNTU_VERSION "${LSB_RELEASE_ID_SHORT}") + endif (UBUNTU_FOUND) + endif (UBUNTU_EXTRA) + endif (DEBIAN_FOUND) + endif (APPLE) +endif (UNIX) + +# ---------------------------------------------------------------------- +# thirdparty directory +set(THIRDPARTY_DIR "${MILVUS_SOURCE_DIR}/thirdparty") + +# ---------------------------------------------------------------------- +# ExternalProject options + +string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE) + +set(EP_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}") +set(EP_C_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}") + +# Set -fPIC on all external projects +set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -fPIC") +set(EP_C_FLAGS "${EP_C_FLAGS} -fPIC") + +# CC/CXX environment variables are captured on the first invocation of the +# builder (e.g make or ninja) instead of when CMake is invoked into to build +# directory. This leads to issues if the variables are exported in a subshell +# and the invocation of make/ninja is in distinct subshell without the same +# environment (CC/CXX). +set(EP_COMMON_TOOLCHAIN -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}) + +if (CMAKE_AR) + set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_AR=${CMAKE_AR}) +endif () + +if (CMAKE_RANLIB) + set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_RANLIB=${CMAKE_RANLIB}) +endif () + +# External projects are still able to override the following declarations. +# cmake command line will favor the last defined variable when a duplicate is +# encountered. This requires that `EP_COMMON_CMAKE_ARGS` is always the first +# argument. +set(EP_COMMON_CMAKE_ARGS + ${EP_COMMON_TOOLCHAIN} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DCMAKE_C_FLAGS=${EP_C_FLAGS} + -DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS} + -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS}) + +if (NOT MILVUS_VERBOSE_THIRDPARTY_BUILD) + set(EP_LOG_OPTIONS LOG_CONFIGURE 1 LOG_BUILD 1 LOG_INSTALL 1 LOG_DOWNLOAD 1) +else () + set(EP_LOG_OPTIONS) +endif () + +# Ensure that a default make is set +if ("${MAKE}" STREQUAL "") + find_program(MAKE make) +endif () + +if (NOT DEFINED MAKE_BUILD_ARGS) + set(MAKE_BUILD_ARGS "-j8") +endif () +message(STATUS "Third Party MAKE_BUILD_ARGS = ${MAKE_BUILD_ARGS}") + +# ---------------------------------------------------------------------- +# Find pthreads + +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +# ---------------------------------------------------------------------- +# Versions and URLs for toolchain builds, which also can be used to configure +# offline builds + +# Read toolchain versions from cpp/thirdparty/versions.txt +file(STRINGS "${THIRDPARTY_DIR}/versions.txt" TOOLCHAIN_VERSIONS_TXT) +foreach (_VERSION_ENTRY ${TOOLCHAIN_VERSIONS_TXT}) + # Exclude comments + if (NOT _VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_VERSION=") + continue() + endif () + + string(REGEX MATCH "^[^=]*" _LIB_NAME ${_VERSION_ENTRY}) + string(REPLACE "${_LIB_NAME}=" "" _LIB_VERSION ${_VERSION_ENTRY}) + + # Skip blank or malformed lines + if (${_LIB_VERSION} STREQUAL "") + continue() + endif () + + # For debugging + #message(STATUS "${_LIB_NAME}: ${_LIB_VERSION}") + + set(${_LIB_NAME} "${_LIB_VERSION}") +endforeach () + diff --git a/core/cmake/Utils.cmake b/core/cmake/Utils.cmake new file mode 100644 index 0000000000..f057a48457 --- /dev/null +++ b/core/cmake/Utils.cmake @@ -0,0 +1,102 @@ +# get build time +MACRO(get_current_time CURRENT_TIME) + execute_process(COMMAND "date" "+%Y-%m-%d %H:%M.%S" OUTPUT_VARIABLE ${CURRENT_TIME}) + string(REGEX REPLACE "\n" "" ${CURRENT_TIME} ${${CURRENT_TIME}}) +ENDMACRO(get_current_time) + +# get build type +MACRO(get_build_type) + cmake_parse_arguments(BUILD_TYPE "" "TARGET;DEFAULT" "" ${ARGN}) + if (NOT DEFINED CMAKE_BUILD_TYPE) + set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT}) + elseif (CMAKE_BUILD_TYPE STREQUAL "Release") + set(${BUILD_TYPE_TARGET} "Release") + elseif (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(${BUILD_TYPE_TARGET} "Debug") + else () + set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT}) + endif () +ENDMACRO(get_build_type) + +# get git branch name +MACRO(get_git_branch_name GIT_BRANCH_NAME) + set(GIT_BRANCH_NAME_REGEX "[0-9]+\\.[0-9]+\\.[0-9]") + + execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | sed 's/.*(\\(.*\\))/\\1/' | sed 's/.*, //' | sed 's=[a-zA-Z]*\/==g'" + OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + + if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") + execute_process(COMMAND "git" rev-parse --abbrev-ref HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + endif () + + if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") + execute_process(COMMAND "git" symbolic-ref -q --short HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + endif () + + message(DEBUG "GIT_BRANCH_NAME = ${GIT_BRANCH_NAME}") + + # Some unexpected case + if (NOT GIT_BRANCH_NAME STREQUAL "") + string(REGEX REPLACE "\n" "" GIT_BRANCH_NAME ${GIT_BRANCH_NAME}) + else () + set(GIT_BRANCH_NAME "#") + endif () +ENDMACRO(get_git_branch_name) + +# get last commit id +MACRO(get_last_commit_id LAST_COMMIT_ID) + execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | awk '{print $2}'" + OUTPUT_VARIABLE ${LAST_COMMIT_ID}) + + message(DEBUG "LAST_COMMIT_ID = ${${LAST_COMMIT_ID}}") + + if (NOT LAST_COMMIT_ID STREQUAL "") + string(REGEX REPLACE "\n" "" ${LAST_COMMIT_ID} ${${LAST_COMMIT_ID}}) + else () + set(LAST_COMMIT_ID "Unknown") + endif () +ENDMACRO(get_last_commit_id) + +# get milvus version +MACRO(get_milvus_version) + cmake_parse_arguments(VER "" "TARGET;DEFAULT" "" ${ARGN}) + + # Step 1: get branch name + get_git_branch_name(GIT_BRANCH_NAME) + message(DEBUG ${GIT_BRANCH_NAME}) + + # Step 2: match MAJOR.MINOR.PATCH format or set DEFAULT value + string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\.([0-9]+)" ${VER_TARGET} ${GIT_BRANCH_NAME}) + if (NOT ${VER_TARGET}) + set(${VER_TARGET} ${VER_DEFAULT}) + endif() +ENDMACRO(get_milvus_version) + +# set definition +MACRO(set_milvus_definition DEF_PASS_CMAKE MILVUS_DEF) + if (${${DEF_PASS_CMAKE}}) + add_compile_definitions(${MILVUS_DEF}) + endif() +ENDMACRO(set_milvus_definition) + +MACRO(append_flags target) + cmake_parse_arguments(M "" "" "FLAGS" ${ARGN}) + foreach(FLAG IN ITEMS ${M_FLAGS}) + set(${target} "${${target}} ${FLAG}") + endforeach() +ENDMACRO(append_flags) + +macro(create_executable) + cmake_parse_arguments(E "" "TARGET" "SRCS;LIBS;DEFS" ${ARGN}) + add_executable(${E_TARGET}) + target_sources(${E_TARGET} PRIVATE ${E_SRCS}) + target_link_libraries(${E_TARGET} PRIVATE ${E_LIBS}) + target_compile_definitions(${E_TARGET} PRIVATE ${E_DEFS}) +endmacro() + +macro(create_library) + cmake_parse_arguments(L "" "TARGET" "SRCS;LIBS;DEFS" ${ARGN}) + add_library(${L_TARGET} ${L_SRCS}) + target_link_libraries(${L_TARGET} PRIVATE ${L_LIBS}) + target_compile_definitions(${L_TARGET} PRIVATE ${L_DEFS}) +endmacro() \ No newline at end of file diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt index 746c72aac8..38fc44234a 100644 --- a/core/src/CMakeLists.txt +++ b/core/src/CMakeLists.txt @@ -1,4 +1,75 @@ -add_subdirectory(utils) -add_subdirectory(dog_segment) -#add_subdirectory(index) -add_subdirectory(query) +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +include_directories(${MILVUS_SOURCE_DIR}) +include_directories(${MILVUS_ENGINE_SRC}) +include_directories(${MILVUS_THIRDPARTY_SRC}) + +#include_directories(${MILVUS_ENGINE_SRC}/grpc) +set(FOUND_OPENBLAS "unknown") +add_subdirectory(index) +set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE) +foreach (DIR ${INDEX_INCLUDE_DIRS}) + include_directories(${DIR}) +endforeach () + + +add_subdirectory( utils ) +add_subdirectory( log) +add_subdirectory( dog_segment) +add_subdirectory( cache ) +add_subdirectory( query ) +# add_subdirectory( db ) # target milvus_engine +# add_subdirectory( server ) + +# set(link_lib +# milvus_engine +# # dog_segment +# #query +# utils +# curl +# ) + + +# set( BOOST_LIB libboost_system.a +# libboost_filesystem.a +# libboost_serialization.a +# ) + +# set( THIRD_PARTY_LIBS yaml-cpp +# ) + + +# target_link_libraries( server +# PUBLIC ${link_lib} +# ${THIRD_PARTY_LIBS} +# ${BOOST_LIB} +# ) + +# # **************************** Get&Print Include Directories **************************** +# get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES ) + +# foreach ( dir ${dirs} ) +# message( STATUS "Current Include DIRS: ") +# endforeach () + +# set( SERVER_LIBS server ) + + +# add_executable( milvus_server ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp +# ) +# #target_include_directories(db PUBLIC ${PROJECT_BINARY_DIR}/thirdparty/pulsar-client-cpp/pulsar-client-cpp-src/pulsar-client-cpp/include) + + +# target_link_libraries( milvus_server PRIVATE ${SERVER_LIBS} ) +# install( TARGETS milvus_server DESTINATION bin ) diff --git a/core/src/CMakeLists_old.txt b/core/src/CMakeLists_old.txt new file mode 100644 index 0000000000..746c72aac8 --- /dev/null +++ b/core/src/CMakeLists_old.txt @@ -0,0 +1,4 @@ +add_subdirectory(utils) +add_subdirectory(dog_segment) +#add_subdirectory(index) +add_subdirectory(query) diff --git a/core/src/cache/CMakeLists.txt b/core/src/cache/CMakeLists.txt new file mode 100644 index 0000000000..55ea8bc8a3 --- /dev/null +++ b/core/src/cache/CMakeLists.txt @@ -0,0 +1,20 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- +aux_source_directory( ${MILVUS_ENGINE_SRC}/cache CACHE_FILES ) +add_library( cache STATIC ) +target_sources( cache PRIVATE ${CACHE_FILES} + CacheMgr.inl + Cache.inl + ) +target_include_directories( cache PUBLIC ${MILVUS_ENGINE_SRC}/cache ) + diff --git a/core/src/cache/Cache.h b/core/src/cache/Cache.h new file mode 100644 index 0000000000..34594c4573 --- /dev/null +++ b/core/src/cache/Cache.h @@ -0,0 +1,104 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "LRU.h" +#include "utils/Log.h" + +#include +#include +#include +#include + +namespace milvus { +namespace cache { + +template +class Cache { + public: + // mem_capacity, units:GB + Cache(int64_t capacity_gb, int64_t cache_max_count, const std::string& header = ""); + ~Cache() = default; + + int64_t + usage() const { + return usage_; + } + + // unit: BYTE + int64_t + capacity() const { + return capacity_; + } + + // unit: BYTE + void + set_capacity(int64_t capacity); + + double + freemem_percent() const { + return freemem_percent_; + } + + void + set_freemem_percent(double percent) { + freemem_percent_ = percent; + } + + size_t + size() const; + + bool + exists(const std::string& key); + + ItemObj + get(const std::string& key); + + void + insert(const std::string& key, const ItemObj& item); + + void + erase(const std::string& key); + + bool + reserve(const int64_t size); + + void + print(); + + void + clear(); + + private: + void + insert_internal(const std::string& key, const ItemObj& item); + + void + erase_internal(const std::string& key); + + void + free_memory_internal(const int64_t target_size); + + private: + std::string header_; + int64_t usage_; + int64_t capacity_; + double freemem_percent_; + + LRU lru_; + mutable std::mutex mutex_; +}; + +} // namespace cache +} // namespace milvus + +#include "cache/Cache.inl" diff --git a/core/src/cache/Cache.inl b/core/src/cache/Cache.inl new file mode 100644 index 0000000000..c371efe515 --- /dev/null +++ b/core/src/cache/Cache.inl @@ -0,0 +1,191 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +namespace milvus { +namespace cache { + +constexpr double DEFAULT_THRESHOLD_PERCENT = 0.7; + +template +Cache::Cache(int64_t capacity, int64_t cache_max_count, const std::string& header) + : header_(header), + usage_(0), + capacity_(capacity), + freemem_percent_(DEFAULT_THRESHOLD_PERCENT), + lru_(cache_max_count) { +} + +template +void +Cache::set_capacity(int64_t capacity) { + std::lock_guard lock(mutex_); + if (capacity > 0) { + capacity_ = capacity; + free_memory_internal(capacity); + } +} + +template +size_t +Cache::size() const { + std::lock_guard lock(mutex_); + return lru_.size(); +} + +template +bool +Cache::exists(const std::string& key) { + std::lock_guard lock(mutex_); + return lru_.exists(key); +} + +template +ItemObj +Cache::get(const std::string& key) { + std::lock_guard lock(mutex_); + if (!lru_.exists(key)) { + return nullptr; + } + return lru_.get(key); +} + +template +void +Cache::insert(const std::string& key, const ItemObj& item) { + std::lock_guard lock(mutex_); + insert_internal(key, item); +} + +template +void +Cache::erase(const std::string& key) { + std::lock_guard lock(mutex_); + erase_internal(key); +} + +template +bool +Cache::reserve(const int64_t item_size) { + std::lock_guard lock(mutex_); + if (item_size > capacity_) { + LOG_SERVER_ERROR_ << header_ << " item size " << (item_size >> 20) << "MB too big to insert into cache capacity" + << (capacity_ >> 20) << "MB"; + return false; + } + if (item_size > capacity_ - usage_) { + free_memory_internal(capacity_ - item_size); + } + return true; +} + +template +void +Cache::clear() { + std::lock_guard lock(mutex_); + lru_.clear(); + usage_ = 0; + LOG_SERVER_DEBUG_ << header_ << " Clear cache !"; +} + + +template +void +Cache::print() { + std::lock_guard lock(mutex_); + size_t cache_count = lru_.size(); + // for (auto it = lru_.begin(); it != lru_.end(); ++it) { + // LOG_SERVER_DEBUG_ << it->first; + // } + LOG_SERVER_DEBUG_ << header_ << " [item count]: " << cache_count << ", [usage] " << (usage_ >> 20) + << "MB, [capacity] " << (capacity_ >> 20) << "MB"; +} + +template +void +Cache::insert_internal(const std::string& key, const ItemObj& item) { + if (item == nullptr) { + return; + } + + size_t item_size = item->Size(); + + // if key already exist, subtract old item size + if (lru_.exists(key)) { + const ItemObj& old_item = lru_.get(key); + usage_ -= old_item->Size(); + } + + // plus new item size + usage_ += item_size; + + // if usage exceed capacity, free some items + if (usage_ > capacity_) { + LOG_SERVER_DEBUG_ << header_ << " Current usage " << (usage_ >> 20) << "MB is too high for capacity " + << (capacity_ >> 20) << "MB, start free memory"; + free_memory_internal(capacity_); + } + + // insert new item + lru_.put(key, item); + LOG_SERVER_DEBUG_ << header_ << " Insert " << key << " size: " << (item_size >> 20) << "MB into cache"; + LOG_SERVER_DEBUG_ << header_ << " Count: " << lru_.size() << ", Usage: " << (usage_ >> 20) << "MB, Capacity: " + << (capacity_ >> 20) << "MB"; +} + +template +void +Cache::erase_internal(const std::string& key) { + if (!lru_.exists(key)) { + return; + } + + const ItemObj& item = lru_.get(key); + size_t item_size = item->Size(); + + lru_.erase(key); + + usage_ -= item_size; + LOG_SERVER_DEBUG_ << header_ << " Erase " << key << " size: " << (item_size >> 20) << "MB from cache"; + LOG_SERVER_DEBUG_ << header_ << " Count: " << lru_.size() << ", Usage: " << (usage_ >> 20) << "MB, Capacity: " + << (capacity_ >> 20) << "MB"; +} + +template +void +Cache::free_memory_internal(const int64_t target_size) { + int64_t threshold = std::min((int64_t)(capacity_ * freemem_percent_), target_size); + int64_t delta_size = usage_ - threshold; + if (delta_size <= 0) { + delta_size = 1; // ensure at least one item erased + } + + std::set key_array; + int64_t released_size = 0; + + auto it = lru_.rbegin(); + while (it != lru_.rend() && released_size < delta_size) { + auto& key = it->first; + auto& obj_ptr = it->second; + + key_array.emplace(key); + released_size += obj_ptr->Size(); + ++it; + } + + LOG_SERVER_DEBUG_ << header_ << " To be released memory size: " << (released_size >> 20) << "MB"; + + for (auto& key : key_array) { + erase_internal(key); + } +} + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/CacheMgr.h b/core/src/cache/CacheMgr.h new file mode 100644 index 0000000000..69bd224c73 --- /dev/null +++ b/core/src/cache/CacheMgr.h @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "Cache.h" +// #include "s/Metrics.h" +#include "utils/Log.h" + +#include +#include + +namespace milvus { +namespace cache { + +template +class CacheMgr { + public: + virtual uint64_t + ItemCount() const; + + virtual bool + ItemExists(const std::string& key); + + virtual ItemObj + GetItem(const std::string& key); + + virtual void + InsertItem(const std::string& key, const ItemObj& data); + + virtual void + EraseItem(const std::string& key); + + virtual bool + Reserve(const int64_t size); + + virtual void + PrintInfo(); + + virtual void + ClearCache(); + + int64_t + CacheUsage() const; + + int64_t + CacheCapacity() const; + + void + SetCapacity(int64_t capacity); + + protected: + CacheMgr(); + + virtual ~CacheMgr(); + + protected: + std::shared_ptr> cache_; +}; + +} // namespace cache +} // namespace milvus + +#include "cache/CacheMgr.inl" diff --git a/core/src/cache/CacheMgr.inl b/core/src/cache/CacheMgr.inl new file mode 100644 index 0000000000..8f225c59be --- /dev/null +++ b/core/src/cache/CacheMgr.inl @@ -0,0 +1,137 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +namespace milvus { +namespace cache { + +template +CacheMgr::CacheMgr() { +} + +template +CacheMgr::~CacheMgr() { +} + +template +uint64_t +CacheMgr::ItemCount() const { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return 0; + } + return (uint64_t)(cache_->size()); +} + +template +bool +CacheMgr::ItemExists(const std::string& key) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return false; + } + return cache_->exists(key); +} + +template +ItemObj +CacheMgr::GetItem(const std::string& key) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return nullptr; + } + // server::Metrics::GetInstance().CacheAccessTotalIncrement(); + return cache_->get(key); +} + +template +void +CacheMgr::InsertItem(const std::string& key, const ItemObj& data) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return; + } + cache_->insert(key, data); + // server::Metrics::GetInstance().CacheAccessTotalIncrement(); +} + +template +void +CacheMgr::EraseItem(const std::string& key) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return; + } + cache_->erase(key); + // server::Metrics::GetInstance().CacheAccessTotalIncrement(); +} + +template +bool +CacheMgr::Reserve(const int64_t size) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return false; + } + return cache_->reserve(size); +} + +template +void +CacheMgr::PrintInfo() { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return; + } + cache_->print(); +} + +template +void +CacheMgr::ClearCache() { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return; + } + cache_->clear(); +} + +template +int64_t +CacheMgr::CacheUsage() const { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return 0; + } + return cache_->usage(); +} + +template +int64_t +CacheMgr::CacheCapacity() const { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return 0; + } + return cache_->capacity(); +} + +template +void +CacheMgr::SetCapacity(int64_t capacity) { + if (cache_ == nullptr) { + LOG_SERVER_ERROR_ << "Cache doesn't exist"; + return; + } + cache_->set_capacity(capacity); +} + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/CpuCacheMgr.cpp b/core/src/cache/CpuCacheMgr.cpp new file mode 100644 index 0000000000..e5aaf1d5f7 --- /dev/null +++ b/core/src/cache/CpuCacheMgr.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "cache/CpuCacheMgr.h" + +#include + +// #include + +#include "config/ServerConfig.h" +#include "utils/Log.h" + +namespace milvus { +namespace cache { + +CpuCacheMgr::CpuCacheMgr() { + // cache_ = std::make_shared>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]"); + + // if (config.cache.cpu_cache_threshold() > 0.0) { + // cache_->set_freemem_percent(config.cache.cpu_cache_threshold()); + // } + ConfigMgr::GetInstance().Attach("cache.cache_size", this); +} + +CpuCacheMgr::~CpuCacheMgr() { + ConfigMgr::GetInstance().Detach("cache.cache_size", this); +} + +CpuCacheMgr& +CpuCacheMgr::GetInstance() { + static CpuCacheMgr s_mgr; + return s_mgr; +} + +void +CpuCacheMgr::ConfigUpdate(const std::string& name) { + // SetCapacity(config.cache.cache_size()); +} + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/CpuCacheMgr.h b/core/src/cache/CpuCacheMgr.h new file mode 100644 index 0000000000..455479f41c --- /dev/null +++ b/core/src/cache/CpuCacheMgr.h @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "cache/CacheMgr.h" +#include "cache/DataObj.h" +#include "config/ConfigMgr.h" + +namespace milvus { +namespace cache { + +class CpuCacheMgr : public CacheMgr, public ConfigObserver { + private: + CpuCacheMgr(); + + ~CpuCacheMgr(); + + public: + static CpuCacheMgr& + GetInstance(); + + public: + void + ConfigUpdate(const std::string& name) override; +}; + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/DataObj.h b/core/src/cache/DataObj.h new file mode 100644 index 0000000000..3aea7dea24 --- /dev/null +++ b/core/src/cache/DataObj.h @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +namespace milvus { +namespace cache { + +class DataObj { + public: + virtual int64_t + Size() = 0; +}; + +using DataObjPtr = std::shared_ptr; + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/GpuCacheMgr.cpp b/core/src/cache/GpuCacheMgr.cpp new file mode 100644 index 0000000000..3877ad6899 --- /dev/null +++ b/core/src/cache/GpuCacheMgr.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "cache/GpuCacheMgr.h" +#include "config/ServerConfig.h" +#include "utils/Log.h" + +// #include +#include +#include + +namespace milvus { +namespace cache { + +#ifdef MILVUS_GPU_VERSION +std::mutex GpuCacheMgr::global_mutex_; +std::unordered_map GpuCacheMgr::instance_; + +GpuCacheMgr::GpuCacheMgr(int64_t gpu_id) : gpu_id_(gpu_id) { + std::string header = "[CACHE GPU" + std::to_string(gpu_id) + "]"; + cache_ = std::make_shared>(config.gpu.cache_size(), 1UL << 32, header); + + if (config.gpu.cache_threshold() > 0.0) { + cache_->set_freemem_percent(config.gpu.cache_threshold()); + } + ConfigMgr::GetInstance().Attach("gpu.cache_threshold", this); +} + +GpuCacheMgr::~GpuCacheMgr() { + ConfigMgr::GetInstance().Detach("gpu.cache_threshold", this); +} + +GpuCacheMgrPtr +GpuCacheMgr::GetInstance(int64_t gpu_id) { + if (instance_.find(gpu_id) == instance_.end()) { + std::lock_guard lock(global_mutex_); + if (instance_.find(gpu_id) == instance_.end()) { + instance_[gpu_id] = std::make_shared(gpu_id); + } + } + return instance_[gpu_id]; +} + +void +GpuCacheMgr::ConfigUpdate(const std::string& name) { + std::lock_guard lock(global_mutex_); + for (auto& it : instance_) { + it.second->SetCapacity(config.gpu.cache_size()); + } +} + +#endif + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/GpuCacheMgr.h b/core/src/cache/GpuCacheMgr.h new file mode 100644 index 0000000000..8c648ea5dc --- /dev/null +++ b/core/src/cache/GpuCacheMgr.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include + +#include "cache/CacheMgr.h" +#include "cache/DataObj.h" +#include "config/ConfigMgr.h" + +namespace milvus { +namespace cache { + +#ifdef MILVUS_GPU_VERSION +class GpuCacheMgr; +using GpuCacheMgrPtr = std::shared_ptr; +using MutexPtr = std::shared_ptr; + +class GpuCacheMgr : public CacheMgr, public ConfigObserver { + public: + explicit GpuCacheMgr(int64_t gpu_id); + + ~GpuCacheMgr(); + + static GpuCacheMgrPtr + GetInstance(int64_t gpu_id); + + public: + void + ConfigUpdate(const std::string& name) override; + + private: + int64_t gpu_id_; + static std::mutex global_mutex_; + static std::unordered_map instance_; +}; +#endif + +} // namespace cache +} // namespace milvus diff --git a/core/src/cache/LRU.h b/core/src/cache/LRU.h new file mode 100644 index 0000000000..e8a853cea0 --- /dev/null +++ b/core/src/cache/LRU.h @@ -0,0 +1,116 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace milvus { +namespace cache { + +template +class LRU { + public: + typedef typename std::pair key_value_pair_t; + typedef typename std::list::iterator list_iterator_t; + typedef typename std::list::reverse_iterator reverse_list_iterator_t; + + explicit LRU(size_t max_size) : max_size_(max_size) { + } + + void + put(const key_t& key, const value_t& value) { + auto it = cache_items_map_.find(key); + cache_items_list_.push_front(key_value_pair_t(key, value)); + if (it != cache_items_map_.end()) { + cache_items_list_.erase(it->second); + cache_items_map_.erase(it); + } + cache_items_map_[key] = cache_items_list_.begin(); + + if (cache_items_map_.size() > max_size_) { + auto last = cache_items_list_.end(); + last--; + cache_items_map_.erase(last->first); + cache_items_list_.pop_back(); + } + } + + const value_t& + get(const key_t& key) { + auto it = cache_items_map_.find(key); + if (it == cache_items_map_.end()) { + throw std::range_error("There is no such key in cache"); + } else { + cache_items_list_.splice(cache_items_list_.begin(), cache_items_list_, it->second); + return it->second->second; + } + } + + void + erase(const key_t& key) { + auto it = cache_items_map_.find(key); + if (it != cache_items_map_.end()) { + cache_items_list_.erase(it->second); + cache_items_map_.erase(it); + } + } + + bool + exists(const key_t& key) const { + return cache_items_map_.find(key) != cache_items_map_.end(); + } + + size_t + size() const { + return cache_items_map_.size(); + } + + list_iterator_t + begin() { + iter_ = cache_items_list_.begin(); + return iter_; + } + + list_iterator_t + end() { + return cache_items_list_.end(); + } + + reverse_list_iterator_t + rbegin() { + return cache_items_list_.rbegin(); + } + + reverse_list_iterator_t + rend() { + return cache_items_list_.rend(); + } + + void + clear() { + cache_items_list_.clear(); + cache_items_map_.clear(); + } + + private: + std::list cache_items_list_; + std::unordered_map cache_items_map_; + size_t max_size_; + list_iterator_t iter_; +}; + +} // namespace cache +} // namespace milvus diff --git a/core/src/config/CMakeLists.txt b/core/src/config/CMakeLists.txt new file mode 100644 index 0000000000..49c03988c9 --- /dev/null +++ b/core/src/config/CMakeLists.txt @@ -0,0 +1,31 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +# library +set( CONFIG_SRCS ConfigMgr.h + ConfigMgr.cpp + ConfigType.h + ConfigType.cpp + ServerConfig.h + ServerConfig.cpp + ) + +set( CONFIG_LIBS yaml-cpp + ) + +create_library( + TARGET config + SRCS ${CONFIG_SRCS} + LIBS ${CONFIG_LIBS} +) + diff --git a/core/src/config/ConfigMgr.cpp b/core/src/config/ConfigMgr.cpp new file mode 100644 index 0000000000..22e830a69e --- /dev/null +++ b/core/src/config/ConfigMgr.cpp @@ -0,0 +1,223 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "config/ConfigMgr.h" +#include "config/ServerConfig.h" + +namespace { +const int64_t MB = (1024ll * 1024); +const int64_t GB = (1024ll * 1024 * 1024); + +void +Flatten(const YAML::Node& node, std::unordered_map& target, const std::string& prefix) { + for (auto& it : node) { + auto key = prefix.empty() ? it.first.as() : prefix + "." + it.first.as(); + switch (it.second.Type()) { + case YAML::NodeType::Null: { + target[key] = ""; + break; + } + case YAML::NodeType::Scalar: { + target[key] = it.second.as(); + break; + } + case YAML::NodeType::Sequence: { + std::string value; + for (auto& sub : it.second) value += sub.as() + ","; + target[key] = value; + break; + } + case YAML::NodeType::Map: { + Flatten(it.second, target, key); + break; + } + case YAML::NodeType::Undefined: { + throw "Unexpected"; + } + default: + break; + } + } +} + +void +ThrowIfNotSuccess(const milvus::ConfigStatus& cs) { + if (cs.set_return != milvus::SetReturn::SUCCESS) { + throw cs; + } +} + +}; // namespace + +namespace milvus { + +ConfigMgr ConfigMgr::instance; + +ConfigMgr::ConfigMgr() { + config_list_ = { + + /* general */ + {"timezone", + CreateStringConfig("timezone", false, &config.timezone.value, "UTC+8", nullptr, nullptr)}, + + /* network */ + {"network.address", CreateStringConfig("network.address", false, &config.network.address.value, + "0.0.0.0", nullptr, nullptr)}, + {"network.port", CreateIntegerConfig("network.port", false, 0, 65535, &config.network.port.value, + 19530, nullptr, nullptr)}, + + + /* pulsar */ + {"pulsar.address", CreateStringConfig("pulsar.address", false, &config.pulsar.address.value, + "localhost", nullptr, nullptr)}, + {"pulsar.port", CreateIntegerConfig("pulsar.port", false, 0, 65535, &config.pulsar.port.value, + 6650, nullptr, nullptr)}, + + + /* log */ + {"logs.level", CreateStringConfig("logs.level", false, &config.logs.level.value, "debug", nullptr, nullptr)}, + {"logs.trace.enable", + CreateBoolConfig("logs.trace.enable", false, &config.logs.trace.enable.value, true, nullptr, nullptr)}, + {"logs.path", + CreateStringConfig("logs.path", false, &config.logs.path.value, "/var/lib/milvus/logs", nullptr, nullptr)}, + {"logs.max_log_file_size", CreateSizeConfig("logs.max_log_file_size", false, 512 * MB, 4096 * MB, + &config.logs.max_log_file_size.value, 1024 * MB, nullptr, nullptr)}, + {"logs.log_rotate_num", CreateIntegerConfig("logs.log_rotate_num", false, 0, 1024, + &config.logs.log_rotate_num.value, 0, nullptr, nullptr)}, + + /* tracing */ + {"tracing.json_config_path", CreateStringConfig("tracing.json_config_path", false, + &config.tracing.json_config_path.value, "", nullptr, nullptr)}, + + /* invisible */ + /* engine */ + {"engine.build_index_threshold", + CreateIntegerConfig("engine.build_index_threshold", false, 0, std::numeric_limits::max(), + &config.engine.build_index_threshold.value, 4096, nullptr, nullptr)}, + {"engine.search_combine_nq", + CreateIntegerConfig("engine.search_combine_nq", true, 0, std::numeric_limits::max(), + &config.engine.search_combine_nq.value, 64, nullptr, nullptr)}, + {"engine.use_blas_threshold", + CreateIntegerConfig("engine.use_blas_threshold", true, 0, std::numeric_limits::max(), + &config.engine.use_blas_threshold.value, 1100, nullptr, nullptr)}, + {"engine.omp_thread_num", + CreateIntegerConfig("engine.omp_thread_num", true, 0, std::numeric_limits::max(), + &config.engine.omp_thread_num.value, 0, nullptr, nullptr)}, + {"engine.simd_type", CreateEnumConfig("engine.simd_type", false, &SimdMap, &config.engine.simd_type.value, + SimdType::AUTO, nullptr, nullptr)}, + }; +} + +void +ConfigMgr::Init() { + std::lock_guard lock(GetConfigMutex()); + for (auto& kv : config_list_) { + kv.second->Init(); + } +} + +void +ConfigMgr::Load(const std::string& path) { + /* load from milvus.yaml */ + auto yaml = YAML::LoadFile(path); + /* make it flattened */ + std::unordered_map flattened; + // auto proxy_yaml = yaml["porxy"]; + auto other_yaml = YAML::Node{}; + other_yaml["pulsar"] = yaml["pulsar"]; + Flatten(yaml["proxy"], flattened, ""); + Flatten(other_yaml, flattened, ""); + // Flatten(yaml["proxy"], flattened, ""); + /* update config */ + for (auto& it : flattened) Set(it.first, it.second, false); +} + +void +ConfigMgr::Set(const std::string& name, const std::string& value, bool update) { + std::cout<<"InSet Config "<< name < lock(GetConfigMutex()); + /* update=false when loading from config file */ + if (not update) { + ThrowIfNotSuccess(config->Set(value, update)); + } else if (config->modifiable_) { + /* set manually */ + ThrowIfNotSuccess(config->Set(value, update)); + lock.unlock(); + Notify(name); + } else { + throw ConfigStatus(SetReturn::IMMUTABLE, "Config " + name + " is not modifiable"); + } + } catch (ConfigStatus& cs) { + throw cs; + } catch (...) { + throw "Config " + name + " not found."; + } +} + +std::string +ConfigMgr::Get(const std::string& name) const { + try { + auto& config = config_list_.at(name); + std::lock_guard lock(GetConfigMutex()); + return config->Get(); + } catch (...) { + throw "Config " + name + " not found."; + } +} + +std::string +ConfigMgr::Dump() const { + std::stringstream ss; + for (auto& kv : config_list_) { + auto& config = kv.second; + ss << config->name_ << ": " << config->Get() << std::endl; + } + return ss.str(); +} + +void +ConfigMgr::Attach(const std::string& name, ConfigObserver* observer) { + std::lock_guard lock(observer_mutex_); + observers_[name].push_back(observer); +} + +void +ConfigMgr::Detach(const std::string& name, ConfigObserver* observer) { + std::lock_guard lock(observer_mutex_); + if (observers_.find(name) == observers_.end()) + return; + auto& ob_list = observers_[name]; + ob_list.remove(observer); +} + +void +ConfigMgr::Notify(const std::string& name) { + std::lock_guard lock(observer_mutex_); + if (observers_.find(name) == observers_.end()) + return; + auto& ob_list = observers_[name]; + for (auto& ob : ob_list) { + ob->ConfigUpdate(name); + } +} + +} // namespace milvus diff --git a/core/src/config/ConfigMgr.h b/core/src/config/ConfigMgr.h new file mode 100644 index 0000000000..802ae25064 --- /dev/null +++ b/core/src/config/ConfigMgr.h @@ -0,0 +1,92 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "config/ServerConfig.h" + +namespace milvus { + +class ConfigObserver { + public: + virtual ~ConfigObserver() { + } + + virtual void + ConfigUpdate(const std::string& name) = 0; +}; +using ConfigObserverPtr = std::shared_ptr; + +class ConfigMgr { + public: + static ConfigMgr& + GetInstance() { + return instance; + } + + private: + static ConfigMgr instance; + + public: + ConfigMgr(); + + ConfigMgr(const ConfigMgr&) = delete; + ConfigMgr& + operator=(const ConfigMgr&) = delete; + + ConfigMgr(ConfigMgr&&) = delete; + ConfigMgr& + operator=(ConfigMgr&&) = delete; + + public: + void + Init(); + + void + Load(const std::string& path); + + void + Set(const std::string& name, const std::string& value, bool update = true); + + std::string + Get(const std::string& name) const; + + std::string + Dump() const; + + public: + // Shared pointer should not be used here + void + Attach(const std::string& name, ConfigObserver* observer); + + void + Detach(const std::string& name, ConfigObserver* observer); + + private: + void + Notify(const std::string& name); + + private: + std::unordered_map config_list_; + std::mutex mutex_; + + std::unordered_map> observers_; + std::mutex observer_mutex_; +}; + +} // namespace milvus diff --git a/core/src/config/ConfigType.cpp b/core/src/config/ConfigType.cpp new file mode 100644 index 0000000000..2db20175ba --- /dev/null +++ b/core/src/config/ConfigType.cpp @@ -0,0 +1,528 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "config/ConfigType.h" +#include "config/ServerConfig.h" + +#include +#include +#include +#include +#include +#include + +namespace { +std::unordered_map BYTE_UNITS = { + {"b", 1}, + {"k", 1024}, + {"m", 1024 * 1024}, + {"g", 1024 * 1024 * 1024}, +}; + +bool +is_integer(const std::string& s) { + if (not s.empty() && (std::isdigit(s[0]) || s[0] == '-')) { + auto ss = s.substr(1); + return std::find_if(ss.begin(), ss.end(), [](unsigned char c) { return !std::isdigit(c); }) == ss.end(); + } + return false; +} + +bool +is_number(const std::string& s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isdigit(c); }) == s.end(); +} + +bool +is_alpha(const std::string& s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isalpha(c); }) == s.end(); +} + +template +bool +boundary_check(T val, T lower_bound, T upper_bound) { + return lower_bound <= val && val <= upper_bound; +} + +bool +parse_bool(const std::string& str, std::string& err) { + if (!strcasecmp(str.c_str(), "true")) + return true; + else if (!strcasecmp(str.c_str(), "false")) + return false; + else + err = "The specified value must be true or false"; + return false; +} + +std::string +str_tolower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + return s; +} + +int64_t +parse_bytes(const std::string& str, std::string& err) { + try { + if (str.find_first_of('-') != std::string::npos) { + std::stringstream ss; + ss << "The specified value for memory (" << str << ") should be a positive integer."; + err = ss.str(); + return 0; + } + + std::string s = str; + if (is_number(s)) + return std::stoll(s); + if (s.length() == 0) + return 0; + + auto last_two = s.substr(s.length() - 2, 2); + auto last_one = s.substr(s.length() - 1); + if (is_alpha(last_two) && is_alpha(last_one)) + if (last_one == "b" or last_one == "B") + s = s.substr(0, s.length() - 1); + auto& units = BYTE_UNITS; + auto suffix = str_tolower(s.substr(s.length() - 1)); + + std::string digits_part; + if (is_number(suffix)) { + digits_part = s; + suffix = 'b'; + } else { + digits_part = s.substr(0, s.length() - 1); + } + + if (is_number(digits_part) && (units.find(suffix) != units.end() || is_number(suffix))) { + auto digits = std::stoll(digits_part); + return digits * units[suffix]; + } else { + std::stringstream ss; + ss << "The specified value for memory (" << str << ") should specify the units." + << "The postfix should be one of the `b` `k` `m` `g` characters"; + err = ss.str(); + } + } catch (...) { + err = "Unknown error happened on parse bytes."; + } + return 0; +} + +} // namespace + +// Use (void) to silent unused warnings. +#define assertm(exp, msg) assert(((void)msg, exp)) + +namespace milvus { + +std::vector +OptionValue(const configEnum& ce) { + std::vector ret; + for (auto& e : ce) { + ret.emplace_back(e.first); + } + return ret; +} + +BaseConfig::BaseConfig(const char* name, const char* alias, bool modifiable) + : name_(name), alias_(alias), modifiable_(modifiable) { +} + +void +BaseConfig::Init() { + assertm(not inited_, "already initialized"); + inited_ = true; +} + +BoolConfig::BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value, + std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +BoolConfig::Init() { + BaseConfig::Init(); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +BoolConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + std::string err; + bool value = parse_bool(val, err); + if (not err.empty()) + return ConfigStatus(SetReturn::INVALID, err); + + if (is_valid_fn_ && not is_valid_fn_(value, err)) + return ConfigStatus(SetReturn::INVALID, err); + + bool prev = *config_; + *config_ = value; + if (update && update_fn_ && not update_fn_(value, prev, err)) { + *config_ = prev; + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +BoolConfig::Get() { + assertm(inited_, "uninitialized"); + return *config_ ? "true" : "false"; +} + +StringConfig::StringConfig( + const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value, + std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +StringConfig::Init() { + BaseConfig::Init(); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +StringConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + std::string err; + if (is_valid_fn_ && not is_valid_fn_(val, err)) + return ConfigStatus(SetReturn::INVALID, err); + + std::string prev = *config_; + *config_ = val; + if (update && update_fn_ && not update_fn_(val, prev, err)) { + *config_ = prev; + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +StringConfig::Get() { + assertm(inited_, "uninitialized"); + return *config_; +} + +EnumConfig::EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config, + int64_t default_value, std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + enum_value_(enumd), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +EnumConfig::Init() { + BaseConfig::Init(); + assert(enum_value_ != nullptr); + assertm(not enum_value_->empty(), "enum value empty"); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +EnumConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + if (enum_value_->find(val) == enum_value_->end()) { + auto option_values = OptionValue(*enum_value_); + std::stringstream ss; + ss << "Config " << name_ << "(" << val << ") must be one of following: "; + for (size_t i = 0; i < option_values.size() - 1; ++i) { + ss << option_values[i] << ", "; + } + ss << option_values.back() << "."; + return ConfigStatus(SetReturn::ENUM_VALUE_NOTFOUND, ss.str()); + } + + int64_t value = enum_value_->at(val); + std::string err; + if (is_valid_fn_ && not is_valid_fn_(value, err)) { + return ConfigStatus(SetReturn::INVALID, err); + } + + int64_t prev = *config_; + *config_ = value; + if (update && update_fn_ && not update_fn_(value, prev, err)) { + *config_ = prev; + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +EnumConfig::Get() { + assertm(inited_, "uninitialized"); + for (auto& it : *enum_value_) { + if (*config_ == it.second) { + return it.first; + } + } + return "unknown"; +} + +IntegerConfig::IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, + int64_t upper_bound, int64_t* config, int64_t default_value, + std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + lower_bound_(lower_bound), + upper_bound_(upper_bound), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +IntegerConfig::Init() { + BaseConfig::Init(); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +IntegerConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + if (not is_integer(val)) { + std::stringstream ss; + ss << "Config " << name_ << "(" << val << ") must be a integer."; + return ConfigStatus(SetReturn::INVALID, ss.str()); + } + + int64_t value = std::stoll(val); + if (not boundary_check(value, lower_bound_, upper_bound_)) { + std::stringstream ss; + ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_ + << "]."; + return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str()); + } + + std::string err; + if (is_valid_fn_ && not is_valid_fn_(value, err)) + return ConfigStatus(SetReturn::INVALID, err); + + int64_t prev = *config_; + *config_ = value; + if (update && update_fn_ && not update_fn_(value, prev, err)) { + *config_ = prev; + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +IntegerConfig::Get() { + assertm(inited_, "uninitialized"); + return std::to_string(*config_); +} + +FloatingConfig::FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, + double upper_bound, double* config, double default_value, + std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + lower_bound_(lower_bound), + upper_bound_(upper_bound), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +FloatingConfig::Init() { + BaseConfig::Init(); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +FloatingConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + double value = std::stod(val); + if (not boundary_check(value, lower_bound_, upper_bound_)) { + std::stringstream ss; + ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << ", " << upper_bound_ + << "]."; + return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str()); + } + + std::string err; + if (is_valid_fn_ && not is_valid_fn_(value, err)) + return ConfigStatus(SetReturn::INVALID, err); + + double prev = *config_; + *config_ = value; + if (update && update_fn_ && not update_fn_(value, prev, err)) { + *config_ = prev; + + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +FloatingConfig::Get() { + assertm(inited_, "uninitialized"); + return std::to_string(*config_); +} + +SizeConfig::SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, + int64_t* config, int64_t default_value, + std::function is_valid_fn, + std::function update_fn) + : BaseConfig(name, alias, modifiable), + config_(config), + lower_bound_(lower_bound), + upper_bound_(upper_bound), + default_value_(default_value), + is_valid_fn_(std::move(is_valid_fn)), + update_fn_(std::move(update_fn)) { +} + +void +SizeConfig::Init() { + BaseConfig::Init(); + assert(config_ != nullptr); + *config_ = default_value_; +} + +ConfigStatus +SizeConfig::Set(const std::string& val, bool update) { + assertm(inited_, "uninitialized"); + try { + if (update and not modifiable_) { + std::stringstream ss; + ss << "Config " << name_ << " is immutable."; + return ConfigStatus(SetReturn::IMMUTABLE, ss.str()); + } + + std::string err; + int64_t value = parse_bytes(val, err); + if (not err.empty()) { + return ConfigStatus(SetReturn::INVALID, err); + } + + if (not boundary_check(value, lower_bound_, upper_bound_)) { + std::stringstream ss; + ss << "Config " << name_ << "(" << val << ") must in range [" << lower_bound_ << " Byte, " << upper_bound_ + << " Byte]."; + return ConfigStatus(SetReturn::OUT_OF_RANGE, ss.str()); + } + + if (is_valid_fn_ && not is_valid_fn_(value, err)) { + return ConfigStatus(SetReturn::INVALID, err); + } + + int64_t prev = *config_; + *config_ = value; + if (update && update_fn_ && not update_fn_(value, prev, err)) { + *config_ = prev; + return ConfigStatus(SetReturn::UPDATE_FAILURE, err); + } + + return ConfigStatus(SetReturn::SUCCESS, ""); + } catch (std::exception& e) { + return ConfigStatus(SetReturn::EXCEPTION, e.what()); + } catch (...) { + return ConfigStatus(SetReturn::UNEXPECTED, "unexpected"); + } +} + +std::string +SizeConfig::Get() { + assertm(inited_, "uninitialized"); + return std::to_string(*config_); +} + +} // namespace milvus diff --git a/core/src/config/ConfigType.h b/core/src/config/ConfigType.h new file mode 100644 index 0000000000..5e1a8e92e8 --- /dev/null +++ b/core/src/config/ConfigType.h @@ -0,0 +1,235 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace milvus { + +using configEnum = const std::unordered_map; +std::vector +OptionValue(const configEnum& ce); + +enum SetReturn { + SUCCESS = 1, + IMMUTABLE, + ENUM_VALUE_NOTFOUND, + INVALID, + OUT_OF_RANGE, + UPDATE_FAILURE, + EXCEPTION, + UNEXPECTED, +}; + +struct ConfigStatus { + ConfigStatus(SetReturn sr, std::string msg) : set_return(sr), message(std::move(msg)) { + } + SetReturn set_return; + std::string message; +}; + +class BaseConfig { + public: + BaseConfig(const char* name, const char* alias, bool modifiable); + virtual ~BaseConfig() = default; + + public: + bool inited_ = false; + const char* name_; + const char* alias_; + const bool modifiable_; + + public: + virtual void + Init(); + + virtual ConfigStatus + Set(const std::string& value, bool update) = 0; + + virtual std::string + Get() = 0; +}; +using BaseConfigPtr = std::shared_ptr; + +class BoolConfig : public BaseConfig { + public: + BoolConfig(const char* name, const char* alias, bool modifiable, bool* config, bool default_value, + std::function is_valid_fn, + std::function update_fn); + + private: + bool* config_; + const bool default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +class StringConfig : public BaseConfig { + public: + StringConfig(const char* name, const char* alias, bool modifiable, std::string* config, const char* default_value, + std::function is_valid_fn, + std::function update_fn); + + private: + std::string* config_; + const char* default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +class EnumConfig : public BaseConfig { + public: + EnumConfig(const char* name, const char* alias, bool modifiable, configEnum* enumd, int64_t* config, + int64_t default_value, std::function is_valid_fn, + std::function update_fn); + + private: + int64_t* config_; + configEnum* enum_value_; + const int64_t default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +class IntegerConfig : public BaseConfig { + public: + IntegerConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, + int64_t* config, int64_t default_value, + std::function is_valid_fn, + std::function update_fn); + + private: + int64_t* config_; + int64_t lower_bound_; + int64_t upper_bound_; + const int64_t default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +class FloatingConfig : public BaseConfig { + public: + FloatingConfig(const char* name, const char* alias, bool modifiable, double lower_bound, double upper_bound, + double* config, double default_value, std::function is_valid_fn, + std::function update_fn); + + private: + double* config_; + double lower_bound_; + double upper_bound_; + const double default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +class SizeConfig : public BaseConfig { + public: + SizeConfig(const char* name, const char* alias, bool modifiable, int64_t lower_bound, int64_t upper_bound, + int64_t* config, int64_t default_value, std::function is_valid_fn, + std::function update_fn); + + private: + int64_t* config_; + int64_t lower_bound_; + int64_t upper_bound_; + const int64_t default_value_; + std::function is_valid_fn_; + std::function update_fn_; + + public: + void + Init() override; + + ConfigStatus + Set(const std::string& value, bool update) override; + + std::string + Get() override; +}; + +#define CreateBoolConfig(name, modifiable, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, config_addr, (default), is_valid, update) + +#define CreateStringConfig(name, modifiable, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, config_addr, (default), is_valid, update) + +#define CreateEnumConfig(name, modifiable, enumd, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, enumd, config_addr, (default), is_valid, update) + +#define CreateIntegerConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \ + is_valid, update) + +#define CreateFloatingConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \ + is_valid, update) + +#define CreateSizeConfig(name, modifiable, lower_bound, upper_bound, config_addr, default, is_valid, update) \ + std::make_shared(name, nullptr, modifiable, lower_bound, upper_bound, config_addr, (default), \ + is_valid, update) + +} // namespace milvus diff --git a/core/src/config/ConfigTypeTest1.cpp b/core/src/config/ConfigTypeTest1.cpp new file mode 100644 index 0000000000..72aa5d7227 --- /dev/null +++ b/core/src/config/ConfigTypeTest1.cpp @@ -0,0 +1,493 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include + +#include "config/ServerConfig.h" +#include "gtest/gtest.h" + +namespace milvus { + +#define _MODIFIABLE (true) +#define _IMMUTABLE (false) + +template +class Utils { + public: + bool + validate_fn(const T& value, std::string& err) { + validate_value = value; + return true; + } + + bool + update_fn(const T& value, const T& prev, std::string& err) { + new_value = value; + prev_value = prev; + return true; + } + + protected: + T validate_value; + T new_value; + T prev_value; +}; + +/* ValidBoolConfigTest */ +class ValidBoolConfigTest : public testing::Test, public Utils { + protected: +}; + +TEST_F(ValidBoolConfigTest, init_load_update_get_test) { + auto validate = std::bind(&ValidBoolConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidBoolConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + bool bool_value = true; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, false, validate, update); + ASSERT_EQ(bool_value, true); + ASSERT_EQ(bool_config->modifiable_, true); + + bool_config->Init(); + ASSERT_EQ(bool_value, false); + ASSERT_EQ(bool_config->Get(), "false"); + + { + // now `bool_value` is `false`, calling Set(update=false) to set it to `true`, but not notify update_fn() + validate_value = false; + new_value = false; + prev_value = true; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = bool_config->Set("true", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(bool_value, true); + EXPECT_EQ(bool_config->Get(), "true"); + + // expect change + EXPECT_EQ(validate_value, true); + // expect not change + EXPECT_EQ(new_value, false); + EXPECT_EQ(prev_value, true); + } + + { + // now `bool_value` is `true`, calling Set(update=true) to set it to `false`, will notify update_fn() + validate_value = true; + new_value = true; + prev_value = false; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = bool_config->Set("false", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(bool_value, false); + EXPECT_EQ(bool_config->Get(), "false"); + + // expect change + EXPECT_EQ(validate_value, false); + EXPECT_EQ(new_value, false); + EXPECT_EQ(prev_value, true); + } +} + +/* ValidStringConfigTest */ +class ValidStringConfigTest : public testing::Test, public Utils { + protected: +}; + +TEST_F(ValidStringConfigTest, init_load_update_get_test) { + auto validate = std::bind(&ValidStringConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidStringConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + std::string string_value; + auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", validate, update); + ASSERT_EQ(string_value, ""); + ASSERT_EQ(string_config->modifiable_, true); + + string_config->Init(); + ASSERT_EQ(string_value, "Magic"); + ASSERT_EQ(string_config->Get(), "Magic"); + + { + // now `string_value` is `Magic`, calling Set(update=false) to set it to `cigaM`, but not notify update_fn() + validate_value = ""; + new_value = ""; + prev_value = ""; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = string_config->Set("cigaM", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(string_value, "cigaM"); + EXPECT_EQ(string_config->Get(), "cigaM"); + + // expect change + EXPECT_EQ(validate_value, "cigaM"); + // expect not change + EXPECT_EQ(new_value, ""); + EXPECT_EQ(prev_value, ""); + } + + { + // now `string_value` is `cigaM`, calling Set(update=true) to set it to `Check`, will notify update_fn() + validate_value = ""; + new_value = ""; + prev_value = ""; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = string_config->Set("Check", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(string_value, "Check"); + EXPECT_EQ(string_config->Get(), "Check"); + + // expect change + EXPECT_EQ(validate_value, "Check"); + EXPECT_EQ(new_value, "Check"); + EXPECT_EQ(prev_value, "cigaM"); + } +} + +/* ValidIntegerConfigTest */ +class ValidIntegerConfigTest : public testing::Test, public Utils { + protected: +}; + +TEST_F(ValidIntegerConfigTest, init_load_update_get_test) { + auto validate = std::bind(&ValidIntegerConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidIntegerConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + int64_t integer_value = 0; + auto integer_config = CreateIntegerConfig("i", _MODIFIABLE, -100, 100, &integer_value, 42, validate, update); + ASSERT_EQ(integer_value, 0); + ASSERT_EQ(integer_config->modifiable_, true); + + integer_config->Init(); + ASSERT_EQ(integer_value, 42); + ASSERT_EQ(integer_config->Get(), "42"); + + { + // now `integer_value` is `42`, calling Set(update=false) to set it to `24`, but not notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = integer_config->Set("24", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(integer_value, 24); + EXPECT_EQ(integer_config->Get(), "24"); + + // expect change + EXPECT_EQ(validate_value, 24); + // expect not change + EXPECT_EQ(new_value, 0); + EXPECT_EQ(prev_value, 0); + } + + { + // now `integer_value` is `24`, calling Set(update=true) to set it to `36`, will notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = integer_config->Set("36", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(integer_value, 36); + EXPECT_EQ(integer_config->Get(), "36"); + + // expect change + EXPECT_EQ(validate_value, 36); + EXPECT_EQ(new_value, 36); + EXPECT_EQ(prev_value, 24); + } +} + +/* ValidFloatingConfigTest */ +class ValidFloatingConfigTest : public testing::Test, public Utils { + protected: +}; + +TEST_F(ValidFloatingConfigTest, init_load_update_get_test) { + auto validate = + std::bind(&ValidFloatingConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidFloatingConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + double floating_value = 0.0; + auto floating_config = CreateFloatingConfig("f", _MODIFIABLE, -10.0, 10.0, &floating_value, 3.14, validate, update); + ASSERT_FLOAT_EQ(floating_value, 0.0); + ASSERT_EQ(floating_config->modifiable_, true); + + floating_config->Init(); + ASSERT_FLOAT_EQ(floating_value, 3.14); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 3.14); + + { + // now `floating_value` is `3.14`, calling Set(update=false) to set it to `6.22`, but not notify update_fn() + validate_value = 0.0; + new_value = 0.0; + prev_value = 0.0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = floating_config->Set("6.22", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + ASSERT_FLOAT_EQ(floating_value, 6.22); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 6.22); + + // expect change + ASSERT_FLOAT_EQ(validate_value, 6.22); + // expect not change + ASSERT_FLOAT_EQ(new_value, 0.0); + ASSERT_FLOAT_EQ(prev_value, 0.0); + } + + { + // now `integer_value` is `6.22`, calling Set(update=true) to set it to `-3.14`, will notify update_fn() + validate_value = 0.0; + new_value = 0.0; + prev_value = 0.0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = floating_config->Set("-3.14", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + ASSERT_FLOAT_EQ(floating_value, -3.14); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), -3.14); + + // expect change + ASSERT_FLOAT_EQ(validate_value, -3.14); + ASSERT_FLOAT_EQ(new_value, -3.14); + ASSERT_FLOAT_EQ(prev_value, 6.22); + } +} + +/* ValidEnumConfigTest */ +class ValidEnumConfigTest : public testing::Test, public Utils { + protected: +}; + +// template <> +// int64_t Utils::validate_value = 0; +// template <> +// int64_t Utils::new_value = 0; +// template <> +// int64_t Utils::prev_value = 0; + +TEST_F(ValidEnumConfigTest, init_load_update_get_test) { + auto validate = std::bind(&ValidEnumConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidEnumConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value = 0; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, validate, update); + ASSERT_EQ(enum_value, 0); + ASSERT_EQ(enum_config->modifiable_, true); + + enum_config->Init(); + ASSERT_EQ(enum_value, 1); + ASSERT_EQ(enum_config->Get(), "a"); + + { + // now `enum_value` is `a`, calling Set(update=false) to set it to `b`, but not notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = enum_config->Set("b", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + ASSERT_EQ(enum_value, 2); + ASSERT_EQ(enum_config->Get(), "b"); + + // expect change + ASSERT_EQ(validate_value, 2); + // expect not change + ASSERT_EQ(new_value, 0); + ASSERT_EQ(prev_value, 0); + } + + { + // now `enum_value` is `b`, calling Set(update=true) to set it to `c`, will notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = enum_config->Set("c", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + ASSERT_EQ(enum_value, 3); + ASSERT_EQ(enum_config->Get(), "c"); + + // expect change + ASSERT_EQ(validate_value, 3); + ASSERT_EQ(new_value, 3); + ASSERT_EQ(prev_value, 2); + } +} + +/* ValidSizeConfigTest */ +class ValidSizeConfigTest : public testing::Test, public Utils { + protected: +}; + +// template <> +// int64_t Utils::validate_value = 0; +// template <> +// int64_t Utils::new_value = 0; +// template <> +// int64_t Utils::prev_value = 0; + +TEST_F(ValidSizeConfigTest, init_load_update_get_test) { + auto validate = std::bind(&ValidSizeConfigTest::validate_fn, this, std::placeholders::_1, std::placeholders::_2); + auto update = std::bind(&ValidSizeConfigTest::update_fn, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3); + + int64_t size_value = 0; + auto size_config = CreateSizeConfig("i", _MODIFIABLE, 0, 1024 * 1024, &size_value, 1024, validate, update); + ASSERT_EQ(size_value, 0); + ASSERT_EQ(size_config->modifiable_, true); + + size_config->Init(); + ASSERT_EQ(size_value, 1024); + ASSERT_EQ(size_config->Get(), "1024"); + + { + // now `size_value` is `1024`, calling Set(update=false) to set it to `4096`, but not notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = size_config->Set("4096", false); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(size_value, 4096); + EXPECT_EQ(size_config->Get(), "4096"); + + // expect change + EXPECT_EQ(validate_value, 4096); + // expect not change + EXPECT_EQ(new_value, 0); + EXPECT_EQ(prev_value, 0); + } + + { + // now `size_value` is `4096`, calling Set(update=true) to set it to `256kb`, will notify update_fn() + validate_value = 0; + new_value = 0; + prev_value = 0; + + ConfigStatus status(SetReturn::SUCCESS, ""); + status = size_config->Set("256kb", true); + + EXPECT_EQ(status.set_return, SetReturn::SUCCESS); + EXPECT_EQ(size_value, 256 * 1024); + EXPECT_EQ(size_config->Get(), "262144"); + + // expect change + EXPECT_EQ(validate_value, 262144); + EXPECT_EQ(new_value, 262144); + EXPECT_EQ(prev_value, 4096); + } +} + +class ValidTest : public testing::Test { + protected: + configEnum family{ + {"ipv4", 1}, + {"ipv6", 2}, + }; + + struct Server { + bool running = true; + std::string hostname; + int64_t family = 0; + int64_t port = 0; + double uptime = 0; + }; + + Server server; + + protected: + void + SetUp() override { + config_list = { + CreateBoolConfig("running", true, &server.running, true, nullptr, nullptr), + CreateStringConfig("hostname", true, &server.hostname, "Magic", nullptr, nullptr), + CreateEnumConfig("socket_family", false, &family, &server.family, 2, nullptr, nullptr), + CreateIntegerConfig("port", true, 1024, 65535, &server.port, 19530, nullptr, nullptr), + CreateFloatingConfig("uptime", true, 0, 9999.0, &server.uptime, 0, nullptr, nullptr), + }; + } + + void + TearDown() override { + } + + protected: + void + Init() { + for (auto& config : config_list) { + config->Init(); + } + } + + void + Load() { + std::unordered_map config_file{ + {"running", "false"}, + }; + + for (auto& c : config_file) Set(c.first, c.second, false); + } + + void + Set(const std::string& name, const std::string& value, bool update = true) { + for (auto& config : config_list) { + if (std::strcmp(name.c_str(), config->name_) == 0) { + config->Set(value, update); + return; + } + } + throw "Config " + name + " not found."; + } + + std::string + Get(const std::string& name) { + for (auto& config : config_list) { + if (std::strcmp(name.c_str(), config->name_) == 0) { + return config->Get(); + } + } + throw "Config " + name + " not found."; + } + + std::vector config_list; +}; + +} // namespace milvus diff --git a/core/src/config/ConfigTypeTest2.cpp b/core/src/config/ConfigTypeTest2.cpp new file mode 100644 index 0000000000..42b903c7b1 --- /dev/null +++ b/core/src/config/ConfigTypeTest2.cpp @@ -0,0 +1,861 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "config/ServerConfig.h" +#include "gtest/gtest.h" + +namespace milvus { + +#define _MODIFIABLE (true) +#define _IMMUTABLE (false) + +template +class Utils { + public: + static bool + valid_check_failure(const T& value, std::string& err) { + err = "Value is invalid."; + return false; + } + + static bool + update_failure(const T& value, const T& prev, std::string& err) { + err = "Update is failure"; + return false; + } + + static bool + valid_check_raise_string(const T& value, std::string& err) { + throw "string exception"; + } + + static bool + valid_check_raise_exception(const T& value, std::string& err) { + throw std::bad_alloc(); + } +}; + +/* BoolConfigTest */ +class BoolConfigTest : public testing::Test, public Utils {}; + +TEST_F(BoolConfigTest, nullptr_init_test) { + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, nullptr, true, nullptr, nullptr); + ASSERT_DEATH(bool_config->Init(), "nullptr"); +} + +TEST_F(BoolConfigTest, init_twice_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr); + ASSERT_DEATH( + { + bool_config->Init(); + bool_config->Init(); + }, + "initialized"); +} + +TEST_F(BoolConfigTest, non_init_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr); + ASSERT_DEATH(bool_config->Set("false", true), "uninitialized"); + ASSERT_DEATH(bool_config->Get(), "uninitialized"); +} + +TEST_F(BoolConfigTest, immutable_update_test) { + bool bool_value = false; + auto bool_config = CreateBoolConfig("b", _IMMUTABLE, &bool_value, true, nullptr, nullptr); + bool_config->Init(); + ASSERT_EQ(bool_value, true); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set("false", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_EQ(bool_value, true); +} + +TEST_F(BoolConfigTest, set_invalid_value_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, nullptr); + bool_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set(" false", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("false ", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("afalse", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("falsee", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("abcdefg", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("123456", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); + + status = bool_config->Set("", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); +} + +TEST_F(BoolConfigTest, valid_check_fail_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_failure, nullptr); + bool_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set("123456", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(bool_config->Get(), "true"); +} + +TEST_F(BoolConfigTest, update_fail_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, nullptr, update_failure); + bool_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set("false", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_EQ(bool_config->Get(), "true"); +} + +TEST_F(BoolConfigTest, string_exception_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_string, nullptr); + bool_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set("false", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_EQ(bool_config->Get(), "true"); +} + +TEST_F(BoolConfigTest, standard_exception_test) { + bool bool_value; + auto bool_config = CreateBoolConfig("b", _MODIFIABLE, &bool_value, true, valid_check_raise_exception, nullptr); + bool_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = bool_config->Set("false", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_EQ(bool_config->Get(), "true"); +} + +/* StringConfigTest */ +class StringConfigTest : public testing::Test, public Utils {}; + +TEST_F(StringConfigTest, nullptr_init_test) { + auto string_config = CreateStringConfig("s", true, nullptr, "Magic", nullptr, nullptr); + ASSERT_DEATH(string_config->Init(), "nullptr"); +} + +TEST_F(StringConfigTest, init_twice_test) { + std::string string_value; + auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr); + ASSERT_DEATH( + { + string_config->Init(); + string_config->Init(); + }, + "initialized"); +} + +TEST_F(StringConfigTest, non_init_test) { + std::string string_value; + auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, nullptr); + ASSERT_DEATH(string_config->Set("value", true), "uninitialized"); + ASSERT_DEATH(string_config->Get(), "uninitialized"); +} + +TEST_F(StringConfigTest, immutable_update_test) { + std::string string_value; + auto string_config = CreateStringConfig("s", _IMMUTABLE, &string_value, "Magic", nullptr, nullptr); + string_config->Init(); + ASSERT_EQ(string_value, "Magic"); + + ConfigStatus status(SUCCESS, ""); + status = string_config->Set("cigaM", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_EQ(string_value, "Magic"); +} + +TEST_F(StringConfigTest, valid_check_fail_test) { + std::string string_value; + auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_failure, nullptr); + string_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = string_config->Set("123456", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(string_config->Get(), "Magic"); +} + +TEST_F(StringConfigTest, update_fail_test) { + std::string string_value; + auto string_config = CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", nullptr, update_failure); + string_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = string_config->Set("Mi", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_EQ(string_config->Get(), "Magic"); +} + +TEST_F(StringConfigTest, string_exception_test) { + std::string string_value; + auto string_config = + CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_string, nullptr); + string_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = string_config->Set("any", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_EQ(string_config->Get(), "Magic"); +} + +TEST_F(StringConfigTest, standard_exception_test) { + std::string string_value; + auto string_config = + CreateStringConfig("s", _MODIFIABLE, &string_value, "Magic", valid_check_raise_exception, nullptr); + string_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = string_config->Set("any", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_EQ(string_config->Get(), "Magic"); +} + +/* IntegerConfigTest */ +class IntegerConfigTest : public testing::Test, public Utils {}; + +TEST_F(IntegerConfigTest, nullptr_init_test) { + auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, nullptr, 19530, nullptr, nullptr); + ASSERT_DEATH(integer_config->Init(), "nullptr"); +} + +TEST_F(IntegerConfigTest, init_twice_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr); + ASSERT_DEATH( + { + integer_config->Init(); + integer_config->Init(); + }, + "initialized"); +} + +TEST_F(IntegerConfigTest, non_init_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr); + ASSERT_DEATH(integer_config->Set("42", true), "uninitialized"); + ASSERT_DEATH(integer_config->Get(), "uninitialized"); +} + +TEST_F(IntegerConfigTest, immutable_update_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", _IMMUTABLE, 1024, 65535, &integer_value, 19530, nullptr, nullptr); + integer_config->Init(); + ASSERT_EQ(integer_value, 19530); + + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("2048", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_EQ(integer_value, 19530); +} + +TEST_F(IntegerConfigTest, set_invalid_value_test) { +} + +TEST_F(IntegerConfigTest, valid_check_fail_test) { + int64_t integer_value; + auto integer_config = + CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_failure, nullptr); + integer_config->Init(); + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("2048", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "19530"); +} + +TEST_F(IntegerConfigTest, update_fail_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, update_failure); + integer_config->Init(); + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("2048", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_EQ(integer_config->Get(), "19530"); +} + +TEST_F(IntegerConfigTest, string_exception_test) { + int64_t integer_value; + auto integer_config = + CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_string, nullptr); + integer_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("2048", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_EQ(integer_config->Get(), "19530"); +} + +TEST_F(IntegerConfigTest, standard_exception_test) { + int64_t integer_value; + auto integer_config = + CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, valid_check_raise_exception, nullptr); + integer_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("2048", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_EQ(integer_config->Get(), "19530"); +} + +TEST_F(IntegerConfigTest, out_of_range_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 1024, 65535, &integer_value, 19530, nullptr, nullptr); + integer_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("1023", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(integer_config->Get(), "19530"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("65536", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(integer_config->Get(), "19530"); + } +} + +TEST_F(IntegerConfigTest, invalid_bound_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 100, 0, &integer_value, 50, nullptr, nullptr); + integer_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("30", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(integer_config->Get(), "50"); +} + +TEST_F(IntegerConfigTest, invalid_format_test) { + int64_t integer_value; + auto integer_config = CreateIntegerConfig("i", true, 0, 100, &integer_value, 50, nullptr, nullptr); + integer_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("3-0", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("30-", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("+30", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("a30", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("30a", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = integer_config->Set("3a0", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(integer_config->Get(), "50"); + } +} + +/* FloatingConfigTest */ +class FloatingConfigTest : public testing::Test, public Utils {}; + +TEST_F(FloatingConfigTest, nullptr_init_test) { + auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, nullptr, 4.5, nullptr, nullptr); + ASSERT_DEATH(floating_config->Init(), "nullptr"); +} + +TEST_F(FloatingConfigTest, init_twice_test) { + double floating_value; + auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr); + ASSERT_DEATH( + { + floating_config->Init(); + floating_config->Init(); + }, + "initialized"); +} + +TEST_F(FloatingConfigTest, non_init_test) { + double floating_value; + auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr); + ASSERT_DEATH(floating_config->Set("3.14", true), "uninitialized"); + ASSERT_DEATH(floating_config->Get(), "uninitialized"); +} + +TEST_F(FloatingConfigTest, immutable_update_test) { + double floating_value; + auto floating_config = CreateFloatingConfig("f", _IMMUTABLE, 1.0, 9.9, &floating_value, 4.5, nullptr, nullptr); + floating_config->Init(); + ASSERT_FLOAT_EQ(floating_value, 4.5); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("1.23", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, set_invalid_value_test) { +} + +TEST_F(FloatingConfigTest, valid_check_fail_test) { + double floating_value; + auto floating_config = + CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_failure, nullptr); + floating_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("1.23", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, update_fail_test) { + double floating_value; + auto floating_config = CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, nullptr, update_failure); + floating_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("1.23", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, string_exception_test) { + double floating_value; + auto floating_config = + CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_string, nullptr); + floating_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("1.23", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, standard_exception_test) { + double floating_value; + auto floating_config = + CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr); + floating_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("1.23", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, out_of_range_test) { + double floating_value; + auto floating_config = + CreateFloatingConfig("f", true, 1.0, 9.9, &floating_value, 4.5, valid_check_raise_exception, nullptr); + floating_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("0.99", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); + } + + { + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("10.00", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); + } +} + +TEST_F(FloatingConfigTest, invalid_bound_test) { + double floating_value; + auto floating_config = + CreateFloatingConfig("f", true, 9.9, 1.0, &floating_value, 4.5, valid_check_raise_exception, nullptr); + floating_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("6.0", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); +} + +TEST_F(FloatingConfigTest, DISABLED_invalid_format_test) { + double floating_value; + auto floating_config = CreateFloatingConfig("f", true, 1.0, 100.0, &floating_value, 4.5, nullptr, nullptr); + floating_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("6.0.1", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); + } + + { + ConfigStatus status(SUCCESS, ""); + status = floating_config->Set("6a0", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_FLOAT_EQ(std::stof(floating_config->Get()), 4.5); + } +} + +/* EnumConfigTest */ +class EnumConfigTest : public testing::Test, public Utils {}; + +TEST_F(EnumConfigTest, nullptr_init_test) { + configEnum testEnum{ + {"e", 1}, + }; + int64_t testEnumValue; + auto enum_config_1 = CreateEnumConfig("e", _MODIFIABLE, &testEnum, nullptr, 2, nullptr, nullptr); + ASSERT_DEATH(enum_config_1->Init(), "nullptr"); + + auto enum_config_2 = CreateEnumConfig("e", _MODIFIABLE, nullptr, &testEnumValue, 2, nullptr, nullptr); + ASSERT_DEATH(enum_config_2->Init(), "nullptr"); + + auto enum_config_3 = CreateEnumConfig("e", _MODIFIABLE, nullptr, nullptr, 2, nullptr, nullptr); + ASSERT_DEATH(enum_config_3->Init(), "nullptr"); +} + +TEST_F(EnumConfigTest, init_twice_test) { + configEnum testEnum{ + {"e", 1}, + }; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr); + ASSERT_DEATH( + { + enum_config->Init(); + enum_config->Init(); + }, + "initialized"); +} + +TEST_F(EnumConfigTest, non_init_test) { + configEnum testEnum{ + {"e", 1}, + }; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr); + ASSERT_DEATH(enum_config->Set("e", true), "uninitialized"); + ASSERT_DEATH(enum_config->Get(), "uninitialized"); +} + +TEST_F(EnumConfigTest, immutable_update_test) { + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value = 0; + auto enum_config = CreateEnumConfig("e", _IMMUTABLE, &testEnum, &enum_value, 1, nullptr, nullptr); + enum_config->Init(); + ASSERT_EQ(enum_value, 1); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_EQ(enum_value, 1); +} + +TEST_F(EnumConfigTest, set_invalid_value_check) { + configEnum testEnum{ + {"a", 1}, + }; + int64_t enum_value = 0; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, nullptr); + enum_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::ENUM_VALUE_NOTFOUND); + ASSERT_EQ(enum_config->Get(), "a"); +} + +TEST_F(EnumConfigTest, empty_enum_test) { + configEnum testEnum{}; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 2, nullptr, nullptr); + ASSERT_DEATH(enum_config->Init(), "empty"); +} + +TEST_F(EnumConfigTest, valid_check_fail_test) { + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_failure, nullptr); + enum_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(enum_config->Get(), "a"); +} + +TEST_F(EnumConfigTest, update_fail_test) { + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, nullptr, update_failure); + enum_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_EQ(enum_config->Get(), "a"); +} + +TEST_F(EnumConfigTest, string_exception_test) { + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value; + auto enum_config = CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_string, nullptr); + enum_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_EQ(enum_config->Get(), "a"); +} + +TEST_F(EnumConfigTest, standard_exception_test) { + configEnum testEnum{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + }; + int64_t enum_value; + auto enum_config = + CreateEnumConfig("e", _MODIFIABLE, &testEnum, &enum_value, 1, valid_check_raise_exception, nullptr); + enum_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = enum_config->Set("b", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_EQ(enum_config->Get(), "a"); +} + +/* SizeConfigTest */ +class SizeConfigTest : public testing::Test, public Utils {}; + +TEST_F(SizeConfigTest, nullptr_init_test) { + auto size_config = CreateSizeConfig("i", true, 1024, 4096, nullptr, 2048, nullptr, nullptr); + ASSERT_DEATH(size_config->Init(), "nullptr"); +} + +TEST_F(SizeConfigTest, init_twice_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + ASSERT_DEATH( + { + size_config->Init(); + size_config->Init(); + }, + "initialized"); +} + +TEST_F(SizeConfigTest, non_init_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + ASSERT_DEATH(size_config->Set("3000", true), "uninitialized"); + ASSERT_DEATH(size_config->Get(), "uninitialized"); +} + +TEST_F(SizeConfigTest, immutable_update_test) { + int64_t size_value = 0; + auto size_config = CreateSizeConfig("i", _IMMUTABLE, 1024, 4096, &size_value, 2048, nullptr, nullptr); + size_config->Init(); + ASSERT_EQ(size_value, 2048); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("3000", true); + ASSERT_EQ(status.set_return, SetReturn::IMMUTABLE); + ASSERT_EQ(size_value, 2048); +} + +TEST_F(SizeConfigTest, set_invalid_value_test) { +} + +TEST_F(SizeConfigTest, valid_check_fail_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_failure, nullptr); + size_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("3000", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, update_fail_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, update_failure); + size_config->Init(); + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("3000", true); + ASSERT_EQ(status.set_return, SetReturn::UPDATE_FAILURE); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, string_exception_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_string, nullptr); + size_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("3000", true); + ASSERT_EQ(status.set_return, SetReturn::UNEXPECTED); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, standard_exception_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, valid_check_raise_exception, nullptr); + size_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("3000", true); + ASSERT_EQ(status.set_return, SetReturn::EXCEPTION); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, out_of_range_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + size_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("1023", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(size_config->Get(), "2048"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("4097", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(size_config->Get(), "2048"); + } +} + +TEST_F(SizeConfigTest, negative_integer_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + size_config->Init(); + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("-3KB", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, invalid_bound_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 100, 0, &size_value, 50, nullptr, nullptr); + size_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("30", true); + ASSERT_EQ(status.set_return, SetReturn::OUT_OF_RANGE); + ASSERT_EQ(size_config->Get(), "50"); +} + +TEST_F(SizeConfigTest, invalid_unit_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + size_config->Init(); + + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("1 TB", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); +} + +TEST_F(SizeConfigTest, invalid_format_test) { + int64_t size_value; + auto size_config = CreateSizeConfig("i", true, 1024, 4096, &size_value, 2048, nullptr, nullptr); + size_config->Init(); + + { + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("a10GB", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("200*0", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); + } + + { + ConfigStatus status(SUCCESS, ""); + status = size_config->Set("10AB", true); + ASSERT_EQ(status.set_return, SetReturn::INVALID); + ASSERT_EQ(size_config->Get(), "2048"); + } +} + +} // namespace milvus diff --git a/core/src/config/ServerConfig.cpp b/core/src/config/ServerConfig.cpp new file mode 100644 index 0000000000..0526d7f992 --- /dev/null +++ b/core/src/config/ServerConfig.cpp @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include + +#include "config/ServerConfig.h" + +namespace milvus { + +std::mutex config_mutex; + +std::mutex& +GetConfigMutex() { + return config_mutex; +} + +ServerConfig config; + +std::vector +ParsePreloadCollection(const std::string& str) { + std::stringstream ss(str); + std::vector collections; + std::string collection; + + while (std::getline(ss, collection, ',')) { + collections.push_back(collection); + } + return collections; +} + +std::vector +ParseGPUDevices(const std::string& str) { + std::stringstream ss(str); + std::vector devices; + std::unordered_set device_set; + std::string device; + + while (std::getline(ss, device, ',')) { + if (device.length() < 4) { + /* Invalid format string */ + return {}; + } + device_set.insert(std::stoll(device.substr(3))); + } + + for (auto dev : device_set) devices.push_back(dev); + return devices; +} + +} // namespace milvus diff --git a/core/src/config/ServerConfig.h b/core/src/config/ServerConfig.h new file mode 100644 index 0000000000..a057b49d3b --- /dev/null +++ b/core/src/config/ServerConfig.h @@ -0,0 +1,112 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "config/ConfigType.h" + +namespace milvus { + +extern std::mutex& +GetConfigMutex(); + +template +class ConfigValue { + public: + explicit ConfigValue(T init_value) : value(std::move(init_value)) { + } + + const T& + operator()() { + std::lock_guard lock(GetConfigMutex()); + return value; + } + + public: + T value; +}; + +enum ClusterRole { + RW = 1, + RO, +}; + +enum SimdType { + AUTO = 1, + SSE, + AVX2, + AVX512, +}; + +const configEnum SimdMap{ + {"auto", SimdType::AUTO}, + {"sse", SimdType::SSE}, + {"avx2", SimdType::AVX2}, + {"avx512", SimdType::AVX512}, +}; + +struct ServerConfig { + using String = ConfigValue; + using Bool = ConfigValue; + using Integer = ConfigValue; + using Floating = ConfigValue; + + String timezone{"unknown"}; + + struct Network { + String address{"unknown"}; + Integer port{0}; + } network; + + struct Pulsar{ + String address{"localhost"}; + Integer port{6650}; + }pulsar; + + + struct Engine { + Integer build_index_threshold{4096}; + Integer search_combine_nq{0}; + Integer use_blas_threshold{0}; + Integer omp_thread_num{0}; + Integer simd_type{0}; + } engine; + + struct Tracing { + String json_config_path{"unknown"}; + } tracing; + + + struct Logs { + String level{"unknown"}; + struct Trace { + Bool enable{false}; + } trace; + String path{"unknown"}; + Integer max_log_file_size{0}; + Integer log_rotate_num{0}; + } logs; +}; + +extern ServerConfig config; +extern std::mutex _config_mutex; + +std::vector +ParsePreloadCollection(const std::string&); + +std::vector +ParseGPUDevices(const std::string&); +} // namespace milvus diff --git a/core/src/config/ServerConfigTest.cpp b/core/src/config/ServerConfigTest.cpp new file mode 100644 index 0000000000..76e0f844a7 --- /dev/null +++ b/core/src/config/ServerConfigTest.cpp @@ -0,0 +1,19 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "config/ServerConfig.h" + +TEST(ServerConfigTest, parse_invalid_devices) { + auto collections = milvus::ParseGPUDevices("gpu0,gpu1"); + ASSERT_EQ(collections.size(), 0); +} diff --git a/core/src/dog_segment/CMakeLists.txt b/core/src/dog_segment/CMakeLists.txt index 10d770ecb3..e87bad6daf 100644 --- a/core/src/dog_segment/CMakeLists.txt +++ b/core/src/dog_segment/CMakeLists.txt @@ -13,5 +13,4 @@ add_library(milvus_dog_segment SHARED ) #add_dependencies( segment sqlite mysqlpp ) -target_link_libraries(milvus_dog_segment tbb milvus_utils pthread) - +target_link_libraries(milvus_dog_segment tbb utils pthread knowhere log) diff --git a/core/src/dog_segment/ConcurrentVector.h b/core/src/dog_segment/ConcurrentVector.h index 82dcab3339..5b71d33c6e 100644 --- a/core/src/dog_segment/ConcurrentVector.h +++ b/core/src/dog_segment/ConcurrentVector.h @@ -38,6 +38,7 @@ namespace milvus::dog_segment { template using FixedVector = std::vector; +constexpr int64_t DefaultElementPerChunk = 32 * 1024; template class ThreadSafeVector { @@ -91,7 +92,7 @@ class VectorBase { virtual void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) = 0; }; -template +template class ConcurrentVector : public VectorBase { public: // constants diff --git a/core/src/dog_segment/DeletedRecord.h b/core/src/dog_segment/DeletedRecord.h new file mode 100644 index 0000000000..af4572ec3f --- /dev/null +++ b/core/src/dog_segment/DeletedRecord.h @@ -0,0 +1,37 @@ +#pragma once + +#include "AckResponder.h" +#include "SegmentDefs.h" + +namespace milvus::dog_segment { + +struct DeletedRecord { + std::atomic reserved = 0; + AckResponder ack_responder_; + ConcurrentVector timestamps_; + ConcurrentVector uids_; + struct TmpBitmap { + // Just for query + int64_t del_barrier = 0; + std::vector bitmap; + + }; + std::shared_ptr lru_; + std::shared_mutex shared_mutex_; + + DeletedRecord(): lru_(std::make_shared()) {} + auto get_lru_entry() { + std::shared_lock lck(shared_mutex_); + return lru_; + } + void insert_lru_entry(std::shared_ptr new_entry) { + std::lock_guard lck(shared_mutex_); + if(new_entry->del_barrier <= lru_->del_barrier) { + // DO NOTHING + return; + } + lru_ = std::move(new_entry); + } +}; + +} diff --git a/core/src/dog_segment/IndexMeta.cpp b/core/src/dog_segment/IndexMeta.cpp index fbb05f8545..66faea04f5 100644 --- a/core/src/dog_segment/IndexMeta.cpp +++ b/core/src/dog_segment/IndexMeta.cpp @@ -1,56 +1,55 @@ -// #include "IndexMeta.h" -// #include -// #include -// namespace milvus::dog_segment { -// -// Status -// IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, -// IndexConfig config) { -// Entry entry{ -// index_name, -// field_name, -// type, -// mode, -// std::move(config) -// }; -// VerifyEntry(entry); -// -// if (entries_.count(index_name)) { -// throw std::invalid_argument("duplicate index_name"); -// } -// // TODO: support multiple indexes for single field -// assert(!lookups_.count(field_name)); -// lookups_[field_name] = index_name; -// entries_[index_name] = std::move(entry); -// -// return Status::OK(); -// } -// -// Status -// IndexMeta::DropEntry(const std::string& index_name) { -// assert(entries_.count(index_name)); -// auto entry = std::move(entries_[index_name]); -// if(lookups_[entry.field_name] == index_name) { -// lookups_.erase(entry.field_name); -// } -// return Status::OK(); -// } -// -// void IndexMeta::VerifyEntry(const Entry &entry) { -// auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode); -// if(!is_mode_valid) { -// throw std::invalid_argument("invalid mode"); -// } -// -// auto& schema = *schema_; -// auto& field_meta = schema[entry.index_name]; -// // TODO checking -// if(field_meta.is_vector()) { -// assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); -// } else { -// assert(false); -// } -// } -// -// } // namespace milvus::dog_segment -// \ No newline at end of file +#include "IndexMeta.h" +#include +#include +namespace milvus::dog_segment { + +Status +IndexMeta::AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, + IndexConfig config) { + Entry entry{ + index_name, + field_name, + type, + mode, + std::move(config) + }; + VerifyEntry(entry); + + if (entries_.count(index_name)) { + throw std::invalid_argument("duplicate index_name"); + } + // TODO: support multiple indexes for single field + assert(!lookups_.count(field_name)); + lookups_[field_name] = index_name; + entries_[index_name] = std::move(entry); + + return Status::OK(); +} + +Status +IndexMeta::DropEntry(const std::string& index_name) { + assert(entries_.count(index_name)); + auto entry = std::move(entries_[index_name]); + if(lookups_[entry.field_name] == index_name) { + lookups_.erase(entry.field_name); + } + return Status::OK(); +} + +void IndexMeta::VerifyEntry(const Entry &entry) { + auto is_mode_valid = std::set{IndexMode::MODE_CPU, IndexMode::MODE_GPU}.count(entry.mode); + if(!is_mode_valid) { + throw std::invalid_argument("invalid mode"); + } + + auto& schema = *schema_; + auto& field_meta = schema[entry.field_name]; + // TODO checking + if(field_meta.is_vector()) { + assert(entry.type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ); + } else { + assert(false); + } +} + +} // namespace milvus::dog_segment diff --git a/core/src/dog_segment/IndexMeta.h b/core/src/dog_segment/IndexMeta.h index 3c89b60463..04856fc49d 100644 --- a/core/src/dog_segment/IndexMeta.h +++ b/core/src/dog_segment/IndexMeta.h @@ -3,55 +3,56 @@ //#include // //#include "SegmentDefs.h" -//#include "knowhere/index/IndexType.h" -// +// #include "dog_segment/SegmentBase.h" +#include "dog_segment/SegmentDefs.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/common/Config.h" +#include #include class IndexMeta; namespace milvus::dog_segment { -//// TODO: this is -//class IndexMeta { -// public: -// IndexMeta(SchemaPtr schema) : schema_(schema) { -// } -// using IndexType = knowhere::IndexType; -// using IndexMode = knowhere::IndexMode; -// using IndexConfig = knowhere::Config; -// -// struct Entry { -// std::string index_name; -// std::string field_name; -// IndexType type; -// IndexMode mode; -// IndexConfig config; -// }; -// -// Status -// AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, -// IndexConfig config); -// -// Status -// DropEntry(const std::string& index_name); -// -// const std::map& -// get_entries() { -// return entries_; -// } -// -// const Entry& lookup_by_field(const std::string& field_name) { -// auto index_name = lookups_.at(field_name); -// return entries_.at(index_name); -// } -// private: -// void -// VerifyEntry(const Entry& entry); -// -// private: -// SchemaPtr schema_; -// std::map entries_; // index_name => Entry -// std::map lookups_; // field_name => index_name -//}; -// +// TODO: this is +class IndexMeta { + public: + IndexMeta(SchemaPtr schema) : schema_(schema) { + } + using IndexType = knowhere::IndexType; + using IndexMode = knowhere::IndexMode; + using IndexConfig = knowhere::Config; + + struct Entry { + std::string index_name; + std::string field_name; + IndexType type; + IndexMode mode; + IndexConfig config; + }; + + Status + AddEntry(const std::string& index_name, const std::string& field_name, IndexType type, IndexMode mode, + IndexConfig config); + + Status + DropEntry(const std::string& index_name); + + const std::map& + get_entries() { + return entries_; + } + + const Entry& lookup_by_field(const std::string& field_name) { + auto index_name = lookups_.at(field_name); + return entries_.at(index_name); + } + private: + void + VerifyEntry(const Entry& entry); + + private: + SchemaPtr schema_; + std::map entries_; // index_name => Entry + std::map lookups_; // field_name => index_name +}; + using IndexMetaPtr = std::shared_ptr; -// } // namespace milvus::dog_segment -// \ No newline at end of file diff --git a/core/src/dog_segment/SegmentDefs.h b/core/src/dog_segment/SegmentDefs.h index 20dafe8bfd..b94e6572b1 100644 --- a/core/src/dog_segment/SegmentDefs.h +++ b/core/src/dog_segment/SegmentDefs.h @@ -7,6 +7,7 @@ #include "utils/Types.h" // #include "knowhere/index/Index.h" #include "utils/Status.h" +#include "dog_segment/IndexMeta.h" namespace milvus::dog_segment { using Timestamp = uint64_t; // TODO: use TiKV-like timestamp @@ -152,6 +153,13 @@ class Schema { return sizeof_infos_; } + std::optional get_offset(const std::string& field_name) { + if(!offsets_.count(field_name)) { + return std::nullopt; + } else { + return offsets_[field_name]; + } + } const FieldMeta& operator[](const std::string& field_name) const { @@ -160,7 +168,6 @@ class Schema { auto offset = offset_iter->second; return (*this)[offset]; } - private: // this is where data holds std::vector fields_; @@ -173,5 +180,6 @@ class Schema { }; using SchemaPtr = std::shared_ptr; +using idx_t = int64_t; } // namespace milvus::dog_segment diff --git a/core/src/dog_segment/SegmentNaive.cpp b/core/src/dog_segment/SegmentNaive.cpp index 94fdfe4f75..1c2746e5c3 100644 --- a/core/src/dog_segment/SegmentNaive.cpp +++ b/core/src/dog_segment/SegmentNaive.cpp @@ -5,6 +5,10 @@ #include #include +#include +#include + + namespace milvus::dog_segment { int TestABI() { @@ -13,12 +17,31 @@ TestABI() { std::unique_ptr CreateSegment(SchemaPtr schema, IndexMetaPtr remote_index_meta) { + + if (remote_index_meta == nullptr) { + auto index_meta = std::make_shared(schema); + auto dim = schema->operator[]("fakevec").get_dim(); + // TODO: this is merge of query conf and insert conf + // TODO: should be splitted into multiple configs + auto conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, 0}, + }; + index_meta->AddEntry("fakeindex", "fakevec", knowhere::IndexEnum::INDEX_FAISS_IVFPQ, + knowhere::IndexMode::MODE_CPU, conf); + remote_index_meta = index_meta; + } auto segment = std::make_unique(schema, remote_index_meta); return segment; } -SegmentNaive::Record::Record(const Schema& schema) : uids_(1), timestamps_(1) { - for (auto& field : schema) { +SegmentNaive::Record::Record(const Schema &schema) : uids_(1), timestamps_(1) { + for (auto &field : schema) { if (field.is_vector()) { assert(field.get_data_type() == DataType::VECTOR_FLOAT); entity_vec_.emplace_back(std::make_shared>(field.get_dim())); @@ -41,31 +64,32 @@ SegmentNaive::PreDelete(int64_t size) { return reserved_begin; } -auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier) -> std::shared_ptr { +auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, + int64_t insert_barrier) -> std::shared_ptr { auto old = deleted_record_.get_lru_entry(); - if(old->del_barrier == del_barrier) { + if (old->del_barrier == del_barrier) { return old; } auto current = std::make_shared(*old); - auto& vec = current->bitmap; + auto &vec = current->bitmap; - if(del_barrier < old->del_barrier) { - for(auto del_index = del_barrier; del_index < old->del_barrier; ++del_index) { + if (del_barrier < old->del_barrier) { + for (auto del_index = del_barrier; del_index < old->del_barrier; ++del_index) { // get uid in delete logs auto uid = deleted_record_.uids_[del_index]; // map uid to corrensponding offsets, select the max one, which should be the target // the max one should be closest to query_timestamp, so the delete log should refer to it int64_t the_offset = -1; - auto [iter_b, iter_e] = uid2offset_.equal_range(uid); - for(auto iter = iter_b; iter != iter_e; ++iter) { + auto[iter_b, iter_e] = uid2offset_.equal_range(uid); + for (auto iter = iter_b; iter != iter_e; ++iter) { auto offset = iter->second; - if(record_.timestamps_[offset] < query_timestamp) { + if (record_.timestamps_[offset] < query_timestamp) { assert(offset < vec.size()); the_offset = std::max(the_offset, offset); } } // if not found, skip - if(the_offset == -1) { + if (the_offset == -1) { continue; } // otherwise, clear the flag @@ -74,29 +98,29 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times return current; } else { vec.resize(insert_barrier); - for(auto del_index = old->del_barrier; del_index < del_barrier; ++del_index) { + for (auto del_index = old->del_barrier; del_index < del_barrier; ++del_index) { // get uid in delete logs auto uid = deleted_record_.uids_[del_index]; // map uid to corrensponding offsets, select the max one, which should be the target // the max one should be closest to query_timestamp, so the delete log should refer to it int64_t the_offset = -1; - auto [iter_b, iter_e] = uid2offset_.equal_range(uid); - for(auto iter = iter_b; iter != iter_e; ++iter) { + auto[iter_b, iter_e] = uid2offset_.equal_range(uid); + for (auto iter = iter_b; iter != iter_e; ++iter) { auto offset = iter->second; - if(offset >= insert_barrier){ + if (offset >= insert_barrier) { continue; } - if(offset >= vec.size()) { + if (offset >= vec.size()) { continue; } - if(record_.timestamps_[offset] < query_timestamp) { + if (record_.timestamps_[offset] < query_timestamp) { assert(offset < vec.size()); the_offset = std::max(the_offset, offset); } } // if not found, skip - if(the_offset == -1) { + if (the_offset == -1) { continue; } @@ -109,11 +133,11 @@ auto SegmentNaive::get_deleted_bitmap(int64_t del_barrier, Timestamp query_times } Status -SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw, - const DogDataChunk& entities_raw) { +SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, const Timestamp *timestamps_raw, + const DogDataChunk &entities_raw) { assert(entities_raw.count == size); assert(entities_raw.sizeof_per_row == schema_->get_total_sizeof()); - auto raw_data = reinterpret_cast(entities_raw.raw_data); + auto raw_data = reinterpret_cast(entities_raw.raw_data); // std::vector entities(raw_data, raw_data + size * len_per_row); auto len_per_row = entities_raw.sizeof_per_row; @@ -138,7 +162,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r std::vector timestamps(size); // #pragma omp parallel for for (int index = 0; index < size; ++index) { - auto [t, uid, order_index] = ordering[index]; + auto[t, uid, order_index] = ordering[index]; timestamps[index] = t; uids[index] = uid; for (int fid = 0; fid < schema_->size(); ++fid) { @@ -156,7 +180,7 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r record_.entity_vec_[fid]->set_data_raw(reserved_begin, entities[fid].data(), size); } - for(int i = 0; i < uids.size(); ++i) { + for (int i = 0; i < uids.size(); ++i) { auto uid = uids[i]; // NOTE: this must be the last step, cannot be put above uid2offset_.insert(std::make_pair(uid, reserved_begin + i)); @@ -197,7 +221,8 @@ SegmentNaive::Insert(int64_t reserved_begin, int64_t size, const int64_t* uids_r } Status -SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_raw, const Timestamp* timestamps_raw) { +SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t *uids_raw, + const Timestamp *timestamps_raw) { std::vector> ordering; ordering.resize(size); // #pragma omp parallel for @@ -209,7 +234,7 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_r std::vector timestamps(size); // #pragma omp parallel for for (int index = 0; index < size; ++index) { - auto [t, uid] = ordering[index]; + auto[t, uid] = ordering[index]; timestamps[index] = t; uids[index] = uid; } @@ -228,44 +253,15 @@ SegmentNaive::Delete(int64_t reserved_begin, int64_t size, const int64_t* uids_r // TODO: remove mock Status -SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) { - throw std::runtime_error("unimplemented"); - // auto ack_count = ack_count_.load(); - // assert(query == nullptr); - // assert(schema_->size() >= 1); - // const auto& field = schema_->operator[](0); - // assert(field.get_data_type() == DataType::VECTOR_FLOAT); - // assert(field.get_name() == "fakevec"); - // auto dim = field.get_dim(); - // // assume query vector is [0, 0, ..., 0] - // std::vector query_vector(dim, 0); - // auto& target_vec = record.entity_vecs_[0]; - // int current_index = -1; - // float min_diff = std::numeric_limits::max(); - // for (int index = 0; index < ack_count; ++index) { - // float diff = 0; - // int offset = index * dim; - // for (auto d = 0; d < dim; ++d) { - // auto v = target_vec[offset + d] - query_vector[d]; - // diff += v * v; - // } - // if (diff < min_diff) { - // min_diff = diff; - // current_index = index; - // } - // } - // QueryResult query_result; - // query_result.row_num_ = 1; - // query_result.result_distances_.push_back(min_diff); - // query_result.result_ids_.push_back(record.uids_[current_index]); - // query_result.data_chunk_ = nullptr; - // result = std::move(query_result); - // return Status::OK(); +SegmentNaive::QueryImpl(const query::QueryPtr &query, Timestamp timestamp, QueryResult &result) { +// assert(query); + + throw std::runtime_error("unimplemnted"); } template -int64_t get_barrier(const RecordType& record, Timestamp timestamp) { - auto& vec = record.timestamps_; +int64_t get_barrier(const RecordType &record, Timestamp timestamp) { + auto &vec = record.timestamps_; int64_t beg = 0; int64_t end = record.ack_responder_.GetAck(); while (beg < end) { @@ -280,11 +276,11 @@ int64_t get_barrier(const RecordType& record, Timestamp timestamp) { } Status -SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { +SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult &result) { // TODO: enable delete // TODO: enable index - if(query_info == nullptr) { + if (query_info == nullptr) { query_info = std::make_shared(); query_info->field_name = "fakevec"; query_info->topK = 10; @@ -294,12 +290,12 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult std::default_random_engine e(42); std::uniform_real_distribution<> dis(0.0, 1.0); query_info->query_raw_data.resize(query_info->num_queries * dim); - for(auto& x: query_info->query_raw_data) { + for (auto &x: query_info->query_raw_data) { x = dis(e); } } - auto& field = schema_->operator[](query_info->field_name); + auto &field = schema_->operator[](query_info->field_name); assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); @@ -308,7 +304,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult auto barrier = get_barrier(record_, timestamp); auto del_barrier = get_barrier(deleted_record_, timestamp); - auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, barrier); + auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, barrier); if (!bitmap_holder) { throw std::runtime_error("fuck"); @@ -316,13 +312,13 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult auto bitmap = &bitmap_holder->bitmap; - if(topK > barrier) { + if (topK > barrier) { topK = barrier; } - auto get_L2_distance = [dim](const float* a, const float* b) { + auto get_L2_distance = [dim](const float *a, const float *b) { float L2_distance = 0; - for(auto i = 0; i < dim; ++i) { + for (auto i = 0; i < dim; ++i) { auto d = a[i] - b[i]; L2_distance += d * d; } @@ -332,18 +328,18 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult std::vector>> records(num_queries); // TODO: optimize auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_[0]); - for(int64_t i = 0; i < barrier; ++i) { - if(i < bitmap->size() && bitmap->at(i)) { + for (int64_t i = 0; i < barrier; ++i) { + if (i < bitmap->size() && bitmap->at(i)) { continue; } auto element = vec_ptr->get_element(i); - for(auto query_id = 0; query_id < num_queries; ++query_id) { + for (auto query_id = 0; query_id < num_queries; ++query_id) { auto query_blob = query_info->query_raw_data.data() + query_id * dim; auto dis = get_L2_distance(query_blob, element); - auto& record = records[query_id]; - if(record.size() < topK) { + auto &record = records[query_id]; + if (record.size() < topK) { record.emplace(dis, i); - } else if(record.top().first > dis) { + } else if (record.top().first > dis) { record.emplace(dis, i); record.pop(); } @@ -359,11 +355,11 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult result.result_ids_.resize(row_num); result.result_distances_.resize(row_num); - for(int q_id = 0; q_id < num_queries; ++q_id) { + for (int q_id = 0; q_id < num_queries; ++q_id) { // reverse - for(int i = 0; i < topK; ++i) { + for (int i = 0; i < topK; ++i) { auto dst_id = topK - 1 - i + q_id * topK; - auto [dis, offset] = records[q_id].top(); + auto[dis, offset] = records[q_id].top(); records[q_id].pop(); result.result_ids_[dst_id] = record_.uids_[offset]; result.result_distances_[dst_id] = dis; @@ -384,39 +380,62 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult Status SegmentNaive::Close() { + if (this->record_.reserved != this->record_.ack_responder_.GetAck()) { + std::runtime_error("insert not ready"); + } + if (this->deleted_record_.reserved != this->record_.ack_responder_.GetAck()) { + std::runtime_error("delete not ready"); + } state_ = SegmentState::Closed; return Status::OK(); - // auto src_record = GetMutableRecord(); - // assert(src_record); - // - // auto dst_record = std::make_shared(schema_->size()); - // - // auto data_move = [](auto& dst_vec, const auto& src_vec) { - // assert(dst_vec.size() == 0); - // dst_vec.insert(dst_vec.begin(), src_vec.begin(), src_vec.end()); - // }; - // data_move(dst_record->uids_, src_record->uids_); - // data_move(dst_record->timestamps_, src_record->uids_); - // - // assert(src_record->entity_vecs_.size() == schema_->size()); - // assert(dst_record->entity_vecs_.size() == schema_->size()); - // for (int i = 0; i < schema_->size(); ++i) { - // data_move(dst_record->entity_vecs_[i], src_record->entity_vecs_[i]); - // } - // bool ready_old = false; - // record_immutable_ = dst_record; - // ready_immutable_.compare_exchange_strong(ready_old, true); - // if (ready_old) { - // throw std::logic_error("Close may be called twice, with potential race condition"); - // } - // return Status::OK(); +} + +template +knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry) { + auto offset_opt = schema_->get_offset(entry.field_name); + assert(offset_opt.has_value()); + auto offset = offset_opt.value(); + auto field = (*schema_)[offset]; + auto dim = field.get_dim(); + + auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode); + auto chunk_size = record_.uids_.chunk_size(); + + auto &uids = record_.uids_; + auto entities = record_.get_vec_entity(offset); + + std::vector datasets; + for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) { + auto &uids_chunk = uids.get_chunk(chunk_id); + auto &entities_chunk = entities->get_chunk(chunk_id); + int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk + : DefaultElementPerChunk; + datasets.push_back(knowhere::GenDatasetWithIds(count, dim, entities_chunk.data(), uids_chunk.data())); + } + for (auto &ds: datasets) { + indexing->Train(ds, entry.config); + } + for (auto &ds: datasets) { + indexing->Add(ds, entry.config); + } + return indexing; } Status SegmentNaive::BuildIndex() { - throw std::runtime_error("unimplemented"); - // assert(ready_immutable_); - // throw std::runtime_error("unimplemented"); + for (auto&[index_name, entry]: index_meta_->get_entries()) { + assert(entry.index_name == index_name); + const auto &field = (*schema_)[entry.field_name]; + + if (field.is_vector()) { + assert(field.get_data_type() == engine::DataType::VECTOR_FLOAT); + auto index_ptr = BuildVecIndexImpl(entry); + indexings_[index_name] = index_ptr; + } else { + throw std::runtime_error("unimplemented"); + } + } + return Status::OK(); } } // namespace milvus::dog_segment diff --git a/core/src/dog_segment/SegmentNaive.h b/core/src/dog_segment/SegmentNaive.h index 7180c8db88..fbecbe50ce 100644 --- a/core/src/dog_segment/SegmentNaive.h +++ b/core/src/dog_segment/SegmentNaive.h @@ -4,6 +4,7 @@ #include #include +#include #include "AckResponder.h" #include "ConcurrentVector.h" @@ -11,7 +12,7 @@ // #include "knowhere/index/structured_index/StructuredIndex.h" #include "query/GeneralQuery.h" #include "utils/Status.h" -using idx_t = int64_t; +#include "dog_segment/DeletedRecord.h" namespace milvus::dog_segment { struct ColumnBasedDataChunk { @@ -87,6 +88,27 @@ class SegmentNaive : public SegmentBase { return Status::OK(); } +public: + ssize_t + get_row_count() const override { + return record_.ack_responder_.GetAck(); + } + SegmentState + get_state() const override { + return state_.load(std::memory_order_relaxed); + } + ssize_t + get_deleted_count() const override { + return 0; + } + +public: + friend std::unique_ptr + CreateSegment(SchemaPtr schema, IndexMetaPtr index_meta); + explicit SegmentNaive(SchemaPtr schema, IndexMetaPtr index_meta) + : schema_(schema), index_meta_(index_meta), record_(*schema) { + } + private: struct MutableRecord { ConcurrentVector uids_; @@ -103,79 +125,31 @@ class SegmentNaive : public SegmentBase { ConcurrentVector uids_; std::vector> entity_vec_; Record(const Schema& schema); + template + auto get_vec_entity(int offset) { + return std::static_pointer_cast>(entity_vec_[offset]); + } }; tbb::concurrent_unordered_multimap uid2offset_; - struct DeletedRecord { - std::atomic reserved = 0; - AckResponder ack_responder_; - ConcurrentVector timestamps_; - ConcurrentVector uids_; - struct TmpBitmap { - // Just for query - int64_t del_barrier = 0; - std::vector bitmap; - }; - std::shared_ptr lru_; - std::shared_mutex shared_mutex_; - - DeletedRecord(): lru_(std::make_shared()) {} - auto get_lru_entry() { - std::shared_lock lck(shared_mutex_); - return lru_; - } - void insert_lru_entry(std::shared_ptr new_entry) { - std::lock_guard lck(shared_mutex_); - if(new_entry->del_barrier <= lru_->del_barrier) { - // DO NOTHING - return; - } - lru_ = std::move(new_entry); - } - }; - std::shared_ptr get_deleted_bitmap(int64_t del_barrier, Timestamp query_timestamp, int64_t insert_barrier); Status QueryImpl(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results); - public: - ssize_t - get_row_count() const override { - return record_.ack_responder_.GetAck(); - } - SegmentState - get_state() const override { - return state_.load(std::memory_order_relaxed); - } - ssize_t - get_deleted_count() const override { - return 0; - } - - public: - friend std::unique_ptr - CreateSegment(SchemaPtr schema, IndexMetaPtr index_meta); - explicit SegmentNaive(SchemaPtr schema, IndexMetaPtr index_meta) - : schema_(schema), index_meta_(index_meta), record_(*schema) { - } + template + knowhere::IndexPtr BuildVecIndexImpl(const IndexMeta::Entry& entry); private: SchemaPtr schema_; - IndexMetaPtr index_meta_; std::atomic state_ = SegmentState::Open; Record record_; DeletedRecord deleted_record_; - // tbb::concurrent_unordered_map internal_indexes_; - // std::shared_ptr record_mutable_; - // // to determined that if immutable data if available - // std::shared_ptr record_immutable_ = nullptr; - // std::unordered_map vec_indexings_; - // // TODO: scalar indexing - // // std::unordered_map scalar_indexings_; - // tbb::concurrent_unordered_multimap delete_logs_; + + IndexMetaPtr index_meta_; + std::unordered_map indexings_; // index_name => indexing }; } // namespace milvus::dog_segment diff --git a/core/src/dog_segment/segment_c.cpp b/core/src/dog_segment/segment_c.cpp index 21fae04fa6..b411c63e1b 100644 --- a/core/src/dog_segment/segment_c.cpp +++ b/core/src/dog_segment/segment_c.cpp @@ -3,6 +3,9 @@ #include "SegmentBase.h" #include "segment_c.h" #include "Partition.h" +#include +#include +#include CSegmentBase @@ -46,9 +49,6 @@ Insert(CSegmentBase c_segment, dataChunk.count = count; auto res = segment->Insert(reserved_offset, size, primary_keys, timestamps, dataChunk); - - // TODO: delete print - // std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl; return res.code(); } @@ -58,7 +58,7 @@ PreInsert(CSegmentBase c_segment, long int size) { auto segment = (milvus::dog_segment::SegmentBase*)c_segment; // TODO: delete print - // std::cout << "PreInsert segment " << std::endl; + std::cout << "PreInsert segment " << std::endl; return segment->PreInsert(size); } @@ -81,7 +81,7 @@ PreDelete(CSegmentBase c_segment, long int size) { auto segment = (milvus::dog_segment::SegmentBase*)c_segment; // TODO: delete print - // std::cout << "PreDelete segment " << std::endl; + std::cout << "PreDelete segment " << std::endl; return segment->PreDelete(size); } @@ -114,6 +114,13 @@ Close(CSegmentBase c_segment) { return status.code(); } +int +BuildIndex(CSegmentBase c_segment) { + auto segment = (milvus::dog_segment::SegmentBase*)c_segment; + auto status = segment->BuildIndex(); + return status.code(); +} + bool IsOpened(CSegmentBase c_segment) { diff --git a/core/src/dog_segment/segment_c.h b/core/src/dog_segment/segment_c.h index 937ec69578..9d64d9f68e 100644 --- a/core/src/dog_segment/segment_c.h +++ b/core/src/dog_segment/segment_c.h @@ -50,6 +50,9 @@ Search(CSegmentBase c_segment, int Close(CSegmentBase c_segment); +int +BuildIndex(CSegmentBase c_segment); + bool IsOpened(CSegmentBase c_segment); diff --git a/core/src/index/.gitignore b/core/src/index/.gitignore new file mode 100644 index 0000000000..c263e61d36 --- /dev/null +++ b/core/src/index/.gitignore @@ -0,0 +1 @@ +cmake_build \ No newline at end of file diff --git a/core/src/index/CMakeLists.txt b/core/src/index/CMakeLists.txt new file mode 100644 index 0000000000..5770ead717 --- /dev/null +++ b/core/src/index/CMakeLists.txt @@ -0,0 +1,84 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +cmake_minimum_required(VERSION 3.12) +message(STATUS "------------------------------KNOWHERE-----------------------------------") +message(STATUS "Building using CMake version: ${CMAKE_VERSION}") + +project(knowhere LANGUAGES C CXX) +set(CMAKE_CXX_STANDARD 17) + +# if no build build type is specified, default to release builds +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif (NOT CMAKE_BUILD_TYPE) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + message(STATUS "building milvus_engine on x86 architecture") + set(KNOWHERE_BUILD_ARCH x86_64) +elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "(ppc)") + message(STATUS "building milvus_engine on ppc architecture") + set(KNOWHERE_BUILD_ARCH ppc64le) +else () + message(WARNING "unknown processor type") + message(WARNING "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}") + set(KNOWHERE_BUILD_ARCH unknown) +endif () + +if (CMAKE_BUILD_TYPE STREQUAL "Release") + set(BUILD_TYPE "release") +else () + set(BUILD_TYPE "debug") +endif () +message(STATUS "Build type = ${BUILD_TYPE}") + +set(INDEX_SOURCE_DIR ${PROJECT_SOURCE_DIR}) +set(INDEX_BINARY_DIR ${PROJECT_BINARY_DIR}) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${INDEX_SOURCE_DIR}/cmake") + +include(ExternalProject) +include(DefineOptionsCore) +include(BuildUtilsCore) + +using_ccache_if_defined( KNOWHERE_USE_CCACHE ) + +message(STATUS "Building Knowhere CPU version") + +if (MILVUS_SUPPORT_SPTAG) + message(STATUS "Building Knowhere with SPTAG supported") + add_compile_definitions("MILVUS_SUPPORT_SPTAG") +endif () + +include(ThirdPartyPackagesCore) + +if (CMAKE_BUILD_TYPE STREQUAL "Release") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp") +else () + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp") +endif () + +add_subdirectory(knowhere) + +if (BUILD_COVERAGE STREQUAL "ON") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage") +endif () + +set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE) + +#if (KNOWHERE_BUILD_TESTS) +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_DISABLE_LOGS") +# add_subdirectory(unittest) +#endif () + +config_summary() + diff --git a/core/src/index/archive/KnowhereResource.cpp b/core/src/index/archive/KnowhereResource.cpp new file mode 100644 index 0000000000..e8d457c208 --- /dev/null +++ b/core/src/index/archive/KnowhereResource.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "index/archive/KnowhereResource.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +#include "config/ServerConfig.h" +#include "faiss/FaissHook.h" +// #include "scheduler/Utils.h" +#include "utils/Error.h" +#include "utils/Log.h" + +// #include +#include +#include +#include +#include +#include + +namespace milvus { +namespace engine { + +constexpr int64_t M_BYTE = 1024 * 1024; + +Status +KnowhereResource::Initialize() { + auto simd_type = config.engine.simd_type(); + if (simd_type == SimdType::AVX512) { + faiss::faiss_use_avx512 = true; + faiss::faiss_use_avx2 = false; + faiss::faiss_use_sse = false; + } else if (simd_type == SimdType::AVX2) { + faiss::faiss_use_avx512 = false; + faiss::faiss_use_avx2 = true; + faiss::faiss_use_sse = false; + } else if (simd_type == SimdType::SSE) { + faiss::faiss_use_avx512 = false; + faiss::faiss_use_avx2 = false; + faiss::faiss_use_sse = true; + } else { + faiss::faiss_use_avx512 = true; + faiss::faiss_use_avx2 = true; + faiss::faiss_use_sse = true; + } + std::string cpu_flag; + if (faiss::hook_init(cpu_flag)) { + std::cout << "FAISS hook " << cpu_flag << std::endl; + LOG_ENGINE_DEBUG_ << "FAISS hook " << cpu_flag; + } else { + return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!"); + } + +#ifdef MILVUS_GPU_VERSION + bool enable_gpu = config.gpu.enable(); + // fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false); + if (!enable_gpu) { + return Status::OK(); + } + + struct GpuResourceSetting { + int64_t pinned_memory = 256 * M_BYTE; + int64_t temp_memory = 256 * M_BYTE; + int64_t resource_num = 2; + }; + using GpuResourcesArray = std::map; + GpuResourcesArray gpu_resources; + + // get build index gpu resource + std::vector build_index_gpus = ParseGPUDevices(config.gpu.build_index_devices()); + + for (auto gpu_id : build_index_gpus) { + gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting())); + } + + // get search gpu resource + std::vector search_gpus = ParseGPUDevices(config.gpu.search_devices()); + + for (auto& gpu_id : search_gpus) { + gpu_resources.insert(std::make_pair(gpu_id, GpuResourceSetting())); + } + + // init gpu resources + for (auto& gpu_resource : gpu_resources) { + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(gpu_resource.first, gpu_resource.second.pinned_memory, + gpu_resource.second.temp_memory, + gpu_resource.second.resource_num); + } +#endif + + return Status::OK(); +} + +Status +KnowhereResource::Finalize() { +#ifdef MILVUS_GPU_VERSION + knowhere::FaissGpuResourceMgr::GetInstance().Free(); // free gpu resource. +#endif + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/index/archive/KnowhereResource.h b/core/src/index/archive/KnowhereResource.h new file mode 100644 index 0000000000..1e863cf27e --- /dev/null +++ b/core/src/index/archive/KnowhereResource.h @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class KnowhereResource { + public: + static Status + Initialize(); + + static Status + Finalize(); +}; + +} // namespace engine +} // namespace milvus diff --git a/core/src/index/build.sh b/core/src/index/build.sh new file mode 100755 index 0000000000..512d6316de --- /dev/null +++ b/core/src/index/build.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +BUILD_TYPE="Debug" +BUILD_UNITTEST="OFF" +INSTALL_PREFIX=$(pwd)/cmake_build +MAKE_CLEAN="OFF" +PROFILING="OFF" + +while getopts "p:d:t:uhrcgm" arg +do + case $arg in + t) + BUILD_TYPE=$OPTARG # BUILD_TYPE + ;; + u) + echo "Build and run unittest cases" ; + BUILD_UNITTEST="ON"; + ;; + p) + INSTALL_PREFIX=$OPTARG + ;; + r) + if [[ -d cmake_build ]]; then + rm ./cmake_build -r + MAKE_CLEAN="ON" + fi + ;; + g) + PROFILING="ON" + ;; + h) # help + echo " + +parameter: +-t: build type(default: Debug) +-u: building unit test options(default: OFF) +-p: install prefix(default: $(pwd)/knowhere) +-r: remove previous build directory(default: OFF) +-g: profiling(default: OFF) + +usage: +./build.sh -t \${BUILD_TYPE} [-u] [-h] [-g] [-r] [-c] + " + exit 0 + ;; + ?) + echo "unknown argument" + exit 1 + ;; + esac +done + +if [[ ! -d cmake_build ]]; then + mkdir cmake_build + MAKE_CLEAN="ON" +fi + +cd cmake_build + +CUDA_COMPILER=/usr/local/cuda/bin/nvcc + +if [[ ${MAKE_CLEAN} == "ON" ]]; then + CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \ + -DMILVUS_ENABLE_PROFILING=${PROFILING} \ + ../" + echo ${CMAKE_CMD} + + ${CMAKE_CMD} + make clean +fi + +make -j 8 || exit 1 + +make install || exit 1 diff --git a/core/src/index/cmake/BuildUtilsCore.cmake b/core/src/index/cmake/BuildUtilsCore.cmake new file mode 100644 index 0000000000..6a85bb0b3d --- /dev/null +++ b/core/src/index/cmake/BuildUtilsCore.cmake @@ -0,0 +1,218 @@ +# Define a function that check last file modification +function(Check_Last_Modify cache_check_lists_file_path working_dir last_modified_commit_id) + if (EXISTS "${working_dir}") + if (EXISTS "${cache_check_lists_file_path}") + set(GIT_LOG_SKIP_NUM 0) + set(_MATCH_ALL ON CACHE BOOL "Match all") + set(_LOOP_STATUS ON CACHE BOOL "Whether out of loop") + file(STRINGS ${cache_check_lists_file_path} CACHE_IGNORE_TXT) + while (_LOOP_STATUS) + foreach (_IGNORE_ENTRY ${CACHE_IGNORE_TXT}) + if (NOT _IGNORE_ENTRY MATCHES "^[^#]+") + continue() + endif () + + set(_MATCH_ALL OFF) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --name-status --pretty= WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE CHANGE_FILES) + if (NOT CHANGE_FILES STREQUAL "") + string(REPLACE "\n" ";" _CHANGE_FILES ${CHANGE_FILES}) + foreach (_FILE_ENTRY ${_CHANGE_FILES}) + string(REGEX MATCH "[^ \t]+$" _FILE_NAME ${_FILE_ENTRY}) + execute_process(COMMAND sh -c "echo ${_FILE_NAME} | grep ${_IGNORE_ENTRY}" RESULT_VARIABLE return_code) + if (return_code EQUAL 0) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + set(_LOOP_STATUS OFF) + endif () + endforeach () + else () + set(_LOOP_STATUS OFF) + endif () + endforeach () + + if (_MATCH_ALL) + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + set(_LOOP_STATUS OFF) + endif () + + math(EXPR GIT_LOG_SKIP_NUM "${GIT_LOG_SKIP_NUM} + 1") + endwhile (_LOOP_STATUS) + else () + execute_process(COMMAND git log --no-merges -1 --skip=${GIT_LOG_SKIP_NUM} --pretty=%H WORKING_DIRECTORY ${working_dir} OUTPUT_VARIABLE LAST_MODIFIED_COMMIT_ID) + set(${last_modified_commit_id} ${LAST_MODIFIED_COMMIT_ID} PARENT_SCOPE) + endif () + else () + message(FATAL_ERROR "The directory ${working_dir} does not exist") + endif () +endfunction() + +# Define a function that extracts a cached package +function(ExternalProject_Use_Cache project_name package_file install_path) + message(STATUS "Will use cached package file: ${package_file}") + + ExternalProject_Add(${project_name} + DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E echo + "No download step needed (using cached package)" + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E echo + "No configure step needed (using cached package)" + BUILD_COMMAND ${CMAKE_COMMAND} -E echo + "No build step needed (using cached package)" + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo + "No install step needed (using cached package)" + ) + + # We want our tar files to contain the Install/ prefix (not for any + # very special reason, only for consistency and so that we can identify them + # in the extraction logs) which means that we must extract them in the + # binary (top-level build) directory to have them installed in the right + # place for subsequent ExternalProjects to pick them up. It seems that the + # only way to control the working directory is with Add_Step! + ExternalProject_Add_Step(${project_name} extract + ALWAYS 1 + COMMAND + ${CMAKE_COMMAND} -E echo + "Extracting ${package_file} to ${install_path}" + COMMAND + ${CMAKE_COMMAND} -E tar xzf ${package_file} ${install_path} + WORKING_DIRECTORY ${INDEX_BINARY_DIR} + ) + + ExternalProject_Add_StepTargets(${project_name} extract) +endfunction() + +# Define a function that to create a new cached package +function(ExternalProject_Create_Cache project_name package_file install_path cache_username cache_password cache_path) + if (EXISTS ${package_file}) + message(STATUS "Removing existing package file: ${package_file}") + file(REMOVE ${package_file}) + endif () + + string(REGEX REPLACE "(.+)/.+$" "\\1" package_dir ${package_file}) + if (NOT EXISTS ${package_dir}) + file(MAKE_DIRECTORY ${package_dir}) + endif () + + message(STATUS "Will create cached package file: ${package_file}") + + ExternalProject_Add_Step(${project_name} package + DEPENDEES install + BYPRODUCTS ${package_file} + COMMAND ${CMAKE_COMMAND} -E echo "Updating cached package file: ${package_file}" + COMMAND ${CMAKE_COMMAND} -E tar czvf ${package_file} ${install_path} + COMMAND ${CMAKE_COMMAND} -E echo "Uploading package file ${package_file} to ${cache_path}" + COMMAND curl -u${cache_username}:${cache_password} -T ${package_file} ${cache_path} + ) + + ExternalProject_Add_StepTargets(${project_name} package) +endfunction() + +function(ADD_THIRDPARTY_LIB LIB_NAME) + set(options) + set(one_value_args SHARED_LIB STATIC_LIB) + set(multi_value_args DEPS INCLUDE_DIRECTORIES) + cmake_parse_arguments(ARG + "${options}" + "${one_value_args}" + "${multi_value_args}" + ${ARGN}) + if (ARG_UNPARSED_ARGUMENTS) + message(SEND_ERROR "Error: unrecognized arguments: ${ARG_UNPARSED_ARGUMENTS}") + endif () + + if (ARG_STATIC_LIB AND ARG_SHARED_LIB) + if (NOT ARG_STATIC_LIB) + message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}") + endif () + + set(AUG_LIB_NAME "${LIB_NAME}_static") + add_library(${AUG_LIB_NAME} STATIC IMPORTED) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}") + if (ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif () + message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}") + if (ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif () + + set(AUG_LIB_NAME "${LIB_NAME}_shared") + add_library(${AUG_LIB_NAME} SHARED IMPORTED) + + if (WIN32) + # Mark the ".lib" location as part of a Windows DLL + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}") + else () + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}") + endif () + if (ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif () + message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}") + if (ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif () + elseif (ARG_STATIC_LIB) + set(AUG_LIB_NAME "${LIB_NAME}_static") + add_library(${AUG_LIB_NAME} STATIC IMPORTED) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_STATIC_LIB}") + if (ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif () + message(STATUS "Added static library dependency ${AUG_LIB_NAME}: ${ARG_STATIC_LIB}") + if (ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif () + elseif (ARG_SHARED_LIB) + set(AUG_LIB_NAME "${LIB_NAME}_shared") + add_library(${AUG_LIB_NAME} SHARED IMPORTED) + + if (WIN32) + # Mark the ".lib" location as part of a Windows DLL + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_IMPLIB "${ARG_SHARED_LIB}") + else () + set_target_properties(${AUG_LIB_NAME} + PROPERTIES IMPORTED_LOCATION "${ARG_SHARED_LIB}") + endif () + message(STATUS "Added shared library dependency ${AUG_LIB_NAME}: ${ARG_SHARED_LIB}") + if (ARG_DEPS) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_LINK_LIBRARIES "${ARG_DEPS}") + endif () + if (ARG_INCLUDE_DIRECTORIES) + set_target_properties(${AUG_LIB_NAME} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${ARG_INCLUDE_DIRECTORIES}") + endif () + else () + message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}") + endif () +endfunction() + +MACRO(using_ccache_if_defined KNOWHERE_USE_CCACHE) + if (MILVUS_USE_CCACHE) + find_program(CCACHE_FOUND ccache) + if (CCACHE_FOUND) + message(STATUS "Using ccache: ${CCACHE_FOUND}") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND}) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND}) + # let ccache preserve C++ comments, because some of them may be + # meaningful to the compiler + set(ENV{CCACHE_COMMENTS} "1") + endif (CCACHE_FOUND) + endif () +ENDMACRO(using_ccache_if_defined) diff --git a/core/src/index/cmake/DefineOptionsCore.cmake b/core/src/index/cmake/DefineOptionsCore.cmake new file mode 100644 index 0000000000..5db0fa7d04 --- /dev/null +++ b/core/src/index/cmake/DefineOptionsCore.cmake @@ -0,0 +1,169 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +macro(set_option_category name) + set(KNOWHERE_OPTION_CATEGORY ${name}) + list(APPEND "KNOWHERE_OPTION_CATEGORIES" ${name}) +endmacro() + +macro(define_option name description default) + option(${name} ${description} ${default}) + list(APPEND "KNOWHERE_${KNOWHERE_OPTION_CATEGORY}_OPTION_NAMES" ${name}) + set("${name}_OPTION_DESCRIPTION" ${description}) + set("${name}_OPTION_DEFAULT" ${default}) + set("${name}_OPTION_TYPE" "bool") +endmacro() + +function(list_join lst glue out) + if ("${${lst}}" STREQUAL "") + set(${out} "" PARENT_SCOPE) + return() + endif () + + list(GET ${lst} 0 joined) + list(REMOVE_AT ${lst} 0) + foreach (item ${${lst}}) + set(joined "${joined}${glue}${item}") + endforeach () + set(${out} ${joined} PARENT_SCOPE) +endfunction() + +macro(define_option_string name description default) + set(${name} ${default} CACHE STRING ${description}) + list(APPEND "KNOWHERE_${KNOWHERE_OPTION_CATEGORY}_OPTION_NAMES" ${name}) + set("${name}_OPTION_DESCRIPTION" ${description}) + set("${name}_OPTION_DEFAULT" "\"${default}\"") + set("${name}_OPTION_TYPE" "string") + + set("${name}_OPTION_ENUM" ${ARGN}) + list_join("${name}_OPTION_ENUM" "|" "${name}_OPTION_ENUM") + if (NOT ("${${name}_OPTION_ENUM}" STREQUAL "")) + set_property(CACHE ${name} PROPERTY STRINGS ${ARGN}) + endif () +endmacro() + +#---------------------------------------------------------------------- +set_option_category("Thirdparty") + +set(KNOWHERE_DEPENDENCY_SOURCE_DEFAULT "BUNDLED") + +define_option_string(KNOWHERE_DEPENDENCY_SOURCE + "Method to use for acquiring KNOWHERE's build dependencies" + "${KNOWHERE_DEPENDENCY_SOURCE_DEFAULT}" + "AUTO" + "BUNDLED" + "SYSTEM") + +define_option(KNOWHERE_USE_CCACHE "Use ccache when compiling (if available)" OFF) + +define_option(KNOWHERE_VERBOSE_THIRDPARTY_BUILD + "Show output from ExternalProjects rather than just logging to files" ON) + +define_option(KNOWHERE_BOOST_USE_SHARED "Rely on boost shared libraries where relevant" OFF) + +define_option(KNOWHERE_BOOST_VENDORED "Use vendored Boost instead of existing Boost. \ +Note that this requires linking Boost statically" OFF) + +define_option(KNOWHERE_BOOST_HEADER_ONLY "Use only BOOST headers" OFF) + +define_option(KNOWHERE_WITH_ARROW "Build with ARROW" OFF) + +define_option(KNOWHERE_WITH_OPENBLAS "Build with OpenBLAS library" ON) + +define_option(KNOWHERE_WITH_FAISS "Build with FAISS library" ON) + +define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" OFF) + +define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF) + +define_option(MILVUS_CUDA_ARCH "Build with CUDA arch" "DEFAULT") + +#---------------------------------------------------------------------- +set_option_category("Test and benchmark") + +if (BUILD_UNIT_TEST) + define_option(KNOWHERE_BUILD_TESTS "Build the KNOWHERE googletest unit tests" ON) +else () + define_option(KNOWHERE_BUILD_TESTS "Build the KNOWHERE googletest unit tests" OFF) +endif (BUILD_UNIT_TEST) + +#---------------------------------------------------------------------- +macro(config_summary) + message(STATUS "---------------------------------------------------------------------") + message(STATUS "KNOWHERE version: ${KNOWHERE_VERSION}") + message(STATUS) + message(STATUS "Build configuration summary:") + + message(STATUS " Generator: ${CMAKE_GENERATOR}") + message(STATUS " Build type: ${CMAKE_BUILD_TYPE}") + message(STATUS " Source directory: ${CMAKE_CURRENT_SOURCE_DIR}") + if (${CMAKE_EXPORT_COMPILE_COMMANDS}) + message( + STATUS " Compile commands: ${CMAKE_CURRENT_BINARY_DIR}/compile_commands.json") + endif () + + foreach (category ${KNOWHERE_OPTION_CATEGORIES}) + + message(STATUS) + message(STATUS "${category} options:") + + set(option_names ${KNOWHERE_${category}_OPTION_NAMES}) + + set(max_value_length 0) + foreach (name ${option_names}) + string(LENGTH "\"${${name}}\"" value_length) + if (${max_value_length} LESS ${value_length}) + set(max_value_length ${value_length}) + endif () + endforeach () + + foreach (name ${option_names}) + if ("${${name}_OPTION_TYPE}" STREQUAL "string") + set(value "\"${${name}}\"") + else () + set(value "${${name}}") + endif () + + set(default ${${name}_OPTION_DEFAULT}) + set(description ${${name}_OPTION_DESCRIPTION}) + string(LENGTH ${description} description_length) + if (${description_length} LESS 70) + string( + SUBSTRING + " " + ${description_length} -1 description_padding) + else () + set(description_padding " + ") + endif () + + set(comment "[${name}]") + + if ("${value}" STREQUAL "${default}") + set(comment "[default] ${comment}") + endif () + + if (NOT ("${${name}_OPTION_ENUM}" STREQUAL "")) + set(comment "${comment} [${${name}_OPTION_ENUM}]") + endif () + + string( + SUBSTRING "${value} " + 0 ${max_value_length} value) + + message(STATUS " ${description} ${description_padding} ${value} ${comment}") + endforeach () + + endforeach () + +endmacro() diff --git a/core/src/index/cmake/FindArrow.cmake b/core/src/index/cmake/FindArrow.cmake new file mode 100644 index 0000000000..fdf7c1437f --- /dev/null +++ b/core/src/index/cmake/FindArrow.cmake @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# - Find Arrow (arrow/api.h, libarrow.a, libarrow.so) +# This module defines +# ARROW_FOUND, whether Arrow has been found +# ARROW_FULL_SO_VERSION, full shared object version of found Arrow "100.0.0" +# ARROW_IMPORT_LIB, path to libarrow's import library (Windows only) +# ARROW_INCLUDE_DIR, directory containing headers +# ARROW_LIBS, deprecated. Use ARROW_LIB_DIR instead +# ARROW_LIB_DIR, directory containing Arrow libraries +# ARROW_SHARED_IMP_LIB, deprecated. Use ARROW_IMPORT_LIB instead +# ARROW_SHARED_LIB, path to libarrow's shared library +# ARROW_SO_VERSION, shared object version of found Arrow such as "100" +# ARROW_STATIC_LIB, path to libarrow.a +# ARROW_VERSION, version of found Arrow +# ARROW_VERSION_MAJOR, major version of found Arrow +# ARROW_VERSION_MINOR, minor version of found Arrow +# ARROW_VERSION_PATCH, patch version of found Arrow + +include(FindPkgConfig) +include(FindPackageHandleStandardArgs) + +set(ARROW_SEARCH_LIB_PATH_SUFFIXES) +if(CMAKE_LIBRARY_ARCHITECTURE) + list(APPEND ARROW_SEARCH_LIB_PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}") +endif() +list(APPEND ARROW_SEARCH_LIB_PATH_SUFFIXES + "lib64" + "lib32" + "lib" + "bin") +set(ARROW_CONFIG_SUFFIXES + "_RELEASE" + "_RELWITHDEBINFO" + "_MINSIZEREL" + "_DEBUG" + "") +if(CMAKE_BUILD_TYPE) + string(TOUPPER ${CMAKE_BUILD_TYPE} ARROW_CONFIG_SUFFIX_PREFERRED) + set(ARROW_CONFIG_SUFFIX_PREFERRED "_${ARROW_CONFIG_SUFFIX_PREFERRED}") + list(INSERT ARROW_CONFIG_SUFFIXES 0 "${ARROW_CONFIG_SUFFIX_PREFERRED}") +endif() + +if(NOT DEFINED ARROW_MSVC_STATIC_LIB_SUFFIX) + if(MSVC) + set(ARROW_MSVC_STATIC_LIB_SUFFIX "_static") + else() + set(ARROW_MSVC_STATIC_LIB_SUFFIX "") + endif() +endif() + +# Internal function. +# +# Set shared library name for ${base_name} to ${output_variable}. +# +# Example: +# arrow_build_shared_library_name(ARROW_SHARED_LIBRARY_NAME arrow) +# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.so on Linux +# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dylib on macOS +# # -> ARROW_SHARED_LIBRARY_NAME=arrow.dll with MSVC on Windows +# # -> ARROW_SHARED_LIBRARY_NAME=libarrow.dll with MinGW on Windows +function(arrow_build_shared_library_name output_variable base_name) + set(${output_variable} + "${CMAKE_SHARED_LIBRARY_PREFIX}${base_name}${CMAKE_SHARED_LIBRARY_SUFFIX}" + PARENT_SCOPE) +endfunction() + +# Internal function. +# +# Set import library name for ${base_name} to ${output_variable}. +# This is useful only for MSVC build. Import library is used only +# with MSVC build. +# +# Example: +# arrow_build_import_library_name(ARROW_IMPORT_LIBRARY_NAME arrow) +# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on Linux (meaningless) +# # -> ARROW_IMPORT_LIBRARY_NAME=arrow on macOS (meaningless) +# # -> ARROW_IMPORT_LIBRARY_NAME=arrow.lib with MSVC on Windows +# # -> ARROW_IMPORT_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows +function(arrow_build_import_library_name output_variable base_name) + set(${output_variable} + "${CMAKE_IMPORT_LIBRARY_PREFIX}${base_name}${CMAKE_IMPORT_LIBRARY_SUFFIX}" + PARENT_SCOPE) +endfunction() + +# Internal function. +# +# Set static library name for ${base_name} to ${output_variable}. +# +# Example: +# arrow_build_static_library_name(ARROW_STATIC_LIBRARY_NAME arrow) +# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on Linux +# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.a on macOS +# # -> ARROW_STATIC_LIBRARY_NAME=arrow.lib with MSVC on Windows +# # -> ARROW_STATIC_LIBRARY_NAME=libarrow.dll.a with MinGW on Windows +function(arrow_build_static_library_name output_variable base_name) + set( + ${output_variable} + "${CMAKE_STATIC_LIBRARY_PREFIX}${base_name}${ARROW_MSVC_STATIC_LIB_SUFFIX}${CMAKE_STATIC_LIBRARY_SUFFIX}" + PARENT_SCOPE) +endfunction() + +# Internal function. +# +# Set macro value for ${macro_name} in ${header_content} to ${output_variable}. +# +# Example: +# arrow_extract_macro_value(version_major +# "ARROW_VERSION_MAJOR" +# "#define ARROW_VERSION_MAJOR 1.0.0") +# # -> version_major=1.0.0 +function(arrow_extract_macro_value output_variable macro_name header_content) + string(REGEX MATCH "#define +${macro_name} +[^\r\n]+" macro_definition + "${header_content}") + string(REGEX + REPLACE "^#define +${macro_name} +(.+)$" "\\1" macro_value "${macro_definition}") + set(${output_variable} "${macro_value}" PARENT_SCOPE) +endfunction() + +# Internal macro only for arrow_find_package. +# +# Find package in HOME. +macro(arrow_find_package_home) + find_path(${prefix}_include_dir "${header_path}" + PATHS "${home}" + PATH_SUFFIXES "include" + NO_DEFAULT_PATH) + set(include_dir "${${prefix}_include_dir}") + set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE) + + if(MSVC) + set(CMAKE_SHARED_LIBRARY_SUFFIXES_ORIGINAL ${CMAKE_FIND_LIBRARY_SUFFIXES}) + # .dll isn't found by find_library with MSVC because .dll isn't included in + # CMAKE_FIND_LIBRARY_SUFFIXES. + list(APPEND CMAKE_FIND_LIBRARY_SUFFIXES "${CMAKE_SHARED_LIBRARY_SUFFIX}") + endif() + find_library(${prefix}_shared_lib + NAMES "${shared_lib_name}" + PATHS "${home}" + PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES} + NO_DEFAULT_PATH) + if(MSVC) + set(CMAKE_SHARED_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_ORIGINAL}) + endif() + set(shared_lib "${${prefix}_shared_lib}") + set(${prefix}_SHARED_LIB "${shared_lib}" PARENT_SCOPE) + if(shared_lib) + add_library(${target_shared} SHARED IMPORTED) + set_target_properties(${target_shared} PROPERTIES IMPORTED_LOCATION "${shared_lib}") + if(include_dir) + set_target_properties(${target_shared} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}") + endif() + find_library(${prefix}_import_lib + NAMES "${import_lib_name}" + PATHS "${home}" + PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES} + NO_DEFAULT_PATH) + set(import_lib "${${prefix}_import_lib}") + set(${prefix}_IMPORT_LIB "${import_lib}" PARENT_SCOPE) + if(import_lib) + set_target_properties(${target_shared} PROPERTIES IMPORTED_IMPLIB "${import_lib}") + endif() + endif() + + find_library(${prefix}_static_lib + NAMES "${static_lib_name}" + PATHS "${home}" + PATH_SUFFIXES ${ARROW_SEARCH_LIB_PATH_SUFFIXES} + NO_DEFAULT_PATH) + set(static_lib "${${prefix}_static_lib}") + set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE) + if(static_lib) + add_library(${target_static} STATIC IMPORTED) + set_target_properties(${target_static} PROPERTIES IMPORTED_LOCATION "${static_lib}") + if(include_dir) + set_target_properties(${target_static} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}") + endif() + endif() +endmacro() + +# Internal macro only for arrow_find_package. +# +# Find package by CMake package configuration. +macro(arrow_find_package_cmake_package_configuration) + # ARROW-5575: We need to split target files for each component + if(TARGET ${target_shared} OR TARGET ${target_static}) + set(${cmake_package_name}_FOUND TRUE) + else() + find_package(${cmake_package_name} CONFIG) + endif() + if(${cmake_package_name}_FOUND) + set(${prefix}_USE_CMAKE_PACKAGE_CONFIG TRUE PARENT_SCOPE) + if(TARGET ${target_shared}) + foreach(suffix ${ARROW_CONFIG_SUFFIXES}) + get_target_property(shared_lib ${target_shared} IMPORTED_LOCATION${suffix}) + if(shared_lib) + # Remove shared library version: + # libarrow.so.100.0.0 -> libarrow.so + # Because ARROW_HOME and pkg-config approaches don't add + # shared library version. + string(REGEX + REPLACE "(${CMAKE_SHARED_LIBRARY_SUFFIX})[.0-9]+$" "\\1" shared_lib + "${shared_lib}") + set(${prefix}_SHARED_LIB "${shared_lib}" PARENT_SCOPE) + break() + endif() + endforeach() + endif() + if(TARGET ${target_static}) + foreach(suffix ${ARROW_CONFIG_SUFFIXES}) + get_target_property(static_lib ${target_static} IMPORTED_LOCATION${suffix}) + if(static_lib) + set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE) + break() + endif() + endforeach() + endif() + endif() +endmacro() + +# Internal macro only for arrow_find_package. +# +# Find package by pkg-config. +macro(arrow_find_package_pkg_config) + pkg_check_modules(${prefix}_PC ${pkg_config_name}) + if(${prefix}_PC_FOUND) + set(${prefix}_USE_PKG_CONFIG TRUE PARENT_SCOPE) + + set(include_dir "${${prefix}_PC_INCLUDEDIR}") + set(lib_dir "${${prefix}_PC_LIBDIR}") + set(shared_lib_paths "${${prefix}_PC_LINK_LIBRARIES}") + # Use the first shared library path as the IMPORTED_LOCATION + # for ${target_shared}. This assumes that the first shared library + # path is the shared library path for this module. + list(GET shared_lib_paths 0 first_shared_lib_path) + # Use the rest shared library paths as the INTERFACE_LINK_LIBRARIES + # for ${target_shared}. This assumes that the rest shared library + # paths are dependency library paths for this module. + list(LENGTH shared_lib_paths n_shared_lib_paths) + if(n_shared_lib_paths LESS_EQUAL 1) + set(rest_shared_lib_paths) + else() + list(SUBLIST + shared_lib_paths + 1 + -1 + rest_shared_lib_paths) + endif() + + set(${prefix}_VERSION "${${prefix}_PC_VERSION}" PARENT_SCOPE) + set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE) + set(${prefix}_SHARED_LIB "${first_shared_lib_path}" PARENT_SCOPE) + + add_library(${target_shared} SHARED IMPORTED) + set_target_properties(${target_shared} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + "${include_dir}" + INTERFACE_LINK_LIBRARIES + "${rest_shared_lib_paths}" + IMPORTED_LOCATION + "${first_shared_lib_path}") + + find_library(${prefix}_static_lib + NAMES "${static_lib_name}" + PATHS "${lib_dir}" + NO_DEFAULT_PATH) + set(static_lib "${${prefix}_static_lib}") + set(${prefix}_STATIC_LIB "${static_lib}" PARENT_SCOPE) + if(static_lib) + add_library(${target_static} STATIC IMPORTED) + set_target_properties(${target_static} + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${include_dir}" + IMPORTED_LOCATION "${static_lib}") + endif() + endif() +endmacro() + +function(arrow_find_package + prefix + home + base_name + header_path + cmake_package_name + pkg_config_name) + arrow_build_shared_library_name(shared_lib_name ${base_name}) + arrow_build_import_library_name(import_lib_name ${base_name}) + arrow_build_static_library_name(static_lib_name ${base_name}) + + set(target_shared ${base_name}_shared) + set(target_static ${base_name}_static) + + if(home) + arrow_find_package_home() + set(${prefix}_FIND_APPROACH "HOME: ${home}" PARENT_SCOPE) + else() + arrow_find_package_cmake_package_configuration() + if(${cmake_package_name}_FOUND) + set(${prefix}_FIND_APPROACH + "CMake package configuration: ${cmake_package_name}" + PARENT_SCOPE) + else() + arrow_find_package_pkg_config() + set(${prefix}_FIND_APPROACH "pkg-config: ${pkg_config_name}" PARENT_SCOPE) + endif() + endif() + + if(NOT include_dir) + if(TARGET ${target_shared}) + get_target_property(include_dir ${target_shared} INTERFACE_INCLUDE_DIRECTORIES) + elseif(TARGET ${target_static}) + get_target_property(include_dir ${target_static} INTERFACE_INCLUDE_DIRECTORIES) + endif() + endif() + if(include_dir) + set(${prefix}_INCLUDE_DIR "${include_dir}" PARENT_SCOPE) + endif() + + if(shared_lib) + get_filename_component(lib_dir "${shared_lib}" DIRECTORY) + elseif(static_lib) + get_filename_component(lib_dir "${static_lib}" DIRECTORY) + else() + set(lib_dir NOTFOUND) + endif() + set(${prefix}_LIB_DIR "${lib_dir}" PARENT_SCOPE) + # For backward compatibility + set(${prefix}_LIBS "${lib_dir}" PARENT_SCOPE) +endfunction() + +if(NOT "$ENV{ARROW_HOME}" STREQUAL "") + file(TO_CMAKE_PATH "$ENV{ARROW_HOME}" ARROW_HOME) +endif() +arrow_find_package(ARROW + "${ARROW_HOME}" + arrow + arrow/api.h + Arrow + arrow) + +if(ARROW_HOME) + if(ARROW_INCLUDE_DIR) + file(READ "${ARROW_INCLUDE_DIR}/arrow/util/config.h" ARROW_CONFIG_H_CONTENT) + arrow_extract_macro_value(ARROW_VERSION_MAJOR "ARROW_VERSION_MAJOR" + "${ARROW_CONFIG_H_CONTENT}") + arrow_extract_macro_value(ARROW_VERSION_MINOR "ARROW_VERSION_MINOR" + "${ARROW_CONFIG_H_CONTENT}") + arrow_extract_macro_value(ARROW_VERSION_PATCH "ARROW_VERSION_PATCH" + "${ARROW_CONFIG_H_CONTENT}") + if("${ARROW_VERSION_MAJOR}" STREQUAL "" + OR "${ARROW_VERSION_MINOR}" STREQUAL "" + OR "${ARROW_VERSION_PATCH}" STREQUAL "") + set(ARROW_VERSION "0.0.0") + else() + set(ARROW_VERSION + "${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH}") + endif() + + arrow_extract_macro_value(ARROW_SO_VERSION_QUOTED "ARROW_SO_VERSION" + "${ARROW_CONFIG_H_CONTENT}") + string(REGEX REPLACE "^\"(.+)\"$" "\\1" ARROW_SO_VERSION "${ARROW_SO_VERSION_QUOTED}") + arrow_extract_macro_value(ARROW_FULL_SO_VERSION_QUOTED "ARROW_FULL_SO_VERSION" + "${ARROW_CONFIG_H_CONTENT}") + string(REGEX + REPLACE "^\"(.+)\"$" "\\1" ARROW_FULL_SO_VERSION + "${ARROW_FULL_SO_VERSION_QUOTED}") + endif() +else() + if(ARROW_USE_CMAKE_PACKAGE_CONFIG) + find_package(Arrow CONFIG) + elseif(ARROW_USE_PKG_CONFIG) + pkg_get_variable(ARROW_SO_VERSION arrow so_version) + pkg_get_variable(ARROW_FULL_SO_VERSION arrow full_so_version) + endif() +endif() + +set(ARROW_ABI_VERSION ${ARROW_SO_VERSION}) + +mark_as_advanced(ARROW_ABI_VERSION + ARROW_CONFIG_SUFFIXES + ARROW_FULL_SO_VERSION + ARROW_IMPORT_LIB + ARROW_INCLUDE_DIR + ARROW_LIBS + ARROW_LIB_DIR + ARROW_SEARCH_LIB_PATH_SUFFIXES + ARROW_SHARED_IMP_LIB + ARROW_SHARED_LIB + ARROW_SO_VERSION + ARROW_STATIC_LIB + ARROW_VERSION + ARROW_VERSION_MAJOR + ARROW_VERSION_MINOR + ARROW_VERSION_PATCH) + +find_package_handle_standard_args(Arrow REQUIRED_VARS + # The first required variable is shown + # in the found message. So this list is + # not sorted alphabetically. + ARROW_INCLUDE_DIR + ARROW_LIB_DIR + ARROW_FULL_SO_VERSION + ARROW_SO_VERSION + VERSION_VAR + ARROW_VERSION) +set(ARROW_FOUND ${Arrow_FOUND}) + +if(Arrow_FOUND AND NOT Arrow_FIND_QUIETLY) + message(STATUS "Arrow version: ${ARROW_VERSION} (${ARROW_FIND_APPROACH})") + message(STATUS "Arrow SO and ABI version: ${ARROW_SO_VERSION}") + message(STATUS "Arrow full SO version: ${ARROW_FULL_SO_VERSION}") + message(STATUS "Found the Arrow core shared library: ${ARROW_SHARED_LIB}") + message(STATUS "Found the Arrow core import library: ${ARROW_IMPORT_LIB}") + message(STATUS "Found the Arrow core static library: ${ARROW_STATIC_LIB}") +endif() diff --git a/core/src/index/cmake/FindOpenBLAS.cmake b/core/src/index/cmake/FindOpenBLAS.cmake new file mode 100644 index 0000000000..f8936889da --- /dev/null +++ b/core/src/index/cmake/FindOpenBLAS.cmake @@ -0,0 +1,93 @@ + +if (OpenBLAS_FOUND) # the git version propose a OpenBLASConfig.cmake + message(STATUS "OpenBLASConfig found") + set(OpenBLAS_INCLUDE_DIR ${OpenBLAS_INCLUDE_DIRS}) +else() + message("OpenBLASConfig not found") + unset(OpenBLAS_DIR CACHE) + set(OpenBLAS_INCLUDE_SEARCH_PATHS + /usr/local/openblas/include + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + /usr/local/opt/openblas/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include + ) + + set(OpenBLAS_LIB_SEARCH_PATHS + /usr/local/openblas/lib + /lib/ + /lib/openblas-base + /lib64/ + /usr/lib + /usr/lib/openblas-base + /usr/lib64 + /usr/local/lib + /usr/local/lib64 + /usr/local/opt/openblas/lib + /opt/OpenBLAS/lib + $ENV{OpenBLAS} + $ENV{OpenBLAS}/lib + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/lib + ) + set(DEFAULT_OpenBLAS_LIB_PATH + /usr/local/openblas/lib + ${OPENBLAS_PREFIX}/lib) + + message("DEFAULT_OpenBLAS_LIB_PATH: ${DEFAULT_OpenBLAS_LIB_PATH}") + find_path(OpenBLAS_INCLUDE_DIR NAMES openblas_config.h lapacke.h PATHS ${OpenBLAS_INCLUDE_SEARCH_PATHS}) + find_library(OpenBLAS_LIB NAMES openblas PATHS ${DEFAULT_OpenBLAS_LIB_PATH} NO_DEFAULT_PATH) + find_library(OpenBLAS_LIB NAMES openblas PATHS ${OpenBLAS_LIB_SEARCH_PATHS}) + # mostly for debian + find_library(Lapacke_LIB NAMES lapacke PATHS ${DEFAULT_OpenBLAS_LIB_PATH} NO_DEFAULT_PATH) + find_library(Lapacke_LIB NAMES lapacke PATHS ${OpenBLAS_LIB_SEARCH_PATHS}) + + set(OpenBLAS_FOUND ON) + + # Check include files + if(NOT OpenBLAS_INCLUDE_DIR) + set(OpenBLAS_FOUND OFF) + message(STATUS "Could not find OpenBLAS include. Turning OpenBLAS_FOUND off") + else() + message(STATUS "find OpenBLAS include:${OpenBLAS_INCLUDE_DIR} ") + endif() + + # Check libraries + if(NOT OpenBLAS_LIB) + set(OpenBLAS_FOUND OFF) + message(STATUS "Could not find OpenBLAS lib. Turning OpenBLAS_FOUND off") + else() + message(STATUS "find OpenBLAS lib:${OpenBLAS_LIB} ") + endif() + + if (OpenBLAS_FOUND) + set(FOUND_OPENBLAS "true" PARENT_SCOPE) + set(OpenBLAS_LIBRARIES ${OpenBLAS_LIB}) + STRING(REGEX REPLACE "/libopenblas.so" "" OpenBLAS_LIB_DIR ${OpenBLAS_LIBRARIES}) + message(STATUS "find OpenBLAS libraries:${OpenBLAS_LIBRARIES} ") + if (Lapacke_LIB) + set(OpenBLAS_LIBRARIES ${OpenBLAS_LIBRARIES} ${Lapacke_LIB}) + endif() + if (NOT OpenBLAS_FIND_QUIETLY) + message(STATUS "Found OpenBLAS libraries: ${OpenBLAS_LIBRARIES}") + message(STATUS "Found OpenBLAS include: ${OpenBLAS_INCLUDE_DIR}") + endif() + else() + set(FOUND_OPENBLAS "false" PARENT_SCOPE) + if (OpenBLAS_FIND_REQUIRED) + message(FATAL_ERROR "Could not find OpenBLAS") + endif() + endif() +endif() + +mark_as_advanced( + OpenBLAS_INCLUDE_DIR + OpenBLAS_LIBRARIES + OpenBLAS_LIB_DIR +) diff --git a/core/src/index/cmake/ThirdPartyPackagesCore.cmake b/core/src/index/cmake/ThirdPartyPackagesCore.cmake new file mode 100644 index 0000000000..f8e769147f --- /dev/null +++ b/core/src/index/cmake/ThirdPartyPackagesCore.cmake @@ -0,0 +1,655 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +set(KNOWHERE_THIRDPARTY_DEPENDENCIES + Arrow + FAISS + GTest + OpenBLAS + MKL + ) + +message(STATUS "Using ${KNOWHERE_DEPENDENCY_SOURCE} approach to find dependencies") + +# For each dependency, set dependency source to global default, if unset +foreach (DEPENDENCY ${KNOWHERE_THIRDPARTY_DEPENDENCIES}) + if ("${${DEPENDENCY}_SOURCE}" STREQUAL "") + set(${DEPENDENCY}_SOURCE ${KNOWHERE_DEPENDENCY_SOURCE}) + endif () +endforeach () + +macro(build_dependency DEPENDENCY_NAME) + if ("${DEPENDENCY_NAME}" STREQUAL "Arrow") + build_arrow() + elseif ("${DEPENDENCY_NAME}" STREQUAL "GTest") + build_gtest() + elseif ("${DEPENDENCY_NAME}" STREQUAL "OpenBLAS") + build_openblas() + elseif ("${DEPENDENCY_NAME}" STREQUAL "FAISS") + build_faiss() + elseif ("${DEPENDENCY_NAME}" STREQUAL "MKL") + build_mkl() + else () + message(FATAL_ERROR "Unknown thirdparty dependency to build: ${DEPENDENCY_NAME}") + endif () +endmacro() + +macro(resolve_dependency DEPENDENCY_NAME) + if (${DEPENDENCY_NAME}_SOURCE STREQUAL "AUTO") + find_package(${DEPENDENCY_NAME} MODULE) + if (NOT ${${DEPENDENCY_NAME}_FOUND}) + build_dependency(${DEPENDENCY_NAME}) + endif () + elseif (${DEPENDENCY_NAME}_SOURCE STREQUAL "BUNDLED") + build_dependency(${DEPENDENCY_NAME}) + elseif (${DEPENDENCY_NAME}_SOURCE STREQUAL "SYSTEM") + find_package(${DEPENDENCY_NAME} REQUIRED) + endif () +endmacro() + +# ---------------------------------------------------------------------- +# Identify OS +if (UNIX) + if (APPLE) + set(CMAKE_OS_NAME "osx" CACHE STRING "Operating system name" FORCE) + else (APPLE) + ## Check for Debian GNU/Linux ________________ + find_file(DEBIAN_FOUND debian_version debconf.conf + PATHS /etc + ) + if (DEBIAN_FOUND) + set(CMAKE_OS_NAME "debian" CACHE STRING "Operating system name" FORCE) + endif (DEBIAN_FOUND) + ## Check for Fedora _________________________ + find_file(FEDORA_FOUND fedora-release + PATHS /etc + ) + if (FEDORA_FOUND) + set(CMAKE_OS_NAME "fedora" CACHE STRING "Operating system name" FORCE) + endif (FEDORA_FOUND) + ## Check for RedHat _________________________ + find_file(REDHAT_FOUND redhat-release inittab.RH + PATHS /etc + ) + if (REDHAT_FOUND) + set(CMAKE_OS_NAME "redhat" CACHE STRING "Operating system name" FORCE) + endif (REDHAT_FOUND) + ## Extra check for Ubuntu ____________________ + if (DEBIAN_FOUND) + ## At its core Ubuntu is a Debian system, with + ## a slightly altered configuration; hence from + ## a first superficial inspection a system will + ## be considered as Debian, which signifies an + ## extra check is required. + find_file(UBUNTU_EXTRA legal issue + PATHS /etc + ) + if (UBUNTU_EXTRA) + ## Scan contents of file + file(STRINGS ${UBUNTU_EXTRA} UBUNTU_FOUND + REGEX Ubuntu + ) + ## Check result of string search + if (UBUNTU_FOUND) + set(CMAKE_OS_NAME "ubuntu" CACHE STRING "Operating system name" FORCE) + set(DEBIAN_FOUND FALSE) + endif (UBUNTU_FOUND) + endif (UBUNTU_EXTRA) + endif (DEBIAN_FOUND) + endif (APPLE) +endif (UNIX) + + +# ---------------------------------------------------------------------- +# thirdparty directory +set(THIRDPARTY_DIR "${INDEX_SOURCE_DIR}/thirdparty") + +# ---------------------------------------------------------------------- +# ExternalProject options + +string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE) + +set(FAISS_FLAGS "-DELPP_THREAD_SAFE -fopenmp -Werror=return-type") +set(EP_CXX_FLAGS "${FAISS_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}") +set(EP_C_FLAGS "${FAISS_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}") + +if (NOT MSVC) + # Set -fPIC on all external projects + set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -fPIC") + set(EP_C_FLAGS "${EP_C_FLAGS} -fPIC") +endif () + +# CC/CXX environment variables are captured on the first invocation of the +# builder (e.g make or ninja) instead of when CMake is invoked into to build +# directory. This leads to issues if the variables are exported in a subshell +# and the invocation of make/ninja is in distinct subshell without the same +# environment (CC/CXX). +set(EP_COMMON_TOOLCHAIN -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}) + +if (CMAKE_AR) + set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_AR=${CMAKE_AR}) +endif () + +if (CMAKE_RANLIB) + set(EP_COMMON_TOOLCHAIN ${EP_COMMON_TOOLCHAIN} -DCMAKE_RANLIB=${CMAKE_RANLIB}) +endif () + +# External projects are still able to override the following declarations. +# cmake command line will favor the last defined variable when a duplicate is +# encountered. This requires that `EP_COMMON_CMAKE_ARGS` is always the first +# argument. +set(EP_COMMON_CMAKE_ARGS + ${EP_COMMON_TOOLCHAIN} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DCMAKE_C_FLAGS=${EP_C_FLAGS} + -DCMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_C_FLAGS} + -DCMAKE_CXX_FLAGS=${EP_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS}) + +if (NOT KNOWHERE_VERBOSE_THIRDPARTY_BUILD) + set(EP_LOG_OPTIONS LOG_CONFIGURE 1 LOG_BUILD 1 LOG_INSTALL 1 LOG_DOWNLOAD 1) +else () + set(EP_LOG_OPTIONS) +endif () + +# Ensure that a default make is set +if ("${MAKE}" STREQUAL "") + if (NOT MSVC) + find_program(MAKE make) + endif () +endif () + +set(MAKE_BUILD_ARGS "-j6") + + +# ---------------------------------------------------------------------- +# Find pthreads + +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +# ---------------------------------------------------------------------- +# Versions and URLs for toolchain builds, which also can be used to configure +# offline builds + +# Read toolchain versions from cpp/thirdparty/versions.txt +file(STRINGS "${THIRDPARTY_DIR}/versions.txt" TOOLCHAIN_VERSIONS_TXT) +foreach (_VERSION_ENTRY ${TOOLCHAIN_VERSIONS_TXT}) + # Exclude comments + if (NOT _VERSION_ENTRY MATCHES "^[^#][A-Za-z0-9-_]+_VERSION=") + continue() + endif () + + string(REGEX MATCH "^[^=]*" _LIB_NAME ${_VERSION_ENTRY}) + string(REPLACE "${_LIB_NAME}=" "" _LIB_VERSION ${_VERSION_ENTRY}) + + # Skip blank or malformed lines + if (${_LIB_VERSION} STREQUAL "") + continue() + endif () + + # For debugging + #message(STATUS "${_LIB_NAME}: ${_LIB_VERSION}") + + set(${_LIB_NAME} "${_LIB_VERSION}") +endforeach () + +set(FAISS_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/faiss) + +if (DEFINED ENV{KNOWHERE_ARROW_URL}) + set(ARROW_SOURCE_URL "$ENV{KNOWHERE_ARROW_URL}") +else () + set(ARROW_SOURCE_URL + "https://github.com/apache/arrow.git" + ) +endif () + +if (DEFINED ENV{KNOWHERE_GTEST_URL}) + set(GTEST_SOURCE_URL "$ENV{KNOWHERE_GTEST_URL}") +else () + set(GTEST_SOURCE_URL + "https://github.com/google/googletest/archive/release-${GTEST_VERSION}.tar.gz") +endif () + +if (DEFINED ENV{KNOWHERE_OPENBLAS_URL}) + set(OPENBLAS_SOURCE_URL "$ENV{KNOWHERE_OPENBLAS_URL}") +else () + set(OPENBLAS_SOURCE_URL + "https://github.com/xianyi/OpenBLAS/archive/v${OPENBLAS_VERSION}.tar.gz") +endif () + +# ---------------------------------------------------------------------- +# ARROW +set(ARROW_PREFIX "${INDEX_BINARY_DIR}/arrow_ep-prefix/src/arrow_ep/cpp") + +macro(build_arrow) + message(STATUS "Building Apache ARROW-${ARROW_VERSION} from source") + set(ARROW_STATIC_LIB_NAME arrow) + set(ARROW_LIB_DIR "${ARROW_PREFIX}/lib") + set(ARROW_STATIC_LIB + "${ARROW_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${ARROW_STATIC_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + set(ARROW_INCLUDE_DIR "${ARROW_PREFIX}/include") + + set(ARROW_CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + -DARROW_BUILD_STATIC=ON + -DARROW_BUILD_SHARED=OFF + -DARROW_USE_GLOG=OFF + -DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX} + -DCMAKE_INSTALL_LIBDIR=${ARROW_LIB_DIR} + -DARROW_CUDA=OFF + -DARROW_FLIGHT=OFF + -DARROW_GANDIVA=OFF + -DARROW_GANDIVA_JAVA=OFF + -DARROW_HDFS=OFF + -DARROW_HIVESERVER2=OFF + -DARROW_ORC=OFF + -DARROW_PARQUET=OFF + -DARROW_PLASMA=OFF + -DARROW_PLASMA_JAVA_CLIENT=OFF + -DARROW_PYTHON=OFF + -DARROW_WITH_BZ2=OFF + -DARROW_WITH_ZLIB=OFF + -DARROW_WITH_LZ4=OFF + -DARROW_WITH_SNAPPY=OFF + -DARROW_WITH_ZSTD=OFF + -DARROW_WITH_BROTLI=OFF + -DCMAKE_BUILD_TYPE=Release + -DARROW_DEPENDENCY_SOURCE=BUNDLED #Build all arrow dependencies from source instead of calling find_package first + -DBOOST_SOURCE=AUTO #try to find BOOST in the system default locations and build from source if not found + ) + + externalproject_add(arrow_ep + GIT_REPOSITORY + ${ARROW_SOURCE_URL} + GIT_TAG + ${ARROW_VERSION} + GIT_SHALLOW + TRUE + SOURCE_SUBDIR + cpp + ${EP_LOG_OPTIONS} + CMAKE_ARGS + ${ARROW_CMAKE_ARGS} + BUILD_COMMAND + "" + INSTALL_COMMAND + ${MAKE} ${MAKE_BUILD_ARGS} install + BUILD_BYPRODUCTS + "${ARROW_STATIC_LIB}" + ) + + file(MAKE_DIRECTORY "${ARROW_INCLUDE_DIR}") + add_library(arrow STATIC IMPORTED) + set_target_properties(arrow + PROPERTIES IMPORTED_LOCATION "${ARROW_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${ARROW_INCLUDE_DIR}") + add_dependencies(arrow arrow_ep) + + set(JEMALLOC_PREFIX "${INDEX_BINARY_DIR}/arrow_ep-prefix/src/arrow_ep-build/jemalloc_ep-prefix/src/jemalloc_ep") + + add_custom_command(TARGET arrow_ep POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${ARROW_LIB_DIR} + COMMAND ${CMAKE_COMMAND} -E copy ${JEMALLOC_PREFIX}/lib/libjemalloc_pic.a ${ARROW_LIB_DIR} + DEPENDS ${JEMALLOC_PREFIX}/lib/libjemalloc_pic.a) + +endmacro() + +if (KNOWHERE_WITH_ARROW AND NOT TARGET arrow_ep) + + resolve_dependency(Arrow) + + link_directories(SYSTEM ${ARROW_LIB_DIR}) + include_directories(SYSTEM ${ARROW_INCLUDE_DIR}) +endif () + +# ---------------------------------------------------------------------- +# OpenBLAS +set(OPENBLAS_PREFIX "${INDEX_BINARY_DIR}/openblas_ep-prefix/src/openblas_ep") +macro(build_openblas) + message(STATUS "Building OpenBLAS-${OPENBLAS_VERSION} from source") + set(OpenBLAS_INCLUDE_DIR "${OPENBLAS_PREFIX}/include") + set(OpenBLAS_LIB_DIR "${OPENBLAS_PREFIX}/lib") + set(OPENBLAS_SHARED_LIB + "${OPENBLAS_PREFIX}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}openblas${CMAKE_SHARED_LIBRARY_SUFFIX}") + set(OPENBLAS_STATIC_LIB + "${OPENBLAS_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(OPENBLAS_CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DBUILD_STATIC_LIBS=ON + -DTARGET=CORE2 + -DDYNAMIC_ARCH=1 + -DDYNAMIC_OLDER=1 + -DUSE_THREAD=0 + -DUSE_OPENMP=0 + -DFC=gfortran + -DCC=gcc + -DINTERFACE64=0 + -DNUM_THREADS=128 + -DNO_LAPACKE=1 + "-DVERSION=${OPENBLAS_VERSION}" + "-DCMAKE_INSTALL_PREFIX=${OPENBLAS_PREFIX}" + -DCMAKE_INSTALL_LIBDIR=lib) + + externalproject_add(openblas_ep + URL + ${OPENBLAS_SOURCE_URL} + ${EP_LOG_OPTIONS} + CMAKE_ARGS + ${OPENBLAS_CMAKE_ARGS} + BUILD_COMMAND + ${MAKE} + ${MAKE_BUILD_ARGS} + BUILD_IN_SOURCE + 1 + INSTALL_COMMAND + ${MAKE} + PREFIX=${OPENBLAS_PREFIX} + install + BUILD_BYPRODUCTS + ${OPENBLAS_SHARED_LIB} + ${OPENBLAS_STATIC_LIB}) + + file(MAKE_DIRECTORY "${OpenBLAS_INCLUDE_DIR}") + add_library(openblas SHARED IMPORTED) + set_target_properties( + openblas + PROPERTIES + IMPORTED_LOCATION "${OPENBLAS_SHARED_LIB}" + LIBRARY_OUTPUT_NAME "openblas" + INTERFACE_INCLUDE_DIRECTORIES "${OpenBLAS_INCLUDE_DIR}") + add_dependencies(openblas openblas_ep) + get_target_property(OpenBLAS_INCLUDE_DIR openblas INTERFACE_INCLUDE_DIRECTORIES) + set(OpenBLAS_LIBRARIES "${OPENBLAS_SHARED_LIB}") +endmacro() + +if (KNOWHERE_WITH_OPENBLAS) + resolve_dependency(OpenBLAS) + include_directories(SYSTEM "${OpenBLAS_INCLUDE_DIR}") + link_directories(SYSTEM "${OpenBLAS_LIB_DIR}") +endif() + +# ---------------------------------------------------------------------- +# Google gtest + +macro(build_gtest) + message(STATUS "Building gtest-${GTEST_VERSION} from source") + set(GTEST_VENDORED TRUE) + set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS}") + + if (APPLE) + set(GTEST_CMAKE_CXX_FLAGS + ${GTEST_CMAKE_CXX_FLAGS} + -DGTEST_USE_OWN_TR1_TUPLE=1 + -Wno-unused-value + -Wno-ignored-attributes) + endif () + + set(GTEST_PREFIX "${INDEX_BINARY_DIR}/googletest_ep-prefix/src/googletest_ep") + set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") + set(GTEST_STATIC_LIB + "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GTEST_MAIN_STATIC_LIB + "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") + + set(GTEST_CMAKE_ARGS + ${EP_COMMON_CMAKE_ARGS} + "-DCMAKE_INSTALL_PREFIX=${GTEST_PREFIX}" + "-DCMAKE_INSTALL_LIBDIR=lib" + -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS} + -DCMAKE_BUILD_TYPE=Release) + + set(GMOCK_INCLUDE_DIR "${GTEST_PREFIX}/include") + set(GMOCK_STATIC_LIB + "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + + ExternalProject_Add(googletest_ep + URL + ${GTEST_SOURCE_URL} + BUILD_COMMAND + ${MAKE} + ${MAKE_BUILD_ARGS} + BUILD_BYPRODUCTS + ${GTEST_STATIC_LIB} + ${GTEST_MAIN_STATIC_LIB} + ${GMOCK_STATIC_LIB} + CMAKE_ARGS + ${GTEST_CMAKE_ARGS} + ${EP_LOG_OPTIONS}) + + # The include directory must exist before it is referenced by a target. + file(MAKE_DIRECTORY "${GTEST_INCLUDE_DIR}") + + add_library(gtest STATIC IMPORTED) + set_target_properties(gtest + PROPERTIES IMPORTED_LOCATION "${GTEST_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") + + add_library(gtest_main STATIC IMPORTED) + set_target_properties(gtest_main + PROPERTIES IMPORTED_LOCATION "${GTEST_MAIN_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") + + add_library(gmock STATIC IMPORTED) + set_target_properties(gmock + PROPERTIES IMPORTED_LOCATION "${GMOCK_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${GTEST_INCLUDE_DIR}") + + add_dependencies(gtest googletest_ep) + add_dependencies(gtest_main googletest_ep) + add_dependencies(gmock googletest_ep) + +endmacro() + +# if (KNOWHERE_BUILD_TESTS AND NOT TARGET googletest_ep) +if ( NOT TARGET gtest AND KNOWHERE_BUILD_TESTS ) + resolve_dependency(GTest) + + if (NOT GTEST_VENDORED) + endif () + + # TODO: Don't use global includes but rather target_include_directories + get_target_property(GTEST_INCLUDE_DIR gtest INTERFACE_INCLUDE_DIRECTORIES) + link_directories(SYSTEM "${GTEST_PREFIX}/lib") + include_directories(SYSTEM ${GTEST_INCLUDE_DIR}) +endif () + +# ---------------------------------------------------------------------- +# MKL + +macro(build_mkl) + + if (FAISS_WITH_MKL) + if (EXISTS "/proc/cpuinfo") + FILE(READ /proc/cpuinfo PROC_CPUINFO) + + SET(VENDOR_ID_RX "vendor_id[ \t]*:[ \t]*([a-zA-Z]+)\n") + STRING(REGEX MATCH "${VENDOR_ID_RX}" VENDOR_ID "${PROC_CPUINFO}") + STRING(REGEX REPLACE "${VENDOR_ID_RX}" "\\1" VENDOR_ID "${VENDOR_ID}") + + if (NOT ${VENDOR_ID} STREQUAL "GenuineIntel") + set(FAISS_WITH_MKL OFF) + endif () + endif () + + find_path(MKL_LIB_PATH + NAMES "libmkl_intel_ilp64.a" "libmkl_gnu_thread.a" "libmkl_core.a" + PATH_SUFFIXES "intel/compilers_and_libraries_${MKL_VERSION}/linux/mkl/lib/intel64/") + if (${MKL_LIB_PATH} STREQUAL "MKL_LIB_PATH-NOTFOUND") + message(FATAL_ERROR "Could not find MKL libraries") + endif () + message(STATUS "MKL lib path = ${MKL_LIB_PATH}") + + set(MKL_LIBS + ${MKL_LIB_PATH}/libmkl_intel_ilp64.a + ${MKL_LIB_PATH}/libmkl_gnu_thread.a + ${MKL_LIB_PATH}/libmkl_core.a + ) + endif () +endmacro() + +# ---------------------------------------------------------------------- +# FAISS + +macro(build_faiss) + message(STATUS "Building FAISS-${FAISS_VERSION} from source") + + set(FAISS_PREFIX "${INDEX_BINARY_DIR}/faiss_ep-prefix/src/faiss_ep") + set(FAISS_INCLUDE_DIR "${FAISS_PREFIX}/include") + set(FAISS_STATIC_LIB + "${FAISS_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}faiss${CMAKE_STATIC_LIBRARY_SUFFIX}") + + if (CCACHE_FOUND) + set(FAISS_C_COMPILER "${CCACHE_FOUND} ${CMAKE_C_COMPILER}") + if (MILVUS_GPU_VERSION) + set(FAISS_CXX_COMPILER "${CMAKE_CXX_COMPILER}") + set(FAISS_CUDA_COMPILER "${CCACHE_FOUND} ${CMAKE_CUDA_COMPILER}") + else () + set(FAISS_CXX_COMPILER "${CCACHE_FOUND} ${CMAKE_CXX_COMPILER}") + endif() + else () + set(FAISS_C_COMPILER "${CMAKE_C_COMPILER}") + set(FAISS_CXX_COMPILER "${CMAKE_CXX_COMPILER}") + endif() + + set(FAISS_CONFIGURE_ARGS + "--prefix=${FAISS_PREFIX}" + "CC=${FAISS_C_COMPILER}" + "CXX=${FAISS_CXX_COMPILER}" + "NVCC=${FAISS_CUDA_COMPILER}" + "CFLAGS=${EP_C_FLAGS}" + "CXXFLAGS=${EP_CXX_FLAGS} -mf16c -O3" + --without-python) + + if (FAISS_WITH_MKL) + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "CPPFLAGS=-DFINTEGER=long -DMKL_ILP64 -m64 -I${MKL_LIB_PATH}/../../include" + "LDFLAGS=-L${MKL_LIB_PATH}" + ) + else () + message(STATUS "Build Faiss with OpenBlas/LAPACK") + if(OpenBLAS_FOUND) + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "LDFLAGS=-L${OpenBLAS_LIB_DIR}") + else() + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "LDFLAGS=-L${OPENBLAS_PREFIX}/lib") + endif() + endif () + + if (MILVUS_GPU_VERSION) + if (NOT MILVUS_CUDA_ARCH OR MILVUS_CUDA_ARCH STREQUAL "DEFAULT") + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}" + "--with-cuda-arch=-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75" + ) + else() + STRING(REPLACE ";" " " MILVUS_CUDA_ARCH "${MILVUS_CUDA_ARCH}") + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}" + "--with-cuda-arch=${MILVUS_CUDA_ARCH}" + ) + endif () + else () + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "CPPFLAGS=-DUSE_CPU" + --without-cuda) + endif () + + message(STATUS "Building FAISS with configure args -${FAISS_CONFIGURE_ARGS}") + + if (DEFINED ENV{FAISS_SOURCE_URL}) + set(FAISS_SOURCE_URL "$ENV{FAISS_SOURCE_URL}") + externalproject_add(faiss_ep + URL + ${FAISS_SOURCE_URL} + ${EP_LOG_OPTIONS} + CONFIGURE_COMMAND + "./configure" + ${FAISS_CONFIGURE_ARGS} + BUILD_COMMAND + ${MAKE} ${MAKE_BUILD_ARGS} all + BUILD_IN_SOURCE + 1 + INSTALL_COMMAND + ${MAKE} install + BUILD_BYPRODUCTS + ${FAISS_STATIC_LIB}) + else () + externalproject_add(faiss_ep + DOWNLOAD_COMMAND + "" + SOURCE_DIR + ${FAISS_SOURCE_DIR} + ${EP_LOG_OPTIONS} + CONFIGURE_COMMAND + "./configure" + ${FAISS_CONFIGURE_ARGS} + BUILD_COMMAND + ${MAKE} ${MAKE_BUILD_ARGS} all + BUILD_IN_SOURCE + 1 + INSTALL_COMMAND + ${MAKE} install + BUILD_BYPRODUCTS + ${FAISS_STATIC_LIB}) + endif () + + if(NOT OpenBLAS_FOUND) + message("add faiss dependencies: openblas_ep") + ExternalProject_Add_StepDependencies(faiss_ep configure openblas_ep) + endif() + + file(MAKE_DIRECTORY "${FAISS_INCLUDE_DIR}") + add_library(faiss STATIC IMPORTED) + + set_target_properties( + faiss + PROPERTIES + IMPORTED_LOCATION "${FAISS_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${FAISS_INCLUDE_DIR}" + ) + if (FAISS_WITH_MKL) + set_target_properties( + faiss + PROPERTIES + INTERFACE_LINK_LIBRARIES "${MKL_LIBS}") + else () + set_target_properties( + faiss + PROPERTIES + INTERFACE_LINK_LIBRARIES "${OpenBLAS_LIBRARIES}") + endif () + + add_dependencies(faiss faiss_ep) + +endmacro() + +if (KNOWHERE_WITH_FAISS AND NOT TARGET faiss_ep) + + if (FAISS_WITH_MKL) + resolve_dependency(MKL) + else () + message("faiss with no mkl") + endif () + + resolve_dependency(FAISS) + get_target_property(FAISS_INCLUDE_DIR faiss INTERFACE_INCLUDE_DIRECTORIES) + include_directories(SYSTEM "${FAISS_INCLUDE_DIR}") + link_directories(SYSTEM ${FAISS_PREFIX}/lib/) +endif () diff --git a/core/src/index/knowhere/CMakeLists.txt b/core/src/index/knowhere/CMakeLists.txt new file mode 100644 index 0000000000..d27614e31f --- /dev/null +++ b/core/src/index/knowhere/CMakeLists.txt @@ -0,0 +1,140 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +include_directories(${INDEX_SOURCE_DIR}/knowhere) +include_directories(${INDEX_SOURCE_DIR}/thirdparty) + +if (MILVUS_SUPPORT_SPTAG) + include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService) + + set(SPTAG_SOURCE_DIR ${INDEX_SOURCE_DIR}/thirdparty/SPTAG) + file(GLOB HDR_FILES + ${SPTAG_SOURCE_DIR}/AnnService/inc/Core/*.h + ${SPTAG_SOURCE_DIR}/AnnService/inc/Core/Common/*.h + ${SPTAG_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h + ${SPTAG_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h + ${SPTAG_SOURCE_DIR}/AnnService/inc/Helper/*.h) + file(GLOB SRC_FILES + ${SPTAG_SOURCE_DIR}/AnnService/src/Core/*.cpp + ${SPTAG_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp + ${SPTAG_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp + ${SPTAG_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp + ${SPTAG_SOURCE_DIR}/AnnService/src/Helper/*.cpp) + + if (NOT TARGET SPTAGLibStatic) + add_library(SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES}) + endif () +endif () + +set(external_srcs + knowhere/common/Exception.cpp + knowhere/common/Log.cpp + knowhere/common/Timer.cpp + ) + +set(vector_index_srcs + knowhere/index/vector_index/adapter/VectorAdapter.cpp + knowhere/index/vector_index/helpers/FaissIO.cpp + knowhere/index/vector_index/helpers/IndexParameter.cpp + knowhere/index/vector_index/impl/nsg/Distance.cpp + knowhere/index/vector_index/impl/nsg/NSG.cpp + knowhere/index/vector_index/impl/nsg/NSGHelper.cpp + knowhere/index/vector_index/impl/nsg/NSGIO.cpp + knowhere/index/vector_index/ConfAdapter.cpp + knowhere/index/vector_index/ConfAdapterMgr.cpp + knowhere/index/vector_index/FaissBaseBinaryIndex.cpp + knowhere/index/vector_index/FaissBaseIndex.cpp + knowhere/index/vector_index/IndexBinaryIDMAP.cpp + knowhere/index/vector_index/IndexBinaryIVF.cpp + knowhere/index/vector_index/IndexIDMAP.cpp + knowhere/index/vector_index/IndexIVF.cpp + knowhere/index/vector_index/IndexIVFPQ.cpp + knowhere/index/vector_index/IndexIVFSQ.cpp + knowhere/index/IndexType.cpp + knowhere/index/vector_index/VecIndexFactory.cpp + knowhere/index/vector_index/IndexAnnoy.cpp + knowhere/index/vector_index/IndexRHNSW.cpp + knowhere/index/vector_index/IndexHNSW.cpp + knowhere/index/vector_index/IndexRHNSWFlat.cpp + knowhere/index/vector_index/IndexRHNSWSQ.cpp + knowhere/index/vector_index/IndexRHNSWPQ.cpp + ) + +set(vector_offset_index_srcs + knowhere/index/vector_offset_index/OffsetBaseIndex.cpp + knowhere/index/vector_offset_index/IndexIVF_NM.cpp + knowhere/index/vector_offset_index/IndexNSG_NM.cpp + ) + +if (MILVUS_SUPPORT_SPTAG) + set(vector_index_srcs + knowhere/index/vector_index/adapter/SptagAdapter.cpp + knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp + knowhere/index/vector_index/IndexSPTAG.cpp + ${vector_index_srcs} + ) +endif () + +set(depend_libs + faiss + gomp + gfortran + pthread + ) + +if (MILVUS_SUPPORT_SPTAG) + set(depend_libs + SPTAGLibStatic + ${depend_libs} + ) +endif () + + +if (NOT TARGET knowhere) + add_library( + knowhere STATIC + ${external_srcs} + ${vector_index_srcs} + ${vector_offset_index_srcs} + ) +endif () + +target_link_libraries( + knowhere + ${depend_libs} +) + +set(INDEX_INCLUDE_DIRS + ${INDEX_SOURCE_DIR}/knowhere + ${INDEX_SOURCE_DIR}/thirdparty + ${FAISS_INCLUDE_DIR} + ${OpenBLAS_INCLUDE_DIR} + ${LAPACK_INCLUDE_DIR} + ) + +if (MILVUS_SUPPORT_SPTAG) + set(INDEX_INCLUDE_DIRS + ${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService + ${INDEX_INCLUDE_DIRS} + ) +endif () + +set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE) + +# **************************** Get&Print Include Directories **************************** +get_property( dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES ) + +foreach ( dir ${dirs} ) + message( STATUS "Knowhere Current Include DIRS: " ${dir} ) +endforeach () + diff --git a/core/src/index/knowhere/knowhere/common/BinarySet.h b/core/src/index/knowhere/knowhere/common/BinarySet.h new file mode 100644 index 0000000000..90930b1d49 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/BinarySet.h @@ -0,0 +1,87 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace milvus { +namespace knowhere { + +struct Binary { + std::shared_ptr data; + int64_t size = 0; +}; +using BinaryPtr = std::shared_ptr; + +inline uint8_t* +CopyBinary(const BinaryPtr& bin) { + uint8_t* newdata = new uint8_t[bin->size]; + memcpy(newdata, bin->data.get(), bin->size); + return newdata; +} + +class BinarySet { + public: + BinaryPtr + GetByName(const std::string& name) const { + return binary_map_.at(name); + } + + void + Append(const std::string& name, BinaryPtr binary) { + binary_map_[name] = std::move(binary); + } + + void + Append(const std::string& name, std::shared_ptr data, int64_t size) { + auto binary = std::make_shared(); + binary->data = data; + binary->size = size; + binary_map_[name] = std::move(binary); + } + + // void + // Append(const std::string &name, void *data, int64_t size, ID id) { + // Binary binary; + // binary.data = data; + // binary.size = size; + // binary.id = id; + // binary_map_[name] = binary; + //} + + BinaryPtr + Erase(const std::string& name) { + BinaryPtr result = nullptr; + auto it = binary_map_.find(name); + if (it != binary_map_.end()) { + result = it->second; + binary_map_.erase(it); + } + return result; + } + + void + clear() { + binary_map_.clear(); + } + + public: + std::map binary_map_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Config.h b/core/src/index/knowhere/knowhere/common/Config.h new file mode 100644 index 0000000000..7e9161e1ae --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Config.h @@ -0,0 +1,22 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "utils/Json.h" + +namespace milvus { +namespace knowhere { + +using Config = milvus::json; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Dataset.h b/core/src/index/knowhere/knowhere/common/Dataset.h new file mode 100644 index 0000000000..2d742799f9 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Dataset.h @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace milvus { +namespace knowhere { + +using Value = std::any; +using ValuePtr = std::shared_ptr; + +class Dataset { + public: + Dataset() = default; + + template + void + Set(const std::string& k, T&& v) { + std::lock_guard lk(mutex_); + data_[k] = std::make_shared(std::forward(v)); + } + + template + T + Get(const std::string& k) { + std::lock_guard lk(mutex_); + try { + return std::any_cast(*(data_.at(k))); + } catch (...) { + throw std::logic_error("Can't find this key"); + } + } + + const std::map& + data() const { + return data_; + } + + private: + std::mutex mutex_; + std::map data_; +}; +using DatasetPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Exception.cpp b/core/src/index/knowhere/knowhere/common/Exception.cpp new file mode 100644 index 0000000000..7c379ab790 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Exception.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "Log.h" +#include "knowhere/common/Exception.h" + +namespace milvus { +namespace knowhere { + +KnowhereException::KnowhereException(std::string msg) : msg_(std::move(msg)) { +} + +KnowhereException::KnowhereException(const std::string& m, const char* funcName, const char* file, int line) { + std::string filename; + try { + size_t pos; + std::string file_path(file); + pos = file_path.find_last_of('/'); + filename = file_path.substr(pos + 1); + } catch (std::exception& e) { + LOG_KNOWHERE_ERROR_ << e.what(); + } + + int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str()); + msg_.resize(size + 1); + snprintf(&msg_[0], msg_.size(), "Error in %s at %s:%d: %s", funcName, filename.c_str(), line, m.c_str()); +} + +const char* +KnowhereException::what() const noexcept { + return msg_.c_str(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Exception.h b/core/src/index/knowhere/knowhere/common/Exception.h new file mode 100644 index 0000000000..709c3f6ec7 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Exception.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +namespace milvus { +namespace knowhere { + +class KnowhereException : public std::exception { + public: + explicit KnowhereException(std::string msg); + + KnowhereException(const std::string& msg, const char* funName, const char* file, int line); + + const char* + what() const noexcept override; + + std::string msg_; +}; + +#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what()) + +#define KNOWHERE_THROW_MSG(MSG) \ + do { \ + throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) + +#define KNOHERE_THROW_FORMAT(FMT, ...) \ + do { \ + std::string __s; \ + int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \ + __s.resize(__size + 1); \ + snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \ + throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Log.cpp b/core/src/index/knowhere/knowhere/common/Log.cpp new file mode 100644 index 0000000000..587f9fbcd1 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Log.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/common/Log.h" + +#include +#include +#include +#include + +namespace milvus { +namespace knowhere { + +std::string +LogOut(const char* pattern, ...) { + size_t len = strnlen(pattern, 1024) + 256; + auto str_p = std::make_unique(len); + memset(str_p.get(), 0, len); + + va_list vl; + va_start(vl, pattern); + vsnprintf(str_p.get(), len, pattern, vl); // NOLINT + va_end(vl); + + return std::string(str_p.get()); +} + +void +SetThreadName(const std::string& name) { + pthread_setname_np(pthread_self(), name.c_str()); +} + +std::string +GetThreadName() { + std::string thread_name = "unamed"; + char name[16]; + size_t len = 16; + auto err = pthread_getname_np(pthread_self(), name, len); + if (not err) { + thread_name = name; + } + + return thread_name; +} + +void +log_trace_(const std::string& s) { + LOG_KNOWHERE_TRACE_ << s; +} + +void +log_debug_(const std::string& s) { + LOG_KNOWHERE_DEBUG_ << s; +} + +void +log_info_(const std::string& s) { + LOG_KNOWHERE_INFO_ << s; +} + +void +log_warning_(const std::string& s) { + LOG_KNOWHERE_WARNING_ << s; +} + +void +log_error_(const std::string& s) { + LOG_KNOWHERE_ERROR_ << s; +} + +void +log_fatal_(const std::string& s) { + LOG_KNOWHERE_FATAL_ << s; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Log.h b/core/src/index/knowhere/knowhere/common/Log.h new file mode 100644 index 0000000000..db1677e320 --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Log.h @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "easyloggingpp/easylogging++.h" + +namespace milvus { +namespace knowhere { + +std::string +LogOut(const char* pattern, ...); + +void +SetThreadName(const std::string& name); + +std::string +GetThreadName(); + +void +log_trace_(const std::string&); + +void +log_debug_(const std::string&); + +void +log_info_(const std::string&); + +void +log_warning_(const std::string&); + +void +log_error_(const std::string&); + +void +log_fatal_(const std::string&); + +/* + * Please use LOG_MODULE_LEVEL_C macro in member function of class + * and LOG_MODULE_LEVEL_ macro in other functions. + */ + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define KNOWHERE_MODULE_NAME "KNOWHERE" +#define KNOWHERE_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", KNOWHERE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define KNOWHERE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", KNOWHERE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_KNOWHERE_TRACE_C LOG(TRACE) << KNOWHERE_MODULE_CLASS_FUNCTION +#define LOG_KNOWHERE_DEBUG_C LOG(DEBUG) << KNOWHERE_MODULE_CLASS_FUNCTION +#define LOG_KNOWHERE_INFO_C LOG(INFO) << KNOWHERE_MODULE_CLASS_FUNCTION +#define LOG_KNOWHERE_WARNING_C LOG(WARNING) << KNOWHERE_MODULE_CLASS_FUNCTION +#define LOG_KNOWHERE_ERROR_C LOG(ERROR) << KNOWHERE_MODULE_CLASS_FUNCTION +#define LOG_KNOWHERE_FATAL_C LOG(FATAL) << KNOWHERE_MODULE_CLASS_FUNCTION + +#define LOG_KNOWHERE_TRACE_ LOG(TRACE) << KNOWHERE_MODULE_FUNCTION +#define LOG_KNOWHERE_DEBUG_ LOG(DEBUG) << KNOWHERE_MODULE_FUNCTION +#define LOG_KNOWHERE_INFO_ LOG(INFO) << KNOWHERE_MODULE_FUNCTION +#define LOG_KNOWHERE_WARNING_ LOG(WARNING) << KNOWHERE_MODULE_FUNCTION +#define LOG_KNOWHERE_ERROR_ LOG(ERROR) << KNOWHERE_MODULE_FUNCTION +#define LOG_KNOWHERE_FATAL_ LOG(FATAL) << KNOWHERE_MODULE_FUNCTION + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Timer.cpp b/core/src/index/knowhere/knowhere/common/Timer.cpp new file mode 100644 index 0000000000..e78e39112a --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Timer.cpp @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "knowhere/common/Log.h" +#include "knowhere/common/Timer.h" + +namespace milvus { +namespace knowhere { + +TimeRecorder::TimeRecorder(std::string hdr, int64_t log_level) : header_(std::move(hdr)), log_level_(log_level) { + start_ = last_ = stdclock::now(); +} + +std::string +TimeRecorder::GetTimeSpanStr(double span) { + std::string str_sec = std::to_string(span * 0.000001) + ((span > 1000000) ? " seconds" : " second"); + std::string str_ms = std::to_string(span * 0.001) + " ms"; + + return str_sec + " [" + str_ms + "]"; +} + +void +TimeRecorder::PrintTimeRecord(const std::string& msg, double span) { + std::string str_log; + if (!header_.empty()) { + str_log += header_ + ": "; + } + str_log += msg; + str_log += " ("; + str_log += TimeRecorder::GetTimeSpanStr(span); + str_log += ")"; + + switch (log_level_) { + case 0: + std::cout << str_log << std::endl; + break; + default: + LOG_KNOWHERE_DEBUG_ << str_log; + break; + } +} + +double +TimeRecorder::RecordSection(const std::string& msg) { + stdclock::time_point curr = stdclock::now(); + double span = (std::chrono::duration(curr - last_)).count(); + last_ = curr; + + PrintTimeRecord(msg, span); + return span; +} + +double +TimeRecorder::ElapseFromBegin(const std::string& msg) { + stdclock::time_point curr = stdclock::now(); + double span = (std::chrono::duration(curr - start_)).count(); + + PrintTimeRecord(msg, span); + return span; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Timer.h b/core/src/index/knowhere/knowhere/common/Timer.h new file mode 100644 index 0000000000..79a038535a --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Timer.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +namespace milvus { +namespace knowhere { + +class TimeRecorder { + using stdclock = std::chrono::high_resolution_clock; + + public: + // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 + explicit TimeRecorder(std::string hdr, int64_t log_level = 0); + virtual ~TimeRecorder() = default; + + double + RecordSection(const std::string& msg); + + double + ElapseFromBegin(const std::string& msg); + + static std::string + GetTimeSpanStr(double span); + + private: + void + PrintTimeRecord(const std::string& msg, double span); + + private: + std::string header_; + stdclock::time_point start_; + stdclock::time_point last_; + int64_t log_level_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/common/Typedef.h b/core/src/index/knowhere/knowhere/common/Typedef.h new file mode 100644 index 0000000000..8b43a8159d --- /dev/null +++ b/core/src/index/knowhere/knowhere/common/Typedef.h @@ -0,0 +1,29 @@ + +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +namespace milvus { +namespace knowhere { + +using MetricType = std::string; +// using IndexType = std::string; +using IDType = int64_t; +using FloatType = float; +using BinaryType = uint8_t; +using GraphType = std::vector>; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/Index.h b/core/src/index/knowhere/knowhere/index/Index.h new file mode 100644 index 0000000000..e58460034e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/Index.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include + +#include "cache/DataObj.h" +#include "knowhere/common/BinarySet.h" +#include "knowhere/common/Config.h" + +namespace milvus { +namespace knowhere { + +class Index : public milvus::cache::DataObj { + public: + virtual BinarySet + Serialize(const Config& config) = 0; + + virtual void + Load(const BinarySet&) = 0; +}; + +using IndexPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/IndexType.cpp b/core/src/index/knowhere/knowhere/index/IndexType.cpp new file mode 100644 index 0000000000..ef3391a816 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/IndexType.cpp @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +/* used in 0.8.0 */ +namespace IndexEnum { +const char* INVALID = ""; +const char* INDEX_FAISS_IDMAP = "FLAT"; +const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT"; +const char* INDEX_FAISS_IVFPQ = "IVF_PQ"; +const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8"; +const char* INDEX_FAISS_IVFSQ8H = "IVF_SQ8_HYBRID"; +const char* INDEX_FAISS_BIN_IDMAP = "BIN_FLAT"; +const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT"; +const char* INDEX_NSG = "NSG"; +#ifdef MILVUS_SUPPORT_SPTAG +const char* INDEX_SPTAG_KDT_RNT = "SPTAG_KDT_RNT"; +const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT"; +#endif +const char* INDEX_HNSW = "HNSW"; +const char* INDEX_RHNSWFlat = "RHNSW_FLAT"; +const char* INDEX_RHNSWPQ = "RHNSW_PQ"; +const char* INDEX_RHNSWSQ = "RHNSW_SQ"; +const char* INDEX_ANNOY = "ANNOY"; +} // namespace IndexEnum + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/IndexType.h b/core/src/index/knowhere/knowhere/index/IndexType.h new file mode 100644 index 0000000000..41140a2e15 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/IndexType.h @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include + +namespace milvus { +namespace knowhere { + +/* used in 0.7.0 */ +enum class OldIndexType { + INVALID = 0, + FAISS_IDMAP = 1, + FAISS_IVFFLAT_CPU, + FAISS_IVFFLAT_GPU, + FAISS_IVFFLAT_MIX, // build on gpu and search on cpu + FAISS_IVFPQ_CPU, + FAISS_IVFPQ_GPU, + SPTAG_KDT_RNT_CPU, + FAISS_IVFSQ8_MIX, + FAISS_IVFSQ8_CPU, + FAISS_IVFSQ8_GPU, + FAISS_IVFSQ8_HYBRID, // only support build on gpu. + NSG_MIX, + FAISS_IVFPQ_MIX, + SPTAG_BKT_RNT_CPU, + HNSW, + ANNOY, + RHNSW_FLAT, + RHNSW_PQ, + RHNSW_SQ, + FAISS_BIN_IDMAP = 100, + FAISS_BIN_IVFLAT_CPU = 101, +}; + +using IndexType = std::string; + +/* used in 0.8.0 */ +namespace IndexEnum { +extern const char* INVALID; +extern const char* INDEX_FAISS_IDMAP; +extern const char* INDEX_FAISS_IVFFLAT; +extern const char* INDEX_FAISS_IVFPQ; +extern const char* INDEX_FAISS_IVFSQ8; +extern const char* INDEX_FAISS_IVFSQ8H; +extern const char* INDEX_FAISS_BIN_IDMAP; +extern const char* INDEX_FAISS_BIN_IVFFLAT; +extern const char* INDEX_NSG; +#ifdef MILVUS_SUPPORT_SPTAG +extern const char* INDEX_SPTAG_KDT_RNT; +extern const char* INDEX_SPTAG_BKT_RNT; +#endif +extern const char* INDEX_HNSW; +extern const char* INDEX_RHNSWFlat; +extern const char* INDEX_RHNSWPQ; +extern const char* INDEX_RHNSWSQ; +extern const char* INDEX_ANNOY; +} // namespace IndexEnum + +enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 }; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/preprocessor/Preprocessor.h b/core/src/index/knowhere/knowhere/index/preprocessor/Preprocessor.h new file mode 100644 index 0000000000..6cda87122f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/preprocessor/Preprocessor.h @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include + +#include "knowhere/common/Dataset.h" + +namespace milvus { +namespace knowhere { + +class Preprocessor { + public: + virtual DatasetPtr + Preprocess(const DatasetPtr& input) = 0; +}; + +using PreprocessorPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h new file mode 100644 index 0000000000..6ad310ac43 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndex.h @@ -0,0 +1,86 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include "faiss/utils/ConcurrentBitset.h" +#include "knowhere/index/Index.h" + +namespace milvus { +namespace knowhere { + +enum OperatorType { LT = 0, LE = 1, GT = 3, GE = 4 }; + +static std::map s_map_operator_type = { + {"LT", OperatorType::LT}, + {"LE", OperatorType::LE}, + {"GT", OperatorType::GT}, + {"GE", OperatorType::GE}, +}; + +template +struct IndexStructure { + IndexStructure() : a_(0), idx_(0) { + } + explicit IndexStructure(const T a) : a_(a), idx_(0) { + } + IndexStructure(const T a, const size_t idx) : a_(a), idx_(idx) { + } + bool + operator<(const IndexStructure& b) const { + return a_ < b.a_; + } + bool + operator<=(const IndexStructure& b) const { + return a_ <= b.a_; + } + bool + operator>(const IndexStructure& b) const { + return a_ > b.a_; + } + bool + operator>=(const IndexStructure& b) const { + return a_ >= b.a_; + } + bool + operator==(const IndexStructure& b) const { + return a_ == b.a_; + } + T a_; + size_t idx_; +}; + +template +class StructuredIndex : public Index { + public: + virtual void + Build(const size_t n, const T* values) = 0; + + virtual const faiss::ConcurrentBitsetPtr + In(const size_t n, const T* values) = 0; + + virtual const faiss::ConcurrentBitsetPtr + NotIn(const size_t n, const T* values) = 0; + + virtual const faiss::ConcurrentBitsetPtr + Range(const T value, const OperatorType op) = 0; + + virtual const faiss::ConcurrentBitsetPtr + Range(const T lower_bound_value, bool lb_inclusive, const T upper_bound_value, bool ub_inclusive) = 0; +}; + +template +using StructuredIndexPtr = std::shared_ptr>; +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat-inl.h b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat-inl.h new file mode 100644 index 0000000000..b13b3c2fd4 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat-inl.h @@ -0,0 +1,153 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include "knowhere/index/structured_index/StructuredIndexFlat.h" + +namespace milvus { +namespace knowhere { + +template +StructuredIndexFlat::StructuredIndexFlat() : is_built_(false), data_() { +} + +template +StructuredIndexFlat::StructuredIndexFlat(const size_t n, const T* values) : is_built_(false) { + Build(n, values); +} + +template +StructuredIndexFlat::~StructuredIndexFlat() { +} + +template +void +StructuredIndexFlat::Build(const size_t n, const T* values) { + data_.reserve(n); + T* p = const_cast(values); + for (size_t i = 0; i < n; ++i) { + data_.emplace_back(IndexStructure(*p++, i)); + } + is_built_ = true; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexFlat::In(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + for (size_t i = 0; i < n; ++i) { + for (const auto& index : data_) { + if (index->a_ == *(values + i)) { + bitset->set(index->idx_); + } + } + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexFlat::NotIn(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size(), 0xff); + for (size_t i = 0; i < n; ++i) { + for (const auto& index : data_) { + if (index->a_ == *(values + i)) { + bitset->clear(index->idx_); + } + } + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexFlat::Range(const T value, const OperatorType op) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + auto lb = data_.begin(); + auto ub = data_.end(); + for (; lb <= ub; lb++) { + switch (op) { + case OperatorType::LT: + if (lb < IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::LE: + if (lb <= IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::GT: + if (lb > IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + case OperatorType::GE: + if (lb >= IndexStructure(value)) { + bitset->set(lb->idx_); + } + break; + default: + KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!"); + } + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexFlat::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + if (lower_bound_value > upper_bound_value) { + std::swap(lower_bound_value, upper_bound_value); + std::swap(lb_inclusive, ub_inclusive); + } + auto lb = data_.begin(); + auto ub = data_.end(); + for (; lb <= ub; ++lb) { + if (lb_inclusive && ub_inclusive) { + if (lb >= IndexStructure(lower_bound_value) && lb <= IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else if (lb_inclusive && !ub_inclusive) { + if (lb >= IndexStructure(lower_bound_value) && lb < IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else if (!lb_inclusive && ub_inclusive) { + if (lb > IndexStructure(lower_bound_value) && lb <= IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } else { + if (lb > IndexStructure(lower_bound_value) && lb < IndexStructure(upper_bound_value)) { + bitset->set(lb->idx_); + } + } + } + return bitset; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat.h b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat.h new file mode 100644 index 0000000000..a929ae9fe0 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexFlat.h @@ -0,0 +1,80 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "knowhere/index/structured_index/StructuredIndex.h" + +namespace milvus { +namespace knowhere { + +template +class StructuredIndexFlat : public StructuredIndex { + public: + StructuredIndexFlat(); + StructuredIndexFlat(const size_t n, const T* values); + ~StructuredIndexFlat(); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Build(const size_t n, const T* values) override; + + void + build(); + + const faiss::ConcurrentBitsetPtr + In(const size_t n, const T* values) override; + + const faiss::ConcurrentBitsetPtr + NotIn(const size_t n, const T* values) override; + + const faiss::ConcurrentBitsetPtr + Range(const T value, const OperatorType op) override; + + const faiss::ConcurrentBitsetPtr + Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override; + + const std::vector>& + GetData() { + return data_; + } + + int64_t + Size() override { + return (int64_t)data_.size(); + } + + bool + IsBuilt() const { + return is_built_; + } + + private: + bool is_built_; + std::vector> data_; +}; + +template +using StructuredIndexFlatPtr = std::shared_ptr>; +} // namespace knowhere +} // namespace milvus + +#include "knowhere/index/structured_index/StructuredIndexFlat-inl.h" diff --git a/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort-inl.h b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort-inl.h new file mode 100644 index 0000000000..33b4e5e3c3 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort-inl.h @@ -0,0 +1,199 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include "knowhere/index/structured_index/StructuredIndexSort.h" + +namespace milvus { +namespace knowhere { + +template +StructuredIndexSort::StructuredIndexSort() : is_built_(false), data_() { +} + +template +StructuredIndexSort::StructuredIndexSort(const size_t n, const T* values) : is_built_(false) { + StructuredIndexSort::Build(n, values); +} + +template +StructuredIndexSort::~StructuredIndexSort() { +} + +template +void +StructuredIndexSort::Build(const size_t n, const T* values) { + data_.reserve(n); + T* p = const_cast(values); + for (size_t i = 0; i < n; ++i) { + data_.emplace_back(IndexStructure(*p++, i)); + } + build(); +} + +template +void +StructuredIndexSort::build() { + if (is_built_) + return; + if (data_.size() == 0) { + // todo: throw an exception + KNOWHERE_THROW_MSG("StructuredIndexSort cannot build null values!"); + } + std::sort(data_.begin(), data_.end()); + is_built_ = true; +} + +template +BinarySet +StructuredIndexSort::Serialize(const milvus::knowhere::Config& config) { + if (!is_built_) { + build(); + } + + auto index_data_size = data_.size() * sizeof(IndexStructure); + std::shared_ptr index_data(new uint8_t[index_data_size]); + memcpy(index_data.get(), data_.data(), index_data_size); + + std::shared_ptr index_length(new uint8_t[sizeof(size_t)]); + auto index_size = data_.size(); + memcpy(index_length.get(), &index_size, sizeof(size_t)); + + BinarySet res_set; + res_set.Append("index_data", index_data, index_data_size); + res_set.Append("index_length", index_length, sizeof(size_t)); + return res_set; +} + +template +void +StructuredIndexSort::Load(const milvus::knowhere::BinarySet& index_binary) { + try { + size_t index_size; + auto index_length = index_binary.GetByName("index_length"); + memcpy(&index_size, index_length->data.get(), (size_t)index_length->size); + + auto index_data = index_binary.GetByName("index_data"); + data_.resize(index_size); + memcpy(data_.data(), index_data->data.get(), (size_t)index_data->size); + is_built_ = true; + } catch (...) { + KNOHWERE_ERROR_MSG("StructuredIndexSort Load failed!"); + } +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexSort::In(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + for (size_t i = 0; i < n; ++i) { + auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + for (; lb < ub; ++lb) { + if (lb->a_ != *(values + i)) { + LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort::In, experted value is: " + << *(values + i) << ", but real value is: " << lb->a_; + } + bitset->set(lb->idx_); + } + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexSort::NotIn(const size_t n, const T* values) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size(), 0xff); + for (size_t i = 0; i < n; ++i) { + auto lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + auto ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(*(values + i))); + for (; lb < ub; ++lb) { + if (lb->a_ != *(values + i)) { + LOG_KNOWHERE_ERROR_ << "error happens in StructuredIndexSort::NotIn, experted value is: " + << *(values + i) << ", but real value is: " << lb->a_; + } + bitset->clear(lb->idx_); + } + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexSort::Range(const T value, const OperatorType op) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + auto lb = data_.begin(); + auto ub = data_.end(); + switch (op) { + case OperatorType::LT: + ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::LE: + ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::GT: + lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + case OperatorType::GE: + lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(value)); + break; + default: + KNOWHERE_THROW_MSG("Invalid OperatorType:" + std::to_string((int)op) + "!"); + } + for (; lb < ub; ++lb) { + bitset->set(lb->idx_); + } + return bitset; +} + +template +const faiss::ConcurrentBitsetPtr +StructuredIndexSort::Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) { + if (!is_built_) { + build(); + } + faiss::ConcurrentBitsetPtr bitset = std::make_shared(data_.size()); + if (lower_bound_value > upper_bound_value) { + std::swap(lower_bound_value, upper_bound_value); + std::swap(lb_inclusive, ub_inclusive); + } + auto lb = data_.begin(); + auto ub = data_.end(); + if (lb_inclusive) { + lb = std::lower_bound(data_.begin(), data_.end(), IndexStructure(lower_bound_value)); + } else { + lb = std::upper_bound(data_.begin(), data_.end(), IndexStructure(lower_bound_value)); + } + if (ub_inclusive) { + ub = std::upper_bound(data_.begin(), data_.end(), IndexStructure(upper_bound_value)); + } else { + ub = std::lower_bound(data_.begin(), data_.end(), IndexStructure(upper_bound_value)); + } + for (; lb < ub; ++lb) { + bitset->set(lb->idx_); + } + return bitset; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort.h b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort.h new file mode 100644 index 0000000000..35b73fe573 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/structured_index/StructuredIndexSort.h @@ -0,0 +1,80 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "knowhere/index/structured_index/StructuredIndex.h" + +namespace milvus { +namespace knowhere { + +template +class StructuredIndexSort : public StructuredIndex { + public: + StructuredIndexSort(); + StructuredIndexSort(const size_t n, const T* values); + ~StructuredIndexSort(); + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Build(const size_t n, const T* values) override; + + void + build(); + + const faiss::ConcurrentBitsetPtr + In(const size_t n, const T* values) override; + + const faiss::ConcurrentBitsetPtr + NotIn(const size_t n, const T* values) override; + + const faiss::ConcurrentBitsetPtr + Range(const T value, const OperatorType op) override; + + const faiss::ConcurrentBitsetPtr + Range(T lower_bound_value, bool lb_inclusive, T upper_bound_value, bool ub_inclusive) override; + + const std::vector>& + GetData() { + return data_; + } + + int64_t + Size() override { + return (int64_t)data_.size(); + } + + bool + IsBuilt() const { + return is_built_; + } + + private: + bool is_built_; + std::vector> data_; +}; + +template +using StructuredIndexSortPtr = std::shared_ptr>; +} // namespace knowhere +} // namespace milvus + +#include "knowhere/index/structured_index/StructuredIndexSort-inl.h" diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp new file mode 100644 index 0000000000..ef13ab9da4 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp @@ -0,0 +1,372 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/ConfAdapter.h" +#include +#include +#include +#include +#include +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#ifdef MILVUS_GPU_VERSION +#include "faiss/gpu/utils/DeviceUtils.h" +#endif + +namespace milvus { +namespace knowhere { + +static const int64_t MIN_NLIST = 1; +static const int64_t MAX_NLIST = 1LL << 20; +static const int64_t MIN_NPROBE = 1; +static const int64_t MAX_NPROBE = MAX_NLIST; +static const int64_t DEFAULT_MIN_DIM = 1; +static const int64_t DEFAULT_MAX_DIM = 32768; +static const int64_t DEFAULT_MIN_ROWS = 1; // minimum size for build index +static const int64_t DEFAULT_MAX_ROWS = 50000000; +static const std::vector METRICS{knowhere::Metric::L2, knowhere::Metric::IP}; + +#define CheckIntByRange(key, min, max) \ + if (!oricfg.contains(key) || !oricfg[key].is_number_integer() || oricfg[key].get() > max || \ + oricfg[key].get() < min) { \ + return false; \ + } + +#define CheckIntByValues(key, container) \ + if (!oricfg.contains(key) || !oricfg[key].is_number_integer()) { \ + return false; \ + } else { \ + auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get()); \ + if (finder == std::end(container)) { \ + return false; \ + } \ + } + +#define CheckStrByValues(key, container) \ + if (!oricfg.contains(key) || !oricfg[key].is_string()) { \ + return false; \ + } else { \ + auto finder = std::find(std::begin(container), std::end(container), oricfg[key].get()); \ + if (finder == std::end(container)) { \ + return false; \ + } \ + } + +bool +ConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + return true; +} + +bool +ConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + const int64_t DEFAULT_MIN_K = 1; + const int64_t DEFAULT_MAX_K = 16384; + CheckIntByRange(knowhere::meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K); + return true; +} + +int64_t +MatchNlist(int64_t size, int64_t nlist) { + const int64_t TYPICAL_COUNT = 1000000; + const int64_t PER_NLIST = 16384; + + if (nlist * TYPICAL_COUNT > size * PER_NLIST) { + // nlist is too large, adjust to a proper value + nlist = std::max(1L, size * PER_NLIST / TYPICAL_COUNT); + } + return nlist; +} + +bool +IVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST); + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + + // int64_t nlist = oricfg[knowhere::IndexParams::nlist]; + // CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS); + + // auto tune params + auto nq = oricfg[knowhere::meta::ROWS].get(); + auto nlist = oricfg[knowhere::IndexParams::nlist].get(); + oricfg[knowhere::IndexParams::nlist] = MatchNlist(nq, nlist); + + // Best Practice + // static int64_t MIN_POINTS_PER_CENTROID = 40; + // static int64_t MAX_POINTS_PER_CENTROID = 256; + // CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + int64_t max_nprobe = MAX_NPROBE; +#ifdef MILVUS_GPU_VERSION + if (mode == IndexMode::MODE_GPU) { + max_nprobe = faiss::gpu::getMaxKSelection(); + } +#endif + CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, max_nprobe); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + const int64_t DEFAULT_NBITS = 8; + oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS; + + return IVFConfAdapter::CheckTrain(oricfg, mode); +} + +bool +IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + const int64_t DEFAULT_NBITS = 8; + + oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS; + + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST); + + // int64_t nlist = oricfg[knowhere::IndexParams::nlist]; + // CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS); + + // auto tune params + oricfg[knowhere::IndexParams::nlist] = + MatchNlist(oricfg[knowhere::meta::ROWS].get(), oricfg[knowhere::IndexParams::nlist].get()); + + // Best Practice + // static int64_t MIN_POINTS_PER_CENTROID = 40; + // static int64_t MAX_POINTS_PER_CENTROID = 256; + // CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist); + + std::vector resset; + auto dimension = oricfg[knowhere::meta::DIM].get(); + IVFPQConfAdapter::GetValidMList(dimension, resset); + + CheckIntByValues(knowhere::IndexParams::m, resset); + + return true; +} + +void +IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector& resset) { + resset.clear(); + /* + * Faiss 1.6 + * Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with + * no precomputed codes. Precomputed codes supports any number of dimensions, but will involve memory overheads. + */ + static const std::vector support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1}; + static const std::vector support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1}; + + for (const auto& dimperquantizer : support_dim_per_subquantizer) { + if (!(dimension % dimperquantizer)) { + auto subquantzier_num = dimension / dimperquantizer; + auto finder = std::find(support_subquantizer.begin(), support_subquantizer.end(), subquantzier_num); + if (finder != support_subquantizer.end()) { + resset.push_back(subquantzier_num); + } + } + } +} + +bool +NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + const int64_t MIN_KNNG = 5; + const int64_t MAX_KNNG = 300; + const int64_t MIN_SEARCH_LENGTH = 10; + const int64_t MAX_SEARCH_LENGTH = 300; + const int64_t MIN_OUT_DEGREE = 5; + const int64_t MAX_OUT_DEGREE = 300; + const int64_t MIN_CANDIDATE_POOL_SIZE = 50; + const int64_t MAX_CANDIDATE_POOL_SIZE = 1000; + + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::knng, MIN_KNNG, MAX_KNNG); + CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH); + CheckIntByRange(knowhere::IndexParams::out_degree, MIN_OUT_DEGREE, MAX_OUT_DEGREE); + CheckIntByRange(knowhere::IndexParams::candidate, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE); + + // auto tune params + oricfg[knowhere::IndexParams::nlist] = MatchNlist(oricfg[knowhere::meta::ROWS].get(), 8192); + + int64_t nprobe = int(oricfg[knowhere::IndexParams::nlist].get() * 0.1); + oricfg[knowhere::IndexParams::nprobe] = nprobe < 1 ? 1 : nprobe; + + return true; +} + +bool +NSGConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MIN_SEARCH_LENGTH = 1; + static int64_t MAX_SEARCH_LENGTH = 300; + + CheckIntByRange(knowhere::IndexParams::search_length, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +HNSWConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +RHNSWFlatConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWFlatConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + std::vector resset; + auto dimension = oricfg[knowhere::meta::DIM].get(); + IVFPQConfAdapter::GetValidMList(dimension, resset); + + CheckIntByValues(knowhere::IndexParams::PQM, resset); + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWPQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +RHNSWSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +RHNSWSQConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +bool +BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static const std::vector METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD, + knowhere::Metric::TANIMOTO, knowhere::Metric::SUBSTRUCTURE, + knowhere::Metric::SUPERSTRUCTURE}; + + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + + return true; +} + +bool +BinIVFConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static const std::vector METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD, + knowhere::Metric::TANIMOTO}; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM); + CheckIntByRange(knowhere::IndexParams::nlist, MIN_NLIST, MAX_NLIST); + CheckStrByValues(knowhere::Metric::TYPE, METRICS); + + int64_t nlist = oricfg[knowhere::IndexParams::nlist]; + CheckIntByRange(knowhere::meta::ROWS, nlist, DEFAULT_MAX_ROWS); + + // Best Practice + // static int64_t MIN_POINTS_PER_CENTROID = 40; + // static int64_t MAX_POINTS_PER_CENTROID = 256; + // CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist); + + return true; +} + +bool +ANNOYConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_NTREES = 1; + // too large of n_trees takes much time, if there is real requirement, change this threshold. + static int64_t MAX_NTREES = 1024; + + CheckIntByRange(knowhere::IndexParams::n_trees, MIN_NTREES, MAX_NTREES); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + CheckIntByRange(knowhere::IndexParams::search_k, std::numeric_limits::min(), + std::numeric_limits::max()); + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h new file mode 100644 index 0000000000..506d2a308f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h @@ -0,0 +1,124 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/common/Config.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class ConfAdapter { + public: + virtual bool + CheckTrain(Config& oricfg, const IndexMode mode); + + virtual bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode); +}; +using ConfAdapterPtr = std::shared_ptr; + +class IVFConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class IVFSQConfAdapter : public IVFConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; +}; + +class IVFPQConfAdapter : public IVFConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + static void + GetValidMList(int64_t dimension, std::vector& resset); +}; + +class NSGConfAdapter : public IVFConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class BinIDMAPConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; +}; + +class BinIVFConfAdapter : public IVFConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; +}; + +class HNSWConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class ANNOYConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class RHNSWFlatConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class RHNSWPQConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class RHNSWSQConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp new file mode 100644 index 0000000000..add0d6a665 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/ConfAdapterMgr.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" + +namespace milvus { +namespace knowhere { + +ConfAdapterPtr +AdapterMgr::GetAdapter(const IndexType type) { + if (!init_) { + RegisterAdapter(); + } + + try { + return collection_.at(type)(); + } catch (...) { + KNOWHERE_THROW_MSG("Can not find confadapter: " + type); + } +} + +#define REGISTER_CONF_ADAPTER(T, TYPE, NAME) static AdapterMgr::register_t reg_##NAME##_(TYPE) + +void +AdapterMgr::RegisterAdapter() { + init_ = true; + + REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_FAISS_IDMAP, idmap_adapter); + REGISTER_CONF_ADAPTER(IVFConfAdapter, IndexEnum::INDEX_FAISS_IVFFLAT, ivf_adapter); + REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_adapter); + REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq8_adapter); + REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8H, ivfsq8h_adapter); + REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IDMAP, idmap_bin_adapter); + REGISTER_CONF_ADAPTER(BinIDMAPConfAdapter, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivf_bin_adapter); + REGISTER_CONF_ADAPTER(NSGConfAdapter, IndexEnum::INDEX_NSG, nsg_adapter); +#ifdef MILVUS_SUPPORT_SPTAG + REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_KDT_RNT, sptag_kdt_adapter); + REGISTER_CONF_ADAPTER(ConfAdapter, IndexEnum::INDEX_SPTAG_BKT_RNT, sptag_bkt_adapter); +#endif + REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexEnum::INDEX_HNSW, hnsw_adapter); + REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter); + REGISTER_CONF_ADAPTER(RHNSWFlatConfAdapter, IndexEnum::INDEX_RHNSWFlat, rhnswflat_adapter); + REGISTER_CONF_ADAPTER(RHNSWPQConfAdapter, IndexEnum::INDEX_RHNSWPQ, rhnswpq_adapter); + REGISTER_CONF_ADAPTER(RHNSWSQConfAdapter, IndexEnum::INDEX_RHNSWSQ, rhnswsq_adapter); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h new file mode 100644 index 0000000000..83b9d0f584 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/ConfAdapter.h" + +namespace milvus { +namespace knowhere { + +class AdapterMgr { + public: + template + struct register_t { + explicit register_t(const IndexType type) { + AdapterMgr::GetInstance().collection_[type] = ([] { return std::make_shared(); }); + } + }; + + static AdapterMgr& + GetInstance() { + static AdapterMgr instance; + return instance; + } + + ConfAdapterPtr + GetAdapter(const IndexType indexType); + + void + RegisterAdapter(); + + protected: + bool init_ = false; + std::unordered_map> collection_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp new file mode 100644 index 0000000000..568b70457f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +FaissBaseBinaryIndex::SerializeImpl(const IndexType& type) { + try { + faiss::IndexBinary* index = index_.get(); + + MemoryIOWriter writer; + faiss::write_index_binary(index, &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("BinaryIVF", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +FaissBaseBinaryIndex::LoadImpl(const BinarySet& index_binary, const IndexType& type) { + auto binary = index_binary.GetByName("BinaryIVF"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::IndexBinary* index = faiss::read_index_binary(&reader); + index_.reset(index); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h new file mode 100644 index 0000000000..5db3fb4982 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "knowhere/common/BinarySet.h" +#include "knowhere/common/Dataset.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class FaissBaseBinaryIndex { + protected: + explicit FaissBaseBinaryIndex(std::shared_ptr index) : index_(std::move(index)) { + } + + virtual BinarySet + SerializeImpl(const IndexType& type); + + virtual void + LoadImpl(const BinarySet& index_binary, const IndexType& type); + + public: + std::shared_ptr index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp new file mode 100644 index 0000000000..8d0e9426db --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +FaissBaseIndex::SerializeImpl(const IndexType& type) { + try { + faiss::Index* index = index_.get(); + + MemoryIOWriter writer; + faiss::write_index(index, &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + // TODO(linxj): use virtual func Name() instead of raw string. + res_set.Append("IVF", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +FaissBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) { + auto binary = binary_set.GetByName("IVF"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::Index* index = faiss::read_index(&reader); + index_.reset(index); + + SealImpl(); +} + +void +FaissBaseIndex::SealImpl() { +} + +// FaissBaseIndex::~FaissBaseIndex() {} +// +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h new file mode 100644 index 0000000000..70604ab74f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include + +#include "knowhere/common/BinarySet.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class FaissBaseIndex { + protected: + explicit FaissBaseIndex(std::shared_ptr index) : index_(std::move(index)) { + } + + virtual BinarySet + SerializeImpl(const IndexType& type); + + virtual void + LoadImpl(const BinarySet&, const IndexType& type); + + virtual void + SealImpl(); + + public: + std::shared_ptr index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp new file mode 100644 index 0000000000..d526104383 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp @@ -0,0 +1,172 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexAnnoy.h" + +#include +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +IndexAnnoy::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto metric_type_length = metric_type_.length(); + std::shared_ptr metric_type(new uint8_t[metric_type_length]); + memcpy(metric_type.get(), metric_type_.data(), metric_type_.length()); + + auto dim = Dim(); + std::shared_ptr dim_data(new uint8_t[sizeof(uint64_t)]); + memcpy(dim_data.get(), &dim, sizeof(uint64_t)); + + size_t index_length = index_->get_index_length(); + std::shared_ptr index_data(new uint8_t[index_length]); + memcpy(index_data.get(), index_->get_index(), index_length); + + BinarySet res_set; + res_set.Append("annoy_metric_type", metric_type, metric_type_length); + res_set.Append("annoy_dim", dim_data, sizeof(uint64_t)); + res_set.Append("annoy_index_data", index_data, index_length); + return res_set; +} + +void +IndexAnnoy::Load(const BinarySet& index_binary) { + auto metric_type = index_binary.GetByName("annoy_metric_type"); + metric_type_.resize(static_cast(metric_type->size)); + memcpy(metric_type_.data(), metric_type->data.get(), static_cast(metric_type->size)); + + auto dim_data = index_binary.GetByName("annoy_dim"); + uint64_t dim; + memcpy(&dim, dim_data->data.get(), static_cast(dim_data->size)); + + if (metric_type_ == Metric::L2) { + index_ = std::make_shared>(dim); + } else if (metric_type_ == Metric::IP) { + index_ = std::make_shared>(dim); + } else { + KNOWHERE_THROW_MSG("metric not supported " + metric_type_); + } + + auto index_data = index_binary.GetByName("annoy_index_data"); + char* p = nullptr; + if (!index_->load_index(reinterpret_cast(index_data->data.get()), index_data->size, &p)) { + std::string error_msg(p); + free(p); + KNOWHERE_THROW_MSG(error_msg); + } +} + +void +IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { + if (index_) { + // it is builded all + LOG_KNOWHERE_DEBUG_ << "IndexAnnoy::BuildAll: index_ has been built!"; + return; + } + + GET_TENSOR(dataset_ptr) + + metric_type_ = config[Metric::TYPE]; + if (metric_type_ == Metric::L2) { + index_ = std::make_shared>(dim); + } else if (metric_type_ == Metric::IP) { + index_ = std::make_shared>(dim); + } else { + KNOWHERE_THROW_MSG("metric not supported " + metric_type_); + } + + for (int i = 0; i < rows; ++i) { + index_->add_item(p_ids[i], static_cast(p_data) + dim * i); + } + + index_->build(config[IndexParams::n_trees].get()); +} + +DatasetPtr +IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA_DIM(dataset_ptr) + auto k = config[meta::TOPK].get(); + auto search_k = config[IndexParams::search_k].get(); + auto all_num = rows * k; + auto p_id = static_cast(malloc(all_num * sizeof(int64_t))); + auto p_dist = static_cast(malloc(all_num * sizeof(float))); + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + +#pragma omp parallel for + for (unsigned int i = 0; i < rows; ++i) { + std::vector result; + result.reserve(k); + std::vector distances; + distances.reserve(k); + index_->get_nns_by_vector(static_cast(p_data) + i * dim, k, search_k, &result, &distances, + blacklist); + + int64_t result_num = result.size(); + auto local_p_id = p_id + k * i; + auto local_p_dist = p_dist + k * i; + memcpy(local_p_id, result.data(), result_num * sizeof(int64_t)); + memcpy(local_p_dist, distances.data(), result_num * sizeof(float)); + for (; result_num < k; result_num++) { + local_p_id[result_num] = -1; + local_p_dist[result_num] = 1.0 / 0.0; + } + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexAnnoy::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->get_n_items(); +} + +int64_t +IndexAnnoy::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->get_dim(); +} + +void +IndexAnnoy::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h new file mode 100644 index 0000000000..2881203c79 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "annoy/src/annoylib.h" +#include "annoy/src/kissrandom.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IndexAnnoy : public VecIndex { + public: + IndexAnnoy() { + index_type_ = IndexEnum::INDEX_ANNOY; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("Annoy not support build item dynamically, please invoke BuildAll interface."); + } + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("Annoy not support add item dynamically, please invoke BuildAll interface."); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + MetricType metric_type_; + std::shared_ptr> index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp new file mode 100644 index 0000000000..4462121dda --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -0,0 +1,163 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexBinaryIDMAP.h" + +#include +#include +#include + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +namespace milvus { +namespace knowhere { + +BinarySet +BinaryIDMAP::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +BinaryIDMAP::Load(const BinarySet& index_binary) { + std::lock_guard lk(mutex_); + LoadImpl(index_binary, index_type_); +} + +DatasetPtr +BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr) + + auto k = config[meta::TOPK].get(); + auto elems = rows * k; + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + + return ret_ds; +} + +int64_t +BinaryIDMAP::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +BinaryIDMAP::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + + index_->add_with_ids(rows, reinterpret_cast(p_data), p_ids); +} + +void +BinaryIDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) { + // users will assign the metric type when querying + // so we let Tanimoto be the default type + constexpr faiss::MetricType metric_type = faiss::METRIC_Tanimoto; + + const char* desc = "BFlat"; + auto dim = config[meta::DIM].get(); + auto index = faiss::index_binary_factory(dim, desc, metric_type); + index_.reset(index); +} + +const uint8_t* +BinaryIDMAP::GetRawVectors() { + try { + auto file_index = dynamic_cast(index_.get()); + auto flat_index = dynamic_cast(file_index->index); + return flat_index->xb.data(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +const int64_t* +BinaryIDMAP::GetRawIds() { + try { + auto file_index = dynamic_cast(index_.get()); + return file_index->id_map.data(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA(dataset_ptr) + + std::vector new_ids(rows); + for (int i = 0; i < rows; ++i) { + new_ids[i] = i; + } + + index_->add_with_ids(rows, reinterpret_cast(p_data), new_ids.data()); +} + +void +BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, + const Config& config) { + // assign the metric type + auto bin_flat_index = dynamic_cast(index_.get())->index; + bin_flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); + + auto i_distances = reinterpret_cast(distances); + bin_flat_index->search(n, data, k, i_distances, labels, bitset_); + + // if hamming, it need transform int32 to float + if (bin_flat_index->metric_type == faiss::METRIC_Hamming) { + int64_t num = n * k; + for (int64_t i = 0; i < num; i++) { + distances[i] = static_cast(i_distances[i]); + } + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h new file mode 100644 index 0000000000..db601b8e32 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -0,0 +1,81 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { + public: + BinaryIDMAP() : FaissBaseBinaryIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP; + } + + explicit BinaryIDMAP(std::shared_ptr index) : FaissBaseBinaryIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_BIN_IDMAP; + } + + BinarySet + Serialize(const Config&) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + int64_t + IndexSize() override { + return Count() * Dim() / 8; + } + + virtual const uint8_t* + GetRawVectors(); + + virtual const int64_t* + GetRawIds(); + + protected: + virtual void + QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config); + + protected: + std::mutex mutex_; +}; + +using BinaryIDMAPPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp new file mode 100644 index 0000000000..2ed7e41047 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -0,0 +1,157 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexBinaryIVF.h" + +#include +#include + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +namespace milvus { +namespace knowhere { + +using stdclock = std::chrono::high_resolution_clock; + +BinarySet +BinaryIVF::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +BinaryIVF::Load(const BinarySet& index_binary) { + std::lock_guard lk(mutex_); + LoadImpl(index_binary, index_type_); +} + +DatasetPtr +BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA(dataset_ptr) + + try { + auto k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + + auto ret_ds = std::make_shared(); + + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +int64_t +BinaryIVF::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +BinaryIVF::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +BinaryIVF::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto bin_ivf_index = dynamic_cast(index_.get()); + auto nb = bin_ivf_index->invlists->compute_ntotal(); + auto nlist = bin_ivf_index->nlist; + auto code_size = bin_ivf_index->code_size; + + // binary ivf codes, ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + +void +BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR(dataset_ptr) + + int64_t nlist = config[IndexParams::nlist]; + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::IndexBinary* coarse_quantizer = new faiss::IndexBinaryFlat(dim, metric_type); + auto index = std::make_shared(coarse_quantizer, dim, nlist, metric_type); + index->train(rows, static_cast(p_data)); + index->add_with_ids(rows, static_cast(p_data), p_ids); + index_ = index; +} + +std::shared_ptr +BinaryIVF::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->max_codes = config["max_code"]; + return params; +} + +void +BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, + const Config& config) { + auto params = GenParams(config); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->nprobe = params->nprobe; + + stdclock::time_point before = stdclock::now(); + auto i_distances = reinterpret_cast(distances); + index_->search(n, data, k, i_distances, labels, bitset_); + + stdclock::time_point after = stdclock::now(); + double search_cost = (std::chrono::duration(after - before)).count(); + LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost + << ", quantization cost: " << faiss::indexIVF_stats.quantization_time + << ", data search cost: " << faiss::indexIVF_stats.search_time; + faiss::indexIVF_stats.quantization_time = 0; + faiss::indexIVF_stats.search_time = 0; + + // if hamming, it need transform int32 to float + if (ivf_index->metric_type == faiss::METRIC_Hamming) { + int64_t num = n * k; + for (int64_t i = 0; i < num; i++) { + distances[i] = static_cast(i_distances[i]); + } + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h new file mode 100644 index 0000000000..fe1dc94518 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseBinaryIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { + public: + BinaryIVF() : FaissBaseBinaryIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT; + } + + explicit BinaryIVF(std::shared_ptr index) : FaissBaseBinaryIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_BIN_IVFFLAT; + } + + BinarySet + Serialize(const Config& config) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { + Train(dataset_ptr, config); + } + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override { + KNOWHERE_THROW_MSG("not support yet"); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("AddWithoutIds is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + protected: + virtual std::shared_ptr + GenParams(const Config& config); + + virtual void + QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config); + + protected: + std::mutex mutex_; +}; + +using BinaryIVFIndexPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp new file mode 100644 index 0000000000..601c3fb715 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp @@ -0,0 +1,222 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexHNSW.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "hnswlib/hnswalg.h" +#include "hnswlib/space_ip.h" +#include "hnswlib/space_l2.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +// void +// normalize_vector(float* data, float* norm_array, size_t dim) { +// float norm = 0.0f; +// for (int i = 0; i < dim; i++) norm += data[i] * data[i]; +// norm = 1.0f / (sqrtf(norm) + 1e-30f); +// for (int i = 0; i < dim; i++) norm_array[i] = data[i] * norm; +// } + +BinarySet +IndexHNSW::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + index_->saveIndex(writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("HNSW", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW::Load(const BinarySet& index_binary) { + try { + auto binary = index_binary.GetByName("HNSW"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + hnswlib::SpaceInterface* space = nullptr; + index_ = std::make_shared>(space); + index_->loadIndex(reader); + + normalize = index_->metric_type_ == 1; // 1 == InnerProduct + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + auto dim = dataset_ptr->Get(meta::DIM); + auto rows = dataset_ptr->Get(meta::ROWS); + + hnswlib::SpaceInterface* space; + std::string metric_type = config[Metric::TYPE]; + if (metric_type == Metric::L2) { + space = new hnswlib::L2Space(dim); + } else if (metric_type == Metric::IP) { + space = new hnswlib::InnerProductSpace(dim); + normalize = true; + } else { + KNOWHERE_THROW_MSG("Metric type not supported: " + metric_type); + } + index_ = std::make_shared>(space, rows, config[IndexParams::M].get(), + config[IndexParams::efConstruction].get()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + + GET_TENSOR_DATA_ID(dataset_ptr) + + // if (normalize) { + // std::vector ep_norm_vector(Dim()); + // normalize_vector((float*)(p_data), ep_norm_vector.data(), Dim()); + // index_->addPoint((void*)(ep_norm_vector.data()), p_ids[0]); + // #pragma omp parallel for + // for (int i = 1; i < rows; ++i) { + // std::vector norm_vector(Dim()); + // normalize_vector((float*)(p_data + Dim() * i), norm_vector.data(), Dim()); + // index_->addPoint((void*)(norm_vector.data()), p_ids[i]); + // } + // } else { + // index_->addPoint((void*)(p_data), p_ids[0]); + // #pragma omp parallel for + // for (int i = 1; i < rows; ++i) { + // index_->addPoint((void*)(p_data + Dim() * i), p_ids[i]); + // } + // } + + index_->addPoint(p_data, p_ids[0]); +#pragma omp parallel for + for (int i = 1; i < rows; ++i) { + faiss::BuilderSuspend::check_wait(); + index_->addPoint((reinterpret_cast(p_data) + Dim() * i), p_ids[i]); + } +} + +DatasetPtr +IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + GET_TENSOR_DATA(dataset_ptr) + + size_t k = config[meta::TOPK].get(); + size_t id_size = sizeof(int64_t) * k; + size_t dist_size = sizeof(float) * k; + auto p_id = static_cast(malloc(id_size * rows)); + auto p_dist = static_cast(malloc(dist_size * rows)); + + index_->setEf(config[IndexParams::ef]); + + using P = std::pair; + auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; }; + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); +#pragma omp parallel for + for (unsigned int i = 0; i < rows; ++i) { + std::vector

ret; + const float* single_query = reinterpret_cast(p_data) + i * Dim(); + + // if (normalize) { + // std::vector norm_vector(Dim()); + // normalize_vector((float*)(single_query), norm_vector.data(), Dim()); + // ret = index_->searchKnn((float*)(norm_vector.data()), config[meta::TOPK].get(), compare); + // } else { + // ret = index_->searchKnn((float*)single_query, config[meta::TOPK].get(), compare); + // } + ret = index_->searchKnn(single_query, k, compare, blacklist); + + while (ret.size() < k) { + ret.emplace_back(std::make_pair(-1, -1)); + } + std::vector dist; + std::vector ids; + + if (normalize) { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return float(1 - e.first); }); + } else { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return e.first; }); + } + std::transform(ret.begin(), ret.end(), std::back_inserter(ids), + [](const std::pair& e) { return e.second; }); + + memcpy(p_dist + i * k, dist.data(), dist_size); + memcpy(p_id + i * k, ids.data(), id_size); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexHNSW::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->cur_element_count; +} + +int64_t +IndexHNSW::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return (*static_cast(index_->dist_func_param_)); +} + +void +IndexHNSW::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h new file mode 100644 index 0000000000..d7b97bf468 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h @@ -0,0 +1,67 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "hnswlib/hnswlib.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IndexHNSW : public VecIndex { + public: + IndexHNSW() { + index_type_ = IndexEnum::INDEX_HNSW; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + bool normalize = false; + std::mutex mutex_; + std::shared_ptr> index_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp new file mode 100644 index 0000000000..302dae78ae --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -0,0 +1,234 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "knowhere/index/vector_index/IndexIDMAP.h" + +#include +#include +#include +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +BinarySet +IDMAP::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +IDMAP::Load(const BinarySet& binary_set) { + std::lock_guard lk(mutex_); + LoadImpl(binary_set, index_type_); +} + +void +IDMAP::Train(const DatasetPtr& dataset_ptr, const Config& config) { + // users will assign the metric type when querying + // so we let L2 be the default type + constexpr faiss::MetricType metric_type = faiss::METRIC_L2; + + const char* desc = "IDMap,Flat"; + auto dim = config[meta::DIM].get(); + auto index = faiss::index_factory(dim, desc, metric_type); + index_.reset(index); +} + +void +IDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + index_->add_with_ids(rows, reinterpret_cast(p_data), p_ids); +} + +void +IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::TENSOR); + + // TODO: caiyd need check + std::vector new_ids(rows); + for (int i = 0; i < rows; ++i) { + new_ids[i] = i; + } + + index_->add_with_ids(rows, reinterpret_cast(p_data), new_ids.data()); +} + +DatasetPtr +IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr) + + auto k = config[meta::TOPK].get(); + auto elems = rows * k; + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +#if 0 +DatasetPtr +IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + // GETTENSOR(dataset) + auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + + int64_t k = config[meta::TOPK].get(); + auto elems = rows * k; + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + // todo: enable search by id (zhiru) + // auto blacklist = dataset_ptr->Get("bitset"); + // index_->searchById(rows, (float*)p_data, config[meta::TOPK].get(), p_dist, p_id, blacklist); + index_->search_by_id(rows, p_data, k, p_dist, p_id, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} +#endif + +int64_t +IDMAP::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IDMAP::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +VecIndexPtr +IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } +#else + KNOWHERE_THROW_MSG("Calling IDMAP::CopyCpuToGpu when we are using CPU version"); +#endif +} + +const float* +IDMAP::GetRawVectors() { + try { + auto file_index = dynamic_cast(index_.get()); + auto flat_index = dynamic_cast(file_index->index); + return flat_index->xb.data(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +const int64_t* +IDMAP::GetRawIds() { + try { + auto file_index = dynamic_cast(index_.get()); + return file_index->id_map.data(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +#if 0 +DatasetPtr +IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + // GETTENSOR(dataset) + // auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + auto elems = dataset_ptr->Get(meta::DIM); + + size_t p_x_size = sizeof(float) * elems; + auto p_x = (float*)malloc(p_x_size); + + index_->get_vector_by_id(1, p_data, p_x, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::TENSOR, p_x); + return ret_ds; +} +#endif + +void +IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + // assign the metric type + auto flat_index = dynamic_cast(index_.get())->index; + flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); + index_->search(n, data, k, distances, labels, bitset_); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h new file mode 100644 index 0000000000..ece257e274 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h @@ -0,0 +1,92 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IDMAP : public VecIndex, public FaissBaseIndex { + public: + IDMAP() : FaissBaseIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_IDMAP; + } + + explicit IDMAP(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IDMAP; + } + + BinarySet + Serialize(const Config&) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + +#if 0 + DatasetPtr + QueryById(const DatasetPtr& dataset, const Config& config) override; +#endif + + int64_t + Count() override; + + int64_t + Dim() override; + + int64_t + IndexSize() override { + return Count() * Dim() * sizeof(FloatType); + } + +#if 0 + DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) override; +#endif + + VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&); + + virtual const float* + GetRawVectors(); + + virtual const int64_t* + GetRawIds(); + + protected: + virtual void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + + protected: + std::mutex mutex_; +}; + +using IDMAPPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp new file mode 100644 index 0000000000..628aef4b49 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp @@ -0,0 +1,349 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +using stdclock = std::chrono::high_resolution_clock; + +BinarySet +IVF::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +IVF::Load(const BinarySet& binary_set) { + std::lock_guard lk(mutex_); + LoadImpl(binary_set, index_type_); +} + +void +IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto nlist = config[IndexParams::nlist].get(); + index_ = std::shared_ptr(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type)); + index_->train(rows, reinterpret_cast(p_data)); +} + +void +IVF::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + index_->add_with_ids(rows, reinterpret_cast(p_data), p_ids); +} + +void +IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA(dataset_ptr) + index_->add(rows, reinterpret_cast(p_data)); +} + +DatasetPtr +IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA(dataset_ptr) + + try { + auto k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + + // std::stringstream ss_res_id, ss_res_dist; + // for (int i = 0; i < 10; ++i) { + // printf("%llu", p_id[i]); + // printf("\n"); + // printf("%.6f", p_dist[i]); + // printf("\n"); + // ss_res_id << p_id[i] << " "; + // ss_res_dist << p_dist[i] << " "; + // } + // std::cout << std::endl << "after search: " << std::endl; + // std::cout << ss_res_id.str() << std::endl; + // std::cout << ss_res_dist.str() << std::endl << std::endl; + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +#if 0 +DatasetPtr +IVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + + try { + int64_t k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + // todo: enable search by id (zhiru) + // auto blacklist = dataset_ptr->Get("bitset"); + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->search_by_id(rows, p_data, k, p_dist, p_id, bitset_); + + // std::stringstream ss_res_id, ss_res_dist; + // for (int i = 0; i < 10; ++i) { + // printf("%llu", res_ids[i]); + // printf("\n"); + // printf("%.6f", res_dis[i]); + // printf("\n"); + // ss_res_id << res_ids[i] << " "; + // ss_res_dist << res_dis[i] << " "; + // } + // std::cout << std::endl << "after search: " << std::endl; + // std::cout << ss_res_id.str() << std::endl; + // std::cout << ss_res_dist.str() << std::endl << std::endl; + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto p_data = dataset_ptr->Get(meta::IDS); + auto elems = dataset_ptr->Get(meta::DIM); + + try { + size_t p_x_size = sizeof(float) * elems; + auto p_x = (float*)malloc(p_x_size); + + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->get_vector_by_id(1, p_data, p_x, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::TENSOR, p_x); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} +#endif + +int64_t +IVF::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IVF::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +IVF::Seal() { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + SealImpl(); +} + +void +IVF::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivf_index = dynamic_cast(index_.get()); + auto nb = ivf_index->invlists->compute_ntotal(); + auto nlist = ivf_index->nlist; + auto code_size = ivf_index->code_size; + // ivf codes, ivf ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + +VecIndexPtr +IVF::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } + +#else + KNOWHERE_THROW_MSG("Calling IVF::CopyCpuToGpu when we are using CPU version"); +#endif +} + +void +IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) { + int64_t K = k + 1; + auto ntotal = Count(); + + size_t dim = config[meta::DIM]; + auto batch_size = 1000; + auto tail_batch_size = ntotal % batch_size; + auto batch_search_count = ntotal / batch_size; + auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1; + + std::vector res_dis(K * batch_size); + graph.resize(ntotal); + GraphType res_vec(total_search_count); + for (int i = 0; i < total_search_count; ++i) { + // it is usually used in NSG::train, to check BuilderSuspend + faiss::BuilderSuspend::check_wait(); + + auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size; + + auto& res = res_vec[i]; + res.resize(K * b_size); + + const float* xq = data + batch_size * dim * i; + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + + for (int j = 0; j < b_size; ++j) { + auto& node = graph[batch_size * i + j]; + node.resize(k); + auto start_pos = j * K + 1; + for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) { + node[m] = res[cursor]; + } + } + } +} + +std::shared_ptr +IVF::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->max_codes = config["max_codes"]; + return params; +} + +void +IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + auto params = GenParams(config); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->nprobe = params->nprobe; + stdclock::time_point before = stdclock::now(); + if (params->nprobe > 1 && n <= 4) { + ivf_index->parallel_mode = 1; + } else { + ivf_index->parallel_mode = 0; + } + ivf_index->search(n, data, k, distances, labels, bitset_); + stdclock::time_point after = stdclock::now(); + double search_cost = (std::chrono::duration(after - before)).count(); + LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost + << ", quantization cost: " << faiss::indexIVF_stats.quantization_time + << ", data search cost: " << faiss::indexIVF_stats.search_time; + faiss::indexIVF_stats.quantization_time = 0; + faiss::indexIVF_stats.search_time = 0; +} + +void +IVF::SealImpl() { +#ifdef MILVUS_GPU_VERSION + faiss::Index* index = index_.get(); + auto idx = dynamic_cast(index); + if (idx != nullptr) { + idx->to_readonly(); + } +#endif +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h new file mode 100644 index 0000000000..ccb49aaa8e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h @@ -0,0 +1,101 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include + +#include "knowhere/common/Typedef.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IVF : public VecIndex, public FaissBaseIndex { + public: + IVF() : FaissBaseIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + explicit IVF(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + BinarySet + Serialize(const Config&) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + +#if 0 + DatasetPtr + QueryById(const DatasetPtr& dataset, const Config& config) override; +#endif + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + +#if 0 + DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) override; +#endif + + virtual void + Seal(); + + virtual VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&); + + virtual void + GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config); + + protected: + virtual std::shared_ptr + GenParams(const Config&); + + virtual void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + + void + SealImpl() override; + + protected: + std::mutex mutex_; +}; + +using IVFPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp new file mode 100644 index 0000000000..4865a85b98 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#endif + +namespace milvus { +namespace knowhere { + +void +IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + index_ = std::shared_ptr(new faiss::IndexIVFPQ( + coarse_quantizer, dim, config[IndexParams::nlist].get(), config[IndexParams::m].get(), + config[IndexParams::nbits].get(), metric_type)); + + index_->train(rows, reinterpret_cast(p_data)); +} + +VecIndexPtr +IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } +#else + KNOWHERE_THROW_MSG("Calling IVFPQ::CopyCpuToGpu when we are using CPU version"); +#endif +} + +std::shared_ptr +IVFPQ::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->scan_table_threshold = config["scan_table_threhold"] + // params->polysemous_ht = config["polysemous_ht"] + // params->max_codes = config["max_codes"] + + return params; +} + +void +IVFPQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfpq_index = dynamic_cast(index_.get()); + auto nb = ivfpq_index->invlists->compute_ntotal(); + auto code_size = ivfpq_index->code_size; + auto pq = ivfpq_index->pq; + auto nlist = ivfpq_index->nlist; + auto d = ivfpq_index->d; + + // ivf codes, ivf ids and quantizer + auto capacity = nb * code_size + nb * sizeof(int64_t) + nlist * d * sizeof(float); + auto centroid_table = pq.M * pq.ksub * pq.dsub * sizeof(float); + auto precomputed_table = nlist * pq.M * pq.ksub * sizeof(float); + if (precomputed_table > ivfpq_index->precomputed_table_max_bytes) { + // will not precompute table + precomputed_table = 0; + } + index_size_ = capacity + centroid_table + precomputed_table; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h new file mode 100644 index 0000000000..aed4072099 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/IndexIVF.h" + +namespace milvus { +namespace knowhere { + +class IVFPQ : public IVF { + public: + IVFPQ() : IVF() { + index_type_ = IndexEnum::INDEX_FAISS_IVFPQ; + } + + explicit IVFPQ(std::shared_ptr index) : IVF(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFPQ; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&) override; + + void + UpdateIndexSize() override; + + protected: + std::shared_ptr + GenParams(const Config& config) override; +}; + +using IVFPQPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp new file mode 100644 index 0000000000..7d81f21009 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#include +#endif +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +void +IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + // std::stringstream index_type; + // index_type << "IVF" << config[IndexParams::nlist] << "," + // << "SQ" << config[IndexParams::nbits]; + // index_ = std::shared_ptr( + // faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get()))); + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + index_ = std::shared_ptr(new faiss::IndexIVFScalarQuantizer( + coarse_quantizer, dim, config[IndexParams::nlist].get(), faiss::QuantizerType::QT_8bit, metric_type)); + + index_->train(rows, reinterpret_cast(p_data)); +} + +VecIndexPtr +IVFSQ::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } +#else + KNOWHERE_THROW_MSG("Calling IVFSQ::CopyCpuToGpu when we are using CPU version"); +#endif +} + +void +IVFSQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfsq_index = dynamic_cast(index_.get()); + auto nb = ivfsq_index->invlists->compute_ntotal(); + auto code_size = ivfsq_index->code_size; + auto nlist = ivfsq_index->nlist; + auto d = ivfsq_index->d; + // ivf codes, ivf ids, sq trained vectors and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h new file mode 100644 index 0000000000..0c33eda569 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/IndexIVF.h" + +namespace milvus { +namespace knowhere { + +class IVFSQ : public IVF { + public: + IVFSQ() : IVF() { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + explicit IVFSQ(std::shared_ptr index) : IVF(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&) override; + + void + UpdateIndexSize() override; +}; + +using IVFSQPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp new file mode 100644 index 0000000000..90ba063eeb --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -0,0 +1,181 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexNSG.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/impl/nsg/NSG.h" +#include "knowhere/index/vector_index/impl/nsg/NSGIO.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#endif + +namespace milvus { +namespace knowhere { + +BinarySet +NSG::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + std::lock_guard lk(mutex_); + impl::NsgIndex* index = index_.get(); + + MemoryIOWriter writer; + impl::write_index(index, writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("NSG", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG::Load(const BinarySet& index_binary) { + try { + std::lock_guard lk(mutex_); + auto binary = index_binary.GetByName("NSG"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + auto index = impl::read_index(reader); + index_.reset(index); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA_DIM(dataset_ptr) + + try { + auto elems = rows * config[meta::TOPK].get(); + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + + impl::SearchParams s_params; + s_params.search_length = config[IndexParams::search_length]; + s_params.k = config[meta::TOPK]; + { + std::lock_guard lk(mutex_); + index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get(), p_dist, p_id, + s_params, blacklist); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) { + auto idmap = std::make_shared(); + idmap->Train(dataset_ptr, config); + idmap->AddWithoutIds(dataset_ptr, config); + impl::Graph knng; + const float* raw_data = idmap->GetRawVectors(); + const int64_t k = config[IndexParams::knng].get(); +#ifdef MILVUS_GPU_VERSION + const int64_t device_id = config[knowhere::meta::DEVICEID].get(); + if (device_id == -1) { + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); + } else { + auto gpu_idx = cloner::CopyCpuToGpu(idmap, device_id, config); + auto gpu_idmap = std::dynamic_pointer_cast(gpu_idx); + gpu_idmap->GenGraph(raw_data, k, knng, config); + } +#else + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); +#endif + + impl::BuildParams b_params; + b_params.candidate_pool_size = config[IndexParams::candidate]; + b_params.out_degree = config[IndexParams::out_degree]; + b_params.search_length = config[IndexParams::search_length]; + + GET_TENSOR(dataset_ptr) + + impl::NsgIndex::Metric_Type metric; + auto metric_str = config[Metric::TYPE].get(); + if (metric_str == knowhere::Metric::IP) { + metric = impl::NsgIndex::Metric_Type::Metric_Type_IP; + } else if (metric_str == knowhere::Metric::L2) { + metric = impl::NsgIndex::Metric_Type::Metric_Type_L2; + } else { + KNOWHERE_THROW_MSG("Metric is not supported"); + } + + index_ = std::make_shared(dim, rows, metric); + index_->SetKnnGraph(knng); + index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params); +} + +int64_t +NSG::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +NSG::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->dimension; +} + +void +NSG::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->GetSize(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h new file mode 100644 index 0000000000..9248184993 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.h @@ -0,0 +1,79 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +namespace impl { +class NsgIndex; +} + +class NSG : public VecIndex { + public: + explicit NSG(const int64_t gpu_num = -1) : gpu_(gpu_num) { + if (gpu_ >= 0) { + index_mode_ = IndexMode::MODE_GPU; + } + index_type_ = IndexEnum::INDEX_NSG; + } + + BinarySet + Serialize(const Config&) override; + + void + Load(const BinarySet&) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { + Train(dataset_ptr, config); + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Addwithoutids is not supported"); + } + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + private: + std::mutex mutex_; + int64_t gpu_; + std::shared_ptr index_; +}; + +using NSGIndexPtr = std::shared_ptr(); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp new file mode 100644 index 0000000000..b9c62d8e19 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSW.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +IndexRHNSW::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + writer.name = this->index_type() + "_Index"; + faiss::write_index(index_.get(), &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSW::Load(const BinarySet& index_binary) { + try { + MemoryIOReader reader; + reader.name = this->index_type() + "_Index"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = static_cast(binary->size); + reader.data_ = binary->data.get(); + + auto idx = faiss::read_index(&reader); + index_.reset(idx); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) { + KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of Train, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} + +void +IndexRHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + GET_TENSOR_DATA(dataset_ptr) + + index_->add(rows, reinterpret_cast(p_data)); +} + +DatasetPtr +IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + GET_TENSOR_DATA(dataset_ptr) + + auto k = config[meta::TOPK].get(); + int64_t id_size = sizeof(int64_t) * k; + int64_t dist_size = sizeof(float) * k; + auto p_id = static_cast(malloc(id_size * rows)); + auto p_dist = static_cast(malloc(dist_size * rows)); + for (auto i = 0; i < k * rows; ++i) { + p_id[i] = -1; + p_dist[i] = -1; + } + + auto real_index = dynamic_cast(index_.get()); + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + + real_index->hnsw.efSearch = (config[IndexParams::ef]); + real_index->search(rows, reinterpret_cast(p_data), k, p_dist, p_id, blacklist); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexRHNSW::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IndexRHNSW::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +IndexRHNSW::UpdateIndexSize() { + KNOWHERE_THROW_MSG( + "IndexRHNSW has no implementation of UpdateIndexSize, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} + +/* +BinarySet +IndexRHNSW::SerializeImpl(const milvus::knowhere::IndexType &type) { return BinarySet(); } + +void +IndexRHNSW::SealImpl() {} + +void +IndexRHNSW::LoadImpl(const milvus::knowhere::BinarySet &, const milvus::knowhere::IndexType &type) {} +*/ + +void +IndexRHNSW::AddWithoutIds(const milvus::knowhere::DatasetPtr& dataset, const milvus::knowhere::Config& config) { + KNOWHERE_THROW_MSG("IndexRHNSW has no implementation of AddWithoutIds, please use IndexRHNSW(Flat/SQ/PQ) instead!"); +} +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h new file mode 100644 index 0000000000..7c5a4a6eaf --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSW.h @@ -0,0 +1,67 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/FaissBaseIndex.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +#include +#include "faiss/IndexRHNSW.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSW : public VecIndex, public FaissBaseIndex { + public: + IndexRHNSW() : FaissBaseIndex(nullptr) { + index_type_ = IndexEnum::INVALID; + } + + explicit IndexRHNSW(std::shared_ptr index) : FaissBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INVALID; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; +}; +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp new file mode 100644 index 0000000000..ee81da95b4 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWFlat.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, milvus::knowhere::MetricType metric) { + faiss::MetricType mt = + metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT; + index_ = std::shared_ptr(new faiss::IndexRHNSWFlat(d, M, mt)); +} + +BinarySet +IndexRHNSWFlat::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = this->index_type() + "_Data"; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + auto storage_index = dynamic_cast(real_idx->storage); + faiss::write_index(storage_index, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = this->index_type() + "_Data"; + auto binary = index_binary.GetByName(reader.name); + + reader.total = static_cast(binary->size); + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + + auto idx = new faiss::IndexRHNSWFlat(int(dim), config[IndexParams::M], metric_type); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, reinterpret_cast(p_data)); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWFlat::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h new file mode 100644 index 0000000000..37dab79cdf --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWFlat : public IndexRHNSW { + public: + IndexRHNSWFlat() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWFlat; + } + + explicit IndexRHNSWFlat(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWFlat; + } + + IndexRHNSWFlat(int d, int M, MetricType metric = Metric::L2); + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp new file mode 100644 index 0000000000..cc2e8f020f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp @@ -0,0 +1,102 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWPQ.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M) { + index_ = std::shared_ptr(new faiss::IndexRHNSWPQ(d, pq_m, M)); +} + +BinarySet +IndexRHNSWPQ::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = QUANTIZATION_DATA; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + faiss::write_index(real_idx->storage, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = QUANTIZATION_DATA; + auto binary = index_binary.GetByName(reader.name); + + reader.total = static_cast(binary->size); + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + + auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, reinterpret_cast(p_data)); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWPQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h new file mode 100644 index 0000000000..5e98129489 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.h @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWPQ : public IndexRHNSW { + public: + IndexRHNSWPQ() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWPQ; + } + + explicit IndexRHNSWPQ(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWPQ; + } + + IndexRHNSWPQ(int d, int pq_m, int M); + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; + + private: +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp new file mode 100644 index 0000000000..e352d6fa48 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/IndexRHNSWSQ.h" + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +IndexRHNSWSQ::IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, milvus::knowhere::MetricType metric) { + faiss::MetricType mt = + metric == Metric::L2 ? faiss::MetricType::METRIC_L2 : faiss::MetricType::METRIC_INNER_PRODUCT; + index_ = std::shared_ptr(new faiss::IndexRHNSWSQ(d, qtype, M, mt)); +} + +BinarySet +IndexRHNSWSQ::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + auto res_set = IndexRHNSW::Serialize(config); + MemoryIOWriter writer; + writer.name = QUANTIZATION_DATA; + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Serialize!"); + } + faiss::write_index(real_idx->storage, &writer); + std::shared_ptr data(writer.data_); + + res_set.Append(writer.name, data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::Load(const BinarySet& index_binary) { + try { + IndexRHNSW::Load(index_binary); + MemoryIOReader reader; + reader.name = QUANTIZATION_DATA; + auto binary = index_binary.GetByName(reader.name); + + reader.total = static_cast(binary->size); + reader.data_ = binary->data.get(); + + auto real_idx = dynamic_cast(index_.get()); + if (real_idx == nullptr) { + KNOWHERE_THROW_MSG("dynamic_cast(index_) failed during Load!"); + } + real_idx->storage = faiss::read_index(&reader); + real_idx->init_hnsw(); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + + auto idx = + new faiss::IndexRHNSWSQ(int(dim), faiss::QuantizerType::QT_8bit, config[IndexParams::M], metric_type); + idx->hnsw.efConstruction = config[IndexParams::efConstruction]; + index_ = std::shared_ptr(idx); + index_->train(rows, static_cast(p_data)); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexRHNSWSQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = dynamic_cast(index_.get())->cal_size(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h new file mode 100644 index 0000000000..04a8fe15aa --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.h @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "IndexRHNSW.h" +#include "knowhere/common/Exception.h" + +namespace milvus { +namespace knowhere { + +class IndexRHNSWSQ : public IndexRHNSW { + public: + IndexRHNSWSQ() : IndexRHNSW() { + index_type_ = IndexEnum::INDEX_RHNSWSQ; + } + + explicit IndexRHNSWSQ(std::shared_ptr index) : IndexRHNSW(std::move(index)) { + index_type_ = IndexEnum::INDEX_RHNSWSQ; + } + + IndexRHNSWSQ(int d, faiss::QuantizerType qtype, int M, MetricType metric = Metric::L2); + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + UpdateIndexSize() override; + + private: +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp new file mode 100644 index 0000000000..2dc86678f5 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp @@ -0,0 +1,241 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include + +#include +#include +#include + +#undef mkdir + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexSPTAG.h" +#include "knowhere/index/vector_index/adapter/SptagAdapter.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h" + +namespace milvus { +namespace knowhere { + +CPUSPTAGRNG::CPUSPTAGRNG(const std::string& IndexType) { + if (IndexType == "KDT") { + index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::KDT, SPTAG::VectorValueType::Float); + index_ptr_->SetParameter("DistCalcMethod", "L2"); + index_type_ = IndexEnum::INDEX_SPTAG_KDT_RNT; + } else { + index_ptr_ = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::BKT, SPTAG::VectorValueType::Float); + index_ptr_->SetParameter("DistCalcMethod", "L2"); + index_type_ = IndexEnum::INDEX_SPTAG_BKT_RNT; + } +} + +BinarySet +CPUSPTAGRNG::Serialize(const Config& config) { + std::string index_config; + std::vector index_blobs; + + std::shared_ptr> buffersize = index_ptr_->CalculateBufferSize(); + std::vector res(buffersize->size() + 1); + for (uint64_t i = 1; i < res.size(); i++) { + res[i] = new char[buffersize->at(i - 1)]; + auto ptr = &res[i][0]; + index_blobs.emplace_back(SPTAG::ByteArray((std::uint8_t*)ptr, buffersize->at(i - 1), false)); + } + + index_ptr_->SaveIndex(index_config, index_blobs); + + size_t length = index_config.length(); + char* cstr = new char[length]; + snprintf(cstr, length, "%s", index_config.c_str()); + + BinarySet binary_set; + std::shared_ptr sample; + sample.reset(static_cast(index_blobs[0].Data())); + std::shared_ptr tree; + tree.reset(static_cast(index_blobs[1].Data())); + std::shared_ptr graph; + graph.reset(static_cast(index_blobs[2].Data())); + std::shared_ptr deleteid; + deleteid.reset(static_cast(index_blobs[3].Data())); + std::shared_ptr metadata1; + metadata1.reset(static_cast(index_blobs[4].Data())); + std::shared_ptr metadata2; + metadata2.reset(static_cast(index_blobs[5].Data())); + std::shared_ptr x_cfg; + x_cfg.reset(static_cast((void*)cstr)); + + binary_set.Append("samples", sample, index_blobs[0].Length()); + binary_set.Append("tree", tree, index_blobs[1].Length()); + binary_set.Append("deleteid", deleteid, index_blobs[3].Length()); + binary_set.Append("metadata1", metadata1, index_blobs[4].Length()); + binary_set.Append("metadata2", metadata2, index_blobs[5].Length()); + binary_set.Append("config", x_cfg, length); + binary_set.Append("graph", graph, index_blobs[2].Length()); + + return binary_set; +} + +void +CPUSPTAGRNG::Load(const BinarySet& binary_set) { + std::string index_config; + std::vector index_blobs; + + auto samples = binary_set.GetByName("samples"); + index_blobs.push_back(SPTAG::ByteArray(samples->data.get(), samples->size, false)); + + auto tree = binary_set.GetByName("tree"); + index_blobs.push_back(SPTAG::ByteArray(tree->data.get(), tree->size, false)); + + auto graph = binary_set.GetByName("graph"); + index_blobs.push_back(SPTAG::ByteArray(graph->data.get(), graph->size, false)); + + auto deleteid = binary_set.GetByName("deleteid"); + index_blobs.push_back(SPTAG::ByteArray(deleteid->data.get(), deleteid->size, false)); + + auto metadata1 = binary_set.GetByName("metadata1"); + index_blobs.push_back(SPTAG::ByteArray(CopyBinary(metadata1), metadata1->size, true)); + + auto metadata2 = binary_set.GetByName("metadata2"); + index_blobs.push_back(SPTAG::ByteArray(metadata2->data.get(), metadata2->size, false)); + + auto config = binary_set.GetByName("config"); + index_config = reinterpret_cast(config->data.get()); + + index_ptr_->LoadIndex(index_config, index_blobs); +} + +void +CPUSPTAGRNG::Train(const DatasetPtr& origin, const Config& train_config) { + SetParameters(train_config); + + DatasetPtr dataset = origin; + + auto vectorset = ConvertToVectorSet(dataset); + auto metaset = ConvertToMetadataSet(dataset); + index_ptr_->BuildIndex(vectorset, metaset); +} + +void +CPUSPTAGRNG::SetParameters(const Config& config) { +#define Assign(param_name, str_name) \ + index_ptr_->SetParameter(str_name, std::to_string(build_cfg[param_name].get())) + + if (index_type_ == IndexEnum::INDEX_SPTAG_KDT_RNT) { + auto build_cfg = SPTAGParameterMgr::GetInstance().GetKDTParameters(); + + Assign("kdtnumber", "KDTNumber"); + Assign("numtopdimensionkdtsplit", "NumTopDimensionKDTSplit"); + Assign("samples", "Samples"); + Assign("tptnumber", "TPTNumber"); + Assign("tptleafsize", "TPTLeafSize"); + Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit"); + Assign("neighborhoodsize", "NeighborhoodSize"); + Assign("graphneighborhoodscale", "GraphNeighborhoodScale"); + Assign("graphcefscale", "GraphCEFScale"); + Assign("refineiterations", "RefineIterations"); + Assign("cef", "CEF"); + Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph"); + Assign("numofthreads", "NumberOfThreads"); + Assign("maxcheck", "MaxCheck"); + Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation"); + Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots"); + Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots"); + } else { + auto build_cfg = SPTAGParameterMgr::GetInstance().GetBKTParameters(); + + Assign("bktnumber", "BKTNumber"); + Assign("bktkmeansk", "BKTKMeansK"); + Assign("bktleafsize", "BKTLeafSize"); + Assign("samples", "Samples"); + Assign("tptnumber", "TPTNumber"); + Assign("tptleafsize", "TPTLeafSize"); + Assign("numtopdimensiontptsplit", "NumTopDimensionTPTSplit"); + Assign("neighborhoodsize", "NeighborhoodSize"); + Assign("graphneighborhoodscale", "GraphNeighborhoodScale"); + Assign("graphcefscale", "GraphCEFScale"); + Assign("refineiterations", "RefineIterations"); + Assign("cef", "CEF"); + Assign("maxcheckforrefinegraph", "MaxCheckForRefineGraph"); + Assign("numofthreads", "NumberOfThreads"); + Assign("maxcheck", "MaxCheck"); + Assign("thresholdofnumberofcontinuousnobetterpropagation", "ThresholdOfNumberOfContinuousNoBetterPropagation"); + Assign("numberofinitialdynamicpivots", "NumberOfInitialDynamicPivots"); + Assign("numberofotherdynamicpivots", "NumberOfOtherDynamicPivots"); + } +} + +DatasetPtr +CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) { + SetParameters(config); + + float* p_data = (float*)dataset_ptr->Get(meta::TENSOR); + for (auto i = 0; i < 10; ++i) { + for (auto j = 0; j < 10; ++j) { + std::cout << p_data[i * 10 + j] << " "; + } + std::cout << std::endl; + } + std::vector query_results = ConvertToQueryResult(dataset_ptr, config); + +#pragma omp parallel for + for (auto i = 0; i < query_results.size(); ++i) { + auto target = (float*)query_results[i].GetTarget(); + std::cout << target[0] << ", " << target[1] << ", " << target[2] << std::endl; + index_ptr_->SearchIndex(query_results[i]); + } + + return ConvertToDataset(query_results); +} + +int64_t +CPUSPTAGRNG::Count() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_ptr_->GetNumSamples(); +} + +int64_t +CPUSPTAGRNG::Dim() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_ptr_->GetFeatureDim(); +} + +void +CPUSPTAGRNG::UpdateIndexSize() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_ptr_->GetIndexSize(); +} + +// void +// CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) { +// SetParameters(add_config); +// DatasetPtr dataset = origin->Clone(); + +// // if (index_ptr_->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine +// // && preprocessor_) { +// // preprocessor_->Preprocess(dataset); +// //} + +// auto vectorset = ConvertToVectorSet(dataset); +// auto metaset = ConvertToMetadataSet(dataset); +// index_ptr_->AddIndex(vectorset, metaset); +// } + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h new file mode 100644 index 0000000000..bfb5b8a5da --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h @@ -0,0 +1,77 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include +#include +#include + +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class CPUSPTAGRNG : public VecIndex { + public: + explicit CPUSPTAGRNG(const std::string& IndexType); + + public: + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet& index_array) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { + Train(dataset_ptr, config); + } + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + void + SetParameters(const Config& config); + + private: + std::shared_ptr index_ptr_; +}; + +using CPUSPTAGRNGPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h new file mode 100644 index 0000000000..ab7440de9a --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h @@ -0,0 +1,155 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include "knowhere/common/Dataset.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Typedef.h" +#include "knowhere/index/Index.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +#define RAW_DATA "RAW_DATA" +#define QUANTIZATION_DATA "QUANTIZATION_DATA" + +class VecIndex : public Index { + public: + virtual void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { + Train(dataset_ptr, config); + Add(dataset_ptr, config); + } + + virtual void + Train(const DatasetPtr& dataset, const Config& config) = 0; + + virtual void + Add(const DatasetPtr& dataset, const Config& config) = 0; + + virtual void + AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0; + + virtual DatasetPtr + Query(const DatasetPtr& dataset, const Config& config) = 0; + +#if 0 + virtual DatasetPtr + QueryById(const DatasetPtr& dataset, const Config& config) { + return nullptr; + } +#endif + + // virtual DatasetPtr + // QueryByRange(const DatasetPtr&, const Config&) = 0; + // + // virtual MetricType + // metric_type() = 0; + + virtual int64_t + Dim() = 0; + + virtual int64_t + Count() = 0; + + virtual IndexType + index_type() const { + return index_type_; + } + + virtual IndexMode + index_mode() const { + return index_mode_; + } + +#if 0 + virtual DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) { + return nullptr; + } +#endif + + faiss::ConcurrentBitsetPtr + GetBlacklist() { + return bitset_; + } + + void + SetBlacklist(faiss::ConcurrentBitsetPtr bitset_ptr) { + bitset_ = std::move(bitset_ptr); + } + + const std::vector& + GetUids() const { + return uids_; + } + + void + SetUids(std::vector& uids) { + uids_.clear(); + uids_.swap(uids); + } + + size_t + BlacklistSize() { + if (bitset_) { + return bitset_->u8size() * sizeof(uint8_t); + } else { + return 0; + } + } + + size_t + UidsSize() { + return uids_.size() * sizeof(IDType); + } + + virtual int64_t + IndexSize() { + if (index_size_ == -1) { + KNOWHERE_THROW_MSG("Index size not set"); + } + return index_size_; + } + + void + SetIndexSize(int64_t size) { + index_size_ = size; + } + + virtual void + UpdateIndexSize() { + } + + int64_t + Size() override { + return BlacklistSize() + UidsSize() + IndexSize(); + } + + protected: + IndexType index_type_ = ""; + IndexMode index_mode_ = IndexMode::MODE_CPU; + faiss::ConcurrentBitsetPtr bitset_ = nullptr; + std::vector uids_; + int64_t index_size_ = -1; +}; + +using VecIndexPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp new file mode 100644 index 0000000000..05674e9c9e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "knowhere/index/vector_index/VecIndexFactory.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/IndexAnnoy.h" +#include "knowhere/index/vector_index/IndexBinaryIDMAP.h" +#include "knowhere/index/vector_index/IndexBinaryIVF.h" +#include "knowhere/index/vector_index/IndexHNSW.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexRHNSWFlat.h" +#include "knowhere/index/vector_index/IndexRHNSWPQ.h" +#include "knowhere/index/vector_index/IndexRHNSWSQ.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" + +#ifdef MILVUS_SUPPORT_SPTAG +#include "knowhere/index/vector_index/IndexSPTAG.h" +#endif + +#ifdef MILVUS_GPU_VERSION +#include +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +namespace milvus { +namespace knowhere { + +VecIndexPtr +VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { +#ifdef MILVUS_GPU_VERSION + auto gpu_device = -1; // TODO: remove hardcode here, get from invoker +#endif + if (type == IndexEnum::INDEX_FAISS_IDMAP) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_FAISS_IVFFLAT) { +#ifdef MILVUS_GPU_VERSION + if (mode == IndexMode::MODE_GPU) { + return std::make_shared(gpu_device); + } +#endif + return std::make_shared(); + } else if (type == IndexEnum::INDEX_FAISS_IVFPQ) { +#ifdef MILVUS_GPU_VERSION + if (mode == IndexMode::MODE_GPU) { + return std::make_shared(gpu_device); + } +#endif + return std::make_shared(); + } else if (type == IndexEnum::INDEX_FAISS_IVFSQ8) { +#ifdef MILVUS_GPU_VERSION + if (mode == IndexMode::MODE_GPU) { + return std::make_shared(gpu_device); + } +#endif + return std::make_shared(); +#ifdef MILVUS_GPU_VERSION + } else if (type == IndexEnum::INDEX_FAISS_IVFSQ8H) { + return std::make_shared(gpu_device); +#endif + } else if (type == IndexEnum::INDEX_FAISS_BIN_IDMAP) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_NSG) { + return std::make_shared(-1); +#ifdef MILVUS_SUPPORT_SPTAG + } else if (type == IndexEnum::INDEX_SPTAG_KDT_RNT) { + return std::make_shared("KDT"); + } else if (type == IndexEnum::INDEX_SPTAG_BKT_RNT) { + return std::make_shared("BKT"); +#endif + } else if (type == IndexEnum::INDEX_HNSW) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_ANNOY) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWFlat) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWPQ) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_RHNSWSQ) { + return std::make_shared(); + } else { + return nullptr; + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h new file mode 100644 index 0000000000..c96bd1dc7c --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h @@ -0,0 +1,41 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include + +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class VecIndexFactory { + private: + VecIndexFactory() = default; + VecIndexFactory(const VecIndexFactory&) = delete; + VecIndexFactory + operator=(const VecIndexFactory&) = delete; + + public: + static VecIndexFactory& + GetInstance() { + static VecIndexFactory inst; + return inst; + } + + knowhere::VecIndexPtr + CreateVecIndex(const IndexType& type, const IndexMode mode = IndexMode::MODE_CPU); +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp new file mode 100644 index 0000000000..cf07bd237d --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp @@ -0,0 +1,84 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/adapter/SptagAdapter.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +namespace milvus { +namespace knowhere { + +std::shared_ptr +ConvertToMetadataSet(const DatasetPtr& dataset_ptr) { + auto elems = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + + auto p_offset = (int64_t*)malloc(sizeof(int64_t) * (elems + 1)); + for (auto i = 0; i <= elems; ++i) p_offset[i] = i * 8; + + std::shared_ptr metaset( + new SPTAG::MemMetadataSet(SPTAG::ByteArray((std::uint8_t*)p_data, elems * sizeof(int64_t), false), + SPTAG::ByteArray((std::uint8_t*)p_offset, elems * sizeof(int64_t), true), elems)); + + return metaset; +} + +std::shared_ptr +ConvertToVectorSet(const DatasetPtr& dataset_ptr) { + GET_TENSOR_DATA_DIM(dataset_ptr) + size_t num_bytes = rows * dim * sizeof(float); + SPTAG::ByteArray byte_array((uint8_t*)p_data, num_bytes, false); + + auto vectorset = std::make_shared(byte_array, SPTAG::VectorValueType::Float, dim, rows); + return vectorset; +} + +std::vector +ConvertToQueryResult(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr); + + int64_t k = config[meta::TOPK].get(); + std::vector query_results(rows, SPTAG::QueryResult(nullptr, k, true)); + for (auto i = 0; i < rows; ++i) { + query_results[i].SetTarget((float*)p_data + i * dim); + } + + return query_results; +} + +DatasetPtr +ConvertToDataset(std::vector query_results) { + auto k = query_results[0].GetResultNum(); + auto elems = query_results.size() * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + +#pragma omp parallel for + for (auto i = 0; i < query_results.size(); ++i) { + auto results = query_results[i].GetResults(); + auto num_result = query_results[i].GetResultNum(); + for (auto j = 0; j < num_result; ++j) { + // p_id[i * k + j] = results[j].VID; + p_id[i * k + j] = *(int64_t*)query_results[i].GetMetadata(j).Data(); + p_dist[i * k + j] = results[j].Dist; + } + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.h new file mode 100644 index 0000000000..a8ff4eaf58 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.h @@ -0,0 +1,37 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/common/Config.h" +#include "knowhere/common/Dataset.h" + +namespace milvus { +namespace knowhere { + +std::shared_ptr +ConvertToVectorSet(const DatasetPtr& dataset_ptr); + +std::shared_ptr +ConvertToMetadataSet(const DatasetPtr& dataset_ptr); + +std::vector +ConvertToQueryResult(const DatasetPtr& dataset_ptr, const Config& config); + +DatasetPtr +ConvertToDataset(std::vector query_results); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp new file mode 100644 index 0000000000..a9ba05298a --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp @@ -0,0 +1,41 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "knowhere/common/Dataset.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +DatasetPtr +GenDatasetWithIds(const int64_t nb, const int64_t dim, const void* xb, const int64_t* ids) { + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::ROWS, nb); + ret_ds->Set(meta::DIM, dim); + ret_ds->Set(meta::TENSOR, xb); + ret_ds->Set(meta::IDS, ids); + return ret_ds; +} + +DatasetPtr +GenDataset(const int64_t nb, const int64_t dim, const void* xb) { + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::ROWS, nb); + ret_ds->Set(meta::DIM, dim); + ret_ds->Set(meta::TENSOR, xb); + return ret_ds; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h new file mode 100644 index 0000000000..9fe4e5b93b --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include "knowhere/common/Dataset.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +#define GET_TENSOR_DATA(dataset_ptr) \ + int64_t rows = dataset_ptr->Get(meta::ROWS); \ + const void* p_data = dataset_ptr->Get(meta::TENSOR); + +#define GET_TENSOR_DATA_DIM(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + int64_t dim = dataset_ptr->Get(meta::DIM); + +#define GET_TENSOR_DATA_ID(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + const int64_t* p_ids = dataset_ptr->Get(meta::IDS); + +#define GET_TENSOR(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + int64_t dim = dataset_ptr->Get(meta::DIM); \ + const int64_t* p_ids = dataset_ptr->Get(meta::IDS); + +extern DatasetPtr +GenDatasetWithIds(const int64_t nb, const int64_t dim, const void* xb, const int64_t* ids); + +extern DatasetPtr +GenDataset(const int64_t nb, const int64_t dim, const void* xb); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/GPUIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/GPUIndex.h new file mode 100644 index 0000000000..1e079efd7c --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/GPUIndex.h @@ -0,0 +1,50 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" + +namespace milvus { +namespace knowhere { + +class GPUIndex { + public: + explicit GPUIndex(const int& device_id) : gpu_id_(device_id) { + } + + GPUIndex(const int& device_id, const ResPtr& resource) : gpu_id_(device_id), res_(resource) { + } + + virtual VecIndexPtr + CopyGpuToCpu(const Config&) = 0; + + virtual VecIndexPtr + CopyGpuToGpu(const int64_t, const Config&) = 0; + + void + SetGpuDevice(const int& gpu_id) { + gpu_id_ = gpu_id; + } + + const int64_t + GetGpuDevice() { + return gpu_id_; + } + + protected: + int64_t gpu_id_; + ResWPtr res_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp new file mode 100644 index 0000000000..11b7ff6c49 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#endif +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +VecIndexPtr +GPUIDMAP::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} + +BinarySet +GPUIDMAP::SerializeImpl(const IndexType& type) { + try { + MemoryIOWriter writer; + { + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index); + + faiss::write_index(host_index, &writer); + delete host_index; + } + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +GPUIDMAP::LoadImpl(const BinarySet& index_binary, const IndexType& type) { + auto binary = index_binary.GetByName("IVF"); + MemoryIOReader reader; + { + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::Index* index = faiss::read_index(&reader); + + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) { + ResScope rs(res, gpu_id_, false); + auto device_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index); + index_.reset(device_index); + res_ = res; + } else { + KNOWHERE_THROW_MSG("Load error, can't get gpu resource"); + } + + delete index; + } +} + +VecIndexPtr +GPUIDMAP::CopyGpuToGpu(const int64_t device_id, const Config& config) { + auto cpu_index = CopyGpuToCpu(config); + return std::static_pointer_cast(cpu_index)->CopyCpuToGpu(device_id, config); +} + +const float* +GPUIDMAP::GetRawVectors() { + KNOWHERE_THROW_MSG("Not support"); +} + +const int64_t* +GPUIDMAP::GetRawIds() { + KNOWHERE_THROW_MSG("Not support"); +} + +void +GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + ResScope rs(res_, gpu_id_); + + // assign the metric type + auto flat_index = dynamic_cast(index_.get())->index; + flat_index->metric_type = GetMetricType(config[Metric::TYPE].get()); + index_->search(n, data, k, distances, labels, bitset_); +} + +void +GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) { + int64_t K = k + 1; + auto ntotal = Count(); + + size_t dim = config[meta::DIM]; + auto batch_size = 1000; + auto tail_batch_size = ntotal % batch_size; + auto batch_search_count = ntotal / batch_size; + auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1; + + std::vector res_dis(K * batch_size); + graph.resize(ntotal); + Graph res_vec(total_search_count); + for (int i = 0; i < total_search_count; ++i) { + auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size; + + auto& res = res_vec[i]; + res.resize(K * b_size); + + const float* xq = data + batch_size * dim * i; + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + + for (int j = 0; j < b_size; ++j) { + auto& node = graph[batch_size * i + j]; + node.resize(k); + auto start_pos = j * K + 1; + for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) { + node[m] = res[cursor]; + } + } + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h new file mode 100644 index 0000000000..f9286ed991 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.h @@ -0,0 +1,64 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include + +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/gpu/GPUIndex.h" + +namespace milvus { +namespace knowhere { + +using Graph = std::vector>; + +class GPUIDMAP : public IDMAP, public GPUIndex { + public: + explicit GPUIDMAP(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : IDMAP(std::move(index)), GPUIndex(device_id, res) { + index_mode_ = IndexMode::MODE_GPU; + } + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + VecIndexPtr + CopyGpuToGpu(const int64_t, const Config&) override; + + const float* + GetRawVectors() override; + + const int64_t* + GetRawIds() override; + + void + GenGraph(const float*, const int64_t, GraphType&, const Config&); + + virtual ~GPUIDMAP() = default; + + protected: + BinarySet + SerializeImpl(const IndexType&) override; + + void + LoadImpl(const BinarySet&, const IndexType&) override; + + void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; +}; + +using GPUIDMAPPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp new file mode 100644 index 0000000000..6c01bc64ed --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp @@ -0,0 +1,159 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + faiss::gpu::GpuIndexIVFFlatConfig idx_config; + idx_config.device = gpu_id_; + int32_t nlist = config[IndexParams::nlist]; + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + auto device_index = + new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config); + device_index->train(rows, reinterpret_cast(p_data)); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVF can't get gpu resource"); + } +} + +void +GPUIVF::Add(const DatasetPtr& dataset_ptr, const Config& config) { + auto spt = res_.lock(); + if (spt != nullptr) { + ResScope rs(res_, gpu_id_); + IVF::Add(dataset_ptr, config); + } else { + KNOWHERE_THROW_MSG("Add IVF can't get gpu resource"); + } +} + +VecIndexPtr +GPUIVF::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + auto device_idx = std::dynamic_pointer_cast(index_); + if (device_idx != nullptr) { + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); + } else { + return std::make_shared(index_); + } +} + +VecIndexPtr +GPUIVF::CopyGpuToGpu(const int64_t device_id, const Config& config) { + auto host_index = CopyGpuToCpu(config); + return std::static_pointer_cast(host_index)->CopyCpuToGpu(device_id, config); +} + +BinarySet +GPUIVF::SerializeImpl(const IndexType& type) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + { + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(index); + + faiss::write_index(host_index, &writer); + delete host_index; + } + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) { + auto binary = binary_set.GetByName("IVF"); + MemoryIOReader reader; + { + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::Index* index = faiss::read_index(&reader); + + if (auto temp_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_)) { + ResScope rs(temp_res, gpu_id_, false); + auto device_index = faiss::gpu::index_cpu_to_gpu(temp_res->faiss_res.get(), gpu_id_, index); + index_.reset(device_index); + res_ = temp_res; + } else { + KNOWHERE_THROW_MSG("Load error, can't get gpu resource"); + } + + delete index; + } +} + +void +GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + std::lock_guard lk(mutex_); + + auto device_index = std::dynamic_pointer_cast(index_); + if (device_index) { + device_index->nprobe = config[IndexParams::nprobe]; + ResScope rs(res_, gpu_id_); + + // if query size > 2048 we search by blocks to avoid malloc issue + const int64_t block_size = 2048; + int64_t dim = device_index->d; + for (int64_t i = 0; i < n; i += block_size) { + int64_t search_size = (n - i > block_size) ? block_size : (n - i); + device_index->search(search_size, reinterpret_cast(data) + i * dim, k, distances + i * k, + labels + i * k, bitset_); + } + } else { + KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h new file mode 100644 index 0000000000..49d1b3eef0 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.h @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/gpu/GPUIndex.h" + +namespace milvus { +namespace knowhere { + +class GPUIVF : public IVF, public GPUIndex { + public: + explicit GPUIVF(const int& device_id) : IVF(), GPUIndex(device_id) { + index_mode_ = IndexMode::MODE_GPU; + } + + explicit GPUIVF(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : IVF(std::move(index)), GPUIndex(device_id, res) { + index_mode_ = IndexMode::MODE_GPU; + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + VecIndexPtr + CopyGpuToGpu(const int64_t, const Config&) override; + + protected: + BinarySet + SerializeImpl(const IndexType&) override; + + void + LoadImpl(const BinarySet&, const IndexType&) override; + + void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; +}; + +using GPUIVFPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp new file mode 100644 index 0000000000..03e08700c2 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + auto device_index = + new faiss::gpu::GpuIndexIVFPQ(gpu_res->faiss_res.get(), dim, config[IndexParams::nlist].get(), + config[IndexParams::m], config[IndexParams::nbits], + GetMetricType(config[Metric::TYPE].get())); // IP not support + device_index->train(rows, reinterpret_cast(p_data)); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource"); + } +} + +VecIndexPtr +GPUIVFPQ::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} + +std::shared_ptr +GPUIVFPQ::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->scan_table_threshold = config["scan_table_threhold"] + // params->polysemous_ht = config["polysemous_ht"] + // params->max_codes = config["max_codes"] + + return params; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h new file mode 100644 index 0000000000..c93da0f5c8 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" + +namespace milvus { +namespace knowhere { + +class GPUIVFPQ : public GPUIVF { + public: + explicit GPUIVFPQ(const int& device_id) : GPUIVF(device_id) { + index_type_ = IndexEnum::INDEX_FAISS_IVFPQ; + } + + GPUIVFPQ(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : GPUIVF(std::move(index), device_id, res) { + index_type_ = IndexEnum::INDEX_FAISS_IVFPQ; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + virtual ~GPUIVFPQ() = default; + + protected: + std::shared_ptr + GenParams(const Config& config) override; +}; + +using GPUIVFPQPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp new file mode 100644 index 0000000000..713177630e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + // std::stringstream index_type; + // index_type << "IVF" << config[IndexParams::nlist] << "," + // << "SQ" << config[IndexParams::nbits]; + // faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + // auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto build_index = new faiss::IndexIVFScalarQuantizer( + coarse_quantizer, dim, config[IndexParams::nlist].get(), faiss::QuantizerType::QT_8bit, metric_type); + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index); + device_index->train(rows, reinterpret_cast(p_data)); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource"); + } +} + +VecIndexPtr +GPUIVFSQ::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h new file mode 100644 index 0000000000..c3afcc2412 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" + +namespace milvus { +namespace knowhere { + +class GPUIVFSQ : public GPUIVF { + public: + explicit GPUIVFSQ(const int& device_id) : GPUIVF(device_id) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + explicit GPUIVFSQ(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : GPUIVF(std::move(index), device_id, res) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + virtual ~GPUIVFSQ() = default; +}; + +using GPUIVFSQPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp new file mode 100644 index 0000000000..dc4fa528b9 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp @@ -0,0 +1,287 @@ +// +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { + +#ifdef MILVUS_GPU_VERSION + +void +IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + std::stringstream index_type; + index_type << "IVF" << config[IndexParams::nlist] << "," + << "SQ8Hybrid"; + auto build_index = + faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get())); + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index); + device_index->train(rows, reinterpret_cast(p_data)); + + index_.reset(device_index); + res_ = gpu_res; + gpu_mode_ = 2; + } else { + KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource"); + } + + delete build_index; +} + +VecIndexPtr +IVFSQHybrid::CopyGpuToCpu(const Config& config) { + if (gpu_mode_ == 0) { + return std::make_shared(index_); + } + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + if (auto* ivf_index = dynamic_cast(host_index)) { + ivf_index->to_readonly(); + ivf_index->backup_quantizer(); + } + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} + +VecIndexPtr +IVFSQHybrid::CopyCpuToGpu(const int64_t device_id, const Config& config) { + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + auto idx = dynamic_cast(index_.get()); + idx->restore_quantizer(); + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get(), &option); + std::shared_ptr device_index = std::shared_ptr(gpu_index); + auto new_idx = std::make_shared(device_index, device_id, res); + return new_idx; + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id_) + "resource"); + } +} + +std::pair +IVFSQHybrid::CopyCpuToGpuWithQuantizer(const int64_t device_id, const Config& config) { + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + faiss::IndexComposition index_composition; + index_composition.index = index_.get(); + index_composition.quantizer = nullptr; + index_composition.mode = 0; // copy all + + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, &index_composition, &option); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + auto new_idx = std::make_shared(device_index, device_id, res); + + auto q = std::make_shared(); + q->quantizer = index_composition.quantizer; + q->size = index_composition.quantizer->d * index_composition.quantizer->getNumVecs() * sizeof(float); + q->gpu_id = device_id; + return std::make_pair(new_idx, q); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id_) + "resource"); + } +} + +VecIndexPtr +IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& quantizer_ptr, const Config& config) { + int64_t gpu_id = config[knowhere::meta::DEVICEID]; + + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) { + ResScope rs(res, gpu_id, false); + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + auto ivf_quantizer = std::dynamic_pointer_cast(quantizer_ptr); + if (ivf_quantizer == nullptr) { + KNOWHERE_THROW_MSG("quantizer type not faissivfquantizer"); + } + + auto index_composition = new faiss::IndexComposition; + index_composition->index = index_.get(); + index_composition->quantizer = ivf_quantizer->quantizer; + index_composition->mode = 2; // only 2 + + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option); + std::shared_ptr new_idx; + new_idx.reset(gpu_index); + auto sq_idx = std::make_shared(new_idx, gpu_id, res); + return sq_idx; + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource"); + } +} + +QuantizerPtr +IVFSQHybrid::LoadQuantizer(const Config& config) { + auto gpu_id = config[knowhere::meta::DEVICEID].get(); + + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) { + ResScope rs(res, gpu_id, false); + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + auto index_composition = new faiss::IndexComposition; + index_composition->index = index_.get(); + index_composition->quantizer = nullptr; + index_composition->mode = 1; // only 1 + + auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id, index_composition, &option); + delete gpu_index; + + auto q = std::make_shared(); + + auto& q_ptr = index_composition->quantizer; + q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float); + q->quantizer = q_ptr; + q->gpu_id = gpu_id; + res_ = res; + gpu_mode_ = 1; + return q; + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu: " + std::to_string(gpu_id) + "resource"); + } +} + +void +IVFSQHybrid::SetQuantizer(const QuantizerPtr& quantizer_ptr) { + auto ivf_quantizer = std::dynamic_pointer_cast(quantizer_ptr); + if (ivf_quantizer == nullptr) { + KNOWHERE_THROW_MSG("Quantizer type error"); + } + + auto ivf_index = dynamic_cast(index_.get()); + + auto is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if (is_gpu_flat_index == nullptr) { + // delete ivf_index->quantizer; + ivf_index->quantizer = ivf_quantizer->quantizer; + } + quantizer_gpu_id_ = ivf_quantizer->gpu_id; + gpu_mode_ = 1; +} + +void +IVFSQHybrid::UnsetQuantizer() { + auto* ivf_index = dynamic_cast(index_.get()); + if (ivf_index == nullptr) { + KNOWHERE_THROW_MSG("Index type error"); + } + + ivf_index->quantizer = nullptr; + quantizer_gpu_id_ = -1; +} + +BinarySet +IVFSQHybrid::SerializeImpl(const IndexType& type) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + if (gpu_mode_ == 0) { + MemoryIOWriter writer; + faiss::write_index(index_.get(), &writer); + + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + + return res_set; + } else if (gpu_mode_ == 2) { + return GPUIVF::SerializeImpl(type); + } else { + KNOWHERE_THROW_MSG("Can't serialize IVFSQ8Hybrid"); + } +} + +void +IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) { + FaissBaseIndex::LoadImpl(binary_set, index_type_); // load on cpu + auto* ivf_index = dynamic_cast(index_.get()); + ivf_index->backup_quantizer(); + gpu_mode_ = 0; +} + +void +IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, + const Config& config) { + if (gpu_mode_ == 2) { + GPUIVF::QueryImpl(n, data, k, distances, labels, config); + // index_->search(n, (float*)data, k, distances, labels); + } else if (gpu_mode_ == 1) { // hybrid + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(quantizer_gpu_id_)) { + ResScope rs(res, quantizer_gpu_id_, true); + IVF::QueryImpl(n, data, k, distances, labels, config); + } else { + KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(quantizer_gpu_id_) + "resource"); + } + } else if (gpu_mode_ == 0) { + IVF::QueryImpl(n, data, k, distances, labels, config); + } +} + +void +IVFSQHybrid::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfsqh_index = dynamic_cast(index_.get()); + auto nb = ivfsqh_index->invlists->compute_ntotal(); + auto code_size = ivfsqh_index->code_size; + auto nlist = ivfsqh_index->nlist; + auto d = ivfsqh_index->d; + // ivf codes, ivf ids, sq trained vectors and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float); +} + +FaissIVFQuantizer::~FaissIVFQuantizer() { + if (quantizer != nullptr) { + delete quantizer; + quantizer = nullptr; + } + // else do nothing +} + +#endif + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h new file mode 100644 index 0000000000..4aeb7f6867 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h @@ -0,0 +1,103 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include +#include + +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/gpu/Quantizer.h" + +namespace milvus { +namespace knowhere { + +#ifdef MILVUS_GPU_VERSION + +struct FaissIVFQuantizer : public Quantizer { + faiss::gpu::GpuIndexFlat* quantizer = nullptr; + int64_t gpu_id; + + ~FaissIVFQuantizer() override; +}; +using FaissIVFQuantizerPtr = std::shared_ptr; + +class IVFSQHybrid : public GPUIVFSQ { + public: + explicit IVFSQHybrid(const int& device_id) : GPUIVFSQ(device_id) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H; + gpu_mode_ = 0; + } + + explicit IVFSQHybrid(std::shared_ptr index) : GPUIVFSQ(-1) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H; + index_ = index; + gpu_mode_ = 0; + } + + explicit IVFSQHybrid(std::shared_ptr index, const int64_t device_id, ResPtr& resource) + : GPUIVFSQ(index, device_id, resource) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8H; + gpu_mode_ = 2; + } + + public: + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&) override; + + std::pair + CopyCpuToGpuWithQuantizer(const int64_t, const Config&); + + VecIndexPtr + LoadData(const knowhere::QuantizerPtr&, const Config&); + + QuantizerPtr + LoadQuantizer(const Config& conf); + + void + SetQuantizer(const QuantizerPtr& q); + + void + UnsetQuantizer(); + + void + UpdateIndexSize() override; + + protected: + BinarySet + SerializeImpl(const IndexType&) override; + + void + LoadImpl(const BinarySet&, const IndexType&) override; + + void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + + protected: + int64_t gpu_mode_ = 0; // 0,1,2 + int64_t quantizer_gpu_id_ = -1; +}; + +using IVFSQHybridPtr = std::shared_ptr; + +#endif + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h new file mode 100644 index 0000000000..89f1e03d79 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/Quantizer.h @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include "knowhere/common/Config.h" + +namespace milvus { +namespace knowhere { + +struct Quantizer { + virtual ~Quantizer() = default; + + int64_t size = -1; +}; +using QuantizerPtr = std::shared_ptr; + +// struct QuantizerCfg : Cfg { +// int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data +// }; +// using QuantizerConfig = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/BuilderSuspend.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/BuilderSuspend.h new file mode 100644 index 0000000000..d77d9a9dea --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/BuilderSuspend.h @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "faiss/BuilderSuspend.h" + +namespace milvus { +namespace knowhere { + +inline void +BuilderSuspend() { + faiss::BuilderSuspend::suspend(); +} + +inline void +BuildResume() { + faiss::BuilderSuspend::resume(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp new file mode 100644 index 0000000000..0b343c01ac --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/gpu/GPUIndex.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +namespace milvus { +namespace knowhere { +namespace cloner { + +void +CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) { + /* do real copy */ + auto uids = src_index->GetUids(); + dst_index->SetUids(uids); + dst_index->SetBlacklist(src_index->GetBlacklist()); + dst_index->SetIndexSize(src_index->IndexSize()); +} + +VecIndexPtr +CopyGpuToCpu(const VecIndexPtr& index, const Config& config) { + if (auto device_index = std::dynamic_pointer_cast(index)) { + VecIndexPtr result = device_index->CopyGpuToCpu(config); + CopyIndexData(result, index); + return result; + } else { + KNOWHERE_THROW_MSG("index type is not gpuindex"); + } +} + +VecIndexPtr +CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& config) { + VecIndexPtr result; + if (auto device_index = std::dynamic_pointer_cast(index)) { + result = device_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); + } else if (auto device_index = std::dynamic_pointer_cast(index)) { + result = device_index->CopyGpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); + } else { + KNOWHERE_THROW_MSG("this index type not support transfer to gpu"); + } + + CopyIndexData(result, index); + return result; +} + +} // namespace cloner +} // namespace knowhere +} // namespace milvus +#endif diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.h new file mode 100644 index 0000000000..8252d27d4a --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.h @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { +namespace cloner { + +extern VecIndexPtr +CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& config); + +extern VecIndexPtr +CopyGpuToCpu(const VecIndexPtr& index, const Config& config); + +} // namespace cloner +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp new file mode 100644 index 0000000000..1e3a837f6c --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp @@ -0,0 +1,132 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/common/Log.h" + +#include + +namespace milvus { +namespace knowhere { + +constexpr int64_t MB = 1LL << 20; + +FaissGpuResourceMgr& +FaissGpuResourceMgr::GetInstance() { + static FaissGpuResourceMgr instance; + return instance; +} + +void +FaissGpuResourceMgr::AllocateTempMem(ResPtr& resource, const int64_t device_id, const int64_t size) { + if (size) { + resource->faiss_res->setTempMemory(size); + } else { + auto search = devices_params_.find(device_id); + if (search != devices_params_.end()) { + resource->faiss_res->setTempMemory(search->second.temp_mem_size); + } + // else do nothing. allocate when use. + } +} + +void +FaissGpuResourceMgr::InitDevice(int64_t device_id, int64_t pin_mem_size, int64_t temp_mem_size, int64_t res_num) { + DeviceParams params; + params.pinned_mem_size = pin_mem_size; + params.temp_mem_size = temp_mem_size; + params.resource_num = res_num; + + devices_params_.emplace(device_id, params); + LOG_KNOWHERE_DEBUG_ << "DEVICEID " << device_id << ", pin_mem_size " << pin_mem_size / MB << "MB, temp_mem_size " + << temp_mem_size / MB << "MB, resource count " << res_num; +} + +void +FaissGpuResourceMgr::InitResource() { + if (!initialized_) { + std::lock_guard lock(init_mutex_); + + if (!initialized_) { + for (auto& device : devices_params_) { + auto& device_id = device.first; + + mutex_cache_.emplace(device_id, std::make_unique()); + + auto& device_param = device.second; + auto& bq = idle_map_[device_id]; + + for (int64_t i = 0; i < device_param.resource_num; ++i) { + auto raw_resource = std::make_shared(); + + // TODO(linxj): enable set pinned memory + auto res_wrapper = std::make_shared(raw_resource); + AllocateTempMem(res_wrapper, device_id, 0); + + bq.Put(res_wrapper); + } + LOG_KNOWHERE_DEBUG_ << "DEVICEID " << device_id << ", resource count " << bq.Size(); + } + initialized_ = true; + } + } +} + +ResPtr +FaissGpuResourceMgr::GetRes(const int64_t device_id, const int64_t alloc_size) { + InitResource(); + + auto finder = idle_map_.find(device_id); + if (finder != idle_map_.end()) { + auto& bq = finder->second; + auto&& resource = bq.Take(); + AllocateTempMem(resource, device_id, alloc_size); + return resource; + } else { + LOG_KNOWHERE_ERROR_ << "GPU device " << device_id << " not initialized"; + for (auto& item : idle_map_) { + auto& bq = item.second; + LOG_KNOWHERE_ERROR_ << "DEVICEID " << item.first << ", resource count " << bq.Size(); + } + return nullptr; + } +} + +void +FaissGpuResourceMgr::MoveToIdle(const int64_t device_id, const ResPtr& res) { + auto finder = idle_map_.find(device_id); + if (finder != idle_map_.end()) { + auto& bq = finder->second; + bq.Put(res); + } +} + +void +FaissGpuResourceMgr::Free() { + for (auto& item : idle_map_) { + auto& bq = item.second; + while (!bq.Empty()) { + bq.Take(); + } + } + initialized_ = false; +} + +void +FaissGpuResourceMgr::Dump() { + for (auto& item : idle_map_) { + auto& bq = item.second; + LOG_KNOWHERE_DEBUG_ << "DEVICEID: " << item.first << ", resource count:" << bq.Size(); + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h new file mode 100644 index 0000000000..48eb11f26f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h @@ -0,0 +1,130 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include + +#include "src/utils/BlockingQueue.h" + +namespace milvus { +namespace knowhere { + +struct Resource { + explicit Resource(std::shared_ptr& r) : faiss_res(r) { + static int64_t global_id = 0; + id = global_id++; + } + + std::shared_ptr faiss_res; + int64_t id; + std::mutex mutex; +}; +using ResPtr = std::shared_ptr; +using ResWPtr = std::weak_ptr; + +class FaissGpuResourceMgr { + public: + friend class ResScope; + using ResBQ = BlockingQueue; + + public: + struct DeviceParams { + int64_t temp_mem_size = 0; + int64_t pinned_mem_size = 0; + int64_t resource_num = 2; + }; + + public: + static FaissGpuResourceMgr& + GetInstance(); + + // Free gpu resource, avoid cudaGetDevice error when deallocate. + // this func should be invoke before main return + void + Free(); + + void + AllocateTempMem(ResPtr& resource, const int64_t device_id, const int64_t size); + + void + InitDevice(int64_t device_id, int64_t pin_mem_size = 0, int64_t temp_mem_size = 0, int64_t res_num = 2); + + void + InitResource(); + + // allocate gpu memory invoke by build or copy_to_gpu + ResPtr + GetRes(const int64_t device_id, const int64_t alloc_size = 0); + + void + MoveToIdle(const int64_t device_id, const ResPtr& res); + + void + Dump(); + + protected: + bool initialized_ = false; + std::mutex init_mutex_; + + std::map> mutex_cache_; + std::map devices_params_; + std::map idle_map_; +}; + +class ResScope { + public: + ResScope(ResPtr& res, const int64_t device_id, const bool isown) + : resource(res), device_id(device_id), move(true), own(isown) { + Lock(); + } + + ResScope(ResWPtr& res, const int64_t device_id, const bool isown) + : resource(res), device_id(device_id), move(true), own(isown) { + Lock(); + } + + // specif for search + // get the ownership of gpuresource and gpu + ResScope(ResWPtr& res, const int64_t device_id) : device_id(device_id), move(false), own(true) { + resource = res.lock(); + Lock(); + } + + void + Lock() { + if (own) + FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->lock(); + resource->mutex.lock(); + } + + ~ResScope() { + if (own) + FaissGpuResourceMgr::GetInstance().mutex_cache_[device_id]->unlock(); + if (move) + FaissGpuResourceMgr::GetInstance().MoveToIdle(device_id, resource); + resource->mutex.unlock(); + } + + private: + ResPtr resource; // hold resource until deconstruct + int64_t device_id; + bool move = true; + bool own = false; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp new file mode 100644 index 0000000000..1782956520 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +// TODO(linxj): Get From Config File +static size_t magic_num = 2; + +size_t +MemoryIOWriter::operator()(const void* ptr, size_t size, size_t nitems) { + auto total_need = size * nitems + rp; + + if (!data_) { // data == nullptr + total = total_need * magic_num; + rp = size * nitems; + data_ = new uint8_t[total]; + memcpy(data_, ptr, rp); + return nitems; + } + + if (total_need > total) { + total = total_need * magic_num; + auto new_data = new uint8_t[total]; + memcpy(new_data, data_, rp); + delete[] data_; + data_ = new_data; + + memcpy((data_ + rp), ptr, size * nitems); + rp = total_need; + } else { + memcpy((data_ + rp), ptr, size * nitems); + rp = total_need; + } + + return nitems; +} + +size_t +MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) { + if (rp >= total) { + return 0; + } + size_t nremain = (total - rp) / size; + if (nremain < nitems) { + nitems = nremain; + } + memcpy(ptr, (data_ + rp), size * nitems); + rp += size * nitems; + return nitems; +} + +void +enable_faiss_logging() { + faiss::LOG_DEBUG_ = &log_debug_; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h new file mode 100644 index 0000000000..e777c6746e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissIO.h @@ -0,0 +1,54 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +namespace milvus { +namespace knowhere { + +struct MemoryIOWriter : public faiss::IOWriter { + uint8_t* data_ = nullptr; + size_t total = 0; + size_t rp = 0; + + size_t + operator()(const void* ptr, size_t size, size_t nitems) override; + + template + size_t + write(T* ptr, size_t size, size_t nitems = 1) { + return operator()((const void*)ptr, size, nitems); + } +}; + +struct MemoryIOReader : public faiss::IOReader { + uint8_t* data_; + size_t rp = 0; + size_t total = 0; + + size_t + operator()(void* ptr, size_t size, size_t nitems) override; + + template + size_t + read(T* ptr, size_t size, size_t nitems = 1) { + return operator()((void*)ptr, size, nitems); + } +}; + +void +enable_faiss_logging(); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp new file mode 100644 index 0000000000..2d63a3f4a0 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/common/Exception.h" + +#include + +namespace milvus { +namespace knowhere { + +faiss::MetricType +GetMetricType(const std::string& type) { + if (type == Metric::L2) { + return faiss::METRIC_L2; + } + if (type == Metric::IP) { + return faiss::METRIC_INNER_PRODUCT; + } + if (type == Metric::JACCARD) { + return faiss::METRIC_Jaccard; + } + if (type == Metric::TANIMOTO) { + return faiss::METRIC_Tanimoto; + } + if (type == Metric::HAMMING) { + return faiss::METRIC_Hamming; + } + if (type == Metric::SUBSTRUCTURE) { + return faiss::METRIC_Substructure; + } + if (type == Metric::SUPERSTRUCTURE) { + return faiss::METRIC_Superstructure; + } + + KNOWHERE_THROW_MSG("Metric type is invalid"); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h new file mode 100644 index 0000000000..69c84f5279 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h @@ -0,0 +1,71 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +namespace milvus { +namespace knowhere { + +namespace meta { +constexpr const char* DIM = "dim"; +constexpr const char* TENSOR = "tensor"; +constexpr const char* ROWS = "rows"; +constexpr const char* IDS = "ids"; +constexpr const char* DISTANCE = "distance"; +constexpr const char* TOPK = "k"; +constexpr const char* DEVICEID = "gpu_id"; +}; // namespace meta + +namespace IndexParams { +// IVF Params +constexpr const char* nprobe = "nprobe"; +constexpr const char* nlist = "nlist"; +constexpr const char* m = "m"; // PQ +constexpr const char* nbits = "nbits"; // PQ/SQ + +// NSG Params +constexpr const char* knng = "knng"; +constexpr const char* search_length = "search_length"; +constexpr const char* out_degree = "out_degree"; +constexpr const char* candidate = "candidate_pool_size"; + +// HNSW Params +constexpr const char* efConstruction = "efConstruction"; +constexpr const char* M = "M"; +constexpr const char* ef = "ef"; + +// Annoy Params +constexpr const char* n_trees = "n_trees"; +constexpr const char* search_k = "search_k"; + +// PQ Params +constexpr const char* PQM = "PQM"; +} // namespace IndexParams + +namespace Metric { +constexpr const char* TYPE = "metric_type"; +constexpr const char* IP = "IP"; +constexpr const char* L2 = "L2"; +constexpr const char* HAMMING = "HAMMING"; +constexpr const char* JACCARD = "JACCARD"; +constexpr const char* TANIMOTO = "TANIMOTO"; +constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE"; +constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE"; +} // namespace Metric + +extern faiss::MetricType +GetMetricType(const std::string& type); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp new file mode 100644 index 0000000000..968002f389 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "knowhere/index/vector_index/helpers/SPTAGParameterMgr.h" + +namespace milvus { +namespace knowhere { + +const Config& +SPTAGParameterMgr::GetKDTParameters() { + return kdt_config_; +} + +const Config& +SPTAGParameterMgr::GetBKTParameters() { + return bkt_config_; +} + +SPTAGParameterMgr::SPTAGParameterMgr() { + kdt_config_["kdtnumber"] = 1; + kdt_config_["numtopdimensionkdtsplit"] = 5; + kdt_config_["samples"] = 100; + kdt_config_["tptnumber"] = 1; + kdt_config_["tptleafsize"] = 2000; + kdt_config_["numtopdimensiontptsplit"] = 5; + kdt_config_["neighborhoodsize"] = 32; + kdt_config_["graphneighborhoodscale"] = 2; + kdt_config_["graphcefscale"] = 2; + kdt_config_["refineiterations"] = 0; + kdt_config_["cef"] = 1000; + kdt_config_["maxcheckforrefinegraph"] = 10000; + kdt_config_["numofthreads"] = 1; + kdt_config_["maxcheck"] = 8192; + kdt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3; + kdt_config_["numberofinitialdynamicpivots"] = 50; + kdt_config_["numberofotherdynamicpivots"] = 4; + + bkt_config_["bktnumber"] = 1; + bkt_config_["bktkmeansk"] = 32; + bkt_config_["bktleafsize"] = 8; + bkt_config_["samples"] = 100; + bkt_config_["tptnumber"] = 1; + bkt_config_["tptleafsize"] = 2000; + bkt_config_["numtopdimensiontptsplit"] = 5; + bkt_config_["neighborhoodsize"] = 32; + bkt_config_["graphneighborhoodscale"] = 2; + bkt_config_["graphcefscale"] = 2; + bkt_config_["refineiterations"] = 0; + bkt_config_["cef"] = 1000; + bkt_config_["maxcheckforrefinegraph"] = 10000; + bkt_config_["numofthreads"] = 1; + bkt_config_["maxcheck"] = 8192; + bkt_config_["thresholdofnumberofcontinuousnobetterpropagation"] = 3; + bkt_config_["numberofinitialdynamicpivots"] = 50; + bkt_config_["numberofotherdynamicpivots"] = 4; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h b/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h new file mode 100644 index 0000000000..300ddc7db6 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.h @@ -0,0 +1,55 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include "IndexParameter.h" +#include "knowhere/common/Config.h" + +namespace milvus { +namespace knowhere { + +class SPTAGParameterMgr { + public: + const Config& + GetKDTParameters(); + + const Config& + GetBKTParameters(); + + public: + static SPTAGParameterMgr& + GetInstance() { + static SPTAGParameterMgr instance; + return instance; + } + + SPTAGParameterMgr(const SPTAGParameterMgr&) = delete; + + SPTAGParameterMgr& + operator=(const SPTAGParameterMgr&) = delete; + + private: + SPTAGParameterMgr(); + + private: + Config kdt_config_; + Config bkt_config_; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp new file mode 100644 index 0000000000..95515e1162 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp @@ -0,0 +1,247 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "knowhere/index/vector_index/impl/nsg/Distance.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +#if 0 /* use FAISS distance calculation algorithm instead */ + +float +DistanceL2::Compare(const float* a, const float* b, unsigned size) const { + float result = 0; + +#ifdef __GNUC__ +#ifdef __AVX__ + +#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_sub_ps(tmp1, tmp2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp1); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_L2SQR(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_L2SQR(l, r, sum, l0, r0); + AVX_L2SQR(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + +#else +#ifdef __SSE2__ +#define SSE_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm_load_ps(addr1); \ + tmp2 = _mm_load_ps(addr2); \ + tmp1 = _mm_sub_ps(tmp1, tmp2); \ + tmp1 = _mm_mul_ps(tmp1, tmp1); \ + dest = _mm_add_ps(dest, tmp1); + + __m128 sum; + __m128 l0, l1, l2, l3; + __m128 r0, r1, r2, r3; + unsigned D = (size + 3) & ~3U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) { + case 12: + SSE_L2SQR(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_L2SQR(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_L2SQR(e_l, e_r, sum, l0, r0); + default: + break; + } + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + SSE_L2SQR(l, r, sum, l0, r0); + SSE_L2SQR(l + 4, r + 4, sum, l1, r1); + SSE_L2SQR(l + 8, r + 8, sum, l2, r2); + SSE_L2SQR(l + 12, r + 12, sum, l3, r3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; + +// nomal distance +#else + + float diff0, diff1, diff2, diff3; + const float* last = a + size; + const float* unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) { + diff0 = a[0] - b[0]; + diff1 = a[1] - b[1]; + diff2 = a[2] - b[2]; + diff3 = a[3] - b[3]; + result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) { + diff0 = *a++ - *b++; + result += diff0 * diff0; + } +#endif +#endif +#endif + + return result; +} + +float +DistanceIP::Compare(const float* a, const float* b, unsigned size) const { + float result = 0; + +#ifdef __GNUC__ +#ifdef __AVX__ +#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp2); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_DOT(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_DOT(l, r, sum, l0, r0); + AVX_DOT(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + +#else +#ifdef __SSE2__ +#define SSE_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm128_loadu_ps(addr1); \ + tmp2 = _mm128_loadu_ps(addr2); \ + tmp1 = _mm128_mul_ps(tmp1, tmp2); \ + dest = _mm128_add_ps(dest, tmp1); + __m128 sum; + __m128 l0, l1, l2, l3; + __m128 r0, r1, r2, r3; + unsigned D = (size + 3) & ~3U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; + + sum = _mm_load_ps(unpack); + switch (DR) { + case 12: + SSE_DOT(e_l + 8, e_r + 8, sum, l2, r2); + case 8: + SSE_DOT(e_l + 4, e_r + 4, sum, l1, r1); + case 4: + SSE_DOT(e_l, e_r, sum, l0, r0); + default: + break; + } + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + SSE_DOT(l, r, sum, l0, r0); + SSE_DOT(l + 4, r + 4, sum, l1, r1); + SSE_DOT(l + 8, r + 8, sum, l2, r2); + SSE_DOT(l + 12, r + 12, sum, l3, r3); + } + _mm_storeu_ps(unpack, sum); + result += unpack[0] + unpack[1] + unpack[2] + unpack[3]; +#else + + float dot0, dot1, dot2, dot3; + const float* last = a + size; + const float* unroll_group = last - 3; + + /* Process 4 items with each loop for efficiency. */ + while (a < unroll_group) { + dot0 = a[0] * b[0]; + dot1 = a[1] * b[1]; + dot2 = a[2] * b[2]; + dot3 = a[3] * b[3]; + result += dot0 + dot1 + dot2 + dot3; + a += 4; + b += 4; + } + /* Process last 0-3 pixels. Not needed for standard vector lengths. */ + while (a < last) { + result += *a++ * *b++; + } +#endif +#endif +#endif + return result; +} + +#else + +float +DistanceL2::Compare(const float* a, const float* b, unsigned size) const { + return faiss::fvec_L2sqr(a, b, static_cast(size)); +} + +float +DistanceIP::Compare(const float* a, const float* b, unsigned size) const { + return -(faiss::fvec_inner_product(a, b, static_cast(size))); +} + +#endif + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.h new file mode 100644 index 0000000000..ac8d20c559 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.h @@ -0,0 +1,36 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +namespace milvus { +namespace knowhere { +namespace impl { + +struct Distance { + virtual ~Distance() = default; + virtual float + Compare(const float* a, const float* b, unsigned size) const = 0; +}; + +struct DistanceL2 : public Distance { + float + Compare(const float* a, const float* b, unsigned size) const override; +}; + +struct DistanceIP : public Distance { + float + Compare(const float* a, const float* b, unsigned size) const override; +}; + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp new file mode 100644 index 0000000000..226f3ac9b8 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp @@ -0,0 +1,914 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "knowhere/index/vector_index/impl/nsg/NSG.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +unsigned int seed = 100; + +NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric) + : dimension(dimension), ntotal(n), metric_type(metric) { + if (metric == Metric_Type::Metric_Type_L2) { + distance_ = new DistanceL2; + } else if (metric == Metric_Type::Metric_Type_IP) { + distance_ = new DistanceIP; + } +} + +NsgIndex::~NsgIndex() { + // delete[] ori_data_; + delete[] ids_; + delete distance_; +} + +void +NsgIndex::Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters) { + ntotal = nb; + // ori_data_ = new float[ntotal * dimension]; + ids_ = new int64_t[ntotal]; + // memcpy((void*)ori_data_, (void*)data, sizeof(float) * ntotal * dimension); + memcpy(ids_, ids, sizeof(int64_t) * ntotal); + + search_length = parameters.search_length; + out_degree = parameters.out_degree; + candidate_pool_size = parameters.candidate_pool_size; + + TimeRecorder rc("NSG", 1); + InitNavigationPoint(data); + rc.RecordSection("init"); + + Link(data); + rc.RecordSection("Link"); + + CheckConnectivity(data); + rc.RecordSection("Connect"); + rc.ElapseFromBegin("finish"); + + is_trained = true; + + int total_degree = 0; + for (size_t i = 0; i < ntotal; ++i) { + total_degree += nsg[i].size(); + } + LOG_KNOWHERE_DEBUG_ << "Graph physical size: " << total_degree * sizeof(node_t) / 1024 / 1024 << "m"; + LOG_KNOWHERE_DEBUG_ << "Average degree: " << total_degree / ntotal; + + // Debug code + // for (size_t i = 0; i < ntotal; i++) { + // auto& x = nsg[i]; + // for (size_t j = 0; j < x.size(); j++) { + // std::cout << "id: " << x[j] << std::endl; + // } + // std::cout << std::endl; + // } +} + +void +NsgIndex::InitNavigationPoint(float* data) { + // calculate the center of vectors + auto center = new float[dimension]; + memset(center, 0, sizeof(float) * dimension); + + for (size_t i = 0; i < ntotal; i++) { + for (size_t j = 0; j < dimension; j++) { + center[j] += data[i * dimension + j]; + } + } + for (size_t j = 0; j < dimension; j++) { + center[j] /= ntotal; + } + + // select navigation point + std::vector resset; + navigation_point = rand_r(&seed) % ntotal; // random initialize navigating point + GetNeighbors(center, data, resset, knng); + navigation_point = resset[0].id; + + // Debug code + // std::cout << "ep: " << navigation_point << std::endl; + // for (int k = 0; k < resset.size(); ++k) { + // std::cout << "id: " << resset[k].id << ", dis: " << resset[k].distance << std::endl; + // } + // std::cout << std::endl; + // + // std::cout << "ep: " << navigation_point << std::endl; + // + // float r1 = distance_->Compare(center, ori_data_ + navigation_point * dimension, dimension); + // assert(r1 == resset[0].distance); +} + +// Specify Link +void +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, + boost::dynamic_bitset<>& has_calculated_dist) { + auto& graph = knng; + size_t buffer_size = search_length; + + if (buffer_size > ntotal) { + KNOWHERE_THROW_MSG("Build Error, search_length > ntotal"); + } + + resset.resize(search_length); + std::vector init_ids(buffer_size); + // std::vector init_ids; + + { + /* + * copy navigation-point neighbor, pick random node if less than buffer size + */ + size_t count = 0; + + // Get all neighbors + for (size_t i = 0; i < init_ids.size() && i < graph[navigation_point].size(); ++i) { + // for (size_t i = 0; i < graph[navigation_point].size(); ++i) { + // init_ids.push_back(graph[navigation_point][i]); + init_ids[i] = graph[navigation_point][i]; + has_calculated_dist[init_ids[i]] = true; + ++count; + } + while (count < buffer_size) { + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) { + continue; // duplicate id + } + // init_ids.push_back(id); + init_ids[count] = id; + ++count; + has_calculated_dist[id] = true; + } + } + + { + // resset.resize(init_ids.size()); + + // init resset and sort by distance + for (size_t i = 0; i < init_ids.size(); ++i) { + node_t id = init_ids[i]; + + if (id >= static_cast(ntotal)) { + KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); + continue; + } + + float dist = distance_->Compare(data + dimension * id, query, dimension); + resset[i] = Neighbor(id, dist, false); + + //// difference from other GetNeighbors + fullset.push_back(resset[i]); + /////////////////////////////////////// + } + std::sort(resset.begin(), resset.end()); // sort by distance + + // search nearest neighbor + size_t cursor = 0; + while (cursor < buffer_size) { + size_t nearest_updated_pos = buffer_size; + + if (!resset[cursor].has_explored) { + resset[cursor].has_explored = true; + + node_t start_pos = resset[cursor].id; + auto& wait_for_search_node_vec = graph[start_pos]; + for (node_t id : wait_for_search_node_vec) { + if (has_calculated_dist[id]) { + continue; + } + has_calculated_dist[id] = true; + + float dist = distance_->Compare(query, data + dimension * id, dimension); + Neighbor nn(id, dist, false); + fullset.push_back(nn); + + if (dist >= resset[buffer_size - 1].distance) { + continue; + } + + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) { + nearest_updated_pos = pos; + } + + // assert(buffer_size + 1 >= resset.size()); + if (buffer_size + 1 < resset.size()) { + ++buffer_size; + } + } + } + if (cursor >= nearest_updated_pos) { + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } + } + } +} + +// FindUnconnectedNode +void +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset) { + auto& graph = nsg; + size_t buffer_size = search_length; + + if (buffer_size > ntotal) { + KNOWHERE_THROW_MSG("Build Error, search_length > ntotal"); + } + + // std::vector init_ids; + std::vector init_ids(buffer_size); + resset.resize(buffer_size); + boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; + + { + /* + * copy navigation-point neighbor, pick random node if less than buffer size + */ + size_t count = 0; + + // Get all neighbors + for (size_t i = 0; i < init_ids.size() && i < graph[navigation_point].size(); ++i) { + // for (size_t i = 0; i < graph[navigation_point].size(); ++i) { + // init_ids.push_back(graph[navigation_point][i]); + init_ids[i] = graph[navigation_point][i]; + has_calculated_dist[init_ids[i]] = true; + ++count; + } + while (count < buffer_size) { + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) { + continue; // duplicate id + } + // init_ids.push_back(id); + init_ids[count] = id; + ++count; + has_calculated_dist[id] = true; + } + } + + { + // resset.resize(init_ids.size()); + + // init resset and sort by distance + for (size_t i = 0; i < init_ids.size(); ++i) { + node_t id = init_ids[i]; + + if (id >= static_cast(ntotal)) { + KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); + continue; + } + + float dist = distance_->Compare(data + id * dimension, query, dimension); + resset[i] = Neighbor(id, dist, false); + } + std::sort(resset.begin(), resset.end()); // sort by distance + + // search nearest neighbor + size_t cursor = 0; + while (cursor < buffer_size) { + size_t nearest_updated_pos = buffer_size; + + if (!resset[cursor].has_explored) { + resset[cursor].has_explored = true; + + node_t start_pos = resset[cursor].id; + auto& wait_for_search_node_vec = graph[start_pos]; + for (node_t id : wait_for_search_node_vec) { + if (has_calculated_dist[id]) { + continue; + } + has_calculated_dist[id] = true; + + float dist = distance_->Compare(data + dimension * id, query, dimension); + Neighbor nn(id, dist, false); + fullset.push_back(nn); + + if (dist >= resset[buffer_size - 1].distance) { + continue; + } + + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) { + nearest_updated_pos = pos; + } + + // assert(buffer_size + 1 >= resset.size()); + if (buffer_size + 1 < resset.size()) { + ++buffer_size; // trick + } + } + } + if (cursor >= nearest_updated_pos) { + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } + } + } +} + +void +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, + SearchParams* params) { + size_t buffer_size = params ? params->search_length : search_length; + + if (buffer_size > ntotal) { + KNOWHERE_THROW_MSG("Build Error, search_length > ntotal"); + } + + std::vector init_ids(buffer_size); + resset.resize(buffer_size); + boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; + + { + /* + * copy navigation-point neighbor, pick random node if less than buffer size + */ + size_t count = 0; + + // Get all neighbors + for (size_t i = 0; i < init_ids.size() && i < graph[navigation_point].size(); ++i) { + init_ids[i] = graph[navigation_point][i]; + has_calculated_dist[init_ids[i]] = true; + ++count; + } + while (count < buffer_size) { + node_t id = rand_r(&seed) % ntotal; + if (has_calculated_dist[id]) { + continue; // duplicate id + } + init_ids[count] = id; + ++count; + has_calculated_dist[id] = true; + } + } + + { + // resset.resize(init_ids.size()); + + // init resset and sort by distance + for (size_t i = 0; i < init_ids.size(); ++i) { + node_t id = init_ids[i]; + + if (id >= static_cast(ntotal)) { + KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); + } + + float dist = distance_->Compare(data + id * dimension, query, dimension); + resset[i] = Neighbor(id, dist, false); + } + std::sort(resset.begin(), resset.end()); // sort by distance + + // search nearest neighbor + size_t cursor = 0; + while (cursor < buffer_size) { + size_t nearest_updated_pos = buffer_size; + + if (!resset[cursor].has_explored) { + resset[cursor].has_explored = true; + + node_t start_pos = resset[cursor].id; + auto& wait_for_search_node_vec = graph[start_pos]; + for (node_t id : wait_for_search_node_vec) { + if (has_calculated_dist[id]) { + continue; + } + has_calculated_dist[id] = true; + + float dist = distance_->Compare(query, data + dimension * id, dimension); + + if (dist >= resset[buffer_size - 1].distance) { + continue; + } + + //// difference from other GetNeighbors + Neighbor nn(id, dist, false); + /////////////////////////////////////// + + size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node + if (pos < nearest_updated_pos) { + nearest_updated_pos = pos; + } + + //>> Debug code + ///// + // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << + // nearest_updated_pos << std::endl; + ///// + // trick: avoid search query search_length < init_ids.size() ... + if (buffer_size + 1 < resset.size()) { + ++buffer_size; + } + } + } + if (cursor >= nearest_updated_pos) { + cursor = nearest_updated_pos; // re-search from new pos + } else { + ++cursor; + } + } + } +} + +void +NsgIndex::Link(float* data) { + auto cut_graph_dist = new float[ntotal * out_degree]; + nsg.resize(ntotal); + +#pragma omp parallel + { + std::vector fullset; + std::vector temp; + boost::dynamic_bitset<> flags{ntotal, 0}; +#pragma omp for schedule(dynamic, 100) + for (size_t n = 0; n < ntotal; ++n) { + faiss::BuilderSuspend::check_wait(); + fullset.clear(); + temp.clear(); + flags.reset(); + GetNeighbors(data + dimension * n, data, temp, fullset, flags); + SyncPrune(data, n, fullset, flags, cut_graph_dist); + } + + // Debug code + // std::cout << "ep: " << 0 << std::endl; + // for (int k = 0; k < fullset.size(); ++k) { + // std::cout << "id: " << fullset[k].id << ", dis: " << fullset[k].distance << std::endl; + // } + } + knng.clear(); + + // Debug code + // for (size_t i = 0; i < ntotal; i++) + // { + // auto& x = nsg[i]; + // for (size_t j=0; j < x.size(); j++) + // { + // std::cout << "id: " << x[j] << std::endl; + // } + // std::cout << std::endl; + // } + + std::vector mutex_vec(ntotal); +#pragma omp for schedule(dynamic, 100) + for (unsigned n = 0; n < ntotal; ++n) { + faiss::BuilderSuspend::check_wait(); + InterInsert(data, n, mutex_vec, cut_graph_dist); + } + delete[] cut_graph_dist; +} + +void +NsgIndex::SyncPrune(float* data, size_t n, std::vector& pool, boost::dynamic_bitset<>& has_calculated, + float* cut_graph_dist) { + // avoid lose nearest neighbor in knng + for (size_t i = 0; i < knng[n].size(); ++i) { + auto id = knng[n][i]; + if (has_calculated[id]) { + continue; + } + float dist = distance_->Compare(data + dimension * n, data + dimension * id, dimension); + pool.emplace_back(Neighbor(id, dist, true)); + } + + // sort and find closest node + unsigned cursor = 0; + std::sort(pool.begin(), pool.end()); + std::vector result; + if (pool[cursor].id == static_cast(n)) { + cursor++; + } + result.push_back(pool[cursor]); // init result with nearest neighbor + + SelectEdge(data, cursor, pool, result, true); + + // filling the cut_graph + auto& des_id_pool = nsg[n]; + float* des_dist_pool = cut_graph_dist + n * out_degree; + for (size_t i = 0; i < result.size(); ++i) { + des_id_pool.push_back(result[i].id); + des_dist_pool[i] = result[i].distance; + } + if (result.size() < out_degree) { + des_dist_pool[result.size()] = -1; + } + //>> Optimize: reserve id_pool capacity +} + +//>> Optimize: remove read-lock +void +NsgIndex::InterInsert(float* data, unsigned n, std::vector& mutex_vec, float* cut_graph_dist) { + auto& current = n; + + auto& neighbor_id_pool = nsg[current]; + float* neighbor_dist_pool = cut_graph_dist + current * out_degree; + for (size_t i = 0; i < out_degree; ++i) { + if (neighbor_dist_pool[i] == -1) { + break; + } + + size_t current_neighbor = neighbor_id_pool[i]; // center's neighbor id + auto& nsn_id_pool = nsg[current_neighbor]; // nsn => neighbor's neighbor + float* nsn_dist_pool = cut_graph_dist + current_neighbor * out_degree; + + std::vector wait_for_link_pool; // maintain candidate neighbor of the current neighbor. + int duplicate = false; + { + LockGuard lk(mutex_vec[current_neighbor]); + for (size_t j = 0; j < out_degree; ++j) { + if (nsn_dist_pool[j] == -1) { + break; + } + + // At least one edge can be connected back + if (n == nsn_id_pool[j]) { + duplicate = true; + break; + } + + Neighbor nsn(nsn_id_pool[j], nsn_dist_pool[j]); + wait_for_link_pool.push_back(nsn); + } + } + if (duplicate) { + continue; + } + + // original: (neighbor) <------- (current) + // after: (neighbor) -------> (current) + // current node as a neighbor of its neighbor + Neighbor current_as_neighbor(n, neighbor_dist_pool[i]); + wait_for_link_pool.push_back(current_as_neighbor); + + // re-selectEdge if candidate neighbor num > out_degree + if (wait_for_link_pool.size() > out_degree) { + std::vector result; + + unsigned start = 0; + std::sort(wait_for_link_pool.begin(), wait_for_link_pool.end()); + result.push_back(wait_for_link_pool[start]); + + SelectEdge(data, start, wait_for_link_pool, result); + + { + LockGuard lk(mutex_vec[current_neighbor]); + for (size_t j = 0; j < result.size(); ++j) { + nsn_id_pool[j] = result[j].id; + nsn_dist_pool[j] = result[j].distance; + } + } + } else { + LockGuard lk(mutex_vec[current_neighbor]); + for (size_t j = 0; j < out_degree; ++j) { + if (nsn_dist_pool[j] == -1) { + nsn_id_pool.push_back(current_as_neighbor.id); + nsn_dist_pool[j] = current_as_neighbor.distance; + if (j + 1 < out_degree) { + nsn_dist_pool[j + 1] = -1; + } + break; + } + } + } + } +} + +void +NsgIndex::SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, + bool limit) { + auto& pool = sort_pool; + + /* + * edge selection + * + * search in pool and search deepth is under candidate_pool_size + * max result size equal to out_degress + */ + size_t search_deepth = limit ? candidate_pool_size : pool.size(); + while (result.size() < out_degree && cursor < search_deepth && (++cursor) < pool.size()) { + auto& p = pool[cursor]; + bool should_link = true; + for (auto& t : result) { + float dist = distance_->Compare(data + dimension * t.id, data + dimension * p.id, dimension); + if (dist < p.distance) { + should_link = false; + break; + } + } + if (should_link) { + result.push_back(p); + } + } +} + +void +NsgIndex::CheckConnectivity(float* data) { + auto root = navigation_point; + boost::dynamic_bitset<> has_linked{ntotal, 0}; + int64_t linked_count = 0; + + while (linked_count < static_cast(ntotal)) { + faiss::BuilderSuspend::check_wait(); + DFS(root, has_linked, linked_count); + if (linked_count >= static_cast(ntotal)) { + break; + } + FindUnconnectedNode(data, has_linked, root); + } +} + +void +NsgIndex::DFS(size_t root, boost::dynamic_bitset<>& has_linked, int64_t& linked_count) { + size_t start = root; + std::stack s; + s.push(root); + if (!has_linked[root]) { + linked_count++; // not link + } + has_linked[root] = true; // link start... + + while (!s.empty()) { + size_t next = ntotal + 1; + + for (auto i : nsg[start]) { + if (has_linked[i] == false) { // if not link + next = i; + break; + } + } + if (next == (ntotal + 1)) { + s.pop(); + if (s.empty()) { + break; + } + start = s.top(); + continue; + } + start = next; + has_linked[start] = true; + s.push(start); + ++linked_count; + } +} + +void +NsgIndex::FindUnconnectedNode(float* data, boost::dynamic_bitset<>& has_linked, int64_t& root) { + // find any of unlinked-node + size_t id = ntotal; + for (size_t i = 0; i < ntotal; i++) { // find not link + if (has_linked[i] == false) { + id = i; + break; + } + } + + if (id == ntotal) { + return; // No Unlinked Node + } + + // search unlinked-node's neighbor + std::vector tmp, pool; + GetNeighbors(data + dimension * id, data, tmp, pool); + std::sort(pool.begin(), pool.end()); + + size_t found = 0; + for (auto node : pool) { // find nearest neighbor and add unlinked-node as its neighbor + if (has_linked[node.id]) { + root = node.id; + found = 1; + break; + } + } + if (found == 0) { + while (true) { // random a linked-node and add unlinked-node as its neighbor + size_t rid = rand_r(&seed) % ntotal; + if (has_linked[rid]) { + root = rid; + break; + } + } + } + nsg[root].push_back(id); +} + +// void +// NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) { +// size_t buffer_size = params ? params->search_length : search_length; + +// if (buffer_size > ntotal) { +// KNOWHERE_THROW_MSG("Search Error, search_length > ntotal"); +// } + +// std::vector resset(buffer_size); +// std::vector init_ids(buffer_size); +// boost::dynamic_bitset<> has_calculated_dist{ntotal, 0}; + +// { +// /* +// * copy navigation-point neighbor, pick random node if less than buffer size +// */ +// size_t count = 0; + +// // Get all neighbors +// for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) { +// init_ids[i] = nsg[navigation_point][i]; +// has_calculated_dist[init_ids[i]] = true; +// ++count; +// } +// while (count < buffer_size) { +// node_t id = rand_r(&seed) % ntotal; +// if (has_calculated_dist[id]) +// continue; // duplicate id +// init_ids[count] = id; +// ++count; +// has_calculated_dist[id] = true; +// } +// } + +// { +// // init resset and sort by distance +// for (size_t i = 0; i < init_ids.size(); ++i) { +// node_t id = init_ids[i]; + +// if (id >= static_cast(ntotal)) { +// KNOWHERE_THROW_MSG("Search Error, id > ntotal"); +// } + +// float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension); +// resset[i] = Neighbor(id, dist, false); +// } +// std::sort(resset.begin(), resset.end()); // sort by distance + +// // search nearest neighbor +// size_t cursor = 0; +// while (cursor < buffer_size) { +// size_t nearest_updated_pos = buffer_size; + +// if (!resset[cursor].has_explored) { +// resset[cursor].has_explored = true; + +// node_t start_pos = resset[cursor].id; +// auto& wait_for_search_node_vec = nsg[start_pos]; +// for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) { +// node_t id = wait_for_search_node_vec[i]; +// if (has_calculated_dist[id]) +// continue; +// has_calculated_dist[id] = true; + +// float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension); + +// if (dist >= resset[buffer_size - 1].distance) +// continue; + +// //// difference from other GetNeighbors +// Neighbor nn(id, dist, false); +// /////////////////////////////////////// + +// size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node +// if (pos < nearest_updated_pos) +// nearest_updated_pos = pos; + +// //>> Debug code +// ///// +// // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << +// // nearest_updated_pos << std::endl; +// ///// + +// // trick: avoid search query search_length < init_ids.size() ... +// if (buffer_size + 1 < resset.size()) +// ++buffer_size; +// } +// } +// if (cursor >= nearest_updated_pos) { +// cursor = nearest_updated_pos; // re-search from new pos +// } else { +// ++cursor; +// } +// } +// } + +// if ((resset.size() - params->k) >= 0) { +// for (size_t i = 0; i < params->k; ++i) { +// I[i] = resset[i].id; +// D[i] = resset[i].distance; +// } +// } else { +// size_t i = 0; +// for (; i < resset.size(); ++i) { +// I[i] = resset[i].id; +// D[i] = resset[i].distance; +// } +// for (; i < params->k; ++i) { +// I[i] = -1; +// D[i] = -1; +// } +// } +// } + +// void +// NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, +// int64_t* ids, SearchParams& params) { +// // if (k >= 45) { +// // params.search_length = k; +// // } + +// TimeRecorder rc("nsgsearch", 1); + +// if (nq == 1) { +// GetNeighbors(query, ids, dist, ¶ms); +// } else { +// #pragma omp parallel for +// for (unsigned int i = 0; i < nq; ++i) { +// const float* single_query = query + i * dim; +// GetNeighbors(single_query, ids + i * k, dist + i * k, ¶ms); +// } +// } +// rc.ElapseFromBegin("seach finish"); +// } + +void +NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, + float* dist, int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) { + std::vector> resset(nq); + + TimeRecorder rc("NsgIndex::search", 1); + if (nq == 1) { + GetNeighbors(query, data, resset[0], nsg, ¶ms); + } else { +#pragma omp parallel for + for (unsigned int i = 0; i < nq; ++i) { + const float* single_query = query + i * dim; + GetNeighbors(single_query, data, resset[i], nsg, ¶ms); + } + } + rc.RecordSection("search"); + + bool is_ip = (metric_type == Metric_Type::Metric_Type_IP); + for (unsigned int i = 0; i < nq; ++i) { + unsigned int pos = 0; + for (auto node : resset[i]) { + if (pos >= k) { + break; // already top k + } + if (!bitset || !bitset->test(node.id)) { + ids[i * k + pos] = ids_[node.id]; + dist[i * k + pos] = is_ip ? -node.distance : node.distance; + ++pos; + } + } + // fill with -1 + for (unsigned int j = pos; j < k; ++j) { + ids[i * k + j] = -1; + dist[i * k + j] = -1; + } + } + rc.RecordSection("merge"); +} + +void +NsgIndex::SetKnnGraph(Graph& g) { + knng = std::move(g); +} + +int64_t +NsgIndex::GetSize() { + int64_t ret = 0; + ret += sizeof(*this); + ret += ntotal * dimension * sizeof(float); + ret += ntotal * sizeof(int64_t); + ret += sizeof(*distance_); + for (auto& v : nsg) { + ret += v.size() * sizeof(node_t); + } + for (auto& v : knng) { + ret += v.size() * sizeof(node_t); + } + return ret; +} + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h new file mode 100644 index 0000000000..af7f608dc7 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h @@ -0,0 +1,153 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include +#include + +#include "Distance.h" +#include "Neighbor.h" +#include "knowhere/common/Config.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +using node_t = int64_t; + +struct BuildParams { + size_t search_length; + size_t out_degree; + size_t candidate_pool_size; +}; + +struct SearchParams { + size_t search_length; + size_t k; +}; + +using Graph = std::vector>; + +class NsgIndex { + public: + enum Metric_Type { + Metric_Type_L2 = 0, + Metric_Type_IP, + }; + + size_t dimension; + size_t ntotal; // totabl nb of indexed vectors + int32_t metric_type; // enum Metric_Type + Distance* distance_; + + // float* ori_data_; + int64_t* ids_; + Graph nsg; // final graph + Graph knng; // reset after build + + node_t navigation_point; // offset of node in origin data + + bool is_trained = false; + + /* + * build and search parameter + */ + size_t search_length; + size_t candidate_pool_size; // search deepth in fullset + size_t out_degree; + + public: + explicit NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric); + + NsgIndex() = default; + + virtual ~NsgIndex(); + + void + SetKnnGraph(Graph& knng); + + void + Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters); + + void + Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, + int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr); + + int64_t + GetSize(); + + // Not support yet. + // virtual void Add() = 0; + // virtual void Add_with_ids() = 0; + // virtual void Delete() = 0; + // virtual void Delete_with_ids() = 0; + // virtual void Rebuild(size_t nb, + // const float *data, + // const int64_t *ids, + // const Parameters ¶meters) = 0; + // virtual void Build(size_t nb, + // const float *data, + // const BuildParam ¶meters); + + protected: + void + InitNavigationPoint(float* data); + + // link specify + void + GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, + boost::dynamic_bitset<>& has_calculated_dist); + + // FindUnconnectedNode + void + GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset); + + // navigation-point + void + GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, + SearchParams* param = nullptr); + + // only for search + // void + // GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params); + + void + Link(float* data); + + void + SyncPrune(float* data, size_t q, std::vector& pool, boost::dynamic_bitset<>& has_calculated, + float* cut_graph_dist); + + void + SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, + bool limit = false); + + void + InterInsert(float* data, unsigned n, std::vector& mutex_vec, float* dist); + + void + CheckConnectivity(float* data); + + void + DFS(size_t root, boost::dynamic_bitset<>& flags, int64_t& count); + + void + FindUnconnectedNode(float* data, boost::dynamic_bitset<>& flags, int64_t& root); +}; + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.cpp new file mode 100644 index 0000000000..54d90d602c --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +// TODO: impl search && insert && return insert pos. why not just find and swap? +int +InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn) { + //>> Fix: Add assert + // for (unsigned int i = 0; i < K; ++i) { + // assert(addr[i].id != nn.id); + // } + + // find the location to insert + int left = 0, right = K - 1; + if (addr[left].distance > nn.distance) { + //>> Fix: memmove overflow, dump when vector deconstruct + memmove(&addr[left + 1], &addr[left], (K - 1) * sizeof(Neighbor)); + addr[left] = nn; + return left; + } + if (addr[right].distance < nn.distance) { + addr[K] = nn; + return K; + } + while (left < right - 1) { + int mid = (left + right) / 2; + if (addr[mid].distance > nn.distance) { + right = mid; + } else { + left = mid; + } + } + // check equal ID + + while (left > 0) { + if (addr[left].distance < nn.distance) { // pos is right + break; + } + if (addr[left].id == nn.id) { + return K + 1; + } + left--; + } + if (addr[left].id == nn.id || addr[right].id == nn.id) { + return K + 1; + } + + //>> Fix: memmove overflow, dump when vector deconstruct + memmove(&addr[right + 1], &addr[right], (K - 1 - right) * sizeof(Neighbor)); + addr[right] = nn; + return right; +} + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.h new file mode 100644 index 0000000000..f4be91db37 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGHelper.h @@ -0,0 +1,25 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "Neighbor.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +extern int +InsertIntoPool(Neighbor* addr, unsigned K, Neighbor nn); + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp new file mode 100644 index 0000000000..16f16e4be6 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp @@ -0,0 +1,68 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/index/vector_index/impl/nsg/NSGIO.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +void +write_index(NsgIndex* index, MemoryIOWriter& writer) { + writer(&index->metric_type, sizeof(int32_t), 1); + writer(&index->ntotal, sizeof(index->ntotal), 1); + writer(&index->dimension, sizeof(index->dimension), 1); + writer(&index->navigation_point, sizeof(index->navigation_point), 1); + // writer(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); + writer(index->ids_, sizeof(int64_t) * index->ntotal, 1); + + for (unsigned i = 0; i < index->ntotal; ++i) { + auto neighbor_num = static_cast(index->nsg[i].size()); + writer(&neighbor_num, sizeof(node_t), 1); + writer(index->nsg[i].data(), neighbor_num * sizeof(node_t), 1); + } +} + +NsgIndex* +read_index(MemoryIOReader& reader) { + size_t ntotal; + size_t dimension; + int32_t metric; + reader(&metric, sizeof(int32_t), 1); + reader(&ntotal, sizeof(size_t), 1); + reader(&dimension, sizeof(size_t), 1); + auto index = new NsgIndex(dimension, ntotal, static_cast(metric)); + reader(&index->navigation_point, sizeof(index->navigation_point), 1); + + // index->ori_data_ = new float[index->ntotal * index->dimension]; + index->ids_ = new int64_t[index->ntotal]; + // reader(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); + reader(index->ids_, sizeof(int64_t) * index->ntotal, 1); + + index->nsg.reserve(index->ntotal); + index->nsg.resize(index->ntotal); + node_t neighbor_num; + for (unsigned i = 0; i < index->ntotal; ++i) { + reader(&neighbor_num, sizeof(node_t), 1); + index->nsg[i].reserve(neighbor_num); + index->nsg[i].resize(neighbor_num); + reader(index->nsg[i].data(), neighbor_num * sizeof(node_t), 1); + } + + index->is_trained = true; + return index; +} + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.h new file mode 100644 index 0000000000..e40092d7c8 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.h @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "NSG.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { +namespace impl { + +extern void +write_index(NsgIndex* index, MemoryIOWriter& writer); + +extern NsgIndex* +read_index(MemoryIOReader& reader); + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Neighbor.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Neighbor.h new file mode 100644 index 0000000000..6644f8f262 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Neighbor.h @@ -0,0 +1,46 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include + +namespace milvus { +namespace knowhere { +namespace impl { + +using node_t = int64_t; + +// TODO: search use simple neighbor +struct Neighbor { + node_t id; // offset of node in origin data + float distance; + bool has_explored; + + Neighbor() = default; + + explicit Neighbor(node_t id, float distance, bool f) : id{id}, distance{distance}, has_explored(f) { + } + + explicit Neighbor(node_t id, float distance) : id{id}, distance{distance}, has_explored(false) { + } + + inline bool + operator<(const Neighbor& other) const { + return distance < other.distance; + } +}; + +typedef std::lock_guard LockGuard; + +} // namespace impl +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp new file mode 100644 index 0000000000..135f84c9eb --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp @@ -0,0 +1,363 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +using stdclock = std::chrono::high_resolution_clock; + +BinarySet +IVF_NM::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +IVF_NM::Load(const BinarySet& binary_set) { + std::lock_guard lk(mutex_); + LoadImpl(binary_set, index_type_); + + // Construct arranged data from original data + auto binary = binary_set.GetByName(RAW_DATA); + auto original_data = reinterpret_cast(binary->data.get()); + auto ivf_index = dynamic_cast(index_.get()); + auto invlists = ivf_index->invlists; + auto d = ivf_index->d; + size_t nb = binary->size / invlists->code_size; + auto arranged_data = new float[d * nb]; + prefix_sum.resize(invlists->nlist); + size_t curr_index = 0; + +#ifndef MILVUS_GPU_VERSION + auto ails = dynamic_cast(invlists); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = ails->ids[i].size(); + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + d * (curr_index + j), original_data + d * ails->ids[i][j], d * sizeof(float)); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#else + auto rol = dynamic_cast(invlists); + auto lengths = rol->readonly_length; + auto rol_ids = reinterpret_cast(rol->pin_readonly_ids->data); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = lengths[i]; + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + d * (curr_index + j), original_data + d * rol_ids[curr_index + j], + d * sizeof(float)); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#endif + data_ = std::shared_ptr(reinterpret_cast(arranged_data)); +} + +void +IVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto nlist = config[IndexParams::nlist].get(); + index_ = std::shared_ptr(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type)); + index_->train(rows, reinterpret_cast(p_data)); +} + +void +IVF_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + index_->add_with_ids_without_codes(rows, reinterpret_cast(p_data), p_ids); +} + +void +IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA(dataset_ptr) + index_->add_without_codes(rows, reinterpret_cast(p_data)); +} + +DatasetPtr +IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA(dataset_ptr) + + try { + auto k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + QueryImpl(rows, reinterpret_cast(p_data), k, p_dist, p_id, config); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +#if 0 +DatasetPtr +IVF_NM::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + + try { + int64_t k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + // todo: enable search by id (zhiru) + // auto blacklist = dataset_ptr->Get("bitset"); + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->search_by_id(rows, p_data, k, p_dist, p_id, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto p_data = dataset_ptr->Get(meta::IDS); + auto elems = dataset_ptr->Get(meta::DIM); + + try { + size_t p_x_size = sizeof(float) * elems; + auto p_x = (float*)malloc(p_x_size); + + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->get_vector_by_id(1, p_data, p_x, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::TENSOR, p_x); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} +#endif + +void +IVF_NM::Seal() { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + SealImpl(); +} + +VecIndexPtr +IVF_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + auto gpu_index = + faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } + +#else + KNOWHERE_THROW_MSG("Calling IVF_NM::CopyCpuToGpu when we are using CPU version"); +#endif +} + +void +IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) { + int64_t K = k + 1; + auto ntotal = Count(); + + size_t dim = config[meta::DIM]; + auto batch_size = 1000; + auto tail_batch_size = ntotal % batch_size; + auto batch_search_count = ntotal / batch_size; + auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1; + + std::vector res_dis(K * batch_size); + graph.resize(ntotal); + GraphType res_vec(total_search_count); + for (int i = 0; i < total_search_count; ++i) { + // it is usually used in NSG::train, to check BuilderSuspend + faiss::BuilderSuspend::check_wait(); + + auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size; + + auto& res = res_vec[i]; + res.resize(K * b_size); + + const float* xq = data + batch_size * dim * i; + QueryImpl(b_size, xq, K, res_dis.data(), res.data(), config); + + for (int j = 0; j < b_size; ++j) { + auto& node = graph[batch_size * i + j]; + node.resize(k); + auto start_pos = j * K + 1; + for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) { + node[m] = res[cursor]; + } + } + } +} + +std::shared_ptr +IVF_NM::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->max_codes = config["max_codes"]; + return params; +} + +void +IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + auto params = GenParams(config); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->nprobe = params->nprobe; + stdclock::time_point before = stdclock::now(); + if (params->nprobe > 1 && n <= 4) { + ivf_index->parallel_mode = 1; + } else { + ivf_index->parallel_mode = 0; + } + bool is_sq8 = (index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8) ? true : false; + ivf_index->search_without_codes(n, reinterpret_cast(data), data_.get(), prefix_sum, is_sq8, k, + distances, labels, bitset_); + stdclock::time_point after = stdclock::now(); + double search_cost = (std::chrono::duration(after - before)).count(); + LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost + << ", quantization cost: " << faiss::indexIVF_stats.quantization_time + << ", data search cost: " << faiss::indexIVF_stats.search_time; + faiss::indexIVF_stats.quantization_time = 0; + faiss::indexIVF_stats.search_time = 0; +} + +void +IVF_NM::SealImpl() { +#ifdef MILVUS_GPU_VERSION + faiss::Index* index = index_.get(); + auto idx = dynamic_cast(index); + if (idx != nullptr) { + idx->to_readonly_without_codes(); + } +#endif +} + +int64_t +IVF_NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IVF_NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +IVF_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivf_index = dynamic_cast(index_.get()); + auto nb = ivf_index->invlists->compute_ntotal(); + auto nlist = ivf_index->nlist; + auto code_size = ivf_index->code_size; + // ivf codes, ivf ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h new file mode 100644 index 0000000000..4924dd19fb --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -0,0 +1,103 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include + +#include "knowhere/common/Typedef.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_offset_index/OffsetBaseIndex.h" + +namespace milvus { +namespace knowhere { + +class IVF_NM : public VecIndex, public OffsetBaseIndex { + public: + IVF_NM() : OffsetBaseIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + explicit IVF_NM(std::shared_ptr index) : OffsetBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + +#if 0 + DatasetPtr + QueryById(const DatasetPtr& dataset, const Config& config) override; +#endif + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + +#if 0 + DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) override; +#endif + + virtual void + Seal(); + + virtual VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&); + + virtual void + GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config); + + protected: + virtual std::shared_ptr + GenParams(const Config&); + + virtual void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + + void + SealImpl() override; + + protected: + std::mutex mutex_; + std::shared_ptr data_ = nullptr; + std::vector prefix_sum; +}; + +using IVFNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp new file mode 100644 index 0000000000..32ef1cadba --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp @@ -0,0 +1,184 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/impl/nsg/NSGIO.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#endif + +namespace milvus { +namespace knowhere { + +BinarySet +NSG_NM::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + std::lock_guard lk(mutex_); + impl::NsgIndex* index = index_.get(); + + MemoryIOWriter writer; + impl::write_index(index, writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("NSG_NM", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG_NM::Load(const BinarySet& index_binary) { + try { + std::lock_guard lk(mutex_); + auto binary = index_binary.GetByName("NSG_NM"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + auto index = impl::read_index(reader); + index_.reset(index); + + data_ = index_binary.GetByName(RAW_DATA)->data; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA_DIM(dataset_ptr) + + try { + auto topK = config[meta::TOPK].get(); + auto elems = rows * topK; + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = static_cast(malloc(p_id_size)); + auto p_dist = static_cast(malloc(p_dist_size)); + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + + impl::SearchParams s_params; + s_params.search_length = config[IndexParams::search_length]; + s_params.k = config[meta::TOPK]; + { + std::lock_guard lk(mutex_); + // index_->ori_data_ = (float*) data_.get(); + index_->Search(reinterpret_cast(p_data), reinterpret_cast(data_.get()), rows, dim, + topK, p_dist, p_id, s_params, blacklist); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + auto idmap = std::make_shared(); + idmap->Train(dataset_ptr, config); + idmap->AddWithoutIds(dataset_ptr, config); + impl::Graph knng; + const float* raw_data = idmap->GetRawVectors(); + auto k = config[IndexParams::knng].get(); +#ifdef MILVUS_GPU_VERSION + const auto device_id = config[knowhere::meta::DEVICEID].get(); + if (device_id == -1) { + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); + } else { + auto gpu_idx = cloner::CopyCpuToGpu(idmap, device_id, config); + auto gpu_idmap = std::dynamic_pointer_cast(gpu_idx); + gpu_idmap->GenGraph(raw_data, k, knng, config); + } +#else + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); +#endif + + impl::BuildParams b_params; + b_params.candidate_pool_size = config[IndexParams::candidate]; + b_params.out_degree = config[IndexParams::out_degree]; + b_params.search_length = config[IndexParams::search_length]; + + auto p_ids = dataset_ptr->Get(meta::IDS); + + GET_TENSOR_DATA_DIM(dataset_ptr) + impl::NsgIndex::Metric_Type metric_type_nsg; + if (config[Metric::TYPE].get() == "IP") { + metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_IP; + } else if (config[Metric::TYPE].get() == "L2") { + metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_L2; + } else { + KNOWHERE_THROW_MSG("either IP or L2"); + } + index_ = std::make_shared(dim, rows, metric_type_nsg); + index_->SetKnnGraph(knng); + index_->Build_with_ids(rows, reinterpret_cast(const_cast(p_data)), + reinterpret_cast(p_ids), b_params); +} + +int64_t +NSG_NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +NSG_NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->dimension; +} + +void +NSG_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->GetSize() + Dim() * Count() * sizeof(float); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h new file mode 100644 index 0000000000..29e9abb7c9 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h @@ -0,0 +1,83 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +namespace impl { +class NsgIndex; +} + +class NSG_NM : public VecIndex { + public: + explicit NSG_NM(const int64_t gpu_num = -1) : gpu_(gpu_num) { + if (gpu_ >= 0) { + index_mode_ = IndexMode::MODE_GPU; + } + index_type_ = IndexEnum::INDEX_NSG; + } + + BinarySet + Serialize(const Config& config) override; + + void + Load(const BinarySet&) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { + Train(dataset_ptr, config); + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Addwithoutids is not supported"); + } + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + std::mutex mutex_; + int64_t gpu_; + std::shared_ptr index_ = nullptr; + std::shared_ptr data_ = nullptr; +}; + +using NSG_NMIndexPtr = std::shared_ptr(); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp new file mode 100644 index 0000000000..8acff67ef0 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_offset_index/OffsetBaseIndex.h" + +namespace milvus { +namespace knowhere { + +BinarySet +OffsetBaseIndex::SerializeImpl(const IndexType& type) { + try { + faiss::Index* index = index_.get(); + + MemoryIOWriter writer; + faiss::write_index_nm(index, &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +OffsetBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) { + auto binary = binary_set.GetByName("IVF"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::Index* index = faiss::read_index_nm(&reader); + index_.reset(index); + + SealImpl(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h new file mode 100644 index 0000000000..029d8f9f69 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include + +#include "knowhere/common/BinarySet.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class OffsetBaseIndex { + protected: + explicit OffsetBaseIndex(std::shared_ptr index) : index_(std::move(index)) { + } + + virtual BinarySet + SerializeImpl(const IndexType& type); + + virtual void + LoadImpl(const BinarySet&, const IndexType& type); + + virtual void + SealImpl() { /* do nothing */ + } + + public: + std::shared_ptr index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp new file mode 100644 index 0000000000..597e50e506 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + faiss::gpu::GpuIndexIVFFlatConfig idx_config; + idx_config.device = gpu_id_; + int32_t nlist = config[IndexParams::nlist]; + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + auto device_index = + new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config); + device_index->train(rows, reinterpret_cast(p_data)); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVF can't get gpu resource"); + } +} + +void +GPUIVF_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + auto spt = res_.lock(); + if (spt != nullptr) { + ResScope rs(res_, gpu_id_); + IVF::Add(dataset_ptr, config); + } else { + KNOWHERE_THROW_MSG("Add IVF can't get gpu resource"); + } +} + +void +GPUIVF_NM::Load(const BinarySet& binary_set) { + // not supported +} + +VecIndexPtr +GPUIVF_NM::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + auto device_idx = std::dynamic_pointer_cast(index_); + if (device_idx != nullptr) { + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu_without_codes(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); + } else { + return std::make_shared(index_); + } +} + +VecIndexPtr +GPUIVF_NM::CopyGpuToGpu(const int64_t device_id, const Config& config) { + auto host_index = CopyGpuToCpu(config); + return std::static_pointer_cast(host_index)->CopyCpuToGpu(device_id, config); +} + +BinarySet +GPUIVF_NM::SerializeImpl(const IndexType& type) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + { + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu_without_codes(index); + faiss::write_index_nm(host_index, &writer); + delete host_index; + } + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + std::lock_guard lk(mutex_); + + auto device_index = std::dynamic_pointer_cast(index_); + if (device_index) { + device_index->nprobe = config[IndexParams::nprobe]; + ResScope rs(res_, gpu_id_); + + // if query size > 2048 we search by blocks to avoid malloc issue + const int64_t block_size = 2048; + int64_t dim = device_index->d; + for (int64_t i = 0; i < n; i += block_size) { + int64_t search_size = (n - i > block_size) ? block_size : (n - i); + device_index->search(search_size, data + i * dim, k, distances + i * k, labels + i * k, bitset_); + } + } else { + KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h new file mode 100644 index 0000000000..7b4254f200 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/gpu/GPUIndex.h" + +namespace milvus { +namespace knowhere { + +class GPUIVF_NM : public IVF, public GPUIndex { + public: + explicit GPUIVF_NM(const int& device_id) : IVF(), GPUIndex(device_id) { + index_mode_ = IndexMode::MODE_GPU; + } + + explicit GPUIVF_NM(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : IVF(std::move(index)), GPUIndex(device_id, res) { + index_mode_ = IndexMode::MODE_GPU; + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + Load(const BinarySet&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + VecIndexPtr + CopyGpuToGpu(const int64_t, const Config&) override; + + protected: + BinarySet + SerializeImpl(const IndexType&) override; + + void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + + protected: + uint8_t* arranged_data; +}; + +using GPUIVFNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/bug_report.md b/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000..b735373365 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,35 @@ +--- +name: Bug report +about: Create a report to help us improve + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Desktop (please complete the following information):** + - OS: [e.g. iOS] + - Browser [e.g. chrome, safari] + - Version [e.g. 22] + +**Smartphone (please complete the following information):** + - Device: [e.g. iPhone6] + - OS: [e.g. iOS8.1] + - Browser [e.g. stock browser, safari] + - Version [e.g. 22] + +**Additional context** +Add any other context about the problem here. diff --git a/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/feature_request.md b/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000..066b2d920a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,17 @@ +--- +name: Feature request +about: Suggest an idea for this project + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/core/src/index/thirdparty/SPTAG/.gitignore b/core/src/index/thirdparty/SPTAG/.gitignore new file mode 100644 index 0000000000..973785834c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/.gitignore @@ -0,0 +1,91 @@ +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Linker output +*.ilk +*.map +*.exp + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Kernel Module Compile Results +*.mod* +*.cmd +.tmp_versions/ +modules.order +Module.symvers +Mkfile.old +dkms.conf +/.vs/SPTAG/v14 +/AnnService/x64 +/obj/x64_Release +/PythonWrapper/x64/Release +/x64/Release +/SPTAG.VC.db +/SPTAG.VC.VC.opendb +/AnnService/CoreLibrary.vcxproj.user +/AnnService/IndexBuilder.vcxproj.user +/AnnService/Server.vcxproj.user +/AnnService/SocketLib.vcxproj.user +/PythonWrapper/PythonCore.vcxproj.user +/build +/PythonWrapper/inc/ClientInterface_wrap.cxx +/PythonWrapper/inc/CoreInterface_wrap.cxx +/PythonWrapper/inc/SPTAG.py +/PythonWrapper/inc/SPTAGClient.py +/ipch/TEST-4fb66b42 +/obj/x64_Debug +/x64/Debug +/packages +/Search/Search.vcxproj.user +/AnnService/IndexSearcher.vcxproj.user +/Wrappers/inc/SWIGTYPE_p_RemoteSearchResult.java +/Wrappers/inc/SWIGTYPE_p_QueryResult.java +/Wrappers/inc/SPTAGJNI.java +/Wrappers/inc/SPTAGClientJNI.java +/Wrappers/inc/SPTAGClient.py +/Wrappers/inc/SPTAGClient.java +/Wrappers/inc/SPTAG.py +/Wrappers/inc/SPTAG.java +/Wrappers/inc/CoreInterface_pwrap.cpp +/Wrappers/inc/CoreInterface_jwrap.cpp +/Wrappers/inc/ClientInterface_pwrap.cpp +/Wrappers/inc/ClientInterface_jwrap.cpp +/Wrappers/inc/AnnIndex.java +/Wrappers/inc/AnnClient.java +/AnnService.users - Copy.props +/.vs diff --git a/core/src/index/thirdparty/SPTAG/AnnService.users.props b/core/src/index/thirdparty/SPTAG/AnnService.users.props new file mode 100644 index 0000000000..e65231357f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService.users.props @@ -0,0 +1,18 @@ + + + + + + + $(SystemVersionDef) %(AdditionalOptions) + + + + + $(SolutionDir)\$(Platform)\$(Configuration)\ + $(SolutionDir)\$(Platform)\$(Configuration)\ + + + + + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj new file mode 100644 index 0000000000..4e5514f057 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj @@ -0,0 +1,178 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A} + Aggregator + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(IncludePath) + $(OutAppDir) + $(OutLibDir);$(LibraryPath) + + + false + + + + Level3 + Disabled + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + /guard:cf + %(AdditionalOptions) + + + + + Level3 + Disabled + true + true + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj.filters new file mode 100644 index 0000000000..de08faca3f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Aggregator.vcxproj.filters @@ -0,0 +1,44 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/CMakeLists.txt b/core/src/index/thirdparty/SPTAG/AnnService/CMakeLists.txt new file mode 100644 index 0000000000..fffc5ce426 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/VectorSetReaders/*.h) +file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/VectorSetReaders/*.cpp) + +include_directories(${PROJECT_SOURCE_DIR}/AnnService) + +add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES}) +target_link_libraries (SPTAGLib) +add_library (SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES}) +set_target_properties(SPTAGLibStatic PROPERTIES OUTPUT_NAME SPTAGLib) + +file(GLOB SERVER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Server/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) +file(GLOB SERVER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Server/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) +add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES}) +target_link_libraries(server ${Boost_LIBRARIES}) + +file(GLOB CLIENT_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) +file(GLOB CLIENT_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) +add_executable (client ${CLIENT_FILES} ${CLIENT_HDR_FILES}) +target_link_libraries(client ${Boost_LIBRARIES}) + +file(GLOB AGG_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Aggregator/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) +file(GLOB AGG_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Aggregator/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) +add_executable (aggregator ${AGG_FILES} ${AGG_HDR_FILES}) +target_link_libraries(aggregator ${Boost_LIBRARIES}) + +file(GLOB BUILDER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/*.h) +file(GLOB BUILDER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/*.cpp) +add_executable (indexbuilder ${BUILDER_FILES} ${BUILDER_HDR_FILES}) +target_link_libraries(indexbuilder ${Boost_LIBRARIES}) + +file(GLOB SEARCHER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexSearcher/*.cpp) +add_executable (indexsearcher ${SEARCHER_FILES} ${HDR_FILES}) +target_link_libraries(indexsearcher ${Boost_LIBRARIES}) + +install(TARGETS SPTAGLib SPTAGLibStatic server client aggregator indexbuilder indexsearcher + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj new file mode 100644 index 0000000000..34cb9def7a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj @@ -0,0 +1,145 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB} + Client + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(IncludePath) + $(OutAppDir) + $(OutLibDir);$(LibraryPath) + + + false + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + /guard:cf %(AdditionalOptions) + + + Guard + ProgramDatabase + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj.filters new file mode 100644 index 0000000000..31d84bd1d5 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Client.vcxproj.filters @@ -0,0 +1,32 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + + + Source Files + + + Source Files + + + Source Files + + + + + Header Files + + + Header Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj new file mode 100644 index 0000000000..08921f2444 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj @@ -0,0 +1,192 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + CoreLibrary + 8.1 + CoreLibrary + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + StaticLibrary + true + v140 + MultiByte + + + StaticLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + $(IncludePath);$(ProjectDir) + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(OutLibDir) + + + + Level3 + Disabled + true + true + /Zc:twoPhase- %(AdditionalOptions) + + + + + Level3 + Disabled + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + + + Level3 + MaxSpeed + true + true + true + true + /Zc:twoPhase- %(AdditionalOptions) + + + true + true + + + + + Level3 + MaxSpeed + true + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj.filters new file mode 100644 index 0000000000..94f27df3f1 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/CoreLibrary.vcxproj.filters @@ -0,0 +1,202 @@ + + + + + {c260e4c4-ec44-4d50-941f-078454da2a89} + + + {c7b9ab49-a99f-4eb4-b8ab-61f0730b9a89} + + + {a306a099-8e3f-433d-b065-9d99433f422e} + + + {33f272c5-907e-4848-bbd7-4340fe44f511} + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {47ec2958-e880-4c1a-b663-04fc48c799af} + + + {10d17e5e-9ad2-4000-96d4-83b616480b97} + + + {23fdeb31-9052-47d1-8edb-9b47c4b02707} + + + {6dff7b24-66ea-40b4-b408-d8fe264a6caa} + + + {b0f1e81d-ca05-426e-bffa-75513a52ca6b} + + + {774592a9-40aa-4342-a4af-b711a1cc4d52} + + + {8fb36afb-73ed-4c3d-8c9b-c3581d80c5d1} + + + {f7bc0bc7-1af5-4870-b8ee-fabdbabdb4c4} + + + {5c1449e0-38b7-4c82-976e-cbdc488d3fb5} + + + + + Header Files\Core + + + Header Files\Core + + + Header Files\Core + + + Header Files\Core + + + Header Files\Core + + + Header Files\Core + + + Header Files\Core\BKT + + + Header Files\Helper + + + Header Files\Helper + + + Header Files\Helper + + + Header Files\Core + + + Header Files\Core + + + Header Files\Helper + + + Header Files\Helper + + + Header Files\Helper + + + Header Files\Core\BKT + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\KDT + + + Header Files\Core\KDT + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Core\Common + + + Header Files\Helper + + + Header Files\Helper + + + Header Files\Helper\VectorSetReaders + + + Header Files\Helper + + + + + Source Files\Core + + + Source Files\Core\Common + + + Source Files\Helper + + + Source Files\Helper + + + Source Files\Core + + + Source Files\Core + + + Source Files\Helper + + + Source Files\Helper + + + Source Files\Helper + + + Source Files\Core\BKT + + + Source Files\Core\KDT + + + Source Files\Core\Common + + + Source Files\Helper\VectorSetReaders + + + Source Files\Helper + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj new file mode 100644 index 0000000000..a5d05fb473 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj @@ -0,0 +1,173 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {F492F794-E78B-4B1F-A556-5E045B9163D5} + IndexBuilder + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(IncludePath) + $(OutAppDir) + $(OutLibDir);$(LibraryPath) + + + false + + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + true + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + true + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj.filters new file mode 100644 index 0000000000..0733fae1c1 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/IndexBuilder.vcxproj.filters @@ -0,0 +1,32 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj new file mode 100644 index 0000000000..266ac576b3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj @@ -0,0 +1,170 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {97615D3B-9FA0-469E-B229-95A91A5087E0} + IndexSearcher + 8.1 + IndexSearcher + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + true + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + true + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + Designer + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj.filters new file mode 100644 index 0000000000..82f7700c53 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/IndexSearcher.vcxproj.filters @@ -0,0 +1,25 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Source Files + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj new file mode 100644 index 0000000000..d830f3bc0d --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj @@ -0,0 +1,153 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0} + Server + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(IncludePath) + $(OutAppDir) + $(OutLibDir);$(LibraryPath) + + + false + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + /guard:cf %(AdditionalOptions) + + + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj.filters new file mode 100644 index 0000000000..da95d13f40 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/Server.vcxproj.filters @@ -0,0 +1,56 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj b/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj new file mode 100644 index 0000000000..84f968fa0a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj @@ -0,0 +1,123 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {F9A72303-6381-4C80-86FF-606A2F6F7B96} + SocketLib + 8.1 + + + + StaticLibrary + true + v140 + MultiByte + + + StaticLibrary + false + v140 + true + MultiByte + + + StaticLibrary + true + v140 + MultiByte + + + StaticLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + + $(ProjectDir);$(IncludePath) + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(OutLibDir)\ + + + + Guard + ProgramDatabase + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + + + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj.filters b/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj.filters new file mode 100644 index 0000000000..5aa20f3bbf --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/SocketLib.vcxproj.filters @@ -0,0 +1,65 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorContext.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorContext.h new file mode 100644 index 0000000000..60c86137be --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorContext.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_ +#define _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_ + +#include "inc/Socket/Common.h" +#include "AggregatorSettings.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Aggregator +{ + +enum RemoteMachineStatus : uint8_t +{ + Disconnected = 0, + + Connecting, + + Connected +}; + + +struct RemoteMachine +{ + RemoteMachine(); + + std::string m_address; + + std::string m_port; + + Socket::ConnectionID m_connectionID; + + std::atomic m_status; +}; + +class AggregatorContext +{ +public: + AggregatorContext(const std::string& p_filePath); + + ~AggregatorContext(); + + bool IsInitialized() const; + + const std::vector>& GetRemoteServers() const; + + const std::shared_ptr& GetSettings() const; + +private: + std::vector> m_remoteServers; + + std::shared_ptr m_settings; + + bool m_initialized; +}; + +} // namespace Aggregator +} // namespace AnnService + + +#endif // _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorExecutionContext.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorExecutionContext.h new file mode 100644 index 0000000000..12948a218a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorExecutionContext.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_ +#define _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_ + +#include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Socket/Packet.h" + +#include +#include + +namespace SPTAG +{ +namespace Aggregator +{ + +typedef std::shared_ptr AggregatorResult; + +class AggregatorExecutionContext +{ +public: + AggregatorExecutionContext(std::size_t p_totalServerNumber, + Socket::PacketHeader p_requestHeader); + + ~AggregatorExecutionContext(); + + std::size_t GetServerNumber() const; + + AggregatorResult& GetResult(std::size_t p_num); + + const Socket::PacketHeader& GetRequestHeader() const; + + bool IsCompletedAfterFinsh(std::uint32_t p_finishedCount); + +private: + std::atomic m_unfinishedCount; + + std::vector m_results; + + Socket::PacketHeader m_requestHeader; + +}; + + + + +} // namespace Aggregator +} // namespace AnnService + + +#endif // _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_ + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorService.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorService.h new file mode 100644 index 0000000000..4d864aa5f0 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorService.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_ +#define _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_ + +#include "AggregatorContext.h" +#include "AggregatorExecutionContext.h" +#include "inc/Socket/Server.h" +#include "inc/Socket/Client.h" +#include "inc/Socket/ResourceManager.h" + +#include + +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Aggregator +{ + +class AggregatorService +{ +public: + AggregatorService(); + + ~AggregatorService(); + + bool Initialize(); + + void Run(); + +private: + + void StartClient(); + + void StartListen(); + + void WaitForShutdown(); + + void ConnectToPendingServers(); + + void AddToPendingServers(std::shared_ptr p_remoteServer); + + void SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void AggregateResults(std::shared_ptr p_exectionContext); + + std::shared_ptr GetContext(); + +private: + typedef std::function AggregatorCallback; + + std::shared_ptr m_aggregatorContext; + + std::shared_ptr m_socketServer; + + std::shared_ptr m_socketClient; + + bool m_initalized; + + std::unique_ptr m_threadPool; + + boost::asio::io_context m_ioContext; + + boost::asio::signal_set m_shutdownSignals; + + std::vector> m_pendingConnectServers; + + std::mutex m_pendingConnectServersMutex; + + boost::asio::deadline_timer m_pendingConnectServersTimer; + + Socket::ResourceManager m_aggregatorCallbackManager; +}; + + + +} // namespace Aggregator +} // namespace AnnService + + +#endif // _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorSettings.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorSettings.h new file mode 100644 index 0000000000..cb1e9fe7f7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Aggregator/AggregatorSettings.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_ +#define _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_ + +#include "../Core/Common.h" + +#include + +namespace SPTAG +{ +namespace Aggregator +{ + +struct AggregatorSettings +{ + AggregatorSettings(); + + std::string m_listenAddr; + + std::string m_listenPort; + + std::uint32_t m_searchTimeout; + + SizeType m_threadNum; + + SizeType m_socketThreadNum; +}; + + + + +} // namespace Aggregator +} // namespace AnnService + + +#endif // _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_ + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/ClientWrapper.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/ClientWrapper.h new file mode 100644 index 0000000000..d96a67061e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/ClientWrapper.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_CLIENT_CLIENTWRAPPER_H_ +#define _SPTAG_CLIENT_CLIENTWRAPPER_H_ + +#include "inc/Socket/Client.h" +#include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Socket/ResourceManager.h" +#include "Options.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Client +{ + +class ClientWrapper +{ +public: + typedef std::function Callback; + + ClientWrapper(const ClientOptions& p_options); + + ~ClientWrapper(); + + void SendQueryAsync(const Socket::RemoteQuery& p_query, + Callback p_callback, + const ClientOptions& p_options); + + void WaitAllFinished(); + + bool IsAvailable() const; + +private: + typedef std::pair ConnectionPair; + + Socket::PacketHandlerMapPtr GetHandlerMap(); + + void DecreaseUnfnishedJobCount(); + + const ConnectionPair& GetConnection(); + + void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void HandleDeadConnection(Socket::ConnectionID p_cid); + +private: + ClientOptions m_options; + + std::unique_ptr m_client; + + std::atomic m_unfinishedJobCount; + + std::atomic_bool m_isWaitingFinish; + + std::condition_variable m_waitingQueue; + + std::mutex m_waitingMutex; + + std::vector m_connections; + + std::atomic m_spinCountOfConnection; + + Socket::ResourceManager m_callbackManager; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_CLIENT_OPTIONS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/Options.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/Options.h new file mode 100644 index 0000000000..062061f042 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Client/Options.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_CLIENT_OPTIONS_H_ +#define _SPTAG_CLIENT_OPTIONS_H_ + +#include "inc/Helper/ArgumentsParser.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Client +{ + +class ClientOptions : public Helper::ArgumentsParser +{ +public: + ClientOptions(); + + virtual ~ClientOptions(); + + std::string m_serverAddr; + + std::string m_serverPort; + + // in milliseconds. + std::uint32_t m_searchTimeout; + + std::uint32_t m_threadNum; + + std::uint32_t m_socketThreadNum; + +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_CLIENT_OPTIONS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h new file mode 100644 index 0000000000..e4c52586c7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_BKT_INDEX_H_ +#define _SPTAG_BKT_INDEX_H_ + +#include "../Common.h" +#include "../VectorIndex.h" + +#include "../Common/CommonUtils.h" +#include "../Common/DistanceUtils.h" +#include "../Common/QueryResultSet.h" +#include "../Common/Dataset.h" +#include "../Common/WorkSpace.h" +#include "../Common/WorkSpacePool.h" +#include "../Common/RelativeNeighborhoodGraph.h" +#include "../Common/BKTree.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/StringConvert.h" + +#include +#include + +namespace SPTAG +{ + + namespace Helper + { + class IniReader; + } + + namespace BKT + { + template + class Index : public VectorIndex + { + private: + // data points + COMMON::Dataset m_pSamples; + + // BKT structures. + COMMON::BKTree m_pTrees; + + // Graph structure + COMMON::RelativeNeighborhoodGraph m_pGraph; + + std::string m_sBKTFilename; + std::string m_sGraphFilename; + std::string m_sDataPointsFilename; + std::string m_sDeleteDataPointsFilename; + + std::mutex m_dataAddLock; // protect data and graph + Helper::Concurrent::ConcurrentSet m_deletedID; + float m_fDeletePercentageForRefine; + std::unique_ptr m_workSpacePool; + + int m_iNumberOfThreads; + DistCalcMethod m_iDistCalcMethod; + float(*m_fComputeDistance)(const T* pX, const T* pY, DimensionType length); + + int m_iMaxCheck; + int m_iThresholdOfNumberOfContinuousNoBetterPropagation; + int m_iNumberOfInitialDynamicPivots; + int m_iNumberOfOtherDynamicPivots; + public: + Index() + { +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + VarName = DefaultValue; \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + m_pSamples.SetName("Vector"); + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + } + + ~Index() {} + + inline SizeType GetNumSamples() const { return m_pSamples.R(); } + inline SizeType GetIndexSize() const { return sizeof(*this); } + inline DimensionType GetFeatureDim() const { return m_pSamples.C(); } + + inline int GetCurrMaxCheck() const { return m_iMaxCheck; } + inline int GetNumThreads() const { return m_iNumberOfThreads; } + inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; } + inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; } + inline VectorValueType GetVectorValueType() const { return GetEnumValueType(); } + + inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } + inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; } + inline bool ContainSample(const SizeType idx) const { return !m_deletedID.contains(idx); } + inline bool NeedRefine() const { return m_deletedID.size() >= (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); } + std::shared_ptr> BufferSize() const + { + std::shared_ptr> buffersize(new std::vector); + buffersize->push_back(m_pSamples.BufferSize()); + buffersize->push_back(m_pTrees.BufferSize()); + buffersize->push_back(m_pGraph.BufferSize()); + buffersize->push_back(m_deletedID.bufferSize()); + return std::move(buffersize); + } + + ErrorCode SaveConfig(std::ostream& p_configout) const; + ErrorCode SaveIndexData(const std::string& p_folderPath); + ErrorCode SaveIndexData(const std::vector& p_indexStreams); + + ErrorCode LoadConfig(Helper::IniReader& p_reader); + ErrorCode LoadIndexData(const std::string& p_folderPath); + ErrorCode LoadIndexDataFromMemory(const std::vector& p_indexBlobs); + + ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension); + ErrorCode SearchIndex(QueryResult &p_query) const; + ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr); + ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum); + ErrorCode DeleteIndex(const SizeType& p_id); + + ErrorCode SetParameter(const char* p_param, const char* p_value); + std::string GetParameter(const char* p_param) const; + + ErrorCode RefineIndex(const std::string& p_folderPath); + ErrorCode RefineIndex(const std::vector& p_indexStreams); + + private: + void SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet &p_deleted) const; + void SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const; + }; + } // namespace BKT +} // namespace SPTAG + +#endif // _SPTAG_BKT_INDEX_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/ParameterDefinitionList.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/ParameterDefinitionList.h new file mode 100644 index 0000000000..3f6f9e0222 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/ParameterDefinitionList.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef DefineBKTParameter + +// DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) +DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath") +DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath") +DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath") +DefineBKTParameter(m_sDeleteDataPointsFilename, std::string, std::string("deletes.bin"), "DeleteVectorFilePath") + +DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber") +DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK") +DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize") +DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples") + + +DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber") +DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize") +DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit") + +DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, DimensionType, 32L, "NeighborhoodSize") +DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale") +DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale") +DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations") +DefineBKTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF") +DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph") + +DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads") +DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod") + +DefineBKTParameter(m_fDeletePercentageForRefine, float, 0.4F, "DeletePercentageForRefine") +DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck") +DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation") +DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots") +DefineBKTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots") + +#endif diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common.h new file mode 100644 index 0000000000..02182a4bf2 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common.h @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_CORE_COMMONDEFS_H_ +#define _SPTAG_CORE_COMMONDEFS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#ifndef _MSC_VER +#include +#include +#define FolderSep '/' +#define mkdir(a) mkdir(a, ACCESSPERMS) +inline bool direxists(const char* path) { + struct stat info; + return stat(path, &info) == 0 && (info.st_mode & S_IFDIR); +} +inline bool fileexists(const char* path) { + struct stat info; + return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0; +} +template +inline T min(T a, T b) { + return a < b ? a : b; +} +template +inline T max(T a, T b) { + return a > b ? a : b; +} + +#ifndef _rotl +#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n)))) +#endif + +#else +#define WIN32_LEAN_AND_MEAN +#include +#include +#define FolderSep '\\' +#define mkdir(a) CreateDirectory(a, NULL) +inline bool direxists(const char* path) { + auto dwAttr = GetFileAttributes((LPCSTR)path); + return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY); +} +inline bool fileexists(const char* path) { + auto dwAttr = GetFileAttributes((LPCSTR)path); + return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0; +} +#endif + +namespace SPTAG +{ +typedef std::int32_t SizeType; +typedef std::int32_t DimensionType; + +const SizeType MaxSize = (std::numeric_limits::max)(); +const float MinDist = (std::numeric_limits::min)(); +const float MaxDist = (std::numeric_limits::max)(); +const float Epsilon = 0.000000001f; + +class MyException : public std::exception +{ +private: + std::string Exp; +public: + MyException(std::string e) { Exp = e; } +#ifdef _MSC_VER + const char* what() const { return Exp.c_str(); } +#else + const char* what() const noexcept { return Exp.c_str(); } +#endif +}; + +enum class ErrorCode : std::uint16_t +{ +#define DefineErrorCode(Name, Value) Name = Value, +#include "DefinitionList.h" +#undef DefineErrorCode + + Undefined +}; +static_assert(static_cast(ErrorCode::Undefined) != 0, "Empty ErrorCode!"); + + +enum class DistCalcMethod : std::uint8_t +{ +#define DefineDistCalcMethod(Name) Name, +#include "DefinitionList.h" +#undef DefineDistCalcMethod + + Undefined +}; +static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!"); + + +enum class VectorValueType : std::uint8_t +{ +#define DefineVectorValueType(Name, Type) Name, +#include "DefinitionList.h" +#undef DefineVectorValueType + + Undefined +}; +static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); + + +enum class IndexAlgoType : std::uint8_t +{ +#define DefineIndexAlgo(Name) Name, +#include "DefinitionList.h" +#undef DefineIndexAlgo + + Undefined +}; +static_assert(static_cast(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!"); + + +template +constexpr VectorValueType GetEnumValueType() +{ + return VectorValueType::Undefined; +} + + +#define DefineVectorValueType(Name, Type) \ +template<> \ +constexpr VectorValueType GetEnumValueType() \ +{ \ + return VectorValueType::Name; \ +} \ + +#include "DefinitionList.h" +#undef DefineVectorValueType + + +inline std::size_t GetValueTypeSize(VectorValueType p_valueType) +{ + switch (p_valueType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return sizeof(Type); \ + +#include "DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + + return 0; +} + +} // namespace SPTAG + +#endif // _SPTAG_CORE_COMMONDEFS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/BKTree.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/BKTree.h new file mode 100644 index 0000000000..56583be164 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/BKTree.h @@ -0,0 +1,490 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_BKTREE_H_ +#define _SPTAG_COMMON_BKTREE_H_ + +#include +#include +#include +#include + +#include "../VectorIndex.h" + +#include "CommonUtils.h" +#include "QueryResultSet.h" +#include "WorkSpace.h" + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. + +namespace SPTAG +{ + namespace COMMON + { + // node type for storing BKT + struct BKTNode + { + SizeType centerid; + SizeType childStart; + SizeType childEnd; + + BKTNode(SizeType cid = -1) : centerid(cid), childStart(-1), childEnd(-1) {} + }; + + template + struct KmeansArgs { + int _K; + DimensionType _D; + int _T; + T* centers; + SizeType* counts; + float* newCenters; + SizeType* newCounts; + int* label; + SizeType* clusterIdx; + float* clusterDist; + T* newTCenters; + + KmeansArgs(int k, DimensionType dim, SizeType datasize, int threadnum) : _K(k), _D(dim), _T(threadnum) { + centers = new T[k * dim]; + counts = new SizeType[k]; + newCenters = new float[threadnum * k * dim]; + newCounts = new SizeType[threadnum * k]; + label = new int[datasize]; + clusterIdx = new SizeType[threadnum * k]; + clusterDist = new float[threadnum * k]; + newTCenters = new T[k * dim]; + } + + ~KmeansArgs() { + delete[] centers; + delete[] counts; + delete[] newCenters; + delete[] newCounts; + delete[] label; + delete[] clusterIdx; + delete[] clusterDist; + delete[] newTCenters; + } + + inline void ClearCounts() { + memset(newCounts, 0, sizeof(SizeType) * _T * _K); + } + + inline void ClearCenters() { + memset(newCenters, 0, sizeof(float) * _T * _K * _D); + } + + inline void ClearDists(float dist) { + for (int i = 0; i < _T * _K; i++) { + clusterIdx[i] = -1; + clusterDist[i] = dist; + } + } + + void Shuffle(std::vector& indices, SizeType first, SizeType last) { + SizeType* pos = new SizeType[_K]; + pos[0] = first; + for (int k = 1; k < _K; k++) pos[k] = pos[k - 1] + newCounts[k - 1]; + + for (int k = 0; k < _K; k++) { + if (newCounts[k] == 0) continue; + SizeType i = pos[k]; + while (newCounts[k] > 0) { + SizeType swapid = pos[label[i]] + newCounts[label[i]] - 1; + newCounts[label[i]]--; + std::swap(indices[i], indices[swapid]); + std::swap(label[i], label[swapid]); + } + while (indices[i] != clusterIdx[k]) i++; + std::swap(indices[i], indices[pos[k] + counts[k] - 1]); + } + delete[] pos; + } + }; + + class BKTree + { + public: + BKTree(): m_iTreeNumber(1), m_iBKTKmeansK(32), m_iBKTLeafSize(8), m_iSamples(1000) {} + + BKTree(BKTree& other): m_iTreeNumber(other.m_iTreeNumber), + m_iBKTKmeansK(other.m_iBKTKmeansK), + m_iBKTLeafSize(other.m_iBKTLeafSize), + m_iSamples(other.m_iSamples) {} + ~BKTree() {} + + inline const BKTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; } + inline BKTNode& operator[](SizeType index) { return m_pTreeRoots[index]; } + + inline SizeType size() const { return (SizeType)m_pTreeRoots.size(); } + + inline const std::unordered_map& GetSampleMap() const { return m_pSampleCenterMap; } + + template + void BuildTrees(VectorIndex* index, std::vector* indices = nullptr) + { + struct BKTStackItem { + SizeType index, first, last; + BKTStackItem(SizeType index_, SizeType first_, SizeType last_) : index(index_), first(first_), last(last_) {} + }; + std::stack ss; + + std::vector localindices; + if (indices == nullptr) { + localindices.resize(index->GetNumSamples()); + for (SizeType i = 0; i < index->GetNumSamples(); i++) localindices[i] = i; + } + else { + localindices.assign(indices->begin(), indices->end()); + } + KmeansArgs args(m_iBKTKmeansK, index->GetFeatureDim(), (SizeType)localindices.size(), omp_get_num_threads()); + + m_pSampleCenterMap.clear(); + for (char i = 0; i < m_iTreeNumber; i++) + { + std::random_shuffle(localindices.begin(), localindices.end()); + + m_pTreeStart.push_back((SizeType)m_pTreeRoots.size()); + m_pTreeRoots.push_back(BKTNode((SizeType)localindices.size())); + std::cout << "Start to build BKTree " << i + 1 << std::endl; + + ss.push(BKTStackItem(m_pTreeStart[i], 0, (SizeType)localindices.size())); + while (!ss.empty()) { + BKTStackItem item = ss.top(); ss.pop(); + SizeType newBKTid = (SizeType)m_pTreeRoots.size(); + m_pTreeRoots[item.index].childStart = newBKTid; + if (item.last - item.first <= m_iBKTLeafSize) { + for (SizeType j = item.first; j < item.last; j++) { + m_pTreeRoots.push_back(BKTNode(localindices[j])); + } + } + else { // clustering the data into BKTKmeansK clusters + int numClusters = KmeansClustering(index, localindices, item.first, item.last, args); + if (numClusters <= 1) { + SizeType end = min(item.last + 1, (SizeType)localindices.size()); + std::sort(localindices.begin() + item.first, localindices.begin() + end); + m_pTreeRoots[item.index].centerid = localindices[item.first]; + m_pTreeRoots[item.index].childStart = -m_pTreeRoots[item.index].childStart; + for (SizeType j = item.first + 1; j < end; j++) { + m_pTreeRoots.push_back(BKTNode(localindices[j])); + m_pSampleCenterMap[localindices[j]] = m_pTreeRoots[item.index].centerid; + } + m_pSampleCenterMap[-1 - m_pTreeRoots[item.index].centerid] = item.index; + } + else { + for (int k = 0; k < m_iBKTKmeansK; k++) { + if (args.counts[k] == 0) continue; + m_pTreeRoots.push_back(BKTNode(localindices[item.first + args.counts[k] - 1])); + if (args.counts[k] > 1) ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1)); + item.first += args.counts[k]; + } + } + } + m_pTreeRoots[item.index].childEnd = (SizeType)m_pTreeRoots.size(); + } + std::cout << i + 1 << " BKTree built, " << m_pTreeRoots.size() - m_pTreeStart[i] << " " << localindices.size() << std::endl; + } + } + + inline std::uint64_t BufferSize() const + { + return sizeof(int) + sizeof(SizeType) * m_iTreeNumber + + sizeof(SizeType) + sizeof(BKTNode) * m_pTreeRoots.size(); + } + + bool SaveTrees(std::ostream& p_outstream) const + { + p_outstream.write((char*)&m_iTreeNumber, sizeof(int)); + p_outstream.write((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber); + SizeType treeNodeSize = (SizeType)m_pTreeRoots.size(); + p_outstream.write((char*)&treeNodeSize, sizeof(SizeType)); + p_outstream.write((char*)m_pTreeRoots.data(), sizeof(BKTNode) * treeNodeSize); + std::cout << "Save BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + bool SaveTrees(std::string sTreeFileName) const + { + std::cout << "Save BKT to " << sTreeFileName << std::endl; + std::ofstream output(sTreeFileName, std::ios::binary); + if (!output.is_open()) return false; + SaveTrees(output); + output.close(); + return true; + } + + bool LoadTrees(char* pBKTMemFile) + { + m_iTreeNumber = *((int*)pBKTMemFile); + pBKTMemFile += sizeof(int); + m_pTreeStart.resize(m_iTreeNumber); + memcpy(m_pTreeStart.data(), pBKTMemFile, sizeof(SizeType) * m_iTreeNumber); + pBKTMemFile += sizeof(SizeType)*m_iTreeNumber; + + SizeType treeNodeSize = *((SizeType*)pBKTMemFile); + pBKTMemFile += sizeof(SizeType); + m_pTreeRoots.resize(treeNodeSize); + memcpy(m_pTreeRoots.data(), pBKTMemFile, sizeof(BKTNode) * treeNodeSize); + std::cout << "Load BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + bool LoadTrees(std::string sTreeFileName) + { + std::cout << "Load BKT From " << sTreeFileName << std::endl; + std::ifstream input(sTreeFileName, std::ios::binary); + if (!input.is_open()) return false; + + input.read((char*)&m_iTreeNumber, sizeof(int)); + m_pTreeStart.resize(m_iTreeNumber); + input.read((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber); + + SizeType treeNodeSize; + input.read((char*)&treeNodeSize, sizeof(SizeType)); + m_pTreeRoots.resize(treeNodeSize); + input.read((char*)m_pTreeRoots.data(), sizeof(BKTNode) * treeNodeSize); + input.close(); + std::cout << "Load BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + template + void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const + { + for (char i = 0; i < m_iTreeNumber; i++) { + const BKTNode& node = m_pTreeRoots[m_pTreeStart[i]]; + if (node.childStart < 0) { + p_space.m_SPTQueue.insert(COMMON::HeapCell(m_pTreeStart[i], p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(node.centerid)))); + } + else { + for (SizeType begin = node.childStart; begin < node.childEnd; begin++) { + SizeType index = m_pTreeRoots[begin].centerid; + p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index)))); + } + } + } + } + + template + void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet &p_query, + COMMON::WorkSpace &p_space, const int p_limits) const + { + do + { + COMMON::HeapCell bcell = p_space.m_SPTQueue.pop(); + const BKTNode& tnode = m_pTreeRoots[bcell.node]; + if (tnode.childStart < 0) { + if (!p_space.CheckAndSet(tnode.centerid)) { + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); + } + if (p_space.m_iNumberOfCheckedLeaves >= p_limits) break; + } + else { + if (!p_space.CheckAndSet(tnode.centerid)) { + p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); + } + for (SizeType begin = tnode.childStart; begin < tnode.childEnd; begin++) { + SizeType index = m_pTreeRoots[begin].centerid; + p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index)))); + } + } + } while (!p_space.m_SPTQueue.empty()); + } + + private: + + template + float KmeansAssign(VectorIndex* p_index, + std::vector& indices, + const SizeType first, const SizeType last, KmeansArgs& args, const bool updateCenters) const { + float currDist = 0; + int threads = omp_get_num_threads(); + float lambda = (updateCenters) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() / (100.0f * (last - first)) : 0.0f; + SizeType subsize = (last - first - 1) / threads + 1; + +#pragma omp parallel for + for (int tid = 0; tid < threads; tid++) + { + SizeType istart = first + tid * subsize; + SizeType iend = min(first + (tid + 1) * subsize, last); + SizeType *inewCounts = args.newCounts + tid * m_iBKTKmeansK; + float *inewCenters = args.newCenters + tid * m_iBKTKmeansK * p_index->GetFeatureDim(); + SizeType * iclusterIdx = args.clusterIdx + tid * m_iBKTKmeansK; + float * iclusterDist = args.clusterDist + tid * m_iBKTKmeansK; + float idist = 0; + for (SizeType i = istart; i < iend; i++) { + int clusterid = 0; + float smallestDist = MaxDist; + for (int k = 0; k < m_iBKTKmeansK; k++) { + float dist = p_index->ComputeDistance(p_index->GetSample(indices[i]), (const void*)(args.centers + k*p_index->GetFeatureDim())) + lambda*args.counts[k]; + if (dist > -MaxDist && dist < smallestDist) { + clusterid = k; smallestDist = dist; + } + } + args.label[i] = clusterid; + inewCounts[clusterid]++; + idist += smallestDist; + if (updateCenters) { + const T* v = (const T*)p_index->GetSample(indices[i]); + float* center = inewCenters + clusterid*p_index->GetFeatureDim(); + for (DimensionType j = 0; j < p_index->GetFeatureDim(); j++) center[j] += v[j]; + if (smallestDist > iclusterDist[clusterid]) { + iclusterDist[clusterid] = smallestDist; + iclusterIdx[clusterid] = indices[i]; + } + } + else { + if (smallestDist <= iclusterDist[clusterid]) { + iclusterDist[clusterid] = smallestDist; + iclusterIdx[clusterid] = indices[i]; + } + } + } + COMMON::Utils::atomic_float_add(&currDist, idist); + } + + for (int i = 1; i < threads; i++) { + for (int k = 0; k < m_iBKTKmeansK; k++) + args.newCounts[k] += args.newCounts[i*m_iBKTKmeansK + k]; + } + + if (updateCenters) { + for (int i = 1; i < threads; i++) { + float* currCenter = args.newCenters + i*m_iBKTKmeansK*p_index->GetFeatureDim(); + for (size_t j = 0; j < ((size_t)m_iBKTKmeansK) * p_index->GetFeatureDim(); j++) args.newCenters[j] += currCenter[j]; + + for (int k = 0; k < m_iBKTKmeansK; k++) { + if (args.clusterIdx[i*m_iBKTKmeansK + k] != -1 && args.clusterDist[i*m_iBKTKmeansK + k] > args.clusterDist[k]) { + args.clusterDist[k] = args.clusterDist[i*m_iBKTKmeansK + k]; + args.clusterIdx[k] = args.clusterIdx[i*m_iBKTKmeansK + k]; + } + } + } + + int maxcluster = -1; + SizeType maxCount = 0; + for (int k = 0; k < m_iBKTKmeansK; k++) { + if (args.newCounts[k] > maxCount && DistanceUtils::ComputeL2Distance((T*)p_index->GetSample(args.clusterIdx[k]), args.centers + k * p_index->GetFeatureDim(), p_index->GetFeatureDim()) > 1e-6) + { + maxcluster = k; + maxCount = args.newCounts[k]; + } + } + + if (maxcluster != -1 && (args.clusterIdx[maxcluster] < 0 || args.clusterIdx[maxcluster] >= p_index->GetNumSamples())) + std::cout << "first:" << first << " last:" << last << " maxcluster:" << maxcluster << "(" << args.newCounts[maxcluster] << ") Error dist:" << args.clusterDist[maxcluster] << std::endl; + + for (int k = 0; k < m_iBKTKmeansK; k++) { + T* TCenter = args.newTCenters + k * p_index->GetFeatureDim(); + if (args.newCounts[k] == 0) { + if (maxcluster != -1) { + //int nextid = Utils::rand_int(last, first); + //while (args.label[nextid] != maxcluster) nextid = Utils::rand_int(last, first); + SizeType nextid = args.clusterIdx[maxcluster]; + std::memcpy(TCenter, p_index->GetSample(nextid), sizeof(T)*p_index->GetFeatureDim()); + } + else { + std::memcpy(TCenter, args.centers + k * p_index->GetFeatureDim(), sizeof(T)*p_index->GetFeatureDim()); + } + } + else { + float* currCenters = args.newCenters + k * p_index->GetFeatureDim(); + for (DimensionType j = 0; j < p_index->GetFeatureDim(); j++) currCenters[j] /= args.newCounts[k]; + + if (p_index->GetDistCalcMethod() == DistCalcMethod::Cosine) { + COMMON::Utils::Normalize(currCenters, p_index->GetFeatureDim(), COMMON::Utils::GetBase()); + } + for (DimensionType j = 0; j < p_index->GetFeatureDim(); j++) TCenter[j] = (T)(currCenters[j]); + } + } + } + else { + for (int i = 1; i < threads; i++) { + for (int k = 0; k < m_iBKTKmeansK; k++) { + if (args.clusterIdx[i*m_iBKTKmeansK + k] != -1 && args.clusterDist[i*m_iBKTKmeansK + k] <= args.clusterDist[k]) { + args.clusterDist[k] = args.clusterDist[i*m_iBKTKmeansK + k]; + args.clusterIdx[k] = args.clusterIdx[i*m_iBKTKmeansK + k]; + } + } + } + } + return currDist; + } + + template + int KmeansClustering(VectorIndex* p_index, + std::vector& indices, const SizeType first, const SizeType last, KmeansArgs& args) const { + int iterLimit = 100; + + SizeType batchEnd = min(first + m_iSamples, last); + float currDiff, currDist, minClusterDist = MaxDist; + for (int numKmeans = 0; numKmeans < 3; numKmeans++) { + for (int k = 0; k < m_iBKTKmeansK; k++) { + SizeType randid = COMMON::Utils::rand(last, first); + std::memcpy(args.centers + k*p_index->GetFeatureDim(), p_index->GetSample(indices[randid]), sizeof(T)*p_index->GetFeatureDim()); + } + args.ClearCounts(); + currDist = KmeansAssign(p_index, indices, first, batchEnd, args, false); + if (currDist < minClusterDist) { + minClusterDist = currDist; + memcpy(args.newTCenters, args.centers, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim()); + memcpy(args.counts, args.newCounts, sizeof(SizeType) * m_iBKTKmeansK); + } + } + + minClusterDist = MaxDist; + int noImprovement = 0; + for (int iter = 0; iter < iterLimit; iter++) { + std::memcpy(args.centers, args.newTCenters, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim()); + std::random_shuffle(indices.begin() + first, indices.begin() + last); + + args.ClearCenters(); + args.ClearCounts(); + args.ClearDists(-MaxDist); + currDist = KmeansAssign(p_index, indices, first, batchEnd, args, true); + memcpy(args.counts, args.newCounts, sizeof(SizeType) * m_iBKTKmeansK); + + currDiff = 0; + for (int k = 0; k < m_iBKTKmeansK; k++) { + currDiff += p_index->ComputeDistance((const void*)(args.centers + k*p_index->GetFeatureDim()), (const void*)(args.newTCenters + k*p_index->GetFeatureDim())); + } + + if (currDist < minClusterDist) { + noImprovement = 0; + minClusterDist = currDist; + } + else { + noImprovement++; + } + if (currDiff < 1e-3 || noImprovement >= 5) break; + } + + args.ClearCounts(); + args.ClearDists(MaxDist); + currDist = KmeansAssign(p_index, indices, first, last, args, false); + memcpy(args.counts, args.newCounts, sizeof(SizeType) * m_iBKTKmeansK); + + int numClusters = 0; + for (int i = 0; i < m_iBKTKmeansK; i++) if (args.counts[i] > 0) numClusters++; + + if (numClusters <= 1) { + //if (last - first > 1) std::cout << "large cluster:" << last - first << " dist:" << currDist << std::endl; + return numClusters; + } + args.Shuffle(indices, first, last); + return numClusters; + } + + private: + std::vector m_pTreeStart; + std::vector m_pTreeRoots; + std::unordered_map m_pSampleCenterMap; + + public: + int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples; + }; + } +} +#endif diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/CommonUtils.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/CommonUtils.h new file mode 100644 index 0000000000..96a8d0b4fa --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/CommonUtils.h @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_COMMONUTILS_H_ +#define _SPTAG_COMMON_COMMONUTILS_H_ + +#include "../Common.h" + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#define PREFETCH + +#ifndef _MSC_VER +#include +#include +#include +#include + +#define InterlockedCompareExchange(a,b,c) __sync_val_compare_and_swap(a, c, b) +#define Sleep(a) usleep(a * 1000) +#define strtok_s(a, b, c) strtok_r(a, b, c) +#endif + +namespace SPTAG +{ + namespace COMMON + { + class Utils { + public: + static SizeType rand(SizeType high = MaxSize, SizeType low = 0) // Generates a random int value. + { + return low + (SizeType)(float(high - low)*(std::rand() / (RAND_MAX + 1.0))); + } + + static inline float atomic_float_add(volatile float* ptr, const float operand) + { + union { + volatile long iOld; + float fOld; + }; + union { + long iNew; + float fNew; + }; + + while (true) { + iOld = *(volatile long *)ptr; + fNew = fOld + operand; + if (InterlockedCompareExchange((long *)ptr, iNew, iOld) == iOld) { + return fNew; + } + } + } + + static double GetVector(char* cstr, const char* sep, std::vector& arr, DimensionType& NumDim) { + char* current; + char* context = nullptr; + + DimensionType i = 0; + double sum = 0; + arr.clear(); + current = strtok_s(cstr, sep, &context); + while (current != nullptr && (i < NumDim || NumDim < 0)) { + try { + float val = (float)atof(current); + arr.push_back(val); + } + catch (std::exception e) { + std::cout << "Exception:" << e.what() << std::endl; + return -2; + } + + sum += arr[i] * arr[i]; + current = strtok_s(nullptr, sep, &context); + i++; + } + + if (NumDim < 0) NumDim = i; + if (i < NumDim) return -2; + return std::sqrt(sum); + } + + template + static void Normalize(T* arr, DimensionType col, int base) { + double vecLen = 0; + for (DimensionType j = 0; j < col; j++) { + double val = arr[j]; + vecLen += val * val; + } + vecLen = std::sqrt(vecLen); + if (vecLen < 1e-6) { + T val = (T)(1.0 / std::sqrt((double)col) * base); + for (DimensionType j = 0; j < col; j++) arr[j] = val; + } + else { + for (DimensionType j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base); + } + } + + static size_t ProcessLine(std::string& currentLine, std::vector& arr, DimensionType& D, int base, DistCalcMethod distCalcMethod) { + size_t index; + double vecLen; + if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast(currentLine.c_str() + index + 1), "|", arr, D)) < -1) { + std::cout << "Parse vector error: " + currentLine << std::endl; + //throw MyException("Error in parsing data " + currentLine); + return -1; + } + if (distCalcMethod == DistCalcMethod::Cosine) { + Normalize(arr.data(), D, base); + } + return index; + } + + template + static void PrepareQuerys(std::ifstream& inStream, std::vector& qString, std::vector>& Query, SizeType& NumQuery, DimensionType& NumDim, DistCalcMethod distCalcMethod, int base) { + std::string currentLine; + std::vector arr; + SizeType i = 0; + size_t index; + while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) { + std::getline(inStream, currentLine); + if (currentLine.length() <= 1 || (index = ProcessLine(currentLine, arr, NumDim, base, distCalcMethod)) < 0) { + continue; + } + qString.push_back(currentLine.substr(0, index)); + if ((SizeType)Query.size() < i + 1) Query.push_back(std::vector(NumDim, 0)); + + for (DimensionType j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j]; + i++; + } + NumQuery = i; + std::cout << "Load data: (" << NumQuery << ", " << NumDim << ")" << std::endl; + } + + template + static inline int GetBase() { + if (GetEnumValueType() != VectorValueType::Float) { + return (int)(std::numeric_limits::max)(); + } + return 1; + } + + static inline void AddNeighbor(SizeType idx, float dist, SizeType *neighbors, float *dists, DimensionType size) + { + size--; + if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size])) + { + DimensionType nb; + for (nb = 0; nb <= size && neighbors[nb] != idx; nb++); + + if (nb > size) + { + nb = size; + while (nb > 0 && (dist < dists[nb - 1] || (dist == dists[nb - 1] && idx < neighbors[nb - 1]))) + { + dists[nb] = dists[nb - 1]; + neighbors[nb] = neighbors[nb - 1]; + nb--; + } + dists[nb] = dist; + neighbors[nb] = idx; + } + } + } + }; + } +} + +#endif // _SPTAG_COMMON_COMMONUTILS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DataUtils.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DataUtils.h new file mode 100644 index 0000000000..5d751c4c98 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DataUtils.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_DATAUTILS_H_ +#define _SPTAG_COMMON_DATAUTILS_H_ + +#include +#include +#include "CommonUtils.h" +#include "../../Helper/CommonHelper.h" + +namespace SPTAG +{ + namespace COMMON + { + const int bufsize = 1 << 30; + + class DataUtils { + public: + static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1, + const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) { + std::ifstream inputStream1, inputStream2; + std::ofstream outputStream; + std::unique_ptr bufferHolder(new char[bufsize]); + char * buf = bufferHolder.get(); + SizeType R1, R2; + DimensionType C1, C2; + +#define MergeVector(inputStream, vectorFile, R, C) \ + inputStream.open(vectorFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \ + return false; \ + } \ + inputStream.read((char *)&(R), sizeof(SizeType)); \ + inputStream.read((char *)&(C), sizeof(DimensionType)); \ + + MergeVector(inputStream1, p_vectorfile1, R1, C1) + MergeVector(inputStream2, p_vectorfile2, R2, C2) +#undef MergeVector + if (C1 != C2) { + inputStream1.close(); inputStream2.close(); + std::cout << "Vector dimensions are not the same!" << std::endl; + return false; + } + R1 += R2; + outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary); + outputStream.write((char *)&R1, sizeof(SizeType)); + outputStream.write((char *)&C1, sizeof(DimensionType)); + while (!inputStream1.eof()) { + inputStream1.read(buf, bufsize); + outputStream.write(buf, inputStream1.gcount()); + } + while (!inputStream2.eof()) { + inputStream2.read(buf, bufsize); + outputStream.write(buf, inputStream2.gcount()); + } + inputStream1.close(); inputStream2.close(); + outputStream.close(); + + if (p_metafile1 != "" && p_metafile2 != "") { + outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary); +#define MergeMeta(inputStream, metaFile) \ + inputStream.open(metaFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \ + return false; \ + } \ + while (!inputStream.eof()) { \ + inputStream.read(buf, bufsize); \ + outputStream.write(buf, inputStream.gcount()); \ + } \ + inputStream.close(); \ + + MergeMeta(inputStream1, p_metafile1) + MergeMeta(inputStream2, p_metafile2) +#undef MergeMeta + outputStream.close(); + delete[] buf; + + std::uint64_t * offsets = reinterpret_cast(buf); + std::uint64_t lastoff = 0; + outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary); + outputStream.write((char *)&R1, sizeof(SizeType)); +#define MergeMetaIndex(inputStream, metaIndexFile) \ + inputStream.open(metaIndexFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \ + return false; \ + } \ + inputStream.read((char *)&R2, sizeof(SizeType)); \ + inputStream.read((char *)offsets, sizeof(std::uint64_t)*(R2 + 1)); \ + inputStream.close(); \ + for (SizeType j = 0; j < R2 + 1; j++) offsets[j] += lastoff; \ + outputStream.write((char *)offsets, sizeof(std::uint64_t)*R2); \ + lastoff = offsets[R2]; \ + + MergeMetaIndex(inputStream1, p_metaindexfile1) + MergeMetaIndex(inputStream2, p_metaindexfile2) +#undef MergeMetaIndex + outputStream.write((char *)&lastoff, sizeof(std::uint64_t)); + outputStream.close(); + + rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str()); + rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str()); + } + rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str()); + + std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl; + return true; + } + }; + } +} + +#endif // _SPTAG_COMMON_DATAUTILS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h new file mode 100644 index 0000000000..0208f6d983 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Dataset.h @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_DATASET_H_ +#define _SPTAG_COMMON_DATASET_H_ + +#include + +#if defined(_MSC_VER) || defined(__INTEL_COMPILER) +#include +#else +#include +#endif // defined(__GNUC__) + +#define ALIGN 32 + +#define aligned_malloc(a, b) _mm_malloc(a, b) +#define aligned_free(a) _mm_free(a) + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. + +namespace SPTAG +{ + namespace COMMON + { + // structure to save Data and Graph + template + class Dataset + { + private: + std::string name = "Data"; + SizeType rows = 0; + DimensionType cols = 1; + bool ownData = false; + T* data = nullptr; + SizeType incRows = 0; + std::vector incBlocks; + static const SizeType rowsInBlock = 1024 * 1024; + public: + Dataset() + { + incBlocks.reserve(MaxSize / rowsInBlock + 1); + } + Dataset(SizeType rows_, DimensionType cols_, T* data_ = nullptr, bool transferOnwership_ = true) + { + Initialize(rows_, cols_, data_, transferOnwership_); + incBlocks.reserve(MaxSize / rowsInBlock + 1); + } + ~Dataset() + { + if (ownData) aligned_free(data); + for (T* ptr : incBlocks) aligned_free(ptr); + incBlocks.clear(); + } + void Initialize(SizeType rows_, DimensionType cols_, T* data_ = nullptr, bool transferOnwership_ = true) + { + rows = rows_; + cols = cols_; + data = data_; + if (data_ == nullptr || !transferOnwership_) + { + ownData = true; + data = (T*)aligned_malloc(((size_t)rows) * cols * sizeof(T), ALIGN); + if (data_ != nullptr) memcpy(data, data_, ((size_t)rows) * cols * sizeof(T)); + else std::memset(data, -1, ((size_t)rows) * cols * sizeof(T)); + } + } + void SetName(const std::string name_) { name = name_; } + void SetR(SizeType R_) + { + if (R_ >= rows) + incRows = R_ - rows; + else + { + rows = R_; + incRows = 0; + } + } + inline SizeType R() const { return rows + incRows; } + inline DimensionType C() const { return cols; } + inline std::uint64_t BufferSize() const { return sizeof(SizeType) + sizeof(DimensionType) + sizeof(T) * R() * C(); } + + inline const T* At(SizeType index) const + { + if (index >= rows) { + SizeType incIndex = index - rows; + return incBlocks[incIndex / rowsInBlock] + ((size_t)(incIndex % rowsInBlock)) * cols; + } + return data + ((size_t)index) * cols; + } + + T* operator[](SizeType index) + { + return (T*)At(index); + } + + const T* operator[](SizeType index) const + { + return At(index); + } + + ErrorCode AddBatch(const T* pData, SizeType num) + { + if (R() > MaxSize - num) return ErrorCode::MemoryOverFlow; + + SizeType written = 0; + while (written < num) { + SizeType curBlockIdx = (incRows + written) / rowsInBlock; + if (curBlockIdx >= (SizeType)incBlocks.size()) { + T* newBlock = (T*)aligned_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN); + if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; + incBlocks.push_back(newBlock); + } + SizeType curBlockPos = (incRows + written) % rowsInBlock; + SizeType toWrite = min(rowsInBlock - curBlockPos, num - written); + std::memcpy(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, pData + ((size_t)written) * cols, ((size_t)toWrite) * cols * sizeof(T)); + written += toWrite; + } + incRows += written; + return ErrorCode::Success; + } + + ErrorCode AddBatch(SizeType num) + { + if (R() > MaxSize - num) return ErrorCode::MemoryOverFlow; + + SizeType written = 0; + while (written < num) { + SizeType curBlockIdx = (incRows + written) / rowsInBlock; + if (curBlockIdx >= (SizeType)incBlocks.size()) { + T* newBlock = (T*)aligned_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN); + if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; + incBlocks.push_back(newBlock); + } + SizeType curBlockPos = (incRows + written) % rowsInBlock; + SizeType toWrite = min(rowsInBlock - curBlockPos, num - written); + std::memset(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, -1, ((size_t)toWrite) * cols * sizeof(T)); + written += toWrite; + } + incRows += written; + return ErrorCode::Success; + } + + bool Save(std::ostream& p_outstream) const + { + SizeType CR = R(); + p_outstream.write((char*)&CR, sizeof(SizeType)); + p_outstream.write((char*)&cols, sizeof(DimensionType)); + p_outstream.write((char*)data, sizeof(T) * cols * rows); + + SizeType blocks = incRows / rowsInBlock; + for (int i = 0; i < blocks; i++) + p_outstream.write((char*)incBlocks[i], sizeof(T) * cols * rowsInBlock); + + SizeType remain = incRows % rowsInBlock; + if (remain > 0) p_outstream.write((char*)incBlocks[blocks], sizeof(T) * cols * remain); + std::cout << "Save " << name << " (" << CR << ", " << cols << ") Finish!" << std::endl; + return true; + } + + bool Save(std::string sDataPointsFileName) const + { + std::cout << "Save " << name << " To " << sDataPointsFileName << std::endl; + std::ofstream output(sDataPointsFileName, std::ios::binary); + if (!output.is_open()) return false; + Save(output); + output.close(); + return true; + } + + bool Load(std::string sDataPointsFileName) + { + std::cout << "Load " << name << " From " << sDataPointsFileName << std::endl; + std::ifstream input(sDataPointsFileName, std::ios::binary); + if (!input.is_open()) return false; + + input.read((char*)&rows, sizeof(SizeType)); + input.read((char*)&cols, sizeof(DimensionType)); + + Initialize(rows, cols); + input.read((char*)data, sizeof(T) * cols * rows); + input.close(); + std::cout << "Load " << name << " (" << rows << ", " << cols << ") Finish!" << std::endl; + return true; + } + + // Functions for loading models from memory mapped files + bool Load(char* pDataPointsMemFile) + { + SizeType R; + DimensionType C; + R = *((SizeType*)pDataPointsMemFile); + pDataPointsMemFile += sizeof(SizeType); + + C = *((DimensionType*)pDataPointsMemFile); + pDataPointsMemFile += sizeof(DimensionType); + + Initialize(R, C, (T*)pDataPointsMemFile, false); + std::cout << "Load " << name << " (" << R << ", " << C << ") Finish!" << std::endl; + return true; + } + + bool Refine(const std::vector& indices, std::ostream& output) + { + SizeType R = (SizeType)(indices.size()); + output.write((char*)&R, sizeof(SizeType)); + output.write((char*)&cols, sizeof(DimensionType)); + + for (SizeType i = 0; i < R; i++) { + output.write((char*)At(indices[i]), sizeof(T) * cols); + } + std::cout << "Save Refine " << name << " (" << R << ", " << cols << ") Finish!" << std::endl; + return true; + } + + bool Refine(const std::vector& indices, std::string sDataPointsFileName) + { + std::cout << "Save Refine " << name << " To " << sDataPointsFileName << std::endl; + std::ofstream output(sDataPointsFileName, std::ios::binary); + if (!output.is_open()) return false; + Refine(indices, output); + output.close(); + return true; + } + }; + } +} + +#endif // _SPTAG_COMMON_DATASET_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DistanceUtils.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DistanceUtils.h new file mode 100644 index 0000000000..8e1d349245 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/DistanceUtils.h @@ -0,0 +1,610 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_DISTANCEUTILS_H_ +#define _SPTAG_COMMON_DISTANCEUTILS_H_ + +#include +#include + +#include "CommonUtils.h" + +#define SSE + +#ifndef _MSC_VER +#define DIFF128 diff128 +#define DIFF256 diff256 +#else +#define DIFF128 diff128.m128_f32 +#define DIFF256 diff256.m256_f32 +#endif + +namespace SPTAG +{ + namespace COMMON + { + class DistanceUtils + { + public: + static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y) + { + __m128i zero = _mm_setzero_si128(); + + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i sign_y = _mm_cmplt_epi8(Y, zero); + + __m128i xlo = _mm_unpacklo_epi8(X, sign_x); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); + __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); + __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); + + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); + } + + static inline __m128 _mm_sqdf_epi8(__m128i X, __m128i Y) + { + __m128i zero = _mm_setzero_si128(); + + __m128i sign_x = _mm_cmplt_epi8(X, zero); + __m128i sign_y = _mm_cmplt_epi8(Y, zero); + + __m128i xlo = _mm_unpacklo_epi8(X, sign_x); + __m128i xhi = _mm_unpackhi_epi8(X, sign_x); + __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); + __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); + + __m128i dlo = _mm_sub_epi16(xlo, ylo); + __m128i dhi = _mm_sub_epi16(xhi, yhi); + + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi))); + } + + static inline __m128 _mm_mul_epu8(__m128i X, __m128i Y) + { + __m128i zero = _mm_setzero_si128(); + + __m128i xlo = _mm_unpacklo_epi8(X, zero); + __m128i xhi = _mm_unpackhi_epi8(X, zero); + __m128i ylo = _mm_unpacklo_epi8(Y, zero); + __m128i yhi = _mm_unpackhi_epi8(Y, zero); + + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); + } + + static inline __m128 _mm_sqdf_epu8(__m128i X, __m128i Y) + { + __m128i zero = _mm_setzero_si128(); + + __m128i xlo = _mm_unpacklo_epi8(X, zero); + __m128i xhi = _mm_unpackhi_epi8(X, zero); + __m128i ylo = _mm_unpacklo_epi8(Y, zero); + __m128i yhi = _mm_unpackhi_epi8(Y, zero); + + __m128i dlo = _mm_sub_epi16(xlo, ylo); + __m128i dhi = _mm_sub_epi16(xhi, yhi); + + return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi))); + } + + static inline __m128 _mm_mul_epi16(__m128i X, __m128i Y) + { + return _mm_cvtepi32_ps(_mm_madd_epi16(X, Y)); + } + + static inline __m128 _mm_sqdf_epi16(__m128i X, __m128i Y) + { + __m128i zero = _mm_setzero_si128(); + + __m128i sign_x = _mm_cmplt_epi16(X, zero); + __m128i sign_y = _mm_cmplt_epi16(Y, zero); + + __m128i xlo = _mm_unpacklo_epi16(X, sign_x); + __m128i xhi = _mm_unpackhi_epi16(X, sign_x); + __m128i ylo = _mm_unpacklo_epi16(Y, sign_y); + __m128i yhi = _mm_unpackhi_epi16(Y, sign_y); + + __m128 dlo = _mm_cvtepi32_ps(_mm_sub_epi32(xlo, ylo)); + __m128 dhi = _mm_cvtepi32_ps(_mm_sub_epi32(xhi, yhi)); + + return _mm_add_ps(_mm_mul_ps(dlo, dlo), _mm_mul_ps(dhi, dhi)); + } + static inline __m128 _mm_sqdf_ps(__m128 X, __m128 Y) + { + __m128 d = _mm_sub_ps(X, Y); + return _mm_mul_ps(d, d); + } +#if defined(AVX) + static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y) + { + __m256i zero = _mm256_setzero_si256(); + + __m256i sign_x = _mm256_cmpgt_epi8(zero, X); + __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); + + __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); + __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); + __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); + __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); + + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi))); + } + static inline __m256 _mm256_sqdf_epi8(__m256i X, __m256i Y) + { + __m256i zero = _mm256_setzero_si256(); + + __m256i sign_x = _mm256_cmpgt_epi8(zero, X); + __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); + + __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); + __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); + __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); + __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); + + __m256i dlo = _mm256_sub_epi16(xlo, ylo); + __m256i dhi = _mm256_sub_epi16(xhi, yhi); + + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi))); + } + static inline __m256 _mm256_mul_epu8(__m256i X, __m256i Y) + { + __m256i zero = _mm256_setzero_si256(); + + __m256i xlo = _mm256_unpacklo_epi8(X, zero); + __m256i xhi = _mm256_unpackhi_epi8(X, zero); + __m256i ylo = _mm256_unpacklo_epi8(Y, zero); + __m256i yhi = _mm256_unpackhi_epi8(Y, zero); + + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi))); + } + static inline __m256 _mm256_sqdf_epu8(__m256i X, __m256i Y) + { + __m256i zero = _mm256_setzero_si256(); + + __m256i xlo = _mm256_unpacklo_epi8(X, zero); + __m256i xhi = _mm256_unpackhi_epi8(X, zero); + __m256i ylo = _mm256_unpacklo_epi8(Y, zero); + __m256i yhi = _mm256_unpackhi_epi8(Y, zero); + + __m256i dlo = _mm256_sub_epi16(xlo, ylo); + __m256i dhi = _mm256_sub_epi16(xhi, yhi); + + return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi))); + } + static inline __m256 _mm256_mul_epi16(__m256i X, __m256i Y) + { + return _mm256_cvtepi32_ps(_mm256_madd_epi16(X, Y)); + } + static inline __m256 _mm256_sqdf_epi16(__m256i X, __m256i Y) + { + __m256i zero = _mm256_setzero_si256(); + + __m256i sign_x = _mm256_cmpgt_epi16(zero, X); + __m256i sign_y = _mm256_cmpgt_epi16(zero, Y); + + __m256i xlo = _mm256_unpacklo_epi16(X, sign_x); + __m256i xhi = _mm256_unpackhi_epi16(X, sign_x); + __m256i ylo = _mm256_unpacklo_epi16(Y, sign_y); + __m256i yhi = _mm256_unpackhi_epi16(Y, sign_y); + + __m256 dlo = _mm256_cvtepi32_ps(_mm256_sub_epi32(xlo, ylo)); + __m256 dhi = _mm256_cvtepi32_ps(_mm256_sub_epi32(xhi, yhi)); + + return _mm256_add_ps(_mm256_mul_ps(dlo, dlo), _mm256_mul_ps(dhi, dhi)); + } + static inline __m256 _mm256_sqdf_ps(__m256 X, __m256 Y) + { + __m256 d = _mm256_sub_ps(X, Y); + return _mm256_mul_ps(d, d); + } +#endif +/* + template + static float ComputeL2Distance(const T *pX, const T *pY, DimensionType length) + { + float diff = 0; + const T* pEnd1 = pX + length; + while (pX < pEnd1) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + return diff; + } +*/ +#define REPEAT(type, ctype, delta, load, exec, acc, result) \ + { \ + type c1 = load((ctype *)(pX)); \ + type c2 = load((ctype *)(pY)); \ + pX += delta; pY += delta; \ + result = acc(result, exec(c1, c2)); \ + } \ + + static float ComputeL2Distance(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) + { + const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + } + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + while (pX < pEnd1) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + return diff; + } + + static float ComputeL2Distance(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) + { + const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + } + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + while (pX < pEnd1) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + return diff; + } + + static float ComputeL2Distance(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) + { + const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + } + while (pX < pEnd8) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd16) { + REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd8) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + + while (pX < pEnd1) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + return diff; + } + + static float ComputeL2Distance(const float *pX, const float *pY, DimensionType length) + { + const float* pEnd16 = pX + ((length >> 4) << 4); + const float* pEnd4 = pX + ((length >> 2) << 2); + const float* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd16) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + } + while (pX < pEnd4) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd16) + { + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd4) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; + while (pX < pEnd4) { + float c1 = (*pX++) - (*pY++); diff += c1 * c1; + c1 = (*pX++) - (*pY++); diff += c1 * c1; + c1 = (*pX++) - (*pY++); diff += c1 * c1; + c1 = (*pX++) - (*pY++); diff += c1 * c1; + } +#endif + while (pX < pEnd1) { + float c1 = (*pX++) - (*pY++); diff += c1 * c1; + } + return diff; + } +/* + template + static float ComputeCosineDistance(const T *pX, const T *pY, DimensionType length) { + float diff = 0; + const T* pEnd1 = pX + length; + while (pX < pEnd1) diff += (*pX++) * (*pY++); + return 1 - diff; + } +*/ + static float ComputeCosineDistance(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { + const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t* pEnd1 = pX + length; +#if defined(SSE) + + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + } + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + } + while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + return 16129 - diff; + } + + static float ComputeCosineDistance(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { + const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t* pEnd1 = pX + length; +#if defined(SSE) + + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + } + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd32) { + REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + } + while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + return 65025 - diff; + } + + static float ComputeCosineDistance(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { + const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd16) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + } + while (pX < pEnd8) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; + +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd16) { + REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd8) { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; +#endif + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; + } + + while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + return 1073676289 - diff; + } + + static float ComputeCosineDistance(const float *pX, const float *pY, DimensionType length) { + const float* pEnd16 = pX + ((length >> 4) << 4); + const float* pEnd4 = pX + ((length >> 2) << 2); + const float* pEnd1 = pX + length; +#if defined(SSE) + __m128 diff128 = _mm_setzero_ps(); + while (pX < pEnd16) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + } + while (pX < pEnd4) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; + +#elif defined(AVX) + __m256 diff256 = _mm256_setzero_ps(); + while (pX < pEnd16) + { + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) + } + __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + while (pX < pEnd4) + { + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + } + float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; +#else + float diff = 0; + while (pX < pEnd4) + { + float c1 = (*pX++) * (*pY++); diff += c1; + c1 = (*pX++) * (*pY++); diff += c1; + c1 = (*pX++) * (*pY++); diff += c1; + c1 = (*pX++) * (*pY++); diff += c1; + } +#endif + while (pX < pEnd1) diff += (*pX++) * (*pY++); + return 1 - diff; + } + + template + static inline float ComputeDistance(const T *p1, const T *p2, DimensionType length, SPTAG::DistCalcMethod distCalcMethod) + { + if (distCalcMethod == SPTAG::DistCalcMethod::L2) + return ComputeL2Distance(p1, p2, length); + + return ComputeCosineDistance(p1, p2, length); + } + + static inline float ConvertCosineSimilarityToDistance(float cs) + { + // Cosine similarity is in [-1, 1], the higher the value, the closer are the two vectors. + // However, the tree is built and searched based on "distance" between two vectors, that's >=0. The smaller the value, the closer are the two vectors. + // So we do a linear conversion from a cosine similarity to a distance value. + return 1 - cs; //[1, 3] + } + + static inline float ConvertDistanceBackToCosineSimilarity(float d) + { + return 1 - d; + } + }; + + + template + float (*DistanceCalcSelector(SPTAG::DistCalcMethod p_method)) (const T*, const T*, DimensionType) + { + switch (p_method) + { + case SPTAG::DistCalcMethod::Cosine: + return &(DistanceUtils::ComputeCosineDistance); + + case SPTAG::DistCalcMethod::L2: + return &(DistanceUtils::ComputeL2Distance); + + default: + break; + } + + return nullptr; + } + } +} + +#endif // _SPTAG_COMMON_DISTANCEUTILS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/FineGrainedLock.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/FineGrainedLock.h new file mode 100644 index 0000000000..0de7ed8b36 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/FineGrainedLock.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_ +#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_ + +#include +#include +#include + +namespace SPTAG +{ + namespace COMMON + { + class FineGrainedLock { + public: + FineGrainedLock() {} + ~FineGrainedLock() { + for (size_t i = 0; i < locks.size(); i++) + locks[i].reset(); + locks.clear(); + } + + void resize(SizeType n) { + SizeType current = (SizeType)locks.size(); + if (current <= n) { + locks.resize(n); + for (SizeType i = current; i < n; i++) + locks[i].reset(new std::mutex); + } + else { + for (SizeType i = n; i < current; i++) + locks[i].reset(); + locks.resize(n); + } + } + + std::mutex& operator[](SizeType idx) { + return *locks[idx]; + } + + const std::mutex& operator[](SizeType idx) const { + return *locks[idx]; + } + private: + std::vector> locks; + }; + } +} + +#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_ \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Heap.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Heap.h new file mode 100644 index 0000000000..261aa498a6 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/Heap.h @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_HEAP_H_ +#define _SPTAG_COMMON_HEAP_H_ + +namespace SPTAG +{ + namespace COMMON + { + + // priority queue + template + class Heap { + public: + Heap() : heap(nullptr), length(0), count(0) {} + + Heap(int size) { Resize(size); } + + void Resize(int size) + { + length = size; + heap.reset(new T[length + 1]); // heap uses 1-based indexing + count = 0; + lastlevel = int(pow(2.0, floor(log2(size)))); + } + ~Heap() {} + inline int size() { return count; } + inline bool empty() { return count == 0; } + inline void clear() { count = 0; } + inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; } + + // Insert a new element in the heap. + void insert(T value) + { + /* If heap is full, then return without adding this element. */ + int loc; + if (count == length) { + int maxi = lastlevel; + for (int i = lastlevel + 1; i <= length; i++) + if (heap[maxi] < heap[i]) maxi = i; + if (value > heap[maxi]) return; + loc = maxi; + } + else { + loc = ++(count); /* Remember 1-based indexing. */ + } + /* Keep moving parents down until a place is found for this node. */ + int par = (loc >> 1); /* Location of parent. */ + while (par > 0 && value < heap[par]) { + heap[loc] = heap[par]; /* Move parent down to loc. */ + loc = par; + par >>= 1; + } + /* Insert the element at the determined location. */ + heap[loc] = value; + } + // Returns the node of minimum value from the heap (top of the heap). + bool pop(T& value) + { + if (count == 0) return false; + /* Switch first node with last. */ + value = heap[1]; + std::swap(heap[1], heap[count]); + count--; + heapify(); /* Move new node 1 to right position. */ + return true; /* Return old last node. */ + } + T& pop() + { + if (count == 0) return heap[0]; + /* Switch first node with last. */ + std::swap(heap[1], heap[count]); + count--; + heapify(); /* Move new node 1 to right position. */ + return heap[count + 1]; /* Return old last node. */ + } + private: + // Storage array for the heap. + // Type T must be comparable. + std::unique_ptr heap; + int length; + int count; // Number of element in the heap + int lastlevel; + // Reorganizes the heap (a parent is smaller than its children) starting with a node. + + void heapify() + { + int parent = 1, next = 2; + while (next < count) { + if (heap[next] > heap[next + 1]) next++; + if (heap[next] < heap[parent]) { + std::swap(heap[parent], heap[next]); + parent = next; + next <<= 1; + } + else break; + } + if (next == count && heap[next] < heap[parent]) std::swap(heap[parent], heap[next]); + } + }; + } +} + +#endif // _SPTAG_COMMON_HEAP_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/KDTree.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/KDTree.h new file mode 100644 index 0000000000..e46c133940 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/KDTree.h @@ -0,0 +1,342 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_KDTREE_H_ +#define _SPTAG_COMMON_KDTREE_H_ + +#include +#include +#include + +#include "../VectorIndex.h" + +#include "CommonUtils.h" +#include "QueryResultSet.h" +#include "WorkSpace.h" + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. + +namespace SPTAG +{ + namespace COMMON + { + // node type for storing KDT + struct KDTNode + { + SizeType left; + SizeType right; + DimensionType split_dim; + float split_value; + }; + + class KDTree + { + public: + KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000) {} + + KDTree(KDTree& other) : m_iTreeNumber(other.m_iTreeNumber), + m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit), + m_iSamples(other.m_iSamples) {} + ~KDTree() {} + + inline const KDTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; } + inline KDTNode& operator[](SizeType index) { return m_pTreeRoots[index]; } + + inline SizeType size() const { return (SizeType)m_pTreeRoots.size(); } + + template + void BuildTrees(VectorIndex* p_index, std::vector* indices = nullptr) + { + std::vector localindices; + if (indices == nullptr) { + localindices.resize(p_index->GetNumSamples()); + for (SizeType i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i; + } + else { + localindices.assign(indices->begin(), indices->end()); + } + + m_pTreeRoots.resize(m_iTreeNumber * localindices.size()); + m_pTreeStart.resize(m_iTreeNumber, 0); +#pragma omp parallel for + for (int i = 0; i < m_iTreeNumber; i++) + { + Sleep(i * 100); std::srand(clock()); + + std::vector pindices(localindices.begin(), localindices.end()); + std::random_shuffle(pindices.begin(), pindices.end()); + + m_pTreeStart[i] = i * (SizeType)pindices.size(); + std::cout << "Start to build KDTree " << i + 1 << std::endl; + SizeType iTreeSize = m_pTreeStart[i]; + DivideTree(p_index, pindices, 0, (SizeType)pindices.size() - 1, m_pTreeStart[i], iTreeSize); + std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl; + } + } + + inline std::uint64_t BufferSize() const + { + return sizeof(int) + sizeof(SizeType) * m_iTreeNumber + + sizeof(SizeType) + sizeof(KDTNode) * m_pTreeRoots.size(); + } + + bool SaveTrees(std::ostream& p_outstream) const + { + p_outstream.write((char*)&m_iTreeNumber, sizeof(int)); + p_outstream.write((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber); + SizeType treeNodeSize = (SizeType)m_pTreeRoots.size(); + p_outstream.write((char*)&treeNodeSize, sizeof(SizeType)); + p_outstream.write((char*)m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize); + std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + bool SaveTrees(std::string sTreeFileName) const + { + std::cout << "Save KDT to " << sTreeFileName << std::endl; + std::ofstream output(sTreeFileName, std::ios::binary); + if (!output.is_open()) return false; + SaveTrees(output); + output.close(); + return true; + } + + bool LoadTrees(char* pKDTMemFile) + { + m_iTreeNumber = *((int*)pKDTMemFile); + pKDTMemFile += sizeof(int); + m_pTreeStart.resize(m_iTreeNumber); + memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(SizeType) * m_iTreeNumber); + pKDTMemFile += sizeof(SizeType)*m_iTreeNumber; + + SizeType treeNodeSize = *((SizeType*)pKDTMemFile); + pKDTMemFile += sizeof(SizeType); + m_pTreeRoots.resize(treeNodeSize); + memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize); + std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + bool LoadTrees(std::string sTreeFileName) + { + std::cout << "Load KDT From " << sTreeFileName << std::endl; + std::ifstream input(sTreeFileName, std::ios::binary); + if (!input.is_open()) return false; + + input.read((char*)&m_iTreeNumber, sizeof(int)); + m_pTreeStart.resize(m_iTreeNumber); + input.read((char*)m_pTreeStart.data(), sizeof(SizeType) * m_iTreeNumber); + + SizeType treeNodeSize; + input.read((char*)&treeNodeSize, sizeof(SizeType)); + m_pTreeRoots.resize(treeNodeSize); + input.read((char*)m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize); + input.close(); + std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl; + return true; + } + + template + void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const int p_limits) const + { + for (int i = 0; i < m_iTreeNumber; i++) { + KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0); + } + + while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits) + { + auto& tcell = p_space.m_SPTQueue.pop(); + if (p_query.worstDist() < tcell.distance) break; + KDTSearch(p_index, p_query, p_space, tcell.node, true, tcell.distance); + } + } + + template + void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const int p_limits) const + { + while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits) + { + auto& tcell = p_space.m_SPTQueue.pop(); + KDTSearch(p_index, p_query, p_space, tcell.node, false, tcell.distance); + } + } + + private: + + template + void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet &p_query, + COMMON::WorkSpace& p_space, const SizeType node, const bool isInit, const float distBound) const { + if (node < 0) + { + SizeType index = -node - 1; + if (index >= p_index->GetNumSamples()) return; +#ifdef PREFETCH + const char* data = (const char *)(p_index->GetSample(index)); + _mm_prefetch(data, _MM_HINT_T0); + _mm_prefetch(data + 64, _MM_HINT_T0); +#endif + if (p_space.CheckAndSet(index)) return; + + ++p_space.m_iNumberOfTreeCheckedLeaves; + ++p_space.m_iNumberOfCheckedLeaves; + p_space.m_NGQueue.insert(COMMON::HeapCell(index, p_index->ComputeDistance((const void*)p_query.GetTarget(), (const void*)data))); + return; + } + + auto& tnode = m_pTreeRoots[node]; + + float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value; + float distanceBound = distBound + diff * diff; + SizeType otherChild, bestChild; + if (diff < 0) + { + bestChild = tnode.left; + otherChild = tnode.right; + } + else + { + otherChild = tnode.left; + bestChild = tnode.right; + } + + if (!isInit || distanceBound < p_query.worstDist()) + { + p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound)); + } + KDTSearch(p_index, p_query, p_space, bestChild, isInit, distBound); + } + + + template + void DivideTree(VectorIndex* p_index, std::vector& indices, SizeType first, SizeType last, + SizeType index, SizeType &iTreeSize) { + ChooseDivision(p_index, m_pTreeRoots[index], indices, first, last); + SizeType i = Subdivide(p_index, m_pTreeRoots[index], indices, first, last); + if (i - 1 <= first) + { + m_pTreeRoots[index].left = -indices[first] - 1; + } + else + { + iTreeSize++; + m_pTreeRoots[index].left = iTreeSize; + DivideTree(p_index, indices, first, i - 1, iTreeSize, iTreeSize); + } + if (last == i) + { + m_pTreeRoots[index].right = -indices[last] - 1; + } + else + { + iTreeSize++; + m_pTreeRoots[index].right = iTreeSize; + DivideTree(p_index, indices, i, last, iTreeSize, iTreeSize); + } + } + + template + void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector& indices, const SizeType first, const SizeType last) + { + std::vector meanValues(p_index->GetFeatureDim(), 0); + std::vector varianceValues(p_index->GetFeatureDim(), 0); + SizeType end = min(first + m_iSamples, last); + SizeType count = end - first + 1; + // calculate the mean of each dimension + for (SizeType j = first; j <= end; j++) + { + const T* v = (const T*)p_index->GetSample(indices[j]); + for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++) + { + meanValues[k] += v[k]; + } + } + for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++) + { + meanValues[k] /= count; + } + // calculate the variance of each dimension + for (SizeType j = first; j <= end; j++) + { + const T* v = (const T*)p_index->GetSample(indices[j]); + for (DimensionType k = 0; k < p_index->GetFeatureDim(); k++) + { + float dist = v[k] - meanValues[k]; + varianceValues[k] += dist*dist; + } + } + // choose the split dimension as one of the dimension inside TOP_DIM maximum variance + node.split_dim = SelectDivisionDimension(varianceValues); + // determine the threshold + node.split_value = meanValues[node.split_dim]; + } + + DimensionType SelectDivisionDimension(const std::vector& varianceValues) const + { + // Record the top maximum variances + std::vector topind(m_numTopDimensionKDTSplit); + int num = 0; + // order the variances + for (DimensionType i = 0; i < (DimensionType)varianceValues.size(); i++) + { + if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]]) + { + if (num < m_numTopDimensionKDTSplit) + { + topind[num++] = i; + } + else + { + topind[num - 1] = i; + } + int j = num - 1; + // order the TOP_DIM variances + while (j > 0 && varianceValues[topind[j]] > varianceValues[topind[j - 1]]) + { + std::swap(topind[j], topind[j - 1]); + j--; + } + } + } + // randomly choose a dimension from TOP_DIM + return topind[COMMON::Utils::rand(num)]; + } + + template + SizeType Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector& indices, const SizeType first, const SizeType last) const + { + SizeType i = first; + SizeType j = last; + // decide which child one point belongs + while (i <= j) + { + SizeType ind = indices[i]; + const T* v = (const T*)p_index->GetSample(ind); + float val = v[node.split_dim]; + if (val < node.split_value) + { + i++; + } + else + { + std::swap(indices[i], indices[j]); + j--; + } + } + // if all the points in the node are equal,equally split the node into 2 + if ((i == first) || (i == last + 1)) + { + i = (first + last + 1) / 2; + } + return i; + } + + private: + std::vector m_pTreeStart; + std::vector m_pTreeRoots; + + public: + int m_iTreeNumber, m_numTopDimensionKDTSplit, m_iSamples; + }; + } +} +#endif diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/NeighborhoodGraph.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/NeighborhoodGraph.h new file mode 100644 index 0000000000..ea47125c36 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/NeighborhoodGraph.h @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_NG_H_ +#define _SPTAG_COMMON_NG_H_ + +#include "../VectorIndex.h" + +#include "CommonUtils.h" +#include "Dataset.h" +#include "FineGrainedLock.h" +#include "QueryResultSet.h" + +namespace SPTAG +{ + namespace COMMON + { + class NeighborhoodGraph + { + public: + NeighborhoodGraph(): m_iTPTNumber(32), + m_iTPTLeafSize(2000), + m_iSamples(1000), + m_numTopDimensionTPTSplit(5), + m_iNeighborhoodSize(32), + m_iNeighborhoodScale(2), + m_iCEFScale(2), + m_iRefineIter(0), + m_iCEF(1000), + m_iMaxCheckForRefineGraph(10000) + { + m_pNeighborhoodGraph.SetName("Graph"); + } + + ~NeighborhoodGraph() {} + + virtual void InsertNeighbors(VectorIndex* index, const SizeType node, SizeType insertNode, float insertDist) = 0; + + virtual void RebuildNeighbors(VectorIndex* index, const SizeType node, SizeType* nodes, const BasicResult* queryResults, const int numResults) = 0; + + virtual float GraphAccuracyEstimation(VectorIndex* index, const SizeType samples, const std::unordered_map* idmap = nullptr) = 0; + + template + void BuildGraph(VectorIndex* index, const std::unordered_map* idmap = nullptr) + { + std::cout << "build RNG graph!" << std::endl; + + m_iGraphSize = index->GetNumSamples(); + m_iNeighborhoodSize = m_iNeighborhoodSize * m_iNeighborhoodScale; + m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize); + m_dataUpdateLock.resize(m_iGraphSize); + + if (m_iGraphSize < 1000) { + RefineGraph(index, idmap); + std::cout << "Build RNG Graph end!" << std::endl; + return; + } + + { + COMMON::Dataset NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize); + std::vector> TptreeDataIndices(m_iTPTNumber, std::vector(m_iGraphSize)); + std::vector>> TptreeLeafNodes(m_iTPTNumber, std::vector>()); + + for (SizeType i = 0; i < m_iGraphSize; i++) + for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) + (NeighborhoodDists)[i][j] = MaxDist; + + std::cout << "Parallel TpTree Partition begin " << std::endl; +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < m_iTPTNumber; i++) + { + Sleep(i * 100); std::srand(clock()); + for (SizeType j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j; + std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end()); + PartitionByTptree(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]); + std::cout << "Finish Getting Leaves for Tree " << i << std::endl; + } + std::cout << "Parallel TpTree Partition done" << std::endl; + + for (int i = 0; i < m_iTPTNumber; i++) + { +#pragma omp parallel for schedule(dynamic) + for (SizeType j = 0; j < (SizeType)TptreeLeafNodes[i].size(); j++) + { + SizeType start_index = TptreeLeafNodes[i][j].first; + SizeType end_index = TptreeLeafNodes[i][j].second; + if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%'; + for (SizeType x = start_index; x < end_index; x++) + { + for (SizeType y = x + 1; y <= end_index; y++) + { + SizeType p1 = TptreeDataIndices[i][x]; + SizeType p2 = TptreeDataIndices[i][y]; + float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2)); + if (idmap != nullptr) { + p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1); + p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2); + } + COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], (NeighborhoodDists)[p1], m_iNeighborhoodSize); + COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], (NeighborhoodDists)[p2], m_iNeighborhoodSize); + } + } + } + TptreeDataIndices[i].clear(); + TptreeLeafNodes[i].clear(); + std::cout << std::endl; + } + TptreeDataIndices.clear(); + TptreeLeafNodes.clear(); + } + + if (m_iMaxCheckForRefineGraph > 0) { + RefineGraph(index, idmap); + } + } + + template + void RefineGraph(VectorIndex* index, const std::unordered_map* idmap = nullptr) + { + m_iCEF *= m_iCEFScale; + m_iMaxCheckForRefineGraph *= m_iCEFScale; + +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < m_iGraphSize; i++) + { + RefineNode(index, i, false); + if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%"; + } + std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl; + + m_iCEF /= m_iCEFScale; + m_iMaxCheckForRefineGraph /= m_iCEFScale; + m_iNeighborhoodSize /= m_iNeighborhoodScale; + +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < m_iGraphSize; i++) + { + RefineNode(index, i, false); + if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%"; + } + std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl; + + if (idmap != nullptr) { + for (auto iter = idmap->begin(); iter != idmap->end(); iter++) + if (iter->first < 0) + { + m_pNeighborhoodGraph[-1 - iter->first][m_iNeighborhoodSize - 1] = -2 - iter->second; + } + } + } + + template + ErrorCode RefineGraph(VectorIndex* index, std::vector& indices, std::vector& reverseIndices, + std::ostream& output, const std::unordered_map* idmap = nullptr) + { + SizeType R = (SizeType)indices.size(); + +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < R; i++) + { + RefineNode(index, indices[i], false); + SizeType* nodes = m_pNeighborhoodGraph[indices[i]]; + for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) + { + if (nodes[j] < 0) nodes[j] = -1; + else nodes[j] = reverseIndices[nodes[j]]; + } + if (idmap == nullptr || idmap->find(-1 - indices[i]) == idmap->end()) continue; + nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]); + } + + m_pNeighborhoodGraph.Refine(indices, output); + return ErrorCode::Success; + } + + + template + void RefineNode(VectorIndex* index, const SizeType node, bool updateNeighbors) + { + COMMON::QueryResultSet query((const T*)index->GetSample(node), m_iCEF + 1); + index->SearchIndex(query); + RebuildNeighbors(index, node, m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF + 1); + + if (updateNeighbors) { + // update neighbors + for (int j = 0; j <= m_iCEF; j++) + { + BasicResult* item = query.GetResult(j); + if (item->VID < 0) break; + if (item->VID == node) continue; + + std::lock_guard lock(m_dataUpdateLock[item->VID]); + InsertNeighbors(index, item->VID, node, item->Dist); + } + } + } + + template + void PartitionByTptree(VectorIndex* index, std::vector& indices, const SizeType first, const SizeType last, + std::vector> & leaves) + { + if (last - first <= m_iTPTLeafSize) + { + leaves.push_back(std::make_pair(first, last)); + } + else + { + std::vector Mean(index->GetFeatureDim(), 0); + + int iIteration = 100; + SizeType end = min(first + m_iSamples, last); + SizeType count = end - first + 1; + // calculate the mean of each dimension + for (SizeType j = first; j <= end; j++) + { + const T* v = (const T*)index->GetSample(indices[j]); + for (DimensionType k = 0; k < index->GetFeatureDim(); k++) + { + Mean[k] += v[k]; + } + } + for (DimensionType k = 0; k < index->GetFeatureDim(); k++) + { + Mean[k] /= count; + } + std::vector Variance; + Variance.reserve(index->GetFeatureDim()); + for (DimensionType j = 0; j < index->GetFeatureDim(); j++) + { + Variance.push_back(BasicResult(j, 0)); + } + // calculate the variance of each dimension + for (SizeType j = first; j <= end; j++) + { + const T* v = (const T*)index->GetSample(indices[j]); + for (DimensionType k = 0; k < index->GetFeatureDim(); k++) + { + float dist = v[k] - Mean[k]; + Variance[k].Dist += dist*dist; + } + } + std::sort(Variance.begin(), Variance.end(), COMMON::Compare); + std::vector indexs(m_numTopDimensionTPTSplit); + std::vector weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit); + float bestvariance = Variance[index->GetFeatureDim() - 1].Dist; + for (int i = 0; i < m_numTopDimensionTPTSplit; i++) + { + indexs[i] = Variance[index->GetFeatureDim() - 1 - i].VID; + bestweight[i] = 0; + } + bestweight[0] = 1; + float bestmean = Mean[indexs[0]]; + + std::vector Val(count); + for (int i = 0; i < iIteration; i++) + { + float sumweight = 0; + for (int j = 0; j < m_numTopDimensionTPTSplit; j++) + { + weight[j] = float(rand() % 10000) / 5000.0f - 1.0f; + sumweight += weight[j] * weight[j]; + } + sumweight = sqrt(sumweight); + for (int j = 0; j < m_numTopDimensionTPTSplit; j++) + { + weight[j] /= sumweight; + } + float mean = 0; + for (SizeType j = 0; j < count; j++) + { + Val[j] = 0; + const T* v = (const T*)index->GetSample(indices[first + j]); + for (int k = 0; k < m_numTopDimensionTPTSplit; k++) + { + Val[j] += weight[k] * v[indexs[k]]; + } + mean += Val[j]; + } + mean /= count; + float var = 0; + for (SizeType j = 0; j < count; j++) + { + float dist = Val[j] - mean; + var += dist * dist; + } + if (var > bestvariance) + { + bestvariance = var; + bestmean = mean; + for (int j = 0; j < m_numTopDimensionTPTSplit; j++) + { + bestweight[j] = weight[j]; + } + } + } + SizeType i = first; + SizeType j = last; + // decide which child one point belongs + while (i <= j) + { + float val = 0; + const T* v = (const T*)index->GetSample(indices[i]); + for (int k = 0; k < m_numTopDimensionTPTSplit; k++) + { + val += bestweight[k] * v[indexs[k]]; + } + if (val < bestmean) + { + i++; + } + else + { + std::swap(indices[i], indices[j]); + j--; + } + } + // if all the points in the node are equal,equally split the node into 2 + if ((i == first) || (i == last + 1)) + { + i = (first + last + 1) / 2; + } + + Mean.clear(); + Variance.clear(); + Val.clear(); + indexs.clear(); + weight.clear(); + bestweight.clear(); + + PartitionByTptree(index, indices, first, i - 1, leaves); + PartitionByTptree(index, indices, i, last, leaves); + } + } + + inline std::uint64_t BufferSize() const + { + return m_pNeighborhoodGraph.BufferSize(); + } + + bool LoadGraph(std::string sGraphFilename) + { + if (!m_pNeighborhoodGraph.Load(sGraphFilename)) return false; + + m_iGraphSize = m_pNeighborhoodGraph.R(); + m_iNeighborhoodSize = m_pNeighborhoodGraph.C(); + m_dataUpdateLock.resize(m_iGraphSize); + return true; + } + + bool LoadGraph(char* pGraphMemFile) + { + m_pNeighborhoodGraph.Load(pGraphMemFile); + + m_iGraphSize = m_pNeighborhoodGraph.R(); + m_iNeighborhoodSize = m_pNeighborhoodGraph.C(); + m_dataUpdateLock.resize(m_iGraphSize); + return true; + } + + bool SaveGraph(std::string sGraphFilename) const + { + return m_pNeighborhoodGraph.Save(sGraphFilename); + } + + bool SaveGraph(std::ostream& output) const + { + return m_pNeighborhoodGraph.Save(output); + } + + inline ErrorCode AddBatch(SizeType num) + { + ErrorCode ret = m_pNeighborhoodGraph.AddBatch(num); + if (ret != ErrorCode::Success) return ret; + + m_iGraphSize += num; + m_dataUpdateLock.resize(m_iGraphSize); + return ErrorCode::Success; + } + + inline SizeType* operator[](SizeType index) { return m_pNeighborhoodGraph[index]; } + + inline const SizeType* operator[](SizeType index) const { return m_pNeighborhoodGraph[index]; } + + inline void SetR(SizeType rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); } + + inline SizeType R() const { return m_iGraphSize; } + + static std::shared_ptr CreateInstance(std::string type); + + protected: + // Graph structure + SizeType m_iGraphSize; + COMMON::Dataset m_pNeighborhoodGraph; + COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph + + public: + int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit; + DimensionType m_iNeighborhoodSize; + int m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph; + }; + } +} +#endif diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/QueryResultSet.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/QueryResultSet.h new file mode 100644 index 0000000000..ff8fa14dfd --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/QueryResultSet.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_QUERYRESULTSET_H_ +#define _SPTAG_COMMON_QUERYRESULTSET_H_ + +#include "../SearchQuery.h" + +namespace SPTAG +{ +namespace COMMON +{ + +inline bool operator < (const BasicResult& lhs, const BasicResult& rhs) +{ + return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID))); +} + + +inline bool Compare(const BasicResult& lhs, const BasicResult& rhs) +{ + return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID))); +} + + +// Space to save temporary answer, similar with TopKCache +template +class QueryResultSet : public QueryResult +{ +public: + QueryResultSet(const T* _target, int _K) : QueryResult(_target, _K, false) + { + } + + QueryResultSet(const QueryResultSet& other) : QueryResult(other) + { + } + + inline void SetTarget(const T *p_target) + { + m_target = p_target; + } + + inline const T* GetTarget() const + { + return reinterpret_cast(m_target); + } + + inline float worstDist() const + { + return m_results[0].Dist; + } + + bool AddPoint(const SizeType index, float dist) + { + if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID)) + { + m_results[0].VID = index; + m_results[0].Dist = dist; + Heapify(m_resultNum); + return true; + } + return false; + } + + inline void SortResult() + { + for (int i = m_resultNum - 1; i >= 0; i--) + { + std::swap(m_results[0], m_results[i]); + Heapify(i); + } + } + +private: + void Heapify(int count) + { + int parent = 0, next = 1, maxidx = count - 1; + while (next < maxidx) + { + if (m_results[next] < m_results[next + 1]) next++; + if (m_results[parent] < m_results[next]) + { + std::swap(m_results[next], m_results[parent]); + parent = next; + next = (parent << 1) + 1; + } + else break; + } + if (next == maxidx && m_results[parent] < m_results[next]) std::swap(m_results[parent], m_results[next]); + } +}; +} +} + +#endif // _SPTAG_COMMON_QUERYRESULTSET_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h new file mode 100644 index 0000000000..33ab01927b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_RNG_H_ +#define _SPTAG_COMMON_RNG_H_ + +#include "NeighborhoodGraph.h" + +namespace SPTAG +{ + namespace COMMON + { + class RelativeNeighborhoodGraph: public NeighborhoodGraph + { + public: + void RebuildNeighbors(VectorIndex* index, const SizeType node, SizeType* nodes, const BasicResult* queryResults, const int numResults) { + DimensionType count = 0; + for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) { + const BasicResult& item = queryResults[j]; + if (item.VID < 0) break; + if (item.VID == node) continue; + + bool good = true; + for (DimensionType k = 0; k < count; k++) { + if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) { + good = false; + break; + } + } + if (good) nodes[count++] = item.VID; + } + for (DimensionType j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1; + } + + void InsertNeighbors(VectorIndex* index, const SizeType node, SizeType insertNode, float insertDist) + { + SizeType* nodes = m_pNeighborhoodGraph[node]; + for (DimensionType k = 0; k < m_iNeighborhoodSize; k++) + { + SizeType tmpNode = nodes[k]; + if (tmpNode < -1) continue; + + if (tmpNode < 0) + { + bool good = true; + for (DimensionType t = 0; t < k; t++) { + if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertNode; + } + break; + } + float tmpDist = index->ComputeDistance(index->GetSample(node), index->GetSample(tmpNode)); + if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode)) + { + bool good = true; + for (DimensionType t = 0; t < k; t++) { + if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertNode; + insertNode = tmpNode; + insertDist = tmpDist; + } + else { + break; + } + } + } + } + + float GraphAccuracyEstimation(VectorIndex* index, const SizeType samples, const std::unordered_map* idmap = nullptr) + { + DimensionType* correct = new DimensionType[samples]; + +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < samples; i++) + { + SizeType x = COMMON::Utils::rand(m_iGraphSize); + //int x = i; + COMMON::QueryResultSet query(nullptr, m_iCEF); + for (SizeType y = 0; y < m_iGraphSize; y++) + { + if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue; + float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y)); + query.AddPoint(y, dist); + } + query.SortResult(); + SizeType * exact_rng = new SizeType[m_iNeighborhoodSize]; + RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF); + + correct[i] = 0; + for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) { + if (exact_rng[j] == -1) { + correct[i] += m_iNeighborhoodSize - j; + break; + } + for (DimensionType k = 0; k < m_iNeighborhoodSize; k++) + if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) { + correct[i]++; + break; + } + } + delete[] exact_rng; + } + float acc = 0; + for (SizeType i = 0; i < samples; i++) acc += float(correct[i]); + acc = acc / samples / m_iNeighborhoodSize; + delete[] correct; + return acc; + } + + }; + } +} +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpace.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpace.h new file mode 100644 index 0000000000..c236d45a1c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpace.h @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_WORKSPACE_H_ +#define _SPTAG_COMMON_WORKSPACE_H_ + +#include "CommonUtils.h" +#include "Heap.h" + +namespace SPTAG +{ + namespace COMMON + { + // node type in the priority queue + struct HeapCell + { + SizeType node; + float distance; + + HeapCell(SizeType _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {} + + inline bool operator < (const HeapCell& rhs) + { + return distance < rhs.distance; + } + + inline bool operator > (const HeapCell& rhs) + { + return distance > rhs.distance; + } + }; + + class OptHashPosVector + { + protected: + // Max loop number in one hash block. + static const int m_maxLoop = 8; + + // Max pool size. + static const int m_poolSize = 8191; + + // Could we use the second hash block. + bool m_secondHash; + + // Record 2 hash tables. + // [0~m_poolSize + 1) is the first block. + // [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block; + SizeType m_hashTable[(m_poolSize + 1) * 2]; + + + inline unsigned hash_func2(unsigned idx, int loop) + { + return (idx + loop) & m_poolSize; + } + + + inline unsigned hash_func(unsigned idx) + { + return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & m_poolSize; + } + + public: + OptHashPosVector() {} + + ~OptHashPosVector() {} + + + void Init(SizeType size) + { + m_secondHash = true; + clear(); + } + + void clear() + { + if (!m_secondHash) + { + // Clear first block. + memset(&m_hashTable[0], 0, sizeof(SizeType)*(m_poolSize + 1)); + } + else + { + // Clear all blocks. + memset(&m_hashTable[0], 0, 2 * sizeof(SizeType) * (m_poolSize + 1)); + m_secondHash = false; + } + } + + + inline bool CheckAndSet(SizeType idx) + { + // Inner Index is begin from 1 + return _CheckAndSet(&m_hashTable[0], idx + 1) == 0; + } + + + inline int _CheckAndSet(SizeType* hashTable, SizeType idx) + { + unsigned index; + + // Get first hash position. + index = hash_func((unsigned)idx); + for (int loop = 0; loop < m_maxLoop; ++loop) + { + if (!hashTable[index]) + { + // index first match and record it. + hashTable[index] = idx; + return 1; + } + if (hashTable[index] == idx) + { + // Hit this item in hash table. + return 0; + } + // Get next hash position. + index = hash_func2(index, loop); + } + + if (hashTable == &m_hashTable[0]) + { + // Use second hash block. + m_secondHash = true; + return _CheckAndSet(&m_hashTable[m_poolSize + 1], idx); + } + + // Do not include this item. + return -1; + } + }; + + // Variables for each single NN search + struct WorkSpace + { + void Initialize(int maxCheck, SizeType dataSize) + { + nodeCheckStatus.Init(dataSize); + m_SPTQueue.Resize(maxCheck * 10); + m_NGQueue.Resize(maxCheck * 30); + + m_iNumberOfTreeCheckedLeaves = 0; + m_iNumberOfCheckedLeaves = 0; + m_iContinuousLimit = maxCheck / 64; + m_iMaxCheck = maxCheck; + m_iNumOfContinuousNoBetterPropagation = 0; + } + + void Reset(int maxCheck) + { + nodeCheckStatus.clear(); + m_SPTQueue.clear(); + m_NGQueue.clear(); + + m_iNumberOfTreeCheckedLeaves = 0; + m_iNumberOfCheckedLeaves = 0; + m_iContinuousLimit = maxCheck / 64; + m_iMaxCheck = maxCheck; + m_iNumOfContinuousNoBetterPropagation = 0; + } + + inline bool CheckAndSet(SizeType idx) + { + return nodeCheckStatus.CheckAndSet(idx); + } + + OptHashPosVector nodeCheckStatus; + //OptHashPosVector nodeCheckStatus; + + // counter for dynamic pivoting + int m_iNumOfContinuousNoBetterPropagation; + int m_iContinuousLimit; + int m_iNumberOfTreeCheckedLeaves; + int m_iNumberOfCheckedLeaves; + int m_iMaxCheck; + + // Prioriy queue used for neighborhood graph + Heap m_NGQueue; + + // Priority queue Used for BKT-Tree + Heap m_SPTQueue; + }; + } +} + +#endif // _SPTAG_COMMON_WORKSPACE_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpacePool.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpacePool.h new file mode 100644 index 0000000000..a322f42af4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/Common/WorkSpacePool.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_WORKSPACEPOOL_H_ +#define _SPTAG_COMMON_WORKSPACEPOOL_H_ + +#include "WorkSpace.h" + +#include +#include + +namespace SPTAG +{ +namespace COMMON +{ + +class WorkSpacePool +{ +public: + WorkSpacePool(int p_maxCheck, SizeType p_vectorCount); + + virtual ~WorkSpacePool(); + + std::shared_ptr Rent(); + + void Return(const std::shared_ptr& p_workSpace); + + void Init(int size); + +private: + std::list> m_workSpacePool; + + std::mutex m_workSpacePoolMutex; + + int m_maxCheck; + + SizeType m_vectorCount; +}; + +} +} + +#endif // _SPTAG_COMMON_WORKSPACEPOOL_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/CommonDataStructure.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/CommonDataStructure.h new file mode 100644 index 0000000000..c158fc8802 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/CommonDataStructure.h @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMONDATASTRUCTURE_H_ +#define _SPTAG_COMMONDATASTRUCTURE_H_ + +#include "inc/Core/Common.h" + +namespace SPTAG +{ + +template +class Array +{ +public: + Array(); + + Array(T* p_array, std::size_t p_length, bool p_transferOwnership); + + Array(T* p_array, std::size_t p_length, std::shared_ptr p_dataHolder); + + Array(Array&& p_right); + + Array(const Array& p_right); + + Array& operator= (Array&& p_right); + + Array& operator= (const Array& p_right); + + T& operator[] (std::size_t p_index); + + const T& operator[] (std::size_t p_index) const; + + ~Array(); + + T* Data() const; + + std::size_t Length() const; + + std::shared_ptr DataHolder() const; + + void Set(T* p_array, std::size_t p_length, bool p_transferOwnership); + + void Clear(); + + static Array Alloc(std::size_t p_length); + + const static Array c_empty; + +private: + T* m_data; + + std::size_t m_length; + + // Notice this is holding an array. Set correct deleter for this. + std::shared_ptr m_dataHolder; +}; + +template +const Array Array::c_empty; + + +template +Array::Array() + : m_data(nullptr), + m_length(0) +{ +} + +template +Array::Array(T* p_array, std::size_t p_length, bool p_transferOnwership) + + : m_data(p_array), + m_length(p_length) +{ + if (p_transferOnwership) + { + m_dataHolder.reset(m_data, std::default_delete()); + } +} + + +template +Array::Array(T* p_array, std::size_t p_length, std::shared_ptr p_dataHolder) + : m_data(p_array), + m_length(p_length), + m_dataHolder(std::move(p_dataHolder)) +{ +} + + +template +Array::Array(Array&& p_right) + : m_data(p_right.m_data), + m_length(p_right.m_length), + m_dataHolder(std::move(p_right.m_dataHolder)) +{ +} + + +template +Array::Array(const Array& p_right) + : m_data(p_right.m_data), + m_length(p_right.m_length), + m_dataHolder(p_right.m_dataHolder) +{ +} + + +template +Array& +Array::operator= (Array&& p_right) +{ + m_data = p_right.m_data; + m_length = p_right.m_length; + m_dataHolder = std::move(p_right.m_dataHolder); + + return *this; +} + + +template +Array& +Array::operator= (const Array& p_right) +{ + m_data = p_right.m_data; + m_length = p_right.m_length; + m_dataHolder = p_right.m_dataHolder; + + return *this; +} + + +template +T& +Array::operator[] (std::size_t p_index) +{ + return m_data[p_index]; +} + + +template +const T& +Array::operator[] (std::size_t p_index) const +{ + return m_data[p_index]; +} + + +template +Array::~Array() +{ +} + + +template +T* +Array::Data() const +{ + return m_data; +} + + +template +std::size_t +Array::Length() const +{ + return m_length; +} + + +template +std::shared_ptr +Array::DataHolder() const +{ + return m_dataHolder; +} + + +template +void +Array::Set(T* p_array, std::size_t p_length, bool p_transferOwnership) +{ + m_data = p_array; + m_length = p_length; + + if (p_transferOwnership) + { + m_dataHolder.reset(m_data, std::default_delete()); + } +} + + +template +void +Array::Clear() +{ + m_data = nullptr; + m_length = 0; + m_dataHolder.reset(); +} + + +template +Array +Array::Alloc(std::size_t p_length) +{ + Array arr; + if (0 == p_length) + { + return arr; + } + + arr.m_dataHolder.reset(new T[p_length], std::default_delete()); + + arr.m_length = p_length; + arr.m_data = arr.m_dataHolder.get(); + return arr; +} + + +typedef Array ByteArray; + +} // namespace SPTAG + +#endif // _SPTAG_COMMONDATASTRUCTURE_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/DefinitionList.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/DefinitionList.h new file mode 100644 index 0000000000..91014963c6 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/DefinitionList.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef DefineVectorValueType + +DefineVectorValueType(Int8, std::int8_t) +DefineVectorValueType(UInt8, std::uint8_t) +DefineVectorValueType(Int16, std::int16_t) +DefineVectorValueType(Float, float) + +#endif // DefineVectorValueType + + +#ifdef DefineDistCalcMethod + +DefineDistCalcMethod(L2) +DefineDistCalcMethod(Cosine) + +#endif // DefineDistCalcMethod + + +#ifdef DefineErrorCode + +// 0x0000 ~ 0x0FFF General Status +DefineErrorCode(Success, 0x0000) +DefineErrorCode(Fail, 0x0001) +DefineErrorCode(FailedOpenFile, 0x0002) +DefineErrorCode(FailedCreateFile, 0x0003) +DefineErrorCode(ParamNotFound, 0x0010) +DefineErrorCode(FailedParseValue, 0x0011) +DefineErrorCode(MemoryOverFlow, 0x0012) +DefineErrorCode(LackOfInputs, 0x0013) + +// 0x1000 ~ 0x1FFF Index Build Status + +// 0x2000 ~ 0x2FFF Index Serve Status + +// 0x3000 ~ 0x3FFF Helper Function Status +DefineErrorCode(ReadIni_FailedParseSection, 0x3000) +DefineErrorCode(ReadIni_FailedParseParam, 0x3001) +DefineErrorCode(ReadIni_DuplicatedSection, 0x3002) +DefineErrorCode(ReadIni_DuplicatedParam, 0x3003) + + +// 0x4000 ~ 0x4FFF Socket Library Status +DefineErrorCode(Socket_FailedResolveEndPoint, 0x4000) +DefineErrorCode(Socket_FailedConnectToEndPoint, 0x4001) + + +#endif // DefineErrorCode + + + +#ifdef DefineIndexAlgo + +DefineIndexAlgo(BKT) +DefineIndexAlgo(KDT) + +#endif // DefineIndexAlgo diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h new file mode 100644 index 0000000000..f3240ebdb2 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_KDT_INDEX_H_ +#define _SPTAG_KDT_INDEX_H_ + +#include "../Common.h" +#include "../VectorIndex.h" + +#include "../Common/CommonUtils.h" +#include "../Common/DistanceUtils.h" +#include "../Common/QueryResultSet.h" +#include "../Common/Dataset.h" +#include "../Common/WorkSpace.h" +#include "../Common/WorkSpacePool.h" +#include "../Common/RelativeNeighborhoodGraph.h" +#include "../Common/KDTree.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/SimpleIniReader.h" + +#include +#include + +namespace SPTAG +{ + + namespace Helper + { + class IniReader; + } + + namespace KDT + { + template + class Index : public VectorIndex + { + private: + // data points + COMMON::Dataset m_pSamples; + + // KDT structures. + COMMON::KDTree m_pTrees; + + // Graph structure + COMMON::RelativeNeighborhoodGraph m_pGraph; + + std::string m_sKDTFilename; + std::string m_sGraphFilename; + std::string m_sDataPointsFilename; + std::string m_sDeleteDataPointsFilename; + + std::mutex m_dataAddLock; // protect data and graph + Helper::Concurrent::ConcurrentSet m_deletedID; + float m_fDeletePercentageForRefine; + std::unique_ptr m_workSpacePool; + + int m_iNumberOfThreads; + DistCalcMethod m_iDistCalcMethod; + float(*m_fComputeDistance)(const T* pX, const T* pY, DimensionType length); + + int m_iMaxCheck; + int m_iThresholdOfNumberOfContinuousNoBetterPropagation; + int m_iNumberOfInitialDynamicPivots; + int m_iNumberOfOtherDynamicPivots; + public: + Index() + { +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + VarName = DefaultValue; \ + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + + m_pSamples.SetName("Vector"); + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + } + + ~Index() {} + + inline SizeType GetNumSamples() const { return m_pSamples.R(); } + inline SizeType GetIndexSize() const { return sizeof(*this); } + inline DimensionType GetFeatureDim() const { return m_pSamples.C(); } + + inline int GetCurrMaxCheck() const { return m_iMaxCheck; } + inline int GetNumThreads() const { return m_iNumberOfThreads; } + inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; } + inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; } + inline VectorValueType GetVectorValueType() const { return GetEnumValueType(); } + + inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } + inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; } + inline bool ContainSample(const SizeType idx) const { return !m_deletedID.contains(idx); } + inline bool NeedRefine() const { return m_deletedID.size() >= (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); } + std::shared_ptr> BufferSize() const + { + std::shared_ptr> buffersize(new std::vector); + buffersize->push_back(m_pSamples.BufferSize()); + buffersize->push_back(m_pTrees.BufferSize()); + buffersize->push_back(m_pGraph.BufferSize()); + buffersize->push_back(m_deletedID.bufferSize()); + return std::move(buffersize); + } + + ErrorCode SaveConfig(std::ostream& p_configout) const; + ErrorCode SaveIndexData(const std::string& p_folderPath); + ErrorCode SaveIndexData(const std::vector& p_indexStreams); + + ErrorCode LoadConfig(Helper::IniReader& p_reader); + ErrorCode LoadIndexData(const std::string& p_folderPath); + ErrorCode LoadIndexDataFromMemory(const std::vector& p_indexBlobs); + + ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension); + ErrorCode SearchIndex(QueryResult &p_query) const; + ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr); + ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum); + ErrorCode DeleteIndex(const SizeType& p_id); + + ErrorCode SetParameter(const char* p_param, const char* p_value); + std::string GetParameter(const char* p_param) const; + + ErrorCode RefineIndex(const std::string& p_folderPath); + ErrorCode RefineIndex(const std::vector& p_indexStreams); + + private: + void SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet &p_deleted) const; + void SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const; + }; + } // namespace KDT +} // namespace SPTAG + +#endif // _SPTAG_KDT_INDEX_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/ParameterDefinitionList.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/ParameterDefinitionList.h new file mode 100644 index 0000000000..c36cb178c1 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/ParameterDefinitionList.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef DefineKDTParameter + +// DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) +DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath") +DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath") +DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath") +DefineKDTParameter(m_sDeleteDataPointsFilename, std::string, std::string("deletes.bin"), "DeleteVectorFilePath") + +DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber") +DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit") +DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "Samples") + +DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber") +DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize") +DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit") + +DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, DimensionType, 32L, "NeighborhoodSize") +DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale") +DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale") +DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations") +DefineKDTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF") +DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph") + +DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads") +DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod") + +DefineKDTParameter(m_fDeletePercentageForRefine, float, 0.4F, "DeletePercentageForRefine") +DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck") +DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation") +DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots") +DefineKDTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots") + +#endif diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/MetadataSet.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/MetadataSet.h new file mode 100644 index 0000000000..37eba14491 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/MetadataSet.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_METADATASET_H_ +#define _SPTAG_METADATASET_H_ + +#include "CommonDataStructure.h" + +#include +#include + +namespace SPTAG +{ + +class MetadataSet +{ +public: + MetadataSet(); + + virtual ~MetadataSet(); + + virtual ByteArray GetMetadata(SizeType p_vectorID) const = 0; + + virtual SizeType Count() const = 0; + + virtual bool Available() const = 0; + + virtual std::pair BufferSize() const = 0; + + virtual void AddBatch(MetadataSet& data) = 0; + + virtual ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut) = 0; + + virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0; + + virtual ErrorCode RefineMetadata(std::vector& indices, std::ostream& p_metaOut, std::ostream& p_metaIndexOut); + + virtual ErrorCode RefineMetadata(std::vector& indices, const std::string& p_metaFile, const std::string& p_metaindexFile); +}; + + +class FileMetadataSet : public MetadataSet +{ +public: + FileMetadataSet(const std::string& p_metaFile, const std::string& p_metaindexFile); + + ~FileMetadataSet(); + + ByteArray GetMetadata(SizeType p_vectorID) const; + + SizeType Count() const; + + bool Available() const; + + std::pair BufferSize() const; + + void AddBatch(MetadataSet& data); + + ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut); + + ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); + +private: + std::ifstream* m_fp = nullptr; + + std::vector m_pOffsets; + + SizeType m_count; + + std::string m_metaFile; + + std::string m_metaindexFile; + + std::vector m_newdata; +}; + + +class MemMetadataSet : public MetadataSet +{ +public: + MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count); + + ~MemMetadataSet(); + + ByteArray GetMetadata(SizeType p_vectorID) const; + + SizeType Count() const; + + bool Available() const; + + std::pair BufferSize() const; + + void AddBatch(MetadataSet& data); + + ErrorCode SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut); + + ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); + +private: + std::vector m_offsets; + + SizeType m_count; + + ByteArray m_metadataHolder; + + ByteArray m_offsetHolder; + + std::vector m_newdata; +}; + + +} // namespace SPTAG + +#endif // _SPTAG_METADATASET_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchQuery.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchQuery.h new file mode 100644 index 0000000000..017b1e2e01 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchQuery.h @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SEARCHQUERY_H_ +#define _SPTAG_SEARCHQUERY_H_ + +#include "SearchResult.h" + +#include + +namespace SPTAG +{ + +// Space to save temporary answer, similar with TopKCache +class QueryResult +{ +public: + typedef BasicResult* iterator; + typedef const BasicResult* const_iterator; + + QueryResult() + : m_target(nullptr), + m_resultNum(0), + m_withMeta(false) + { + } + + + QueryResult(const void* p_target, int p_resultNum, bool p_withMeta) + { + Init(p_target, p_resultNum, p_withMeta); + } + + + QueryResult(const void* p_target, int p_resultNum, bool p_withMeta, BasicResult* p_results) + : m_target(p_target), + m_resultNum(p_resultNum), + m_withMeta(p_withMeta) + { + m_results.Set(p_results, p_resultNum, false); + } + + + QueryResult(const QueryResult& p_other) + { + Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta); + if (m_resultNum > 0) + { + std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data()); + } + } + + + QueryResult& operator=(const QueryResult& p_other) + { + Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta); + if (m_resultNum > 0) + { + std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data()); + } + + return *this; + } + + + ~QueryResult() + { + } + + + inline void Init(const void* p_target, int p_resultNum, bool p_withMeta) + { + m_target = p_target; + m_resultNum = p_resultNum; + m_withMeta = p_withMeta; + + m_results = Array::Alloc(p_resultNum); + } + + + inline int GetResultNum() const + { + return m_resultNum; + } + + + inline const void* GetTarget() + { + return m_target; + } + + + inline void SetTarget(const void* p_target) + { + m_target = p_target; + } + + + inline BasicResult* GetResult(int i) const + { + return i < m_resultNum ? m_results.Data() + i : nullptr; + } + + + inline void SetResult(int p_index, SizeType p_VID, float p_dist) + { + if (p_index < m_resultNum) + { + m_results[p_index].VID = p_VID; + m_results[p_index].Dist = p_dist; + } + } + + + inline BasicResult* GetResults() const + { + return m_results.Data(); + } + + + inline bool WithMeta() const + { + return m_withMeta; + } + + + inline const ByteArray& GetMetadata(int p_index) const + { + if (p_index < m_resultNum && m_withMeta) + { + return m_results[p_index].Meta; + } + + return ByteArray::c_empty; + } + + + inline void SetMetadata(int p_index, ByteArray p_metadata) + { + if (p_index < m_resultNum && m_withMeta) + { + m_results[p_index].Meta = std::move(p_metadata); + } + } + + + inline void Reset() + { + for (int i = 0; i < m_resultNum; i++) + { + m_results[i].VID = -1; + m_results[i].Dist = MaxDist; + m_results[i].Meta.Clear(); + } + } + + + iterator begin() + { + return m_results.Data(); + } + + + iterator end() + { + return m_results.Data() + m_resultNum; + } + + + const_iterator begin() const + { + return m_results.Data(); + } + + + const_iterator end() const + { + return m_results.Data() + m_resultNum; + } + + +protected: + const void* m_target; + + int m_resultNum; + + bool m_withMeta; + + Array m_results; +}; +} // namespace SPTAG + +#endif // _SPTAG_SEARCHQUERY_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchResult.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchResult.h new file mode 100644 index 0000000000..64e173030b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/SearchResult.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SEARCHRESULT_H_ +#define _SPTAG_SEARCHRESULT_H_ + +#include "CommonDataStructure.h" + +namespace SPTAG +{ + struct BasicResult + { + SizeType VID; + float Dist; + ByteArray Meta; + + BasicResult() : VID(-1), Dist(MaxDist) {} + + BasicResult(SizeType p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {} + + BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta) : VID(p_vid), Dist(p_dist), Meta(p_meta) {} + }; + +} // namespace SPTAG + +#endif // _SPTAG_SEARCHRESULT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h new file mode 100644 index 0000000000..b93caf0a9e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_VECTORINDEX_H_ +#define _SPTAG_VECTORINDEX_H_ + +#include "Common.h" +#include "SearchQuery.h" +#include "VectorSet.h" +#include "MetadataSet.h" +#include "inc/Helper/SimpleIniReader.h" + +#include + +namespace SPTAG +{ + +class VectorIndex +{ +public: + VectorIndex(); + + virtual ~VectorIndex(); + + virtual ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension) = 0; + + virtual ErrorCode AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start = nullptr) = 0; + + virtual ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum) = 0; + + virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0; + + virtual float ComputeDistance(const void* pX, const void* pY) const = 0; + virtual const void* GetSample(const SizeType idx) const = 0; + virtual bool ContainSample(const SizeType idx) const = 0; + virtual bool NeedRefine() const = 0; + + virtual DimensionType GetFeatureDim() const = 0; + virtual SizeType GetNumSamples() const = 0; + virtual SizeType GetIndexSize() const = 0; + + virtual DistCalcMethod GetDistCalcMethod() const = 0; + virtual IndexAlgoType GetIndexAlgoType() const = 0; + virtual VectorValueType GetVectorValueType() const = 0; + + virtual std::string GetParameter(const char* p_param) const = 0; + virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0; + + virtual std::shared_ptr> CalculateBufferSize() const; + + virtual ErrorCode LoadIndex(const std::string& p_config, const std::vector& p_indexBlobs); + + virtual ErrorCode LoadIndex(const std::string& p_folderPath); + + virtual ErrorCode SaveIndex(std::string& p_config, const std::vector& p_indexBlobs); + + virtual ErrorCode SaveIndex(const std::string& p_folderPath); + + virtual ErrorCode BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false); + + virtual ErrorCode AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet); + + virtual ErrorCode DeleteIndex(ByteArray p_meta); + + virtual const void* GetSample(ByteArray p_meta); + + virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, bool p_withMeta, BasicResult* p_results) const; + + virtual std::string GetParameter(const std::string& p_param) const; + virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value); + + virtual ByteArray GetMetadata(SizeType p_vectorID) const; + virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath); + + virtual std::string GetIndexName() const + { + if (m_sIndexName == "") return Helper::Convert::ConvertToString(GetIndexAlgoType()); + return m_sIndexName; + } + virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; } + + static std::shared_ptr CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype); + + static ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2); + + static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex); + + static ErrorCode LoadIndex(const std::string& p_config, const std::vector& p_indexBlobs, std::shared_ptr& p_vectorIndex); + +protected: + virtual std::shared_ptr> BufferSize() const = 0; + + virtual ErrorCode SaveConfig(std::ostream& p_configout) const = 0; + + virtual ErrorCode SaveIndexData(const std::string& p_folderPath) = 0; + + virtual ErrorCode SaveIndexData(const std::vector& p_indexStreams) = 0; + + virtual ErrorCode LoadConfig(Helper::IniReader& p_reader) = 0; + + virtual ErrorCode LoadIndexData(const std::string& p_folderPath) = 0; + + virtual ErrorCode LoadIndexDataFromMemory(const std::vector& p_indexBlobs) = 0; + + virtual ErrorCode DeleteIndex(const SizeType& p_id) = 0; + + virtual ErrorCode RefineIndex(const std::string& p_folderPath) = 0; + + virtual ErrorCode RefineIndex(const std::vector& p_indexStreams) = 0; + +private: + void BuildMetaMapping(); + + ErrorCode LoadIndexConfig(Helper::IniReader& p_reader); + + ErrorCode SaveIndexConfig(std::ostream& p_configOut); + +protected: + std::string m_sIndexName; + std::string m_sMetadataFile = "metadata.bin"; + std::string m_sMetadataIndexFile = "metadataIndex.bin"; + std::shared_ptr m_pMetadata; + std::unique_ptr> m_pMetaToVec; +}; + + +} // namespace SPTAG + +#endif // _SPTAG_VECTORINDEX_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorSet.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorSet.h new file mode 100644 index 0000000000..c394c701ff --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorSet.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_VECTORSET_H_ +#define _SPTAG_VECTORSET_H_ + +#include "CommonDataStructure.h" + +namespace SPTAG +{ + +class VectorSet +{ +public: + VectorSet(); + + virtual ~VectorSet(); + + virtual VectorValueType GetValueType() const = 0; + + virtual void* GetVector(SizeType p_vectorID) const = 0; + + virtual void* GetData() const = 0; + + virtual DimensionType Dimension() const = 0; + + virtual SizeType Count() const = 0; + + virtual bool Available() const = 0; + + virtual ErrorCode Save(const std::string& p_vectorFile) const = 0; +}; + + +class BasicVectorSet : public VectorSet +{ +public: + BasicVectorSet(const ByteArray& p_bytesArray, + VectorValueType p_valueType, + DimensionType p_dimension, + SizeType p_vectorCount); + + virtual ~BasicVectorSet(); + + virtual VectorValueType GetValueType() const; + + virtual void* GetVector(SizeType p_vectorID) const; + + virtual void* GetData() const; + + virtual DimensionType Dimension() const; + + virtual SizeType Count() const; + + virtual bool Available() const; + + virtual ErrorCode Save(const std::string& p_vectorFile) const; + +private: + ByteArray m_data; + + VectorValueType m_valueType; + + DimensionType m_dimension; + + SizeType m_vectorCount; + + SizeType m_perVectorDataSize; +}; + +} // namespace SPTAG + +#endif // _SPTAG_VECTORSET_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ArgumentsParser.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ArgumentsParser.h new file mode 100644 index 0000000000..0ae19b8e8f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ArgumentsParser.h @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_ARGUMENTSPARSER_H_ +#define _SPTAG_HELPER_ARGUMENTSPARSER_H_ + +#include "inc/Helper/StringConvert.h" + +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Helper +{ + +class ArgumentsParser +{ +public: + ArgumentsParser(); + + virtual ~ArgumentsParser(); + + virtual bool Parse(int p_argc, char** p_args); + + virtual void PrintHelp(); + +protected: + class IArgument + { + public: + IArgument(); + + virtual ~IArgument(); + + virtual bool ParseValue(int& p_restArgc, char** (&p_args)) = 0; + + virtual void PrintDescription(FILE* p_output) = 0; + + virtual bool IsRequiredButNotSet() const = 0; + }; + + + template + class ArgumentT : public IArgument + { + public: + ArgumentT(DataType& p_target, + const std::string& p_representStringShort, + const std::string& p_representString, + const std::string& p_description, + bool p_followedValue, + const DataType& p_switchAsValue, + bool p_isRequired) + : m_value(p_target), + m_representStringShort(p_representStringShort), + m_representString(p_representString), + m_description(p_description), + m_followedValue(p_followedValue), + c_switchAsValue(p_switchAsValue), + m_isRequired(p_isRequired), + m_isSet(false) + { + } + + virtual ~ArgumentT() + { + } + + + virtual bool ParseValue(int& p_restArgc, char** (&p_args)) + { + if (0 == p_restArgc) + { + return true; + } + + if (0 != strcmp(*p_args, m_representString.c_str()) + && 0 != strcmp(*p_args, m_representStringShort.c_str())) + { + return true; + } + + if (!m_followedValue) + { + m_value = c_switchAsValue; + --p_restArgc; + ++p_args; + m_isSet = true; + return true; + } + + if (p_restArgc < 2) + { + return false; + } + + DataType tmp; + if (!Helper::Convert::ConvertStringTo(p_args[1], tmp)) + { + return false; + } + + m_value = std::move(tmp); + + p_restArgc -= 2; + p_args += 2; + m_isSet = true; + return true; + } + + + virtual void PrintDescription(FILE* p_output) + { + std::size_t padding = 30; + if (!m_representStringShort.empty()) + { + fprintf(p_output, "%s", m_representStringShort.c_str()); + padding -= m_representStringShort.size(); + } + + if (!m_representString.empty()) + { + if (!m_representStringShort.empty()) + { + fprintf(p_output, ", "); + padding -= 2; + } + + fprintf(p_output, "%s", m_representString.c_str()); + padding -= m_representString.size(); + } + + if (m_followedValue) + { + fprintf(p_output, " "); + padding -= 8; + } + + while (padding-- > 0) + { + fputc(' ', p_output); + } + + fprintf(p_output, "%s", m_description.c_str()); + } + + + virtual bool IsRequiredButNotSet() const + { + return m_isRequired && !m_isSet; + } + + private: + DataType & m_value; + + std::string m_representStringShort; + + std::string m_representString; + + std::string m_description; + + bool m_followedValue; + + const DataType c_switchAsValue; + + bool m_isRequired; + + bool m_isSet; + }; + + + template + void AddRequiredOption(DataType& p_target, + const std::string& p_representStringShort, + const std::string& p_representString, + const std::string& p_description) + { + m_arguments.emplace_back(std::shared_ptr( + new ArgumentT(p_target, + p_representStringShort, + p_representString, + p_description, + true, + DataType(), + true))); + } + + + template + void AddOptionalOption(DataType& p_target, + const std::string& p_representStringShort, + const std::string& p_representString, + const std::string& p_description) + { + m_arguments.emplace_back(std::shared_ptr( + new ArgumentT(p_target, + p_representStringShort, + p_representString, + p_description, + true, + DataType(), + false))); + } + + + template + void AddRequiredSwitch(DataType& p_target, + const std::string& p_representStringShort, + const std::string& p_representString, + const std::string& p_description, + const DataType& p_switchAsValue) + { + m_arguments.emplace_back(std::shared_ptr( + new ArgumentT(p_target, + p_representStringShort, + p_representString, + p_description, + false, + p_switchAsValue, + true))); + } + + + template + void AddOptionalSwitch(DataType& p_target, + const std::string& p_representStringShort, + const std::string& p_representString, + const std::string& p_description, + const DataType& p_switchAsValue) + { + m_arguments.emplace_back(std::shared_ptr( + new ArgumentT(p_target, + p_representStringShort, + p_representString, + p_description, + false, + p_switchAsValue, + false))); + } + +private: + std::vector> m_arguments; +}; + + +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_ARGUMENTSPARSER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Base64Encode.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Base64Encode.h new file mode 100644 index 0000000000..8e7919345d --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Base64Encode.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_BASE64ENCODE_H_ +#define _SPTAG_HELPER_BASE64ENCODE_H_ + +#include +#include +#include + +namespace SPTAG +{ +namespace Helper +{ +namespace Base64 +{ + +bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen); + +bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen); + +bool Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen); + +std::size_t CapacityForEncode(std::size_t p_inLen); + +std::size_t CapacityForDecode(std::size_t p_inLen); + + +} // namespace Base64 +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_BASE64ENCODE_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/BufferStream.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/BufferStream.h new file mode 100644 index 0000000000..c97be04f12 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/BufferStream.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_BUFFERSTREAM_H_ +#define _SPTAG_HELPER_BUFFERSTREAM_H_ + +#include +#include +#include + +namespace SPTAG +{ + namespace Helper + { + struct streambuf : public std::basic_streambuf + { + streambuf(char* buffer, size_t size) + { + setp(buffer, buffer + size); + } + }; + + class obufferstream : public std::ostream + { + public: + obufferstream(streambuf* buf, bool transferOwnership) : std::ostream(buf) + { + if (transferOwnership) + m_bufHolder.reset(buf, std::default_delete()); + } + + private: + std::shared_ptr m_bufHolder; + }; + } // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_BUFFERSTREAM_H_ + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/CommonHelper.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/CommonHelper.h new file mode 100644 index 0000000000..7f14784707 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/CommonHelper.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_COMMONHELPER_H_ +#define _SPTAG_HELPER_COMMONHELPER_H_ + +#include "../Core/Common.h" + +#include +#include +#include +#include +#include +#include + + +namespace SPTAG +{ +namespace Helper +{ +namespace StrUtils +{ + +void ToLowerInPlace(std::string& p_str); + +std::vector SplitString(const std::string& p_str, const std::string& p_separator); + +std::pair FindTrimmedSegment(const char* p_begin, + const char* p_end, + const std::function& p_isSkippedChar); + +bool StartsWith(const char* p_str, const char* p_prefix); + +bool StrEqualIgnoreCase(const char* p_left, const char* p_right); + +} // namespace StrUtils +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_COMMONHELPER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Concurrent.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Concurrent.h new file mode 100644 index 0000000000..35c7cc93e4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/Concurrent.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_CONCURRENT_H_ +#define _SPTAG_HELPER_CONCURRENT_H_ + + +#include +#include +#include + + +namespace SPTAG +{ +namespace Helper +{ +namespace Concurrent +{ + +class SpinLock +{ +public: + SpinLock() = default; + + void Lock() noexcept + { + while (m_lock.test_and_set(std::memory_order_acquire)) + { + } + } + + void Unlock() noexcept + { + m_lock.clear(std::memory_order_release); + } + + SpinLock(const SpinLock&) = delete; + SpinLock& operator = (const SpinLock&) = delete; + +private: + std::atomic_flag m_lock = ATOMIC_FLAG_INIT; +}; + +template +class LockGuard { +public: + LockGuard(Lock& lock) noexcept + : m_lock(lock) { + lock.Lock(); + } + + LockGuard(Lock& lock, std::adopt_lock_t) noexcept + : m_lock(lock) {} + + ~LockGuard() { + m_lock.Unlock(); + } + + LockGuard(const LockGuard&) = delete; + LockGuard& operator=(const LockGuard&) = delete; + +private: + Lock& m_lock; +}; + + +class WaitSignal +{ +public: + WaitSignal(); + + WaitSignal(std::uint32_t p_unfinished); + + ~WaitSignal(); + + void Reset(std::uint32_t p_unfinished); + + void Wait(); + + void FinishOne(); + +private: + std::atomic m_unfinished; + + std::atomic_bool m_isWaiting; + + std::mutex m_mutex; + + std::condition_variable m_cv; +}; + + +} // namespace Base64 +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_CONCURRENT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ConcurrentSet.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ConcurrentSet.h new file mode 100644 index 0000000000..61254dc2eb --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/ConcurrentSet.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_CONCURRENTSET_H_ +#define _SPTAG_HELPER_CONCURRENTSET_H_ + +#include +#include + +namespace SPTAG +{ + namespace Helper + { + namespace Concurrent + { + template + class ConcurrentSet + { + public: + ConcurrentSet(); + + ~ConcurrentSet(); + + size_t size() const; + + bool contains(const T& key) const; + + void insert(const T& key); + + std::shared_timed_mutex& getLock(); + + bool save(std::ostream& output); + + bool save(std::string filename); + + bool load(std::string filename); + + bool load(char* pmemoryFile); + + std::uint64_t bufferSize() const; + + private: + std::unique_ptr m_lock; + std::unordered_set m_data; + }; + + template + ConcurrentSet::ConcurrentSet() + { + m_lock.reset(new std::shared_timed_mutex); + } + + template + ConcurrentSet::~ConcurrentSet() + { + } + + template + size_t ConcurrentSet::size() const + { + std::shared_lock lock(*m_lock); + return m_data.size(); + } + + template + bool ConcurrentSet::contains(const T& key) const + { + std::shared_lock lock(*m_lock); + return (m_data.find(key) != m_data.end()); + } + + template + void ConcurrentSet::insert(const T& key) + { + std::unique_lock lock(*m_lock); + m_data.insert(key); + } + + template + std::shared_timed_mutex& ConcurrentSet::getLock() + { + return *m_lock; + } + + template + std::uint64_t ConcurrentSet::bufferSize() const + { + return sizeof(SizeType) + sizeof(T) * m_data.size(); + } + + template + bool ConcurrentSet::save(std::ostream& output) + { + SizeType count = (SizeType)m_data.size(); + output.write((char*)&count, sizeof(SizeType)); + for (auto iter = m_data.begin(); iter != m_data.end(); iter++) + output.write((char*)&(*iter), sizeof(T)); + std::cout << "Save DeleteID (" << count << ") Finish!" << std::endl; + return true; + } + + template + bool ConcurrentSet::save(std::string filename) + { + std::cout << "Save DeleteID To " << filename << std::endl; + std::ofstream output(filename, std::ios::binary); + if (!output.is_open()) return false; + save(output); + output.close(); + return true; + } + + template + bool ConcurrentSet::load(std::string filename) + { + std::cout << "Load DeleteID From " << filename << std::endl; + std::ifstream input(filename, std::ios::binary); + if (!input.is_open()) return false; + + SizeType count; + T ID; + input.read((char*)&count, sizeof(SizeType)); + for (SizeType i = 0; i < count; i++) + { + input.read((char*)&ID, sizeof(T)); + m_data.insert(ID); + } + input.close(); + std::cout << "Load DeleteID (" << count << ") Finish!" << std::endl; + return true; + } + + template + bool ConcurrentSet::load(char* pmemoryFile) + { + SizeType count; + count = *((SizeType*)pmemoryFile); + pmemoryFile += sizeof(SizeType); + + m_data.insert((T*)pmemoryFile, ((T*)pmemoryFile) + count); + pmemoryFile += sizeof(T) * count; + std::cout << "Load DeleteID (" << count << ") Finish!" << std::endl; + return true; + } + } + } +} +#endif // _SPTAG_HELPER_CONCURRENTSET_H_ \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/SimpleIniReader.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/SimpleIniReader.h new file mode 100644 index 0000000000..ad8d58f6f7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/SimpleIniReader.h @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_INIREADER_H_ +#define _SPTAG_HELPER_INIREADER_H_ + +#include "../Core/Common.h" +#include "StringConvert.h" + +#include +#include +#include +#include +#include + + +namespace SPTAG +{ +namespace Helper +{ + +// Simple INI Reader with basic functions. Case insensitive. +class IniReader +{ +public: + typedef std::map ParameterValueMap; + + IniReader(); + + ~IniReader(); + + ErrorCode LoadIniFile(const std::string& p_iniFilePath); + + ErrorCode LoadIni(std::istream& p_input); + + bool DoesSectionExist(const std::string& p_section) const; + + bool DoesParameterExist(const std::string& p_section, const std::string& p_param) const; + + const ParameterValueMap& GetParameters(const std::string& p_section) const; + + template + DataType GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const; + + void SetParameter(const std::string& p_section, const std::string& p_param, const std::string& p_val); + +private: + bool GetRawValue(const std::string& p_section, const std::string& p_param, std::string& p_value) const; + + template + static inline DataType ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal); + +private: + const static ParameterValueMap c_emptyParameters; + + std::map> m_parameters; +}; + + +template +DataType +IniReader::GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const +{ + std::string value; + if (!GetRawValue(p_section, p_param, value)) + { + return p_defaultVal; + } + + return ConvertStringTo(std::move(value), p_defaultVal); +} + + +template +inline DataType +IniReader::ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal) +{ + DataType value; + if (Convert::ConvertStringTo(p_str.c_str(), value)) + { + return value; + } + + return p_defaultVal; +} + + +template <> +inline std::string +IniReader::ConvertStringTo(std::string&& p_str, const std::string& p_defaultVal) +{ + return std::move(p_str); +} + + +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_INIREADER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/StringConvert.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/StringConvert.h new file mode 100644 index 0000000000..b6e53df785 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/StringConvert.h @@ -0,0 +1,374 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_STRINGCONVERTHELPER_H_ +#define _SPTAG_HELPER_STRINGCONVERTHELPER_H_ + +#include "inc/Core/Common.h" +#include "CommonHelper.h" + +#include +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Helper +{ +namespace Convert +{ + +template +inline bool ConvertStringTo(const char* p_str, DataType& p_value) +{ + if (nullptr == p_str) + { + return false; + } + + std::istringstream sstream; + sstream.str(p_str); + if (p_str >> p_value) + { + return true; + } + + return false; +} + + +template +inline std::string ConvertToString(const DataType& p_value) +{ + return std::to_string(p_value); +} + + +// Specialization of ConvertStringTo<>(). + +template +inline bool ConvertStringToSignedInt(const char* p_str, DataType& p_value) +{ + static_assert(std::is_integral::value && std::is_signed::value, "type check"); + + if (nullptr == p_str) + { + return false; + } + + char* end = nullptr; + errno = 0; + auto val = std::strtoll(p_str, &end, 10); + if (errno == ERANGE || end == p_str || *end != '\0') + { + return false; + } + + if (val < (std::numeric_limits::min)() || val >(std::numeric_limits::max)()) + { + return false; + } + + p_value = static_cast(val); + return true; +} + + +template +inline bool ConvertStringToUnsignedInt(const char* p_str, DataType& p_value) +{ + static_assert(std::is_integral::value && std::is_unsigned::value, "type check"); + + if (nullptr == p_str) + { + return false; + } + + char* end = nullptr; + errno = 0; + auto val = std::strtoull(p_str, &end, 10); + if (errno == ERANGE || end == p_str || *end != '\0') + { + return false; + } + + if (val < (std::numeric_limits::min)() || val >(std::numeric_limits::max)()) + { + return false; + } + + p_value = static_cast(val); + return true; +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::string& p_value) +{ + if (nullptr == p_str) + { + return false; + } + + p_value = p_str; + return true; +} + + +template <> +inline bool ConvertStringTo(const char* p_str, float& p_value) +{ + if (nullptr == p_str) + { + return false; + } + + char* end = nullptr; + errno = 0; + p_value = std::strtof(p_str, &end); + return (errno != ERANGE && end != p_str && *end == '\0'); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, double& p_value) +{ + if (nullptr == p_str) + { + return false; + } + + char* end = nullptr; + errno = 0; + p_value = std::strtod(p_str, &end); + return (errno != ERANGE && end != p_str && *end == '\0'); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::int8_t& p_value) +{ + return ConvertStringToSignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::int16_t& p_value) +{ + return ConvertStringToSignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::int32_t& p_value) +{ + return ConvertStringToSignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::int64_t& p_value) +{ + return ConvertStringToSignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::uint8_t& p_value) +{ + return ConvertStringToUnsignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::uint16_t& p_value) +{ + return ConvertStringToUnsignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::uint32_t& p_value) +{ + return ConvertStringToUnsignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, std::uint64_t& p_value) +{ + return ConvertStringToUnsignedInt(p_str, p_value); +} + + +template <> +inline bool ConvertStringTo(const char* p_str, bool& p_value) +{ + if (StrUtils::StrEqualIgnoreCase(p_str, "true")) + { + p_value = true; + + } + else if (StrUtils::StrEqualIgnoreCase(p_str, "false")) + { + p_value = false; + } + else + { + return false; + } + + return true; +} + + +template <> +inline bool ConvertStringTo(const char* p_str, IndexAlgoType& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineIndexAlgo(Name) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = IndexAlgoType::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineIndexAlgo + + return false; +} + + +template <> +inline bool ConvertStringTo(const char* p_str, DistCalcMethod& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineDistCalcMethod(Name) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = DistCalcMethod::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineDistCalcMethod + + return false; +} + + +template <> +inline bool ConvertStringTo(const char* p_str, VectorValueType& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineVectorValueType(Name, Type) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = VectorValueType::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + return false; +} + + +// Specialization of ConvertToString<>(). + +template<> +inline std::string ConvertToString(const std::string& p_value) +{ + return p_value; +} + + +template<> +inline std::string ConvertToString(const bool& p_value) +{ + return p_value ? "true" : "false"; +} + + +template <> +inline std::string ConvertToString(const IndexAlgoType& p_value) +{ + switch (p_value) + { +#define DefineIndexAlgo(Name) \ + case IndexAlgoType::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineIndexAlgo + + default: + break; + } + + return "Undefined"; +} + + +template <> +inline std::string ConvertToString(const DistCalcMethod& p_value) +{ + switch (p_value) + { +#define DefineDistCalcMethod(Name) \ + case DistCalcMethod::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineDistCalcMethod + + default: + break; + } + + return "Undefined"; +} + + +template <> +inline std::string ConvertToString(const VectorValueType& p_value) +{ + switch (p_value) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + + return "Undefined"; +} + + +} // namespace Convert +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_STRINGCONVERTHELPER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReader.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReader.h new file mode 100644 index 0000000000..cd148c1d04 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReader.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_VECTORSETREADER_H_ +#define _SPTAG_HELPER_VECTORSETREADER_H_ + +#include "inc/Core/Common.h" +#include "inc/Core/VectorSet.h" +#include "inc/Core/MetadataSet.h" +#include "inc/Helper/ArgumentsParser.h" + +#include + +namespace SPTAG +{ +namespace Helper +{ + +class ReaderOptions : public ArgumentsParser +{ +public: + ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, std::string p_vectorDelimiter = "|", std::uint32_t p_threadNum = 32); + + ~ReaderOptions(); + + std::uint32_t m_threadNum; + + DimensionType m_dimension; + + std::string m_vectorDelimiter; + + SPTAG::VectorValueType m_inputValueType; +}; + +class VectorSetReader +{ +public: + VectorSetReader(std::shared_ptr p_options); + + virtual ~VectorSetReader(); + + virtual ErrorCode LoadFile(const std::string& p_filePath) = 0; + + virtual std::shared_ptr GetVectorSet() const = 0; + + virtual std::shared_ptr GetMetadataSet() const = 0; + + static std::shared_ptr CreateInstance(std::shared_ptr p_options); + +protected: + std::shared_ptr m_options; +}; + + + +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_VECTORSETREADER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReaders/DefaultReader.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReaders/DefaultReader.h new file mode 100644 index 0000000000..52c8404caf --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Helper/VectorSetReaders/DefaultReader.h @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_HELPER_VECTORSETREADERS_DEFAULTREADER_H_ +#define _SPTAG_HELPER_VECTORSETREADERS_DEFAULTREADER_H_ + +#include "../VectorSetReader.h" +#include "inc/Helper/Concurrent.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Helper +{ + +class DefaultReader : public VectorSetReader +{ +public: + DefaultReader(std::shared_ptr p_options); + + virtual ~DefaultReader(); + + virtual ErrorCode LoadFile(const std::string& p_filePaths); + + virtual std::shared_ptr GetVectorSet() const; + + virtual std::shared_ptr GetMetadataSet() const; + +private: + typedef std::pair FileInfoPair; + + static std::vector GetFileSizes(const std::string& p_filePaths); + + void LoadFileInternal(const std::string& p_filePath, + std::uint32_t p_subtaskID, + std::uint32_t p_fileBlockID, + std::size_t p_fileBlockSize); + + void MergeData(); + + template + bool TranslateVector(char* p_str, DataType* p_vector) + { + DimensionType eleCount = 0; + char* next = p_str; + while ((*next) != '\0') + { + while ((*next) != '\0' && m_options->m_vectorDelimiter.find(*next) == std::string::npos) + { + ++next; + } + + bool reachEnd = ('\0' == (*next)); + *next = '\0'; + if (p_str != next) + { + if (eleCount >= m_options->m_dimension) + { + return false; + } + + if (!Helper::Convert::ConvertStringTo(p_str, p_vector[eleCount++])) + { + return false; + } + } + + if (reachEnd) + { + break; + } + + ++next; + p_str = next; + } + + return eleCount == m_options->m_dimension; + } + +private: + std::uint32_t m_subTaskCount; + + std::size_t m_subTaskBlocksize; + + std::atomic m_totalRecordCount; + + std::atomic m_totalRecordVectorBytes; + + std::vector m_subTaskRecordCount; + + std::string m_vectorOutput; + + std::string m_metadataConentOutput; + + std::string m_metadataIndexOutput; + + Helper::Concurrent::WaitSignal m_waitSignal; +}; + + + +} // namespace Helper +} // namespace SPTAG + +#endif // _SPTAG_HELPER_VECTORSETREADERS_DEFAULT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/Options.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/Options.h new file mode 100644 index 0000000000..b3b3e21e58 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/Options.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_INDEXBUILDER_OPTIONS_H_ +#define _SPTAG_INDEXBUILDER_OPTIONS_H_ + +#include "inc/Core/Common.h" +#include "inc/Helper/VectorSetReader.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace IndexBuilder +{ + +class BuilderOptions : public Helper::ReaderOptions +{ +public: + BuilderOptions(); + + ~BuilderOptions(); + + std::string m_inputFiles; + + std::string m_outputFolder; + + SPTAG::IndexAlgoType m_indexAlgoType; + + std::string m_builderConfigFile; +}; + + +} // namespace IndexBuilder +} // namespace SPTAG + +#endif // _SPTAG_INDEXBUILDER_OPTIONS_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/ThreadPool.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/ThreadPool.h new file mode 100644 index 0000000000..7256f71ae5 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/IndexBuilder/ThreadPool.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_INDEXBUILDER_THREADPOOL_H_ +#define _SPTAG_INDEXBUILDER_THREADPOOL_H_ + +#include +#include + +namespace SPTAG +{ +namespace IndexBuilder +{ +namespace ThreadPool +{ + +void Init(std::uint32_t p_threadNum); + +bool Queue(std::function p_workItem); + +std::uint32_t CurrentThreadNum(); + +} +} // namespace IndexBuilder +} // namespace SPTAG + +#endif // _SPTAG_INDEXBUILDER_THREADPOOL_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/QueryParser.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/QueryParser.h new file mode 100644 index 0000000000..9444e40862 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/QueryParser.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_QUERYPARSER_H_ +#define _SPTAG_SERVER_QUERYPARSER_H_ + +#include "../Core/Common.h" +#include "../Core/CommonDataStructure.h" + +#include + +namespace SPTAG +{ +namespace Service +{ + + +class QueryParser +{ +public: + typedef std::pair OptionPair; + + QueryParser(); + + ~QueryParser(); + + ErrorCode Parse(const std::string& p_query, const char* p_vectorSeparator); + + const std::vector& GetVectorElements() const; + + const std::vector& GetOptions() const; + + const char* GetVectorBase64() const; + + SizeType GetVectorBase64Length() const; + +private: + std::vector m_options; + + std::vector m_vectorElements; + + const char* m_vectorBase64; + + SizeType m_vectorBase64Length; + + ByteArray m_dataHolder; + + static const char* c_defaultVectorSeparator; +}; + + +} // namespace Server +} // namespace AnnService + + +#endif // _SPTAG_SERVER_QUERYPARSER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutionContext.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutionContext.h new file mode 100644 index 0000000000..cba4df4651 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutionContext.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_ +#define _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_ + +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SearchQuery.h" +#include "inc/Socket/RemoteSearchQuery.h" +#include "ServiceSettings.h" +#include "QueryParser.h" + +#include +#include +#include + + +namespace SPTAG +{ +namespace Service +{ + +typedef Socket::IndexSearchResult SearchResult; + +class SearchExecutionContext +{ +public: + SearchExecutionContext(const std::shared_ptr& p_serviceSettings); + + ~SearchExecutionContext(); + + ErrorCode ParseQuery(const std::string& p_query); + + ErrorCode ExtractOption(); + + ErrorCode ExtractVector(VectorValueType p_targetType); + + void AddResults(std::string p_indexName, QueryResult& p_results); + + std::vector& GetResults(); + + const std::vector& GetResults() const; + + const ByteArray& GetVector() const; + + const std::vector& GetSelectedIndexNames() const; + + const SizeType GetVectorDimension() const; + + const std::vector& GetOptions() const; + + const SizeType GetResultNum() const; + + const bool GetExtractMetadata() const; + +private: + const std::shared_ptr c_serviceSettings; + + QueryParser m_queryParser; + + std::vector m_indexNames; + + ByteArray m_vector; + + SizeType m_vectorDimension; + + std::vector m_results; + + VectorValueType m_inputValueType; + + bool m_extractMetadata; + + SizeType m_resultNum; +}; + +} // namespace Server +} // namespace AnnService + + +#endif // _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutor.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutor.h new file mode 100644 index 0000000000..201832651b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchExecutor.h @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_SEARCHEXECUTOR_H_ +#define _SPTAG_SERVER_SEARCHEXECUTOR_H_ + +#include "ServiceContext.h" +#include "ServiceSettings.h" +#include "SearchExecutionContext.h" +#include "QueryParser.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Service +{ + +class SearchExecutor +{ +public: + typedef std::function)> CallBack; + + SearchExecutor(std::string p_queryString, + std::shared_ptr p_serviceContext, + const CallBack& p_callback); + + ~SearchExecutor(); + + void Execute(); + +private: + void ExecuteInternal(); + + void SelectIndex(); + +private: + CallBack m_callback; + + const std::shared_ptr c_serviceContext; + + std::shared_ptr m_executionContext; + + std::string m_queryString; + + std::vector> m_selectedIndex; +}; + + +} // namespace Server +} // namespace AnnService + + +#endif // _SPTAG_SERVER_SEARCHEXECUTOR_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchService.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchService.h new file mode 100644 index 0000000000..34d0c6064c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/SearchService.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_SERVICE_H_ +#define _SPTAG_SERVER_SERVICE_H_ + +#include "ServiceContext.h" +#include "../Socket/Server.h" + +#include + +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Service +{ + +class SearchExecutionContext; + +class SearchService +{ +public: + SearchService(); + + ~SearchService(); + + bool Initialize(int p_argNum, char* p_args[]); + + void Run(); + +private: + void RunSocketMode(); + + void RunInteractiveMode(); + + void SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet); + + void SearchHanlderCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket); + +private: + enum class ServeMode : std::uint8_t + { + Interactive, + + Socket + }; + + std::shared_ptr m_serviceContext; + + std::shared_ptr m_socketServer; + + bool m_initialized; + + ServeMode m_serveMode; + + std::unique_ptr m_threadPool; + + boost::asio::io_context m_ioContext; + + boost::asio::signal_set m_shutdownSignals; +}; + + +} // namespace Server +} // namespace AnnService + + +#endif // _SPTAG_SERVER_SERVICE_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceContext.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceContext.h new file mode 100644 index 0000000000..b1a7b84045 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceContext.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_SERVICECONTEX_H_ +#define _SPTAG_SERVER_SERVICECONTEX_H_ + +#include "inc/Core/VectorIndex.h" +#include "ServiceSettings.h" + +#include +#include + +namespace SPTAG +{ +namespace Service +{ + +class ServiceContext +{ +public: + ServiceContext(const std::string& p_configFilePath); + + ~ServiceContext(); + + const std::map>& GetIndexMap() const; + + const std::shared_ptr& GetServiceSettings() const; + + bool IsInitialized() const; + +private: + bool m_initialized; + + std::shared_ptr m_settings; + + std::map> m_fullIndexList; +}; + + +} // namespace Server +} // namespace AnnService + +#endif // _SPTAG_SERVER_SERVICECONTEX_H_ + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceSettings.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceSettings.h new file mode 100644 index 0000000000..9077487355 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Server/ServiceSettings.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SERVER_SERVICESTTINGS_H_ +#define _SPTAG_SERVER_SERVICESTTINGS_H_ + +#include "../Core/Common.h" + +#include + +namespace SPTAG +{ +namespace Service +{ + +struct ServiceSettings +{ + ServiceSettings(); + + std::string m_vectorSeparator; + + std::string m_listenAddr; + + std::string m_listenPort; + + SizeType m_defaultMaxResultNumber; + + SizeType m_threadNum; + + SizeType m_socketThreadNum; +}; + + + + +} // namespace Server +} // namespace AnnService + + +#endif // _SPTAG_SERVER_SERVICESTTINGS_H_ + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Client.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Client.h new file mode 100644 index 0000000000..a57465dfd7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Client.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_CLIENT_H_ +#define _SPTAG_SOCKET_CLIENT_H_ + +#include "inc/Core/Common.h" +#include "Connection.h" +#include "ConnectionManager.h" +#include "Packet.h" + +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +class Client +{ +public: + typedef std::function ConnectCallback; + + Client(const PacketHandlerMapPtr& p_handlerMap, + std::size_t p_threadNum, + std::uint32_t p_heartbeatIntervalSeconds); + + ~Client(); + + ConnectionID ConnectToServer(const std::string& p_address, + const std::string& p_port, + SPTAG::ErrorCode& p_ec); + + void AsyncConnectToServer(const std::string& p_address, + const std::string& p_port, + ConnectCallback p_callback); + + void SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback); + + void SetEventOnConnectionClose(std::function p_event); + +private: + void KeepIoContext(); + +private: + std::atomic_bool m_stopped; + + std::uint32_t m_heartbeatIntervalSeconds; + + boost::asio::io_context m_ioContext; + + boost::asio::deadline_timer m_deadlineTimer; + + std::shared_ptr m_connectionManager; + + std::vector m_threadPool; + + const PacketHandlerMapPtr c_requestHandlerMap; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_CLIENT_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Common.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Common.h new file mode 100644 index 0000000000..dc06af1bb4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Common.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_COMMON_H_ +#define _SPTAG_SOCKET_COMMON_H_ + +#include + +namespace SPTAG +{ +namespace Socket +{ + +typedef std::uint32_t ConnectionID; + +typedef std::uint32_t ResourceID; + +extern const ConnectionID c_invalidConnectionID; + +extern const ResourceID c_invalidResourceID; + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_COMMON_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Connection.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Connection.h new file mode 100644 index 0000000000..1d75d093b3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Connection.h @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_CONNECTION_H_ +#define _SPTAG_SOCKET_CONNECTION_H_ + +#include "Packet.h" + +#include +#include +#include + +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +class ConnectionManager; + +class Connection : public std::enable_shared_from_this +{ +public: + typedef std::shared_ptr Ptr; + + Connection(ConnectionID p_connectionID, + boost::asio::ip::tcp::socket&& p_socket, + const PacketHandlerMapPtr& p_handlerMap, + std::weak_ptr p_connectionManager); + + void Start(); + + void Stop(); + + void StartHeartbeat(std::size_t p_intervalSeconds); + + void AsyncSend(Packet p_packet, std::function p_callback); + + ConnectionID GetConnectionID() const; + + ConnectionID GetRemoteConnectionID() const; + + Connection(const Connection&) = delete; + Connection& operator=(const Connection&) = delete; + +private: + void AsyncReadHeader(); + + void AsyncReadBody(); + + void HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred); + + void HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred); + + void SendHeartbeat(std::size_t p_intervalSeconds); + + void SendRegister(); + + void HandleHeartbeatRequest(); + + void HandleRegisterRequest(); + + void HandleRegisterResponse(); + + void HandleNoHandlerResponse(); + + void OnConnectionFail(const boost::system::error_code& p_ec); + +private: + const ConnectionID c_connectionID; + + ConnectionID m_remoteConnectionID; + + const std::weak_ptr c_connectionManager; + + const PacketHandlerMapPtr c_handlerMap; + + boost::asio::ip::tcp::socket m_socket; + + boost::asio::io_context::strand m_strand; + + boost::asio::deadline_timer m_heartbeatTimer; + + std::uint8_t m_packetHeaderReadBuffer[PacketHeader::c_bufferSize]; + + Packet m_packetRead; + + std::atomic_bool m_stopped; + + std::atomic_bool m_heartbeatStarted; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_CONNECTION_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ConnectionManager.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ConnectionManager.h new file mode 100644 index 0000000000..e487c61053 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ConnectionManager.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_CONNECTIONMANAGER_H_ +#define _SPTAG_SOCKET_CONNECTIONMANAGER_H_ + +#include "Connection.h" +#include "inc/Helper/Concurrent.h" + +#include +#include +#include +#include +#include + +#include + +namespace SPTAG +{ +namespace Socket +{ + +class ConnectionManager : public std::enable_shared_from_this +{ +public: + ConnectionManager(); + + ConnectionID AddConnection(boost::asio::ip::tcp::socket&& p_socket, + const PacketHandlerMapPtr& p_handlerMap, + std::uint32_t p_heartbeatIntervalSeconds); + + void RemoveConnection(ConnectionID p_connectionID); + + Connection::Ptr GetConnection(ConnectionID p_connectionID); + + void SetEventOnRemoving(std::function p_event); + + void StopAll(); + +private: + inline static std::uint32_t GetPosition(ConnectionID p_connectionID); + +private: + static constexpr std::uint32_t c_connectionPoolSize = 1 << 8; + + static constexpr std::uint32_t c_connectionPoolMask = c_connectionPoolSize - 1; + + struct ConnectionItem + { + ConnectionItem(); + + std::atomic_bool m_isEmpty; + + Connection::Ptr m_connection; + }; + + // Start from 1. 0 means not assigned. + std::atomic m_nextConnectionID; + + std::atomic m_connectionCount; + + std::array m_connections; + + Helper::Concurrent::SpinLock m_spinLock; + + std::function m_eventOnRemoving; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_CONNECTIONMANAGER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Packet.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Packet.h new file mode 100644 index 0000000000..8c99b09fed --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Packet.h @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_PACKET_H_ +#define _SPTAG_SOCKET_PACKET_H_ + +#include "Common.h" + +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +enum class PacketType : std::uint8_t +{ + Undefined = 0x00, + + HeartbeatRequest = 0x01, + + RegisterRequest = 0x02, + + SearchRequest = 0x03, + + ResponseMask = 0x80, + + HeartbeatResponse = ResponseMask | HeartbeatRequest, + + RegisterResponse = ResponseMask | RegisterRequest, + + SearchResponse = ResponseMask | SearchRequest +}; + + +enum class PacketProcessStatus : std::uint8_t +{ + Ok = 0x00, + + Timeout = 0x01, + + Dropped = 0x02, + + Failed = 0x03 +}; + + +struct PacketHeader +{ + static constexpr std::size_t c_bufferSize = 16; + + PacketHeader(); + PacketHeader(PacketHeader&& p_right); + PacketHeader(const PacketHeader& p_right); + + std::size_t WriteBuffer(std::uint8_t* p_buffer); + + void ReadBuffer(const std::uint8_t* p_buffer); + + PacketType m_packetType; + + PacketProcessStatus m_processStatus; + + std::uint32_t m_bodyLength; + + // Meaning of this is different with different PacketType. + // In most request case, it means connection expeced for response. + // In most response case, it means connection which handled request. + ConnectionID m_connectionID; + + ResourceID m_resourceID; +}; + + +static_assert(sizeof(PacketHeader) <= PacketHeader::c_bufferSize, ""); + + +class Packet +{ +public: + Packet(); + Packet(Packet&& p_right); + Packet(const Packet& p_right); + + PacketHeader& Header(); + + std::uint8_t* HeaderBuffer() const; + + std::uint8_t* Body() const; + + std::uint8_t* Buffer() const; + + std::uint32_t BufferLength() const; + + std::uint32_t BufferCapacity() const; + + void AllocateBuffer(std::uint32_t p_bodyCapacity); + +private: + PacketHeader m_header; + + std::shared_ptr m_buffer; + + std::uint32_t m_bufferCapacity; +}; + + +struct PacketTypeHash +{ + std::size_t operator()(const PacketType& p_val) const + { + return static_cast(p_val); + } +}; + + +typedef std::function PacketHandler; + +typedef std::unordered_map PacketHandlerMap; +typedef std::shared_ptr PacketHandlerMapPtr; + + +namespace PacketTypeHelper +{ + +bool IsRequestPacket(PacketType p_type); + +bool IsResponsePacket(PacketType p_type); + +PacketType GetCrosspondingResponseType(PacketType p_type); + +} + + +} // namespace SPTAG +} // namespace Socket + +#endif // _SPTAG_SOCKET_SOCKETSERVER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/RemoteSearchQuery.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/RemoteSearchQuery.h new file mode 100644 index 0000000000..900aa6cb16 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/RemoteSearchQuery.h @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_REMOTESEARCHQUERY_H_ +#define _SPTAG_SOCKET_REMOTESEARCHQUERY_H_ + +#include "inc/Core/CommonDataStructure.h" +#include "inc/Core/SearchQuery.h" + +#include +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +// TODO: use Bond replace below structures. + +struct RemoteQuery +{ + static constexpr std::uint16_t MajorVersion() { return 1; } + static constexpr std::uint16_t MirrorVersion() { return 0; } + + enum class QueryType : std::uint8_t + { + String = 0 + }; + + RemoteQuery(); + + std::size_t EstimateBufferSize() const; + + std::uint8_t* Write(std::uint8_t* p_buffer) const; + + const std::uint8_t* Read(const std::uint8_t* p_buffer); + + + QueryType m_type; + + std::string m_queryString; +}; + + +struct IndexSearchResult +{ + std::string m_indexName; + + QueryResult m_results; +}; + + +struct RemoteSearchResult +{ + static constexpr std::uint16_t MajorVersion() { return 1; } + static constexpr std::uint16_t MirrorVersion() { return 0; } + + enum class ResultStatus : std::uint8_t + { + Success = 0, + + Timeout = 1, + + FailedNetwork = 2, + + FailedExecute = 3, + + Dropped = 4 + }; + + RemoteSearchResult(); + + RemoteSearchResult(const RemoteSearchResult& p_right); + + RemoteSearchResult(RemoteSearchResult&& p_right); + + RemoteSearchResult& operator=(RemoteSearchResult&& p_right); + + std::size_t EstimateBufferSize() const; + + std::uint8_t* Write(std::uint8_t* p_buffer) const; + + const std::uint8_t* Read(const std::uint8_t* p_buffer); + + + ResultStatus m_status; + + std::vector m_allIndexResults; +}; + + + +} // namespace SPTAG +} // namespace Socket + +#endif // _SPTAG_SOCKET_REMOTESEARCHQUERY_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ResourceManager.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ResourceManager.h new file mode 100644 index 0000000000..404cac830f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/ResourceManager.h @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_RESOURCEMANAGER_H_ +#define _SPTAG_SOCKET_RESOURCEMANAGER_H_ + +#include "Common.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace std +{ +typedef atomic atomic_uint32_t; +} + +namespace SPTAG +{ +namespace Socket +{ + +template +class ResourceManager : public std::enable_shared_from_this> +{ +public: + typedef std::function)> TimeoutCallback; + + ResourceManager() + : m_nextResourceID(1), + m_isStopped(false), + m_timeoutItemCount(0) + { + m_timeoutChecker = std::thread(&ResourceManager::StartCheckTimeout, this); + } + + + ~ResourceManager() + { + m_isStopped = true; + m_timeoutChecker.join(); + } + + + ResourceID Add(const std::shared_ptr& p_resource, + std::uint32_t p_timeoutMilliseconds, + TimeoutCallback p_timeoutCallback) + { + ResourceID rid = m_nextResourceID.fetch_add(1); + while (c_invalidResourceID == rid) + { + rid = m_nextResourceID.fetch_add(1); + } + + { + std::lock_guard guard(m_resourcesMutex); + m_resources.emplace(rid, p_resource); + } + + if (p_timeoutMilliseconds > 0) + { + std::unique_ptr item(new ResourceItem); + + item->m_resourceID = rid; + item->m_callback = std::move(p_timeoutCallback); + item->m_expireTime = m_clock.now() + std::chrono::milliseconds(p_timeoutMilliseconds); + + { + std::lock_guard guard(m_timeoutListMutex); + m_timeoutList.emplace_back(std::move(item)); + } + + ++m_timeoutItemCount; + } + + return rid; + } + + + std::shared_ptr GetAndRemove(ResourceID p_resourceID) + { + std::shared_ptr ret; + std::lock_guard guard(m_resourcesMutex); + auto iter = m_resources.find(p_resourceID); + if (iter != m_resources.end()) + { + ret = iter->second; + m_resources.erase(iter); + } + + return ret; + } + + + void Remove(ResourceID p_resourceID) + { + std::lock_guard guard(m_resourcesMutex); + auto iter = m_resources.find(p_resourceID); + if (iter != m_resources.end()) + { + m_resources.erase(iter); + } + } + +private: + void StartCheckTimeout() + { + std::vector> timeouted; + timeouted.reserve(1024); + while (!m_isStopped) + { + if (m_timeoutItemCount > 0) + { + std::lock_guard guard(m_timeoutListMutex); + while (!m_timeoutList.empty() + && m_timeoutList.front()->m_expireTime <= m_clock.now()) + { + timeouted.emplace_back(std::move(m_timeoutList.front())); + m_timeoutList.pop_front(); + --m_timeoutItemCount; + } + } + + if (timeouted.empty()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + else + { + for (auto& item : timeouted) + { + auto resource = GetAndRemove(item->m_resourceID); + if (nullptr != resource) + { + item->m_callback(std::move(resource)); + } + } + + timeouted.clear(); + } + } + } + + +private: + struct ResourceItem + { + ResourceItem() + : m_resourceID(c_invalidResourceID) + { + } + + ResourceID m_resourceID; + + TimeoutCallback m_callback; + + std::chrono::time_point m_expireTime; + }; + + std::deque> m_timeoutList; + + std::atomic m_timeoutItemCount; + + std::mutex m_timeoutListMutex; + + std::unordered_map> m_resources; + + std::atomic m_nextResourceID; + + std::mutex m_resourcesMutex; + + std::chrono::high_resolution_clock m_clock; + + std::thread m_timeoutChecker; + + bool m_isStopped; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_RESOURCEMANAGER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Server.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Server.h new file mode 100644 index 0000000000..aac97bf84b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/Server.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_SERVER_H_ +#define _SPTAG_SOCKET_SERVER_H_ + +#include "Connection.h" +#include "ConnectionManager.h" +#include "Packet.h" + +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ + +class Server +{ +public: + Server(const std::string& p_address, + const std::string& p_port, + const PacketHandlerMapPtr& p_handlerMap, + std::size_t p_threadNum); + + ~Server(); + + void StartListen(); + + void SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback); + + void SetEventOnConnectionClose(std::function p_event); + +private: + void StartAccept(); + +private: + boost::asio::io_context m_ioContext; + + boost::asio::ip::tcp::acceptor m_acceptor; + + std::shared_ptr m_connectionManager; + + std::vector m_threadPool; + + const PacketHandlerMapPtr m_requestHandlerMap; +}; + + +} // namespace Socket +} // namespace SPTAG + +#endif // _SPTAG_SOCKET_SERVER_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/SimpleSerialization.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/SimpleSerialization.h new file mode 100644 index 0000000000..6da925625b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Socket/SimpleSerialization.h @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SOCKET_SIMPLESERIALIZATION_H_ +#define _SPTAG_SOCKET_SIMPLESERIALIZATION_H_ + +#include "inc/Core/CommonDataStructure.h" + +#include +#include +#include +#include + +namespace SPTAG +{ +namespace Socket +{ +namespace SimpleSerialization +{ + + template + inline std::uint8_t* + SimpleWriteBuffer(const T& p_val, std::uint8_t* p_buffer) + { + static_assert(std::is_fundamental::value || std::is_enum::value, + "Only applied for fundanmental type."); + + *(reinterpret_cast(p_buffer)) = p_val; + return p_buffer + sizeof(T); + } + + + template + inline const std::uint8_t* + SimpleReadBuffer(const std::uint8_t* p_buffer, T& p_val) + { + static_assert(std::is_fundamental::value || std::is_enum::value, + "Only applied for fundanmental type."); + + p_val = *(reinterpret_cast(p_buffer)); + return p_buffer + sizeof(T); + } + + + template + inline std::size_t + EstimateBufferSize(const T& p_val) + { + static_assert(std::is_fundamental::value || std::is_enum::value, + "Only applied for fundanmental type."); + + return sizeof(T); + } + + + template<> + inline std::uint8_t* + SimpleWriteBuffer(const std::string& p_val, std::uint8_t* p_buffer) + { + p_buffer = SimpleWriteBuffer(static_cast(p_val.size()), p_buffer); + + std::memcpy(p_buffer, p_val.c_str(), p_val.size()); + return p_buffer + p_val.size(); + } + + + template<> + inline const std::uint8_t* + SimpleReadBuffer(const std::uint8_t* p_buffer, std::string& p_val) + { + p_val.clear(); + std::uint32_t len = 0; + p_buffer = SimpleReadBuffer(p_buffer, len); + + if (len > 0) + { + p_val.reserve(len); + p_val.assign(reinterpret_cast(p_buffer), len); + } + + return p_buffer + len; + } + + + template<> + inline std::size_t + EstimateBufferSize(const std::string& p_val) + { + return sizeof(std::uint32_t) + p_val.size(); + } + + + template<> + inline std::uint8_t* + SimpleWriteBuffer(const ByteArray& p_val, std::uint8_t* p_buffer) + { + p_buffer = SimpleWriteBuffer(static_cast(p_val.Length()), p_buffer); + + std::memcpy(p_buffer, p_val.Data(), p_val.Length()); + return p_buffer + p_val.Length(); + } + + + template<> + inline const std::uint8_t* + SimpleReadBuffer(const std::uint8_t* p_buffer, ByteArray& p_val) + { + p_val.Clear(); + std::uint32_t len = 0; + p_buffer = SimpleReadBuffer(p_buffer, len); + + if (len > 0) + { + p_val = ByteArray::Alloc(len); + std::memcpy(p_val.Data(), p_buffer, len); + } + + return p_buffer + len; + } + + + template<> + inline std::size_t + EstimateBufferSize(const ByteArray& p_val) + { + return sizeof(std::uint32_t) + p_val.Length(); + } + + + template + inline std::uint8_t* + SimpleWriteSharedPtrBuffer(const std::shared_ptr& p_val, std::uint8_t* p_buffer) + { + if (nullptr == p_val) + { + return SimpleWriteBuffer(false, p_buffer); + } + + p_buffer = SimpleWriteBuffer(true, p_buffer); + p_buffer = SimpleWriteBuffer(*p_val, p_buffer); + return p_buffer; + } + + + template + inline const std::uint8_t* + SimpleReadSharedPtrBuffer(const std::uint8_t* p_buffer, std::shared_ptr& p_val) + { + p_val.reset(); + bool isNotNull = false; + p_buffer = SimpleReadBuffer(p_buffer, isNotNull); + + if (isNotNull) + { + p_val.reset(new T); + p_buffer = SimpleReadBuffer(p_buffer, *p_val); + } + + return p_buffer; + } + + + template + inline std::size_t + EstimateSharedPtrBufferSize(const std::shared_ptr& p_val) + { + return sizeof(bool) + (nullptr == p_val ? 0 : EstimateBufferSize(*p_val)); + } + +} // namespace SimpleSerialization +} // namespace SPTAG +} // namespace Socket + +#endif // _SPTAG_SOCKET_SIMPLESERIALIZATION_H_ diff --git a/core/src/index/thirdparty/SPTAG/AnnService/packages.config b/core/src/index/thirdparty/SPTAG/AnnService/packages.config new file mode 100644 index 0000000000..2dbed9b530 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/packages.config @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorContext.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorContext.cpp new file mode 100644 index 0000000000..a36c2c61e9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorContext.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Aggregator/AggregatorContext.h" +#include "inc/Helper/SimpleIniReader.h" + +using namespace SPTAG; +using namespace SPTAG::Aggregator; + +RemoteMachine::RemoteMachine() + : m_connectionID(Socket::c_invalidConnectionID), + m_status(RemoteMachineStatus::Disconnected) +{ +} + + +AggregatorContext::AggregatorContext(const std::string& p_filePath) + : m_initialized(false) +{ + Helper::IniReader iniReader; + if (ErrorCode::Success != iniReader.LoadIniFile(p_filePath)) + { + return; + } + + m_settings.reset(new AggregatorSettings); + + m_settings->m_listenAddr = iniReader.GetParameter("Service", "ListenAddr", std::string("0.0.0.0")); + m_settings->m_listenPort = iniReader.GetParameter("Service", "ListenPort", std::string("8100")); + m_settings->m_threadNum = iniReader.GetParameter("Service", "ThreadNumber", static_cast(8)); + m_settings->m_socketThreadNum = iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); + + const std::string emptyStr; + + std::uint32_t serverNum = iniReader.GetParameter("Servers", "Number", static_cast(0)); + + for (std::uint32_t i = 0; i < serverNum; ++i) + { + std::string sectionName("Server_"); + sectionName += std::to_string(i); + if (!iniReader.DoesSectionExist(sectionName)) + { + continue; + } + + std::shared_ptr remoteMachine(new RemoteMachine); + + remoteMachine->m_address = iniReader.GetParameter(sectionName, "Address", emptyStr); + remoteMachine->m_port = iniReader.GetParameter(sectionName, "Port", emptyStr); + + if (remoteMachine->m_address.empty() || remoteMachine->m_port.empty()) + { + continue; + } + + m_remoteServers.push_back(std::move(remoteMachine)); + } + + m_initialized = true; +} + + +AggregatorContext::~AggregatorContext() +{ +} + + +bool +AggregatorContext::IsInitialized() const +{ + return m_initialized; +} + + +const std::vector>& +AggregatorContext::GetRemoteServers() const +{ + return m_remoteServers; +} + + +const std::shared_ptr& +AggregatorContext::GetSettings() const +{ + return m_settings; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorExecutionContext.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorExecutionContext.cpp new file mode 100644 index 0000000000..8f7a28375a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorExecutionContext.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Aggregator/AggregatorExecutionContext.h" + +using namespace SPTAG; +using namespace SPTAG::Aggregator; + +AggregatorExecutionContext::AggregatorExecutionContext(std::size_t p_totalServerNumber, + Socket::PacketHeader p_requestHeader) + : m_requestHeader(std::move(p_requestHeader)) +{ + m_results.clear(); + m_results.resize(p_totalServerNumber); + + m_unfinishedCount = static_cast(p_totalServerNumber); +} + + +AggregatorExecutionContext::~AggregatorExecutionContext() +{ +} + + +std::size_t +AggregatorExecutionContext::GetServerNumber() const +{ + return m_results.size(); +} + + +AggregatorResult& +AggregatorExecutionContext::GetResult(std::size_t p_num) +{ + return m_results[p_num]; +} + + +const Socket::PacketHeader& +AggregatorExecutionContext::GetRequestHeader() const +{ + return m_requestHeader; +} + + +bool +AggregatorExecutionContext::IsCompletedAfterFinsh(std::uint32_t p_finishedCount) +{ + auto lastCount = m_unfinishedCount.fetch_sub(p_finishedCount); + return lastCount <= p_finishedCount; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorService.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorService.cpp new file mode 100644 index 0000000000..24c1672cde --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorService.cpp @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Aggregator/AggregatorService.h" + +using namespace SPTAG; +using namespace SPTAG::Aggregator; + +AggregatorService::AggregatorService() + : m_shutdownSignals(m_ioContext), + m_pendingConnectServersTimer(m_ioContext) +{ +} + + +AggregatorService::~AggregatorService() +{ +} + + +bool +AggregatorService::Initialize() +{ + std::string configFilePath = "Aggregator.ini"; + m_aggregatorContext.reset(new AggregatorContext(configFilePath)); + + m_initalized = m_aggregatorContext->IsInitialized(); + + return m_initalized; +} + + +void +AggregatorService::Run() +{ + auto threadNum = max((SPTAG::SizeType)1, GetContext()->GetSettings()->m_threadNum); + m_threadPool.reset(new boost::asio::thread_pool(threadNum)); + + StartClient(); + StartListen(); + WaitForShutdown(); +} + + +void +AggregatorService::StartClient() +{ + auto context = GetContext(); + Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); + handlerMap->emplace(Socket::PacketType::SearchResponse, + [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) + { + boost::asio::post(*m_threadPool, + std::bind(&AggregatorService::SearchResponseHanlder, + this, + p_srcID, + std::move(p_packet))); + }); + + + m_socketClient.reset(new Socket::Client(handlerMap, + context->GetSettings()->m_socketThreadNum, + 30)); + + m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) + { + auto context = this->GetContext(); + for (const auto& server : context->GetRemoteServers()) + { + if (nullptr != server && p_cid == server->m_connectionID) + { + server->m_status = RemoteMachineStatus::Disconnected; + this->AddToPendingServers(server); + } + } + }); + + { + std::lock_guard guard(m_pendingConnectServersMutex); + m_pendingConnectServers = context->GetRemoteServers(); + } + + ConnectToPendingServers(); +} + + +void +AggregatorService::StartListen() +{ + auto context = GetContext(); + Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); + handlerMap->emplace(Socket::PacketType::SearchRequest, + [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) + { + boost::asio::post(*m_threadPool, + std::bind(&AggregatorService::SearchRequestHanlder, + this, + p_srcID, + std::move(p_packet))); + }); + + m_socketServer.reset(new Socket::Server(context->GetSettings()->m_listenAddr, + context->GetSettings()->m_listenPort, + handlerMap, + context->GetSettings()->m_socketThreadNum)); + + fprintf(stderr, + "Start to listen %s:%s ...\n", + context->GetSettings()->m_listenAddr.c_str(), + context->GetSettings()->m_listenPort.c_str()); +} + + +void +AggregatorService::WaitForShutdown() +{ + m_shutdownSignals.add(SIGINT); + m_shutdownSignals.add(SIGTERM); +#ifdef SIGQUIT + m_shutdownSignals.add(SIGQUIT); +#endif + + m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) + { + fprintf(stderr, "Received shutdown signals.\n"); + m_pendingConnectServersTimer.cancel(); + }); + + m_ioContext.run(); + fprintf(stderr, "Start shutdown procedure.\n"); + + m_socketServer.reset(); + m_threadPool->stop(); + m_threadPool->join(); +} + + +void +AggregatorService::ConnectToPendingServers() +{ + auto context = GetContext(); + std::vector> pendingList; + pendingList.reserve(context->GetRemoteServers().size()); + + { + std::lock_guard guard(m_pendingConnectServersMutex); + pendingList.swap(m_pendingConnectServers); + } + + for (auto& pendingServer : pendingList) + { + if (pendingServer->m_status != RemoteMachineStatus::Disconnected) + { + continue; + } + + pendingServer->m_status = RemoteMachineStatus::Connecting; + std::shared_ptr server = pendingServer; + auto runner = [server, this]() + { + ErrorCode errCode; + auto cid = m_socketClient->ConnectToServer(server->m_address, server->m_port, errCode); + if (Socket::c_invalidConnectionID == cid) + { + if (ErrorCode::Socket_FailedResolveEndPoint == errCode) + { + fprintf(stderr, + "[Error] Failed to resolve %s %s.\n", + server->m_address.c_str(), + server->m_port.c_str()); + } + else + { + this->AddToPendingServers(std::move(server)); + } + } + else + { + server->m_connectionID = cid; + server->m_status = RemoteMachineStatus::Connected; + } + }; + boost::asio::post(*m_threadPool, std::move(runner)); + } + + m_pendingConnectServersTimer.expires_from_now(boost::posix_time::seconds(30)); + m_pendingConnectServersTimer.async_wait([this](const boost::system::error_code& p_ec) + { + if (boost::asio::error::operation_aborted != p_ec) + { + ConnectToPendingServers(); + } + }); +} + + +void +AggregatorService::AddToPendingServers(std::shared_ptr p_remoteServer) +{ + std::lock_guard guard(m_pendingConnectServersMutex); + m_pendingConnectServers.emplace_back(std::move(p_remoteServer)); +} + + +void +AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + auto context = GetContext(); + std::vector remoteServers; + remoteServers.reserve(context->GetRemoteServers().size()); + + for (const auto& server : context->GetRemoteServers()) + { + if (RemoteMachineStatus::Connected != server->m_status) + { + continue; + } + + remoteServers.push_back(server->m_connectionID); + } + + Socket::PacketHeader requestHeader = p_packet.Header(); + if (Socket::c_invalidConnectionID == requestHeader.m_connectionID) + { + requestHeader.m_connectionID = p_localConnectionID; + } + + std::shared_ptr executionContext( + new AggregatorExecutionContext(remoteServers.size(), requestHeader)); + + for (std::uint32_t i = 0; i < remoteServers.size(); ++i) + { + AggregatorCallback callback = [this, executionContext, i](Socket::RemoteSearchResult p_result) + { + executionContext->GetResult(i).reset(new Socket::RemoteSearchResult(std::move(p_result))); + if (executionContext->IsCompletedAfterFinsh(1)) + { + this->AggregateResults(std::move(executionContext)); + } + }; + + auto timeoutCallback = [](std::shared_ptr p_callback) + { + if (nullptr != p_callback) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::Timeout; + + (*p_callback)(std::move(result)); + } + }; + + auto connectCallback = [callback](bool p_connectSucc) + { + if (!p_connectSucc) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::FailedNetwork; + + callback(std::move(result)); + } + }; + + Socket::Packet packet; + packet.Header().m_packetType = Socket::PacketType::SearchRequest; + packet.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + packet.Header().m_bodyLength = p_packet.Header().m_bodyLength; + packet.Header().m_connectionID = Socket::c_invalidConnectionID; + packet.Header().m_resourceID = m_aggregatorCallbackManager.Add(std::make_shared(std::move(callback)), + context->GetSettings()->m_searchTimeout, + std::move(timeoutCallback)); + + packet.AllocateBuffer(packet.Header().m_bodyLength); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + memcpy(packet.Body(), p_packet.Body(), packet.Header().m_bodyLength); + + m_socketClient->SendPacket(remoteServers[i], std::move(packet), connectCallback); + } +} + + +void +AggregatorService::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + auto callback = m_aggregatorCallbackManager.GetAndRemove(p_packet.Header().m_resourceID); + if (nullptr == callback) + { + return; + } + + if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::FailedExecute; + + (*callback)(std::move(result)); + } + else + { + Socket::RemoteSearchResult result; + result.Read(p_packet.Body()); + (*callback)(std::move(result)); + } +} + + +std::shared_ptr +AggregatorService::GetContext() +{ + // Add mutex if necessary. + return m_aggregatorContext; +} + + +void +AggregatorService::AggregateResults(std::shared_ptr p_exectionContext) +{ + if (nullptr == p_exectionContext) + { + return; + } + + Socket::Packet packet; + packet.Header().m_packetType = Socket::PacketType::SearchResponse; + packet.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + packet.Header().m_resourceID = p_exectionContext->GetRequestHeader().m_resourceID; + + Socket::RemoteSearchResult remoteResult; + remoteResult.m_status = Socket::RemoteSearchResult::ResultStatus::Success; + + std::size_t resultNum = 0; + for (std::size_t i = 0; i < p_exectionContext->GetServerNumber(); ++i) + { + const auto& result = p_exectionContext->GetResult(i); + if (nullptr == result) + { + continue; + } + + resultNum += result->m_allIndexResults.size(); + } + + remoteResult.m_allIndexResults.reserve(resultNum); + for (std::size_t i = 0; i < p_exectionContext->GetServerNumber(); ++i) + { + const auto& result = p_exectionContext->GetResult(i); + if (nullptr == result) + { + continue; + } + + for (auto& indexRes : result->m_allIndexResults) + { + remoteResult.m_allIndexResults.emplace_back(std::move(indexRes)); + } + } + + std::uint32_t cap = static_cast(remoteResult.EstimateBufferSize()); + packet.AllocateBuffer(cap); + packet.Header().m_bodyLength = static_cast(remoteResult.Write(packet.Body()) - packet.Body()); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + + m_socketServer->SendPacket(p_exectionContext->GetRequestHeader().m_connectionID, + std::move(packet), + nullptr); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorSettings.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorSettings.cpp new file mode 100644 index 0000000000..a3e2bc6806 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/AggregatorSettings.cpp @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Aggregator/AggregatorSettings.h" + +using namespace SPTAG; +using namespace SPTAG::Aggregator; + +AggregatorSettings::AggregatorSettings() + : m_searchTimeout(100), + m_threadNum(8), + m_socketThreadNum(8) +{ +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/main.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/main.cpp new file mode 100644 index 0000000000..2a06025d5e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Aggregator/main.cpp @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Aggregator/AggregatorService.h" + +SPTAG::Aggregator::AggregatorService g_service; + +int main(int argc, char* argv[]) +{ + if (!g_service.Initialize()) + { + return 1; + } + + g_service.Run(); + + return 0; +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Client/ClientWrapper.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/ClientWrapper.cpp new file mode 100644 index 0000000000..7e91c63195 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/ClientWrapper.cpp @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Client/ClientWrapper.h" + +using namespace SPTAG; +using namespace SPTAG::Socket; +using namespace SPTAG::Client; + +ClientWrapper::ClientWrapper(const ClientOptions& p_options) + : m_options(p_options), + m_unfinishedJobCount(0), + m_isWaitingFinish(false) +{ + m_client.reset(new SPTAG::Socket::Client(GetHandlerMap(), p_options.m_socketThreadNum, 30)); + m_client->SetEventOnConnectionClose(std::bind(&ClientWrapper::HandleDeadConnection, + this, + std::placeholders::_1)); + + m_connections.reserve(m_options.m_threadNum); + for (std::uint32_t i = 0; i < m_options.m_threadNum; ++i) + { + SPTAG::ErrorCode errCode; + ConnectionPair conn(c_invalidConnectionID, c_invalidConnectionID); + conn.first = m_client->ConnectToServer(p_options.m_serverAddr, p_options.m_serverPort, errCode); + if (SPTAG::ErrorCode::Socket_FailedResolveEndPoint == errCode) + { + fprintf(stderr, "Unable to resolve remote address.\n"); + return; + } + + if (c_invalidConnectionID != conn.first) + { + m_connections.emplace_back(std::move(conn)); + } + } +} + + +ClientWrapper::~ClientWrapper() +{ +} + + +void +ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, + Callback p_callback, + const ClientOptions& p_options) +{ + if (!bool(p_callback)) + { + return; + } + + auto conn = GetConnection(); + + auto timeoutCallback = [this](std::shared_ptr p_callback) + { + DecreaseUnfnishedJobCount(); + if (nullptr != p_callback) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::Timeout; + + (*p_callback)(std::move(result)); + } + }; + + + auto connectCallback = [p_callback, this](bool p_connectSucc) + { + if (!p_connectSucc) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::FailedNetwork; + + p_callback(std::move(result)); + DecreaseUnfnishedJobCount(); + } + }; + + Socket::Packet packet; + packet.Header().m_connectionID = c_invalidConnectionID; + packet.Header().m_packetType = PacketType::SearchRequest; + packet.Header().m_processStatus = PacketProcessStatus::Ok; + packet.Header().m_resourceID = m_callbackManager.Add(std::make_shared(std::move(p_callback)), + p_options.m_searchTimeout, + std::move(timeoutCallback)); + + packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize()); + packet.AllocateBuffer(packet.Header().m_bodyLength); + p_query.Write(packet.Body()); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + + ++m_unfinishedJobCount; + m_client->SendPacket(conn.first, std::move(packet), connectCallback); +} + + +void +ClientWrapper::WaitAllFinished() +{ + if (m_unfinishedJobCount > 0) + { + std::unique_lock lock(m_waitingMutex); + if (m_unfinishedJobCount > 0) + { + m_isWaitingFinish = true; + m_waitingQueue.wait(lock); + } + } +} + + +PacketHandlerMapPtr +ClientWrapper::GetHandlerMap() +{ + PacketHandlerMapPtr handlerMap(new PacketHandlerMap); + handlerMap->emplace(PacketType::RegisterResponse, + [this](ConnectionID p_localConnectionID, Packet p_packet) -> void + { + for (auto& conn : m_connections) + { + if (conn.first == p_localConnectionID) + { + conn.second = p_packet.Header().m_connectionID; + return; + } + } + }); + + handlerMap->emplace(PacketType::SearchResponse, + std::bind(&ClientWrapper::SearchResponseHanlder, + this, + std::placeholders::_1, + std::placeholders::_2)); + + return handlerMap; +} + + +void +ClientWrapper::DecreaseUnfnishedJobCount() +{ + --m_unfinishedJobCount; + if (0 == m_unfinishedJobCount) + { + std::lock_guard guard(m_waitingMutex); + if (0 == m_unfinishedJobCount && m_isWaitingFinish) + { + m_waitingQueue.notify_all(); + m_isWaitingFinish = false; + } + } +} + + +const ClientWrapper::ConnectionPair& +ClientWrapper::GetConnection() +{ + if (m_connections.size() == 1) + { + return m_connections.front(); + } + + std::size_t triedCount = 0; + std::uint32_t pos = m_spinCountOfConnection.fetch_add(1) % m_connections.size(); + while (c_invalidConnectionID == m_connections[pos].first && triedCount < m_connections.size()) + { + pos = m_spinCountOfConnection.fetch_add(1) % m_connections.size(); + ++triedCount; + } + + return m_connections[pos]; +} + + +void +ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + std::shared_ptr callback = m_callbackManager.GetAndRemove(p_packet.Header().m_resourceID); + if (nullptr == callback) + { + return; + } + + if (p_packet.Header().m_processStatus != PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::FailedExecute; + + (*callback)(std::move(result)); + } + else + { + Socket::RemoteSearchResult result; + result.Read(p_packet.Body()); + (*callback)(std::move(result)); + } + + DecreaseUnfnishedJobCount(); +} + + +void +ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid) +{ + for (auto& conn : m_connections) + { + if (conn.first == p_cid) + { + conn.first = c_invalidConnectionID; + conn.second = c_invalidConnectionID; + + SPTAG::ErrorCode errCode; + while (c_invalidConnectionID == conn.first) + { + conn.first = m_client->ConnectToServer(m_options.m_serverAddr, m_options.m_serverPort, errCode); + if (SPTAG::ErrorCode::Socket_FailedResolveEndPoint == errCode) + { + break; + } + } + + return; + } + } +} + + +bool +ClientWrapper::IsAvailable() const +{ + return !m_connections.empty(); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Client/Options.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/Options.cpp new file mode 100644 index 0000000000..bb067d3d58 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/Options.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Client/Options.h" +#include "inc/Helper/StringConvert.h" + +#include + +using namespace SPTAG; +using namespace SPTAG::Client; + +ClientOptions::ClientOptions() + : m_searchTimeout(9000), + m_threadNum(1), + m_socketThreadNum(2) +{ + AddRequiredOption(m_serverAddr, "-s", "--server", "Server address."); + AddRequiredOption(m_serverPort, "-p", "--port", "Server port."); + AddOptionalOption(m_searchTimeout, "-t", "", "Search timeout."); + AddOptionalOption(m_threadNum, "-cth", "", "Client Thread Number."); + AddOptionalOption(m_socketThreadNum, "-sth", "", "Socket Thread Number."); +} + + +ClientOptions::~ClientOptions() +{ +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Client/main.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/main.cpp new file mode 100644 index 0000000000..52888e3374 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Client/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Client/Options.h" +#include "inc/Client/ClientWrapper.h" + +#include +#include +#include + +std::unique_ptr g_client; + +int main(int argc, char** argv) +{ + SPTAG::Client::ClientOptions options; + if (!options.Parse(argc - 1, argv + 1)) + { + return 1; + } + + g_client.reset(new SPTAG::Client::ClientWrapper(options)); + if (!g_client->IsAvailable()) + { + return 1; + } + + g_client->WaitAllFinished(); + fprintf(stdout, "connection done\n"); + + std::string line; + std::cout << "Query: " << std::flush; + while (std::getline(std::cin, line)) + { + if (line.empty()) + { + break; + } + + SPTAG::Socket::RemoteQuery query; + query.m_type = SPTAG::Socket::RemoteQuery::QueryType::String; + query.m_queryString = std::move(line); + + SPTAG::Socket::RemoteSearchResult result; + auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) + { + result = std::move(p_result); + }; + + g_client->SendQueryAsync(query, callback, options); + g_client->WaitAllFinished(); + + std::cout << "Status: " << static_cast(result.m_status) << std::endl; + + for (const auto& indexRes : result.m_allIndexResults) + { + std::cout << "Index: " << indexRes.m_indexName << std::endl; + + int idx = 0; + for (const auto& res : indexRes.m_results) + { + std::cout << "------------------" << std::endl; + std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; + if (indexRes.m_results.WithMeta()) + { + const auto& metadata = indexRes.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char*)metadata.Data(), metadata.Length()); + } + std::cout << std::endl; + ++idx; + } + } + + std::cout << "Query: " << std::flush; + } + + return 0; +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/BKT/BKTIndex.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/BKT/BKTIndex.cpp new file mode 100644 index 0000000000..e8928726f4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/BKT/BKTIndex.cpp @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/BKT/Index.h" + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. +#pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable:4244) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable:4127) // conditional expression is constant + +namespace SPTAG +{ + namespace BKT + { + template + ErrorCode Index::LoadConfig(Helper::IniReader& p_reader) + { +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_reader.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + return ErrorCode::Success; + } + + template + ErrorCode Index::LoadIndexDataFromMemory(const std::vector& p_indexBlobs) + { + if (p_indexBlobs.size() < 3) return ErrorCode::LackOfInputs; + + if (!m_pSamples.Load((char*)p_indexBlobs[0].Data())) return ErrorCode::FailedParseValue; + if (!m_pTrees.LoadTrees((char*)p_indexBlobs[1].Data())) return ErrorCode::FailedParseValue; + if (!m_pGraph.LoadGraph((char*)p_indexBlobs[2].Data())) return ErrorCode::FailedParseValue; + if (p_indexBlobs.size() > 3 && !m_deletedID.load((char*)p_indexBlobs[3].Data())) return ErrorCode::FailedParseValue; + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; + } + + template + ErrorCode Index::LoadIndexData(const std::string& p_folderPath) + { + if (!m_pSamples.Load(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!m_pTrees.LoadTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail; + if (!m_pGraph.LoadGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (!m_deletedID.load(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail; + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; + } + + template + ErrorCode Index::SaveConfig(std::ostream& p_configOut) const + { +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + p_configOut << RepresentStr << "=" << GetParameter(RepresentStr) << std::endl; + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + p_configOut << std::endl; + return ErrorCode::Success; + } + + template + ErrorCode + Index::SaveIndexData(const std::string& p_folderPath) + { + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + if (!m_pSamples.Save(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!m_pTrees.SaveTrees(p_folderPath + m_sBKTFilename)) return ErrorCode::Fail; + if (!m_pGraph.SaveGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (!m_deletedID.save(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail; + return ErrorCode::Success; + } + + template + ErrorCode Index::SaveIndexData(const std::vector& p_indexStreams) + { + if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; + + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + if (!m_pSamples.Save(*p_indexStreams[0])) return ErrorCode::Fail; + if (!m_pTrees.SaveTrees(*p_indexStreams[1])) return ErrorCode::Fail; + if (!m_pGraph.SaveGraph(*p_indexStreams[2])) return ErrorCode::Fail; + if (!m_deletedID.save(*p_indexStreams[3])) return ErrorCode::Fail; + return ErrorCode::Success; + } + +#pragma region K-NN search + +#define Search(CheckDeleted1) \ + m_pTrees.InitSearchTrees(this, p_query, p_space); \ + const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; \ + while (!p_space.m_SPTQueue.empty()) { \ + m_pTrees.SearchTrees(this, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ + while (!p_space.m_NGQueue.empty()) { \ + COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ + const SizeType *node = m_pGraph[gnode.node]; \ + _mm_prefetch((const char *)node, _MM_HINT_T0); \ + CheckDeleted1 { \ + if (p_query.AddPoint(gnode.node, gnode.distance)) { \ + p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ + SizeType checkNode = node[checkPos]; \ + if (checkNode < -1) { \ + const COMMON::BKTNode& tnode = m_pTrees[-2 - checkNode]; \ + for (SizeType i = -tnode.childStart; i < tnode.childEnd; i++) { \ + if (!p_query.AddPoint(m_pTrees[i].centerid, gnode.distance)) break; \ + } \ + } \ + } \ + else { \ + p_space.m_iNumOfContinuousNoBetterPropagation++; \ + if (p_space.m_iNumOfContinuousNoBetterPropagation > p_space.m_iContinuousLimit || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ + p_query.SortResult(); return; \ + } \ + } \ + } \ + for (DimensionType i = 0; i <= checkPos; i++) { \ + _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + } \ + for (DimensionType i = 0; i <= checkPos; i++) { \ + SizeType nn_index = node[i]; \ + if (nn_index < 0) break; \ + if (p_space.CheckAndSet(nn_index)) continue; \ + float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ + p_space.m_iNumberOfCheckedLeaves++; \ + p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); \ + } \ + if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) { \ + break; \ + } \ + } \ + } \ + p_query.SortResult(); \ + + template + void Index::SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet &p_deleted) const + { + Search(if (!p_deleted.contains(gnode.node))) + } + + template + void Index::SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const + { + Search(;) + } + + template + ErrorCode + Index::SearchIndex(QueryResult &p_query) const + { + auto workSpace = m_workSpacePool->Rent(); + workSpace->Reset(m_iMaxCheck); + + if (m_deletedID.size() > 0) + SearchIndexWithDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID); + else + SearchIndexWithoutDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); + + m_workSpacePool->Return(workSpace); + + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result)); + } + } + return ErrorCode::Success; + } +#pragma endregion + + template + ErrorCode Index::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension) + { + omp_set_num_threads(m_iNumberOfThreads); + + m_pSamples.Initialize(p_vectorNum, p_dimension, (T*)p_data, false); + + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); +#pragma omp parallel for + for (SizeType i = 0; i < GetNumSamples(); i++) { + COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base); + } + } + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + + m_pTrees.BuildTrees(this); + m_pGraph.BuildGraph(this, &(m_pTrees.GetSampleMap())); + + return ErrorCode::Success; + } + + template + ErrorCode Index::RefineIndex(const std::vector& p_indexStreams) + { + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + SizeType newR = GetNumSamples(); + + std::vector indices; + std::vector reverseIndices(newR); + for (SizeType i = 0; i < newR; i++) { + if (!m_deletedID.contains(i)) { + indices.push_back(i); + reverseIndices[i] = i; + } + else { + while (m_deletedID.contains(newR - 1) && newR > i) newR--; + if (newR == i) break; + indices.push_back(newR - 1); + reverseIndices[newR - 1] = i; + newR--; + } + } + + std::cout << "Refine... from " << GetNumSamples() << "->" << newR << std::endl; + + if (false == m_pSamples.Refine(indices, *p_indexStreams[0])) return ErrorCode::Fail; + if (nullptr != m_pMetadata && (p_indexStreams.size() < 6 || ErrorCode::Success != m_pMetadata->RefineMetadata(indices, *p_indexStreams[4], *p_indexStreams[5]))) return ErrorCode::Fail; + + COMMON::BKTree newTrees(m_pTrees); + newTrees.BuildTrees(this, &indices); +#pragma omp parallel for + for (SizeType i = 0; i < newTrees.size(); i++) { + newTrees[i].centerid = reverseIndices[newTrees[i].centerid]; + } + newTrees.SaveTrees(*p_indexStreams[1]); + + m_pGraph.RefineGraph(this, indices, reverseIndices, *p_indexStreams[2], &(newTrees.GetSampleMap())); + + Helper::Concurrent::ConcurrentSet newDeletedID; + newDeletedID.save(*p_indexStreams[3]); + return ErrorCode::Success; + } + + template + ErrorCode Index::RefineIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + + std::vector streams; + streams.push_back(new std::ofstream(folderPath + m_sDataPointsFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sBKTFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sGraphFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sDeleteDataPointsFilename, std::ios::binary)); + if (nullptr != m_pMetadata) + { + streams.push_back(new std::ofstream(folderPath + m_sMetadataFile, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sMetadataIndexFile, std::ios::binary)); + } + + for (size_t i = 0; i < streams.size(); i++) + if (!(((std::ofstream*)streams[i])->is_open())) return ErrorCode::FailedCreateFile; + + ErrorCode ret = RefineIndex(streams); + + for (size_t i = 0; i < streams.size(); i++) + { + ((std::ofstream*)streams[i])->close(); + delete streams[i]; + } + return ret; + } + + template + ErrorCode Index::DeleteIndex(const void* p_vectors, SizeType p_vectorNum) { + const T* ptr_v = (const T*)p_vectors; +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < p_vectorNum; i++) { + COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); + SearchIndex(query); + + for (int i = 0; i < m_pGraph.m_iCEF; i++) { + if (query.GetResult(i)->Dist < 1e-6) { + m_deletedID.insert(query.GetResult(i)->VID); + } + } + } + return ErrorCode::Success; + } + + template + ErrorCode Index::DeleteIndex(const SizeType& p_id) { + m_deletedID.insert(p_id); + return ErrorCode::Success; + } + + template + ErrorCode Index::AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start) + { + SizeType begin, end; + { + std::lock_guard lock(m_dataAddLock); + + begin = GetNumSamples(); + end = GetNumSamples() + p_vectorNum; + + if (p_start != nullptr) *p_start = begin; + + if (begin == 0) return BuildIndex(p_vectors, p_vectorNum, p_dimension); + + if (p_dimension != GetFeatureDim()) return ErrorCode::FailedParseValue; + + if (m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum) != ErrorCode::Success || m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success) { + std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl; + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + return ErrorCode::MemoryOverFlow; + } + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); + for (SizeType i = begin; i < end; i++) { + COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base); + } + } + } + + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true); + } + std::cout << "Add " << p_vectorNum << " vectors" << std::endl; + return ErrorCode::Success; + } + + template + ErrorCode + Index::SetParameter(const char* p_param, const char* p_value) + { + if (nullptr == p_param || nullptr == p_value) return ErrorCode::Fail; + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + fprintf(stderr, "Setting %s with value %s\n", RepresentStr, p_value); \ + VarType tmp; \ + if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ + { \ + VarName = tmp; \ + } \ + } \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + return ErrorCode::Success; + } + + + template + std::string + Index::GetParameter(const char* p_param) const + { + if (nullptr == p_param) return std::string(); + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + return SPTAG::Helper::Convert::ConvertToString(VarName); \ + } \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + return std::string(); + } + } +} + +#define DefineVectorValueType(Name, Type) \ +template class SPTAG::BKT::Index; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/NeighborhoodGraph.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/NeighborhoodGraph.cpp new file mode 100644 index 0000000000..94115dd0a3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/NeighborhoodGraph.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/NeighborhoodGraph.h" +#include "inc/Core/Common/RelativeNeighborhoodGraph.h" + +using namespace SPTAG::COMMON; + +std::shared_ptr NeighborhoodGraph::CreateInstance(std::string type) +{ + std::shared_ptr res; + if (type == "RNG") + { + res.reset(new RelativeNeighborhoodGraph); + } + return res; +} \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/WorkSpacePool.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/WorkSpacePool.cpp new file mode 100644 index 0000000000..a88dbdb2d5 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/Common/WorkSpacePool.cpp @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/WorkSpacePool.h" + +using namespace SPTAG; +using namespace SPTAG::COMMON; + + +WorkSpacePool::WorkSpacePool(int p_maxCheck, SizeType p_vectorCount) + : m_maxCheck(p_maxCheck), + m_vectorCount(p_vectorCount) +{ +} + + +WorkSpacePool::~WorkSpacePool() +{ + for (auto& workSpace : m_workSpacePool) + workSpace.reset(); + m_workSpacePool.clear(); +} + + +std::shared_ptr +WorkSpacePool::Rent() +{ + std::shared_ptr workSpace; + + { + std::lock_guard lock(m_workSpacePoolMutex); + if (!m_workSpacePool.empty()) + { + workSpace = m_workSpacePool.front(); + m_workSpacePool.pop_front(); + } + else + { + workSpace.reset(new WorkSpace); + workSpace->Initialize(m_maxCheck, m_vectorCount); + } + } + return workSpace; +} + + +void +WorkSpacePool::Return(const std::shared_ptr& p_workSpace) +{ + { + std::lock_guard lock(m_workSpacePoolMutex); + m_workSpacePool.push_back(p_workSpace); + } +} + + +void +WorkSpacePool::Init(int size) +{ + for (int i = 0; i < size; i++) + { + std::shared_ptr workSpace(new WorkSpace); + workSpace->Initialize(m_maxCheck, m_vectorCount); + m_workSpacePool.push_back(std::move(workSpace)); + } +} \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/KDT/KDTIndex.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/KDT/KDTIndex.cpp new file mode 100644 index 0000000000..da3c10e095 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/KDT/KDTIndex.cpp @@ -0,0 +1,396 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/KDT/Index.h" + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. +#pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable:4244) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable:4127) // conditional expression is constant + +namespace SPTAG +{ + namespace KDT + { + template + ErrorCode Index::LoadConfig(Helper::IniReader& p_reader) + { +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_reader.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + return ErrorCode::Success; + } + + template + ErrorCode Index::LoadIndexDataFromMemory(const std::vector& p_indexBlobs) + { + if (p_indexBlobs.size() < 3) return ErrorCode::LackOfInputs; + + if (!m_pSamples.Load((char*)p_indexBlobs[0].Data())) return ErrorCode::FailedParseValue; + if (!m_pTrees.LoadTrees((char*)p_indexBlobs[1].Data())) return ErrorCode::FailedParseValue; + if (!m_pGraph.LoadGraph((char*)p_indexBlobs[2].Data())) return ErrorCode::FailedParseValue; + if (p_indexBlobs.size() > 3 && !m_deletedID.load((char*)p_indexBlobs[3].Data())) return ErrorCode::FailedParseValue; + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; + } + + template + ErrorCode Index::LoadIndexData(const std::string& p_folderPath) + { + if (!m_pSamples.Load(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!m_pTrees.LoadTrees(p_folderPath + m_sKDTFilename)) return ErrorCode::Fail; + if (!m_pGraph.LoadGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (!m_deletedID.load(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail; + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; + } + + template + ErrorCode Index::SaveConfig(std::ostream& p_configOut) const + { +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + p_configOut << RepresentStr << "=" << GetParameter(RepresentStr) << std::endl; + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + p_configOut << std::endl; + return ErrorCode::Success; + } + + template + ErrorCode Index::SaveIndexData(const std::string& p_folderPath) + { + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + if (!m_pSamples.Save(p_folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!m_pTrees.SaveTrees(p_folderPath + m_sKDTFilename)) return ErrorCode::Fail; + if (!m_pGraph.SaveGraph(p_folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (!m_deletedID.save(p_folderPath + m_sDeleteDataPointsFilename)) return ErrorCode::Fail; + return ErrorCode::Success; + } + + template + ErrorCode Index::SaveIndexData(const std::vector& p_indexStreams) + { + if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; + + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + if (!m_pSamples.Save(*p_indexStreams[0])) return ErrorCode::Fail; + if (!m_pTrees.SaveTrees(*p_indexStreams[1])) return ErrorCode::Fail; + if (!m_pGraph.SaveGraph(*p_indexStreams[2])) return ErrorCode::Fail; + if (!m_deletedID.save(*p_indexStreams[3])) return ErrorCode::Fail; + return ErrorCode::Success; + } + +#pragma region K-NN search + +#define Search(CheckDeleted1) \ + m_pTrees.InitSearchTrees(this, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ + while (!p_space.m_NGQueue.empty()) { \ + COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ + const SizeType *node = m_pGraph[gnode.node]; \ + _mm_prefetch((const char *)node, _MM_HINT_T0); \ + CheckDeleted1 { \ + if (!p_query.AddPoint(gnode.node, gnode.distance) && p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ + p_query.SortResult(); return; \ + } \ + } \ + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) \ + _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + float upperBound = max(p_query.worstDist(), gnode.distance); \ + bool bLocalOpt = true; \ + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) { \ + SizeType nn_index = node[i]; \ + if (nn_index < 0) break; \ + if (p_space.CheckAndSet(nn_index)) continue; \ + float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ + if (distance2leaf <= upperBound) bLocalOpt = false; \ + p_space.m_iNumberOfCheckedLeaves++; \ + p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); \ + } \ + if (bLocalOpt) p_space.m_iNumOfContinuousNoBetterPropagation++; \ + else p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ + if (p_space.m_iNumOfContinuousNoBetterPropagation > m_iThresholdOfNumberOfContinuousNoBetterPropagation) { \ + if (p_space.m_iNumberOfTreeCheckedLeaves <= p_space.m_iNumberOfCheckedLeaves / 10) { \ + m_pTrees.SearchTrees(this, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ + } else if (gnode.distance > p_query.worstDist()) { \ + break; \ + } \ + } \ + } \ + p_query.SortResult(); \ + + template + void Index::SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const Helper::Concurrent::ConcurrentSet &p_deleted) const + { + Search(if (!p_deleted.contains(gnode.node))) + } + + template + void Index::SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const + { + Search(;) + } + + template + ErrorCode + Index::SearchIndex(QueryResult &p_query) const + { + auto workSpace = m_workSpacePool->Rent(); + workSpace->Reset(m_iMaxCheck); + + if (m_deletedID.size() > 0) + SearchIndexWithDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID); + else + SearchIndexWithoutDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); + + m_workSpacePool->Return(workSpace); + + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result)); + } + } + return ErrorCode::Success; + } +#pragma endregion + + template + ErrorCode Index::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension) + { + omp_set_num_threads(m_iNumberOfThreads); + + m_pSamples.Initialize(p_vectorNum, p_dimension, (T*)p_data, false); + + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); +#pragma omp parallel for + for (SizeType i = 0; i < GetNumSamples(); i++) { + COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base); + } + } + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + + m_pTrees.BuildTrees(this); + m_pGraph.BuildGraph(this); + + return ErrorCode::Success; + } + + template + ErrorCode Index::RefineIndex(const std::vector& p_indexStreams) + { + std::lock_guard lock(m_dataAddLock); + std::shared_lock sharedlock(m_deletedID.getLock()); + + SizeType newR = GetNumSamples(); + + std::vector indices; + std::vector reverseIndices(newR); + for (SizeType i = 0; i < newR; i++) { + if (!m_deletedID.contains(i)) { + indices.push_back(i); + reverseIndices[i] = i; + } + else { + while (m_deletedID.contains(newR - 1) && newR > i) newR--; + if (newR == i) break; + indices.push_back(newR - 1); + reverseIndices[newR - 1] = i; + newR--; + } + } + + std::cout << "Refine... from " << GetNumSamples() << "->" << newR << std::endl; + + if (false == m_pSamples.Refine(indices, *p_indexStreams[0])) return ErrorCode::Fail; + if (nullptr != m_pMetadata && (p_indexStreams.size() < 6 || ErrorCode::Success != m_pMetadata->RefineMetadata(indices, *p_indexStreams[4], *p_indexStreams[5]))) return ErrorCode::Fail; + + m_pGraph.RefineGraph(this, indices, reverseIndices, *p_indexStreams[2]); + + COMMON::KDTree newTrees(m_pTrees); + newTrees.BuildTrees(this, &indices); +#pragma omp parallel for + for (SizeType i = 0; i < newTrees.size(); i++) { + if (newTrees[i].left < 0) + newTrees[i].left = -reverseIndices[-newTrees[i].left - 1] - 1; + if (newTrees[i].right < 0) + newTrees[i].right = -reverseIndices[-newTrees[i].right - 1] - 1; + } + newTrees.SaveTrees(*p_indexStreams[1]); + + Helper::Concurrent::ConcurrentSet newDeletedID; + newDeletedID.save(*p_indexStreams[3]); + return ErrorCode::Success; + } + + template + ErrorCode Index::RefineIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + + std::vector streams; + streams.push_back(new std::ofstream(folderPath + m_sDataPointsFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sKDTFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sGraphFilename, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sDeleteDataPointsFilename, std::ios::binary)); + if (nullptr != m_pMetadata) + { + streams.push_back(new std::ofstream(folderPath + m_sMetadataFile, std::ios::binary)); + streams.push_back(new std::ofstream(folderPath + m_sMetadataIndexFile, std::ios::binary)); + } + + for (size_t i = 0; i < streams.size(); i++) + if (!(((std::ofstream*)streams[i])->is_open())) return ErrorCode::FailedCreateFile; + + ErrorCode ret = RefineIndex(streams); + + for (size_t i = 0; i < streams.size(); i++) + { + ((std::ofstream*)streams[i])->close(); + delete streams[i]; + } + return ret; + } + + template + ErrorCode Index::DeleteIndex(const void* p_vectors, SizeType p_vectorNum) { + const T* ptr_v = (const T*)p_vectors; +#pragma omp parallel for schedule(dynamic) + for (SizeType i = 0; i < p_vectorNum; i++) { + COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); + SearchIndex(query); + + for (int i = 0; i < m_pGraph.m_iCEF; i++) { + if (query.GetResult(i)->Dist < 1e-6) { + m_deletedID.insert(query.GetResult(i)->VID); + } + } + } + return ErrorCode::Success; + } + + template + ErrorCode Index::DeleteIndex(const SizeType& p_id) { + m_deletedID.insert(p_id); + return ErrorCode::Success; + } + + template + ErrorCode Index::AddIndex(const void* p_vectors, SizeType p_vectorNum, DimensionType p_dimension, SizeType* p_start) + { + SizeType begin, end; + { + std::lock_guard lock(m_dataAddLock); + + begin = GetNumSamples(); + end = GetNumSamples() + p_vectorNum; + + if (p_start != nullptr) *p_start = begin; + + if (begin == 0) return BuildIndex(p_vectors, p_vectorNum, p_dimension); + + if (p_dimension != GetFeatureDim()) return ErrorCode::FailedParseValue; + + if (m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum) != ErrorCode::Success || m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success) { + std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl; + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + return ErrorCode::MemoryOverFlow; + } + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); + for (SizeType i = begin; i < end; i++) { + COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base); + } + } + } + + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true); + } + std::cout << "Add " << p_vectorNum << " vectors" << std::endl; + return ErrorCode::Success; + } + + template + ErrorCode + Index::SetParameter(const char* p_param, const char* p_value) + { + if (nullptr == p_param || nullptr == p_value) return ErrorCode::Fail; + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + fprintf(stderr, "Setting %s with value %s\n", RepresentStr, p_value); \ + VarType tmp; \ + if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ + { \ + VarName = tmp; \ + } \ + } \ + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + return ErrorCode::Success; + } + + + template + std::string + Index::GetParameter(const char* p_param) const + { + if (nullptr == p_param) return std::string(); + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + return SPTAG::Helper::Convert::ConvertToString(VarName); \ + } \ + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + + return std::string(); + } + } +} + +#define DefineVectorValueType(Name, Type) \ +template class SPTAG::KDT::Index; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp new file mode 100644 index 0000000000..137eb5d13a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/MetadataSet.cpp @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/MetadataSet.h" + +#include +#include + +using namespace SPTAG; + +ErrorCode +MetadataSet::RefineMetadata(std::vector& indices, std::ostream& p_metaOut, std::ostream& p_metaIndexOut) +{ + SizeType R = (SizeType)indices.size(); + p_metaIndexOut.write((char*)&R, sizeof(SizeType)); + std::uint64_t offset = 0; + for (SizeType i = 0; i < R; i++) { + p_metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + ByteArray meta = GetMetadata(indices[i]); + p_metaOut.write((char*)meta.Data(), sizeof(uint8_t)*meta.Length()); + offset += meta.Length(); + } + p_metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + return ErrorCode::Success; +} + + +ErrorCode +MetadataSet::RefineMetadata(std::vector& indices, const std::string& p_metaFile, const std::string& p_metaindexFile) +{ + std::ofstream metaOut(p_metaFile + "_tmp", std::ios::binary); + std::ofstream metaIndexOut(p_metaindexFile, std::ios::binary); + if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile; + + RefineMetadata(indices, metaOut, metaIndexOut); + metaOut.close(); + metaIndexOut.close(); + + if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); + std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); + return ErrorCode::Success; +} + + +MetadataSet::MetadataSet() +{ +} + + +MetadataSet:: ~MetadataSet() +{ +} + + +FileMetadataSet::FileMetadataSet(const std::string& p_metafile, const std::string& p_metaindexfile) + : m_metaFile(p_metafile), + m_metaindexFile(p_metaindexfile) +{ + m_fp = new std::ifstream(p_metafile, std::ifstream::binary); + std::ifstream fpidx(p_metaindexfile, std::ifstream::binary); + if (!m_fp->is_open() || !fpidx.is_open()) + { + std::cerr << "ERROR: Cannot open meta files " << p_metafile << " and " << p_metaindexfile << "!" << std::endl; + return; + } + + fpidx.read((char *)&m_count, sizeof(m_count)); + m_pOffsets.resize(m_count + 1); + fpidx.read((char *)m_pOffsets.data(), sizeof(std::uint64_t) * (m_count + 1)); + fpidx.close(); +} + + +FileMetadataSet::~FileMetadataSet() +{ + if (m_fp) + { + m_fp->close(); + delete m_fp; + } +} + + +ByteArray +FileMetadataSet::GetMetadata(SizeType p_vectorID) const +{ + std::uint64_t startoff = m_pOffsets[p_vectorID]; + std::uint64_t bytes = m_pOffsets[p_vectorID + 1] - startoff; + if (p_vectorID < m_count) { + m_fp->seekg(startoff, std::ios_base::beg); + ByteArray b = ByteArray::Alloc(bytes); + m_fp->read((char*)b.Data(), bytes); + return b; + } + else { + startoff -= m_pOffsets[m_count]; + return ByteArray((std::uint8_t*)m_newdata.data() + startoff, bytes, false); + } +} + + +SizeType +FileMetadataSet::Count() const +{ + return static_cast(m_pOffsets.size() - 1); +} + + +bool +FileMetadataSet::Available() const +{ + return m_fp && m_fp->is_open() && m_pOffsets.size() > 1; +} + + +std::pair +FileMetadataSet::BufferSize() const +{ + return std::make_pair(m_pOffsets[m_pOffsets.size() - 1], + sizeof(SizeType) + sizeof(std::uint64_t) * m_pOffsets.size()); +} + + +void +FileMetadataSet::AddBatch(MetadataSet& data) +{ + for (SizeType i = 0; i < data.Count(); i++) + { + ByteArray newdata = data.GetMetadata(i); + m_newdata.insert(m_newdata.end(), newdata.Data(), newdata.Data() + newdata.Length()); + m_pOffsets.push_back(m_pOffsets[m_pOffsets.size() - 1] + newdata.Length()); + } +} + + + +ErrorCode +FileMetadataSet::SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut) +{ + m_fp->seekg(0, std::ios_base::beg); + + int bufsize = 1000000; + char* buf = new char[bufsize]; + while (!m_fp->eof()) { + m_fp->read(buf, bufsize); + p_metaOut.write(buf, m_fp->gcount()); + } + delete[] buf; + + if (m_newdata.size() > 0) { + p_metaOut.write((char*)m_newdata.data(), m_newdata.size()); + } + + SizeType count = Count(); + p_metaIndexOut.write((char*)&count, sizeof(SizeType)); + p_metaIndexOut.write((char*)m_pOffsets.data(), sizeof(std::uint64_t) * m_pOffsets.size()); + return ErrorCode::Success; +} + + +ErrorCode +FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) +{ + std::ofstream metaOut(p_metaFile + "_tmp", std::ios::binary); + std::ofstream metaIndexOut(p_metaindexFile, std::ios::binary); + if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile; + + SaveMetadata(metaOut, metaIndexOut); + metaOut.close(); + metaIndexOut.close(); + + m_fp->close(); + if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); + std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); + m_fp->open(p_metaFile, std::ifstream::binary); + m_count = Count(); + m_newdata.clear(); + return ErrorCode::Success; +} + + +MemMetadataSet::MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count) + : m_metadataHolder(std::move(p_metadata)), + m_offsetHolder(std::move(p_offsets)), + m_count(p_count) +{ + const std::uint64_t* newdata = reinterpret_cast(m_offsetHolder.Data()); + m_offsets.insert(m_offsets.end(), newdata, newdata + p_count + 1); +} + + +MemMetadataSet::~MemMetadataSet() +{ +} + + +ByteArray +MemMetadataSet::GetMetadata(SizeType p_vectorID) const +{ + if (p_vectorID < m_count) + { + return ByteArray(m_metadataHolder.Data() + m_offsets[p_vectorID], + m_offsets[p_vectorID + 1] - m_offsets[p_vectorID], + false); + } + else if (p_vectorID < (SizeType)(m_offsets.size() - 1)) { + return ByteArray((std::uint8_t*)m_newdata.data() + m_offsets[p_vectorID] - m_offsets[m_count], + m_offsets[p_vectorID + 1] - m_offsets[p_vectorID], + false); + } + + return ByteArray::c_empty; +} + + +SizeType +MemMetadataSet::Count() const +{ + return static_cast(m_offsets.size() - 1); +} + + +bool +MemMetadataSet::Available() const +{ + return m_metadataHolder.Length() > 0 && m_offsetHolder.Length() > 0; +} + + +std::pair +MemMetadataSet::BufferSize() const +{ + return std::make_pair(m_offsets[m_offsets.size() - 1], + sizeof(SizeType) + sizeof(std::uint64_t) * m_offsets.size()); +} + +void +MemMetadataSet::AddBatch(MetadataSet& data) +{ + for (SizeType i = 0; i < data.Count(); i++) + { + ByteArray newdata = data.GetMetadata(i); + m_newdata.insert(m_newdata.end(), newdata.Data(), newdata.Data() + newdata.Length()); + m_offsets.push_back(m_offsets[m_offsets.size() - 1] + newdata.Length()); + } +} + + +ErrorCode +MemMetadataSet::SaveMetadata(std::ostream& p_metaOut, std::ostream& p_metaIndexOut) +{ + p_metaOut.write(reinterpret_cast(m_metadataHolder.Data()), m_metadataHolder.Length()); + if (m_newdata.size() > 0) { + p_metaOut.write((char*)m_newdata.data(), m_newdata.size()); + } + + SizeType count = Count(); + p_metaIndexOut.write((char*)&count, sizeof(SizeType)); + p_metaIndexOut.write((char*)m_offsets.data(), sizeof(std::uint64_t) * m_offsets.size()); + return ErrorCode::Success; +} + + + +ErrorCode +MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) +{ + std::ofstream metaOut(p_metaFile + "_tmp", std::ios::binary); + std::ofstream metaIndexOut(p_metaindexFile, std::ios::binary); + if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile; + + SaveMetadata(metaOut, metaIndexOut); + metaOut.close(); + metaIndexOut.close(); + + if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); + std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); + return ErrorCode::Success; +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorIndex.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorIndex.cpp new file mode 100644 index 0000000000..9c7ccf5492 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorIndex.cpp @@ -0,0 +1,428 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/VectorIndex.h" +#include "inc/Core/Common/DataUtils.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/BufferStream.h" + +#include "inc/Core/BKT/Index.h" +#include "inc/Core/KDT/Index.h" +#include + + +using namespace SPTAG; + + +VectorIndex::VectorIndex() +{ +} + + +VectorIndex::~VectorIndex() +{ +} + + +std::string +VectorIndex::GetParameter(const std::string& p_param) const +{ + return GetParameter(p_param.c_str()); +} + + +ErrorCode +VectorIndex::SetParameter(const std::string& p_param, const std::string& p_value) +{ + return SetParameter(p_param.c_str(), p_value.c_str()); +} + + +void +VectorIndex::SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath) { + m_pMetadata.reset(new FileMetadataSet(p_metadataFilePath, p_metadataIndexPath)); +} + + +ByteArray +VectorIndex::GetMetadata(SizeType p_vectorID) const { + if (nullptr != m_pMetadata) + { + return m_pMetadata->GetMetadata(p_vectorID); + } + return ByteArray::c_empty; +} + + +std::shared_ptr> VectorIndex::CalculateBufferSize() const +{ + std::shared_ptr> ret = BufferSize(); + if (m_pMetadata != nullptr) + { + auto metasize = m_pMetadata->BufferSize(); + ret->push_back(metasize.first); + ret->push_back(metasize.second); + } + return std::move(ret); +} + + +ErrorCode +VectorIndex::LoadIndexConfig(Helper::IniReader& p_reader) +{ + std::string metadataSection("MetaData"); + if (p_reader.DoesSectionExist(metadataSection)) + { + m_sMetadataFile = p_reader.GetParameter(metadataSection, "MetaDataFilePath", std::string()); + m_sMetadataIndexFile = p_reader.GetParameter(metadataSection, "MetaDataIndexPath", std::string()); + } + + if (DistCalcMethod::Undefined == p_reader.GetParameter("Index", "DistCalcMethod", DistCalcMethod::Undefined)) + { + std::cerr << "Error: Failed to load parameter DistCalcMethod." << std::endl; + return ErrorCode::Fail; + } + return LoadConfig(p_reader); +} + + +ErrorCode +VectorIndex::SaveIndexConfig(std::ostream& p_configOut) +{ + if (nullptr != m_pMetadata) + { + p_configOut << "[MetaData]" << std::endl; + p_configOut << "MetaDataFilePath=" << m_sMetadataFile << std::endl; + p_configOut << "MetaDataIndexPath=" << m_sMetadataIndexFile << std::endl; + if (nullptr != m_pMetaToVec) p_configOut << "MetaDataToVectorIndex=true" << std::endl; + p_configOut << std::endl; + } + + p_configOut << "[Index]" << std::endl; + p_configOut << "IndexAlgoType=" << Helper::Convert::ConvertToString(GetIndexAlgoType()) << std::endl; + p_configOut << "ValueType=" << Helper::Convert::ConvertToString(GetVectorValueType()) << std::endl; + p_configOut << std::endl; + + return SaveConfig(p_configOut); +} + + +void +VectorIndex::BuildMetaMapping() +{ + m_pMetaToVec.reset(new std::unordered_map); + for (SizeType i = 0; i < m_pMetadata->Count(); i++) { + ByteArray meta = m_pMetadata->GetMetadata(i); + m_pMetaToVec->emplace(std::string((char*)meta.Data(), meta.Length()), i); + } +} + + +ErrorCode +VectorIndex::LoadIndex(const std::string& p_config, const std::vector& p_indexBlobs) +{ + SPTAG::Helper::IniReader p_reader; + std::istringstream p_configin(p_config); + if (SPTAG::ErrorCode::Success != p_reader.LoadIni(p_configin)) return ErrorCode::FailedParseValue; + LoadIndexConfig(p_reader); + + if (p_reader.DoesSectionExist("MetaData") && p_indexBlobs.size() > 4) + { + ByteArray pMetaIndex = p_indexBlobs[p_indexBlobs.size() - 1]; + m_pMetadata.reset(new MemMetadataSet(p_indexBlobs[p_indexBlobs.size() - 2], + ByteArray(pMetaIndex.Data() + sizeof(SizeType), pMetaIndex.Length() - sizeof(SizeType), false), + *((SizeType*)pMetaIndex.Data()))); + + if (!m_pMetadata->Available()) + { + std::cerr << "Error: Failed to load metadata." << std::endl; + return ErrorCode::Fail; + } + + if (p_reader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") + { + BuildMetaMapping(); + } + } + return LoadIndexDataFromMemory(p_indexBlobs); +} + + +ErrorCode +VectorIndex::LoadIndex(const std::string& p_folderPath) +{ + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + Helper::IniReader p_configReader; + if (ErrorCode::Success != p_configReader.LoadIniFile(folderPath + "/indexloader.ini")) return ErrorCode::FailedOpenFile; + LoadIndexConfig(p_configReader); + + if (p_configReader.DoesSectionExist("MetaData")) + { + m_pMetadata.reset(new FileMetadataSet(folderPath + m_sMetadataFile, folderPath + m_sMetadataIndexFile)); + + if (!m_pMetadata->Available()) + { + std::cerr << "Error: Failed to load metadata." << std::endl; + return ErrorCode::Fail; + } + + if (p_configReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") + { + BuildMetaMapping(); + } + } + return LoadIndexData(folderPath); +} + + +ErrorCode +VectorIndex::SaveIndex(std::string& p_config, const std::vector& p_indexBlobs) +{ + std::ostringstream p_configStream; + SaveIndexConfig(p_configStream); + p_config = p_configStream.str(); + + std::vector p_indexStreams; + for (size_t i = 0; i < p_indexBlobs.size(); i++) + { + p_indexStreams.push_back(new Helper::obufferstream(new Helper::streambuf((char*)p_indexBlobs[i].Data(), p_indexBlobs[i].Length()), true)); + } + + ErrorCode ret = ErrorCode::Success; + if (NeedRefine()) + { + ret = RefineIndex(p_indexStreams); + } + else + { + if (m_pMetadata != nullptr && p_indexStreams.size() > 5) + { + ret = m_pMetadata->SaveMetadata(*p_indexStreams[p_indexStreams.size() - 2], *p_indexStreams[p_indexStreams.size() - 1]); + } + if (ErrorCode::Success == ret) ret = SaveIndexData(p_indexStreams); + } + for (size_t i = 0; i < p_indexStreams.size(); i++) + { + delete p_indexStreams[i]; + } + return ret; +} + + +ErrorCode +VectorIndex::SaveIndex(const std::string& p_folderPath) +{ + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + + std::ofstream configFile(folderPath + "indexloader.ini"); + if (!configFile.is_open()) return ErrorCode::FailedCreateFile; + SaveIndexConfig(configFile); + configFile.close(); + + if (NeedRefine()) return RefineIndex(p_folderPath); + + if (m_pMetadata != nullptr) + { + ErrorCode ret = m_pMetadata->SaveMetadata(folderPath + m_sMetadataFile, folderPath + m_sMetadataIndexFile); + if (ErrorCode::Success != ret) return ret; + } + return SaveIndexData(folderPath); +} + +ErrorCode +VectorIndex::BuildIndex(std::shared_ptr p_vectorSet, + std::shared_ptr p_metadataSet, bool p_withMetaIndex) +{ + if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->GetValueType() != GetVectorValueType()) + { + return ErrorCode::Fail; + } + + BuildIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension()); + m_pMetadata = std::move(p_metadataSet); + if (p_withMetaIndex && m_pMetadata != nullptr) + { + BuildMetaMapping(); + } + return ErrorCode::Success; +} + + +ErrorCode +VectorIndex::SearchIndex(const void* p_vector, int p_neighborCount, bool p_withMeta, BasicResult* p_results) const { + QueryResult res(p_vector, p_neighborCount, p_withMeta, p_results); + SearchIndex(res); + return ErrorCode::Success; +} + + +ErrorCode +VectorIndex::AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet) { + if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->GetValueType() != GetVectorValueType()) + { + return ErrorCode::Fail; + } + + SizeType currStart; + ErrorCode ret = AddIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension(), &currStart); + if (ret != ErrorCode::Success) return ret; + + if (m_pMetadata == nullptr) { + if (currStart == 0) + m_pMetadata = std::move(p_metadataSet); + else + return ErrorCode::Success; + } + else { + m_pMetadata->AddBatch(*p_metadataSet); + } + + if (m_pMetaToVec != nullptr) { + for (SizeType i = 0; i < p_vectorSet->Count(); i++) { + ByteArray meta = m_pMetadata->GetMetadata(currStart + i); + DeleteIndex(meta); + m_pMetaToVec->emplace(std::string((char*)meta.Data(), meta.Length()), currStart + i); + } + } + return ErrorCode::Success; +} + + +ErrorCode +VectorIndex::DeleteIndex(ByteArray p_meta) { + if (m_pMetaToVec == nullptr) return ErrorCode::Fail; + + std::string meta((char*)p_meta.Data(), p_meta.Length()); + auto iter = m_pMetaToVec->find(meta); + if (iter != m_pMetaToVec->end()) DeleteIndex(iter->second); + return ErrorCode::Success; +} + + +const void* VectorIndex::GetSample(ByteArray p_meta) +{ + if (m_pMetaToVec == nullptr) return nullptr; + + std::string meta((char*)p_meta.Data(), p_meta.Length()); + auto iter = m_pMetaToVec->find(meta); + if (iter != m_pMetaToVec->end()) return GetSample(iter->second); + return nullptr; +} + + +std::shared_ptr +VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) +{ + if (IndexAlgoType::Undefined == p_algo || VectorValueType::Undefined == p_valuetype) + { + return nullptr; + } + + if (p_algo == IndexAlgoType::BKT) { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new BKT::Index); \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + } + else if (p_algo == IndexAlgoType::KDT) { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new KDT::Index); \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + } + return nullptr; +} + + +ErrorCode +VectorIndex::LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex) +{ + Helper::IniReader iniReader; + if (ErrorCode::Success != iniReader.LoadIniFile(p_loaderFilePath + "/indexloader.ini")) return ErrorCode::FailedOpenFile; + + IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); + + p_vectorIndex = CreateInstance(algoType, valueType); + if (p_vectorIndex == nullptr) return ErrorCode::FailedParseValue; + + return p_vectorIndex->LoadIndex(p_loaderFilePath); +} + + + +ErrorCode +VectorIndex::LoadIndex(const std::string& p_config, const std::vector& p_indexBlobs, std::shared_ptr& p_vectorIndex) +{ + SPTAG::Helper::IniReader iniReader; + std::istringstream p_configin(p_config); + if (SPTAG::ErrorCode::Success != iniReader.LoadIni(p_configin)) return ErrorCode::FailedParseValue; + + IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); + + p_vectorIndex = CreateInstance(algoType, valueType); + if (p_vectorIndex == nullptr) return ErrorCode::FailedParseValue; + + return p_vectorIndex->LoadIndex(p_config, p_indexBlobs); +} + + +ErrorCode +VectorIndex::MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) +{ + std::string folderPath1(p_indexFilePath1), folderPath2(p_indexFilePath2); + + std::shared_ptr index1, index2; + LoadIndex(folderPath1, index1); + LoadIndex(folderPath2, index2); + + std::shared_ptr p_vectorSet; + std::shared_ptr p_metaSet; + size_t vectorSize = GetValueTypeSize(index2->GetVectorValueType()) * index2->GetFeatureDim(); + std::uint64_t offsets[2] = { 0 }; + ByteArray metaoffset((std::uint8_t*)offsets, 2 * sizeof(std::uint64_t), false); + for (SizeType i = 0; i < index2->GetNumSamples(); i++) + if (index2->ContainSample(i)) + { + p_vectorSet.reset(new BasicVectorSet(ByteArray((std::uint8_t*)index2->GetSample(i), vectorSize, false), + index2->GetVectorValueType(), index2->GetFeatureDim(), 1)); + ByteArray meta = index2->GetMetadata(i); + offsets[1] = meta.Length(); + p_metaSet.reset(new MemMetadataSet(meta, metaoffset, 1)); + index1->AddIndex(p_vectorSet, p_metaSet); + } + + index1->SaveIndex(folderPath1); + return ErrorCode::Success; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorSet.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorSet.cpp new file mode 100644 index 0000000000..45dd74dd78 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Core/VectorSet.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/VectorSet.h" + +using namespace SPTAG; + +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. + +VectorSet::VectorSet() +{ +} + + +VectorSet::~VectorSet() +{ +} + + +BasicVectorSet::BasicVectorSet(const ByteArray& p_bytesArray, + VectorValueType p_valueType, + DimensionType p_dimension, + SizeType p_vectorCount) + : m_data(p_bytesArray), + m_valueType(p_valueType), + m_dimension(p_dimension), + m_vectorCount(p_vectorCount), + m_perVectorDataSize(static_cast(p_dimension * GetValueTypeSize(p_valueType))) +{ +} + + +BasicVectorSet::~BasicVectorSet() +{ +} + + +VectorValueType +BasicVectorSet::GetValueType() const +{ + return m_valueType; +} + + +void* +BasicVectorSet::GetVector(SizeType p_vectorID) const +{ + if (p_vectorID < 0 || p_vectorID >= m_vectorCount) + { + return nullptr; + } + + return reinterpret_cast(m_data.Data() + ((size_t)p_vectorID) * m_perVectorDataSize); +} + + +void* +BasicVectorSet::GetData() const +{ + return reinterpret_cast(m_data.Data()); +} + +DimensionType +BasicVectorSet::Dimension() const +{ + return m_dimension; +} + + +SizeType +BasicVectorSet::Count() const +{ + return m_vectorCount; +} + + +bool +BasicVectorSet::Available() const +{ + return m_data.Data() != nullptr; +} + + +ErrorCode +BasicVectorSet::Save(const std::string& p_vectorFile) const +{ + FILE * fp = fopen(p_vectorFile.c_str(), "wb"); + if (fp == NULL) return ErrorCode::FailedOpenFile; + + fwrite(&m_vectorCount, sizeof(SizeType), 1, fp); + fwrite(&m_dimension, sizeof(DimensionType), 1, fp); + + fwrite((const void*)(m_data.Data()), m_data.Length(), 1, fp); + fclose(fp); + return ErrorCode::Success; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/ArgumentsParser.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/ArgumentsParser.cpp new file mode 100644 index 0000000000..4f630ec01c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/ArgumentsParser.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/ArgumentsParser.h" + +using namespace SPTAG::Helper; + + +ArgumentsParser::IArgument::IArgument() +{ +} + + +ArgumentsParser::IArgument::~IArgument() +{ +} + + +ArgumentsParser::ArgumentsParser() +{ +} + + +ArgumentsParser::~ArgumentsParser() +{ +} + + +bool +ArgumentsParser::Parse(int p_argc, char** p_args) +{ + while (p_argc > 0) + { + int last = p_argc; + for (auto& option : m_arguments) + { + if (!option->ParseValue(p_argc, p_args)) + { + fprintf(stderr, "Failed to parse args around \"%s\"\n", *p_args); + PrintHelp(); + return false; + } + } + + if (last == p_argc) + { + p_argc -= 1; + p_args += 1; + } + } + + bool isValid = true; + for (auto& option : m_arguments) + { + if (option->IsRequiredButNotSet()) + { + fprintf(stderr, "Required option not set:\n "); + option->PrintDescription(stderr); + fprintf(stderr, "\n"); + isValid = false; + } + } + + if (!isValid) + { + fprintf(stderr, "\n"); + PrintHelp(); + return false; + } + + return true; +} + + +void +ArgumentsParser::PrintHelp() +{ + fprintf(stderr, "Usage: "); + for (auto& option : m_arguments) + { + fprintf(stderr, "\n "); + option->PrintDescription(stderr); + } + + fprintf(stderr, "\n\n"); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Base64Encode.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Base64Encode.cpp new file mode 100644 index 0000000000..5992fa5a31 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Base64Encode.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/Base64Encode.h" + +using namespace SPTAG; +using namespace SPTAG::Helper; + +namespace +{ +namespace Local +{ +const char c_encTable[] = +{ + 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', + 'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z', + '0','1','2','3','4','5','6','7','8','9','+','/' +}; + + +const std::uint8_t c_decTable[] = +{ + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x00 - 0x0f + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x10 - 0x1f + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, // 0x20 - 0x2f + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, // 0x30 - 0x3f + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 0x40 - 0x4f + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64, // 0x50 - 0x5f + 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, // 0x60 - 0x6f + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, // 0x70 - 0x7f + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x80 - 0x8f + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x90 - 0x9f + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xa0 - 0xaf + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xb0 - 0xbf + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xc0 - 0xcf + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xd0 - 0xdf + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xe0 - 0xef + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xf0 - 0xff +}; + + +const char c_paddingChar = '='; +} +} + + +bool +Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen) +{ + using namespace Local; + + p_outLen = 0; + while (p_inLen >= 3) + { + p_out[0] = c_encTable[p_in[0] >> 2]; + p_out[1] = c_encTable[((p_in[0] & 0x03) << 4) | ((p_in[1] & 0xf0) >> 4)]; + p_out[2] = c_encTable[((p_in[1] & 0x0f) << 2) | ((p_in[2] & 0xc0) >> 6)]; + p_out[3] = c_encTable[p_in[2] & 0x3f]; + + p_in += 3; + p_inLen -= 3; + p_out += 4; + p_outLen += 4; + } + + switch (p_inLen) + { + case 1: + p_out[0] = c_encTable[p_in[0] >> 2]; + p_out[1] = c_encTable[(p_in[0] & 0x03) << 4]; + p_out[2] = c_paddingChar; + p_out[3] = c_paddingChar; + + p_outLen += 4; + break; + + case 2: + p_out[0] = c_encTable[p_in[0] >> 2]; + p_out[1] = c_encTable[((p_in[0] & 0x03) << 4) | ((p_in[1] & 0xf0) >> 4)]; + p_out[2] = c_encTable[(p_in[1] & 0x0f) << 2]; + p_out[3] = c_paddingChar; + + p_outLen += 4; + break; + } + + return true; +} + + +bool +Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen) +{ + using namespace Local; + + p_outLen = 0; + while (p_inLen >= 3) + { + p_out << c_encTable[p_in[0] >> 2]; + p_out << c_encTable[((p_in[0] & 0x03) << 4) | ((p_in[1] & 0xf0) >> 4)]; + p_out << c_encTable[((p_in[1] & 0x0f) << 2) | ((p_in[2] & 0xc0) >> 6)]; + p_out << c_encTable[p_in[2] & 0x3f]; + + p_in += 3; + p_inLen -= 3; + p_outLen += 4; + } + + switch (p_inLen) + { + case 1: + p_out << c_encTable[p_in[0] >> 2]; + p_out << c_encTable[(p_in[0] & 0x03) << 4]; + p_out << c_paddingChar; + p_out << c_paddingChar; + + p_outLen += 4; + break; + + case 2: + p_out << c_encTable[p_in[0] >> 2]; + p_out << c_encTable[((p_in[0] & 0x03) << 4) | ((p_in[1] & 0xf0) >> 4)]; + p_out << c_encTable[(p_in[1] & 0x0f) << 2]; + p_out << c_paddingChar; + + p_outLen += 4; + break; + + default: + break; + } + + return true; +} + + +bool +Base64::Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen) +{ + using namespace Local; + + // Should always be padding. + if ((p_inLen & 0x03) != 0) + { + return false; + } + + std::uint8_t u0 = 0; + std::uint8_t u1 = 0; + std::uint8_t u2 = 0; + std::uint8_t u3 = 0; + + p_outLen = 0; + while (p_inLen > 4) + { + u0 = c_decTable[static_cast(p_in[0])]; + u1 = c_decTable[static_cast(p_in[1])]; + u2 = c_decTable[static_cast(p_in[2])]; + u3 = c_decTable[static_cast(p_in[3])]; + + if (u0 > 63 || u1 > 63 || u2 > 63 || u3 > 63) + { + return false; + } + + p_out[0] = (u0 << 2) | (u1 >> 4); + p_out[1] = (u1 << 4) | (u2 >> 2); + p_out[2] = (u2 << 6) | u3; + + p_inLen -= 4; + p_in += 4; + p_out += 3; + p_outLen += 3; + } + + u0 = c_decTable[static_cast(p_in[0])]; + u1 = c_decTable[static_cast(p_in[1])]; + u2 = c_decTable[static_cast(p_in[2])]; + u3 = c_decTable[static_cast(p_in[3])]; + + if (u0 > 63 || u1 > 63 || (c_paddingChar == p_in[2] && c_paddingChar != p_in[3])) + { + return false; + } + + if (u2 > 63 && c_paddingChar != p_in[2]) + { + return false; + } + + if (u3 > 63 && c_paddingChar != p_in[3]) + { + return false; + } + + + p_out[0] = (u0 << 2) | (u1 >> 4); + ++p_outLen; + if (c_paddingChar == p_in[2]) + { + if ((u1 & 0x0F) != 0) + { + return false; + } + } + else + { + p_out[1] = (u1 << 4) | (u2 >> 2); + ++p_outLen; + if (c_paddingChar == p_in[3]) + { + if ((u3 & 0x03) != 0) + { + return false; + } + } + else + { + p_out[2] = (u2 << 6) | u3; + ++p_outLen; + } + } + + return true; +} + + +std::size_t +Base64::CapacityForEncode(std::size_t p_inLen) +{ + return ((p_inLen + 2) / 3) * 4; +} + + +std::size_t +Base64::CapacityForDecode(std::size_t p_inLen) +{ + return (p_inLen / 4) * 3 + ((p_inLen % 4) * 2) / 3; +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/CommonHelper.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/CommonHelper.cpp new file mode 100644 index 0000000000..2d4dc0de5c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/CommonHelper.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/CommonHelper.h" + +#include +#include +#include +#include + +using namespace SPTAG; +using namespace SPTAG::Helper; + +void +StrUtils::ToLowerInPlace(std::string& p_str) +{ + for (char& ch : p_str) + { + if (std::isupper(ch)) + { + ch = ch | 0x20; + } + } +} + + +std::vector +StrUtils::SplitString(const std::string& p_str, const std::string& p_separator) +{ + std::vector ret; + + std::size_t begin = p_str.find_first_not_of(p_separator); + while (std::string::npos != begin) + { + std::size_t end = p_str.find_first_of(p_separator, begin); + if (std::string::npos == end) + { + ret.emplace_back(p_str.substr(begin, p_str.size() - begin)); + break; + } + else + { + ret.emplace_back(p_str.substr(begin, end - begin)); + } + + begin = p_str.find_first_not_of(p_separator, end); + } + + return ret; +} + + +std::pair +StrUtils::FindTrimmedSegment(const char* p_begin, + const char* p_end, + const std::function& p_isSkippedChar) +{ + while (p_begin < p_end) + { + if (!p_isSkippedChar(*p_begin)) + { + break; + } + + ++p_begin; + } + + while (p_end > p_begin) + { + if (!p_isSkippedChar(*(p_end - 1))) + { + break; + } + + --p_end; + } + + return std::make_pair(p_begin, p_end); +} + + +bool +StrUtils::StartsWith(const char* p_str, const char* p_prefix) +{ + if (nullptr == p_prefix) + { + return true; + } + + if (nullptr == p_str) + { + return false; + } + + while ('\0' != (*p_prefix) && '\0' != (*p_str)) + { + if (*p_prefix != *p_str) + { + return false; + } + ++p_prefix; + ++p_str; + } + + return '\0' == *p_prefix; +} + + +bool +StrUtils::StrEqualIgnoreCase(const char* p_left, const char* p_right) +{ + if (p_left == p_right) + { + return true; + } + + if (p_left == nullptr || p_right == nullptr) + { + return false; + } + + auto tryConv = [](char p_ch) -> char + { + if ('a' <= p_ch && p_ch <= 'z') + { + return p_ch - 32; + } + + return p_ch; + }; + + while (*p_left != '\0' && *p_right != '\0') + { + if (tryConv(*p_left) != tryConv(*p_right)) + { + return false; + } + + ++p_left; + ++p_right; + } + + return *p_left == *p_right; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Concurrent.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Concurrent.cpp new file mode 100644 index 0000000000..cbb1bdb643 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/Concurrent.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/Concurrent.h" + +using namespace SPTAG; +using namespace SPTAG::Helper::Concurrent; + +WaitSignal::WaitSignal() + : m_isWaiting(false), + m_unfinished(0) +{ +} + + +WaitSignal::WaitSignal(std::uint32_t p_unfinished) + : m_isWaiting(false), + m_unfinished(p_unfinished) +{ +} + + +WaitSignal::~WaitSignal() +{ + std::lock_guard guard(m_mutex); + if (m_isWaiting) + { + m_cv.notify_all(); + } +} + + +void +WaitSignal::Reset(std::uint32_t p_unfinished) +{ + std::lock_guard guard(m_mutex); + if (m_isWaiting) + { + m_cv.notify_all(); + } + + m_isWaiting = false; + m_unfinished = p_unfinished; +} + + +void +WaitSignal::Wait() +{ + std::unique_lock lock(m_mutex); + if (m_unfinished > 0) + { + m_isWaiting = true; + m_cv.wait(lock); + } +} + + +void +WaitSignal::FinishOne() +{ + if (1 == m_unfinished.fetch_sub(1)) + { + std::lock_guard guard(m_mutex); + if (m_isWaiting) + { + m_isWaiting = false; + m_cv.notify_all(); + } + } +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/SimpleIniReader.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/SimpleIniReader.cpp new file mode 100644 index 0000000000..28610dbe19 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/SimpleIniReader.cpp @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/CommonHelper.h" + +#include +#include +#include +#include + +using namespace SPTAG; +using namespace SPTAG::Helper; + +const IniReader::ParameterValueMap IniReader::c_emptyParameters; + + +IniReader::IniReader() +{ +} + + +IniReader::~IniReader() +{ +} + + +ErrorCode IniReader::LoadIni(std::istream& p_input) +{ + const std::size_t c_bufferSize = 1 << 16; + + std::unique_ptr line(new char[c_bufferSize]); + + std::string currSection; + std::shared_ptr currParamMap(new ParameterValueMap); + + if (m_parameters.count(currSection) == 0) + { + m_parameters.emplace(currSection, currParamMap); + } + + auto isSpace = [](char p_ch) -> bool + { + return std::isspace(p_ch) != 0; + }; + + while (!p_input.eof()) + { + if (!p_input.getline(line.get(), c_bufferSize)) + { + break; + } + + std::size_t len = 0; + while (len < c_bufferSize && line[len] != '\0') + { + ++len; + } + + auto nonSpaceSeg = StrUtils::FindTrimmedSegment(line.get(), line.get() + len, isSpace); + + if (nonSpaceSeg.second <= nonSpaceSeg.first) + { + // Blank line. + continue; + } + + if (';' == *nonSpaceSeg.first) + { + // Comments. + continue; + } + else if ('[' == *nonSpaceSeg.first) + { + // Parse Section + if (']' != *(nonSpaceSeg.second - 1)) + { + return ErrorCode::ReadIni_FailedParseSection; + } + + auto sectionSeg = StrUtils::FindTrimmedSegment(nonSpaceSeg.first + 1, nonSpaceSeg.second - 1, isSpace); + + if (sectionSeg.second <= sectionSeg.first) + { + // Empty section name. + return ErrorCode::ReadIni_FailedParseSection; + } + + currSection.assign(sectionSeg.first, sectionSeg.second); + StrUtils::ToLowerInPlace(currSection); + + if (m_parameters.count(currSection) == 0) + { + currParamMap.reset(new ParameterValueMap); + m_parameters.emplace(currSection, currParamMap); + } + else + { + return ErrorCode::ReadIni_DuplicatedSection; + } + } + else + { + // Parameter Value Pair. + const char* equalSignLoc = nonSpaceSeg.first; + while (equalSignLoc < nonSpaceSeg.second && '=' != *equalSignLoc) + { + ++equalSignLoc; + } + + if (equalSignLoc >= nonSpaceSeg.second) + { + return ErrorCode::ReadIni_FailedParseParam; + } + + auto paramSeg = StrUtils::FindTrimmedSegment(nonSpaceSeg.first, equalSignLoc, isSpace); + + if (paramSeg.second <= paramSeg.first) + { + // Empty parameter name. + return ErrorCode::ReadIni_FailedParseParam; + } + + std::string paramName(paramSeg.first, paramSeg.second); + StrUtils::ToLowerInPlace(paramName); + + if (currParamMap->count(paramName) == 0) + { + currParamMap->emplace(std::move(paramName), std::string(equalSignLoc + 1, nonSpaceSeg.second)); + } + else + { + return ErrorCode::ReadIni_DuplicatedParam; + } + } + } + return ErrorCode::Success; +} + + +ErrorCode +IniReader::LoadIniFile(const std::string& p_iniFilePath) +{ + std::ifstream input(p_iniFilePath); + if (!input.is_open()) return ErrorCode::FailedOpenFile; + ErrorCode ret = LoadIni(input); + input.close(); + return ret; +} + + +bool +IniReader::DoesSectionExist(const std::string& p_section) const +{ + std::string section(p_section); + StrUtils::ToLowerInPlace(section); + return m_parameters.count(section) != 0; +} + + +bool +IniReader::DoesParameterExist(const std::string& p_section, const std::string& p_param) const +{ + std::string name(p_section); + StrUtils::ToLowerInPlace(name); + auto iter = m_parameters.find(name); + if (iter == m_parameters.cend()) + { + return false; + } + + const auto& paramMap = iter->second; + if (paramMap == nullptr) + { + return false; + } + + name = p_param; + StrUtils::ToLowerInPlace(name); + return paramMap->count(name) != 0; +} + + +bool +IniReader::GetRawValue(const std::string& p_section, const std::string& p_param, std::string& p_value) const +{ + std::string name(p_section); + StrUtils::ToLowerInPlace(name); + auto sectionIter = m_parameters.find(name); + if (sectionIter == m_parameters.cend()) + { + return false; + } + + const auto& paramMap = sectionIter->second; + if (paramMap == nullptr) + { + return false; + } + + name = p_param; + StrUtils::ToLowerInPlace(name); + auto paramIter = paramMap->find(name); + if (paramIter == paramMap->cend()) + { + return false; + } + + p_value = paramIter->second; + return true; +} + + +const IniReader::ParameterValueMap& +IniReader::GetParameters(const std::string& p_section) const +{ + std::string name(p_section); + StrUtils::ToLowerInPlace(name); + auto sectionIter = m_parameters.find(name); + if (sectionIter == m_parameters.cend() || nullptr == sectionIter->second) + { + return c_emptyParameters; + } + + return *(sectionIter->second); +} + +void +IniReader::SetParameter(const std::string& p_section, const std::string& p_param, const std::string& p_val) +{ + std::string name(p_section); + StrUtils::ToLowerInPlace(name); + auto sectionIter = m_parameters.find(name); + if (sectionIter == m_parameters.cend() || sectionIter->second == nullptr) + { + m_parameters[name] = std::shared_ptr(new ParameterValueMap); + } + + std::string param(p_param); + StrUtils::ToLowerInPlace(param); + (*m_parameters[name])[param] = p_val; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReader.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReader.cpp new file mode 100644 index 0000000000..44371ae242 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReader.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/VectorSetReader.h" +#include "inc/Helper/VectorSetReaders/DefaultReader.h" + + +using namespace SPTAG; +using namespace SPTAG::Helper; + + +ReaderOptions::ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, std::string p_vectorDelimiter, std::uint32_t p_threadNum) + : m_threadNum(p_threadNum), m_dimension(p_dimension), m_vectorDelimiter(p_vectorDelimiter), m_inputValueType(p_valueType) +{ + AddOptionalOption(m_threadNum, "-t", "--thread", "Thread Number."); + AddOptionalOption(m_vectorDelimiter, "", "--delimiter", "Vector delimiter."); + AddRequiredOption(m_dimension, "-d", "--dimension", "Dimension of vector."); + AddRequiredOption(m_inputValueType, "-v", "--vectortype", "Input vector data type. Default is float."); +} + + +ReaderOptions::~ReaderOptions() +{ +} + + +VectorSetReader::VectorSetReader(std::shared_ptr p_options) + : m_options(p_options) +{ +} + + +VectorSetReader:: ~VectorSetReader() +{ +} + + +std::shared_ptr +VectorSetReader::CreateInstance(std::shared_ptr p_options) +{ + return std::shared_ptr(new DefaultReader(std::move(p_options))); +} + + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp new file mode 100644 index 0000000000..4d775f4a50 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp @@ -0,0 +1,514 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/VectorSetReaders/DefaultReader.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/CommonHelper.h" + +#include +#include +#include +#include + +using namespace SPTAG; +using namespace SPTAG::Helper; + +namespace +{ +namespace Local +{ + +class BinaryLineReader +{ +public: + BinaryLineReader(std::istream& p_inStream) + : m_inStream(p_inStream) + { + m_buffer.reset(new char[c_bufferSize]); + } + + + bool Eof() + { + return m_inStream.eof() && (m_curOffset == m_curTotal); + } + + + std::size_t GetLine(std::unique_ptr& p_buffer, std::size_t& p_bufferSize, std::size_t& p_length) + { + std::size_t consumedCount = 0; + p_length = 0; + while (true) + { + while (m_curOffset < m_curTotal) + { + if (p_bufferSize > p_length) + { + ++consumedCount; + if (!IsDelimiter(m_buffer[m_curOffset])) + { + p_buffer[p_length++] = m_buffer[m_curOffset++]; + } + else + { + ++m_curOffset; + p_buffer[p_length] = '\0'; + return consumedCount + MoveToNextValid(); + } + } + else + { + p_bufferSize *= 2; + std::unique_ptr newBuffer(new char[p_bufferSize]); + memcpy(newBuffer.get(), p_buffer.get(), p_length); + p_buffer.swap(newBuffer); + } + } + + if (m_inStream.eof()) + { + break; + } + + m_inStream.read(m_buffer.get(), c_bufferSize); + m_curTotal = m_inStream.gcount(); + m_curOffset = 0; + } + + if (p_bufferSize <= p_length) + { + p_bufferSize *= 2; + std::unique_ptr newBuffer(new char[p_bufferSize]); + memcpy(newBuffer.get(), p_buffer.get(), p_length); + p_buffer.swap(newBuffer); + } + + p_buffer[p_length] = '\0'; + return consumedCount; + } + + +private: + std::size_t MoveToNextValid() + { + std::size_t skipped = 0; + while (true) + { + while (m_curOffset < m_curTotal) + { + if (IsDelimiter(m_buffer[m_curOffset])) + { + ++skipped; + ++m_curOffset; + } + else + { + return skipped; + } + } + + if (m_inStream.eof()) + { + break; + } + + m_inStream.read(m_buffer.get(), c_bufferSize); + m_curTotal = m_inStream.gcount(); + m_curOffset = 0; + } + + return skipped; + } + + bool IsDelimiter(char p_ch) + { + return p_ch == '\r' || p_ch == '\n'; + } + + static const std::size_t c_bufferSize = 1 << 10; + + std::unique_ptr m_buffer; + + std::istream& m_inStream; + + std::size_t m_curOffset; + + std::size_t m_curTotal; +}; + +} // namespace Local +} // namespace + + +DefaultReader::DefaultReader(std::shared_ptr p_options) + : VectorSetReader(std::move(p_options)), + m_subTaskBlocksize(0) +{ + omp_set_num_threads(m_options->m_threadNum); + + std::string tempFolder("tempfolder"); + if (!direxists(tempFolder.c_str())) + { + mkdir(tempFolder.c_str()); + } + + tempFolder += FolderSep; + m_vectorOutput = tempFolder + "vectorset.bin"; + m_metadataConentOutput = tempFolder + "metadata.bin"; + m_metadataIndexOutput = tempFolder + "metadataindex.bin"; +} + + +DefaultReader::~DefaultReader() +{ + if (fileexists(m_vectorOutput.c_str())) + { + remove(m_vectorOutput.c_str()); + } + + if (fileexists(m_metadataIndexOutput.c_str())) + { + remove(m_metadataIndexOutput.c_str()); + } + + if (fileexists(m_metadataConentOutput.c_str())) + { + remove(m_metadataConentOutput.c_str()); + } +} + + +ErrorCode +DefaultReader::LoadFile(const std::string& p_filePaths) +{ + const auto& files = GetFileSizes(p_filePaths); + std::vector> subWorks; + subWorks.reserve(files.size() * m_options->m_threadNum); + + m_subTaskCount = 0; + for (const auto& fileInfo : files) + { + if (fileInfo.second == (std::numeric_limits::max)()) + { + std::stringstream msg; + msg << "File " << fileInfo.first << " not exists or can't access."; + std::cerr << msg.str() << std::endl; + exit(1); + } + + std::uint32_t fileTaskCount = 0; + std::size_t blockSize = m_subTaskBlocksize; + if (0 == blockSize) + { + fileTaskCount = m_options->m_threadNum; + blockSize = (fileInfo.second + fileTaskCount - 1) / fileTaskCount; + } + else + { + fileTaskCount = static_cast((fileInfo.second + blockSize - 1) / blockSize); + } + + for (std::uint32_t i = 0; i < fileTaskCount; ++i) + { + subWorks.emplace_back(std::bind(&DefaultReader::LoadFileInternal, + this, + fileInfo.first, + m_subTaskCount++, + i, + blockSize)); + } + } + + m_totalRecordCount = 0; + m_totalRecordVectorBytes = 0; + m_subTaskRecordCount.clear(); + m_subTaskRecordCount.resize(m_subTaskCount, 0); + + m_waitSignal.Reset(m_subTaskCount); + +#pragma omp parallel for schedule(dynamic) + for (int64_t i = 0; i < (int64_t)subWorks.size(); i++) + { + subWorks[i](); + } + + m_waitSignal.Wait(); + + MergeData(); + + return ErrorCode::Success; +} + + +std::shared_ptr +DefaultReader::GetVectorSet() const +{ + ByteArray vectorSet = ByteArray::Alloc(m_totalRecordVectorBytes); + char* vecBuf = reinterpret_cast(vectorSet.Data()); + + std::ifstream inputStream; + inputStream.open(m_vectorOutput, std::ifstream::binary); + inputStream.seekg(sizeof(SizeType) + sizeof(DimensionType), std::ifstream::beg); + inputStream.read(vecBuf, m_totalRecordVectorBytes); + inputStream.close(); + + return std::shared_ptr(new BasicVectorSet(vectorSet, + m_options->m_inputValueType, + m_options->m_dimension, + m_totalRecordCount)); +} + + +std::shared_ptr +DefaultReader::GetMetadataSet() const +{ + return std::shared_ptr(new FileMetadataSet(m_metadataConentOutput, m_metadataIndexOutput)); +} + + +void +DefaultReader::LoadFileInternal(const std::string& p_filePath, + std::uint32_t p_subTaskID, + std::uint32_t p_fileBlockID, + std::size_t p_fileBlockSize) +{ + std::size_t lineBufferSize = 1 << 16; + std::unique_ptr currentLine(new char[lineBufferSize]); + + std::ifstream inputStream; + std::ofstream outputStream; + std::ofstream metaStreamContent; + std::ofstream metaStreamIndex; + + SizeType recordCount = 0; + std::uint64_t metaOffset = 0; + std::size_t totalRead = 0; + std::streamoff startpos = p_fileBlockID * p_fileBlockSize; + + inputStream.open(p_filePath, std::ios_base::in | std::ios_base::binary); + if (inputStream.is_open() == false) + { + std::stringstream msg; + msg << "Unable to open file: " << p_filePath << std::endl; + const auto& msgStr = msg.str(); + std::cerr << msgStr; + throw MyException(msgStr); + exit(1); + } + + { + std::stringstream msg; + msg << "Begin Subtask: " << p_subTaskID << ", start offset position:" << startpos << std::endl; + std::cout << msg.str(); + } + + std::string subFileSuffix("_"); + subFileSuffix += std::to_string(p_subTaskID); + subFileSuffix += ".tmp"; + + outputStream.open(m_vectorOutput + subFileSuffix, std::ofstream::binary); + metaStreamContent.open(m_metadataConentOutput + subFileSuffix, std::ofstream::binary); + metaStreamIndex.open(m_metadataIndexOutput + subFileSuffix, std::ofstream::binary); + + inputStream.seekg(startpos, std::ifstream::beg); + + Local::BinaryLineReader lineReader(inputStream); + + std::size_t lineLength; + if (p_fileBlockID != 0) + { + totalRead += lineReader.GetLine(currentLine, lineBufferSize, lineLength); + } + + std::size_t vectorByteSize = GetValueTypeSize(m_options->m_inputValueType) * m_options->m_dimension; + std::unique_ptr vector; + vector.reset(new std::uint8_t[vectorByteSize]); + + while (!lineReader.Eof() && totalRead <= p_fileBlockSize) + { + totalRead += lineReader.GetLine(currentLine, lineBufferSize, lineLength); + if (0 == lineLength) + { + continue; + } + + std::size_t tabIndex = lineLength - 1; + while (tabIndex > 0 && currentLine[tabIndex] != '\t') + { + --tabIndex; + } + + if (0 == tabIndex && currentLine[tabIndex] != '\t') + { + std::stringstream msg; + msg << "Subtask: " << p_subTaskID << " cannot parsing line:" << currentLine.get() << std::endl; + std::cout << msg.str(); + exit(1); + } + + bool parseSuccess = false; + switch (m_options->m_inputValueType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); \ + break; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + parseSuccess = false; + break; + } + + if (!parseSuccess) + { + std::stringstream msg; + msg << "Subtask: " << p_subTaskID << " cannot parsing vector:" << (currentLine.get() + tabIndex + 1) << std::endl; + std::cout << msg.str(); + exit(1); + } + + ++recordCount; + outputStream.write(reinterpret_cast(vector.get()), vectorByteSize); + metaStreamContent.write(currentLine.get(), tabIndex); + metaStreamIndex.write(reinterpret_cast(&metaOffset), sizeof(metaOffset)); + + metaOffset += tabIndex; + } + + metaStreamIndex.write(reinterpret_cast(&metaOffset), sizeof(metaOffset)); + + inputStream.close(); + outputStream.close(); + metaStreamContent.close(); + metaStreamIndex.close(); + + m_totalRecordCount += recordCount; + m_subTaskRecordCount[p_subTaskID] = recordCount; + m_totalRecordVectorBytes += recordCount * vectorByteSize; + + m_waitSignal.FinishOne(); +} + + +void +DefaultReader::MergeData() +{ + const std::size_t bufferSize = 1 << 30; + const std::size_t bufferSizeTrim64 = (bufferSize / sizeof(std::uint64_t)) * sizeof(std::uint64_t); + std::ifstream inputStream; + std::ofstream outputStream; + + std::unique_ptr bufferHolder(new char[bufferSize]); + char* buf = bufferHolder.get(); + + SizeType totalRecordCount = m_totalRecordCount; + + outputStream.open(m_vectorOutput, std::ofstream::binary); + + outputStream.write(reinterpret_cast(&totalRecordCount), sizeof(totalRecordCount)); + outputStream.write(reinterpret_cast(&(m_options->m_dimension)), sizeof(m_options->m_dimension)); + + for (std::uint32_t i = 0; i < m_subTaskCount; ++i) + { + std::string file = m_vectorOutput; + file += "_"; + file += std::to_string(i); + file += ".tmp"; + + inputStream.open(file, std::ifstream::binary); + outputStream << inputStream.rdbuf(); + + inputStream.close(); + remove(file.c_str()); + } + + outputStream.close(); + + outputStream.open(m_metadataConentOutput, std::ofstream::binary); + for (std::uint32_t i = 0; i < m_subTaskCount; ++i) + { + std::string file = m_metadataConentOutput; + file += "_"; + file += std::to_string(i); + file += ".tmp"; + + inputStream.open(file, std::ifstream::binary); + outputStream << inputStream.rdbuf(); + + inputStream.close(); + remove(file.c_str()); + } + + outputStream.close(); + + outputStream.open(m_metadataIndexOutput, std::ofstream::binary); + + outputStream.write(reinterpret_cast(&totalRecordCount), sizeof(totalRecordCount)); + + std::uint64_t totalOffset = 0; + for (std::uint32_t i = 0; i < m_subTaskCount; ++i) + { + std::string file = m_metadataIndexOutput; + file += "_"; + file += std::to_string(i); + file += ".tmp"; + + inputStream.open(file, std::ifstream::binary); + for (SizeType remains = m_subTaskRecordCount[i]; remains > 0;) + { + std::size_t readBytesCount = min(remains * sizeof(std::uint64_t), bufferSizeTrim64); + inputStream.read(buf, readBytesCount); + std::uint64_t* offset = reinterpret_cast(buf); + for (std::uint64_t i = 0; i < readBytesCount / sizeof(std::uint64_t); ++i) + { + offset[i] += totalOffset; + } + + outputStream.write(buf, readBytesCount); + remains -= static_cast(readBytesCount / sizeof(std::uint64_t)); + } + + inputStream.read(buf, sizeof(std::uint64_t)); + totalOffset += *(reinterpret_cast(buf)); + + inputStream.close(); + remove(file.c_str()); + } + + outputStream.write(reinterpret_cast(&totalOffset), sizeof(totalOffset)); + outputStream.close(); +} + + +std::vector +DefaultReader::GetFileSizes(const std::string& p_filePaths) +{ + const auto& files = Helper::StrUtils::SplitString(p_filePaths, ","); + std::vector res; + res.reserve(files.size()); + + for (const auto& filePath : files) + { + if (!fileexists(filePath.c_str())) + { + res.emplace_back(filePath, (std::numeric_limits::max)()); + continue; + } +#ifndef _MSC_VER + struct stat stat_buf; + stat(filePath.c_str(), &stat_buf); +#else + struct _stat64 stat_buf; + _stat64(filePath.c_str(), &stat_buf); +#endif + std::size_t fileSize = stat_buf.st_size; + res.emplace_back(filePath, static_cast(fileSize)); + } + + return res; +} + + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/Options.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/Options.cpp new file mode 100644 index 0000000000..6360b73c2a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/Options.cpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/IndexBuilder/Options.h" +#include "inc/Helper/StringConvert.h" + +#include + +using namespace SPTAG; +using namespace SPTAG::IndexBuilder; + + +BuilderOptions::BuilderOptions() + : Helper::ReaderOptions(VectorValueType::Float, 0, "|", 32) +{ + AddRequiredOption(m_inputFiles, "-i", "--input", "Input raw data."); + AddRequiredOption(m_outputFolder, "-o", "--outputfolder", "Output folder."); + AddRequiredOption(m_indexAlgoType, "-a", "--algo", "Index Algorithm type."); + AddOptionalOption(m_builderConfigFile, "-c", "--config", "Config file for builder."); +} + + +BuilderOptions::~BuilderOptions() +{ +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/ThreadPool.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/ThreadPool.cpp new file mode 100644 index 0000000000..0ecddc1279 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/ThreadPool.cpp @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/IndexBuilder/ThreadPool.h" + +#include + +#include +#include + +using namespace SPTAG::IndexBuilder; + +namespace Local +{ +std::unique_ptr g_threadPool; + +std::atomic_bool g_initialized(false); + +std::uint32_t g_threadNum = 1; +} + + +void +ThreadPool::Init(std::uint32_t p_threadNum) +{ + if (Local::g_initialized.exchange(true)) + { + return; + } + + Local::g_threadNum = std::max((std::uint32_t)1, p_threadNum); + + Local::g_threadPool.reset(new boost::asio::thread_pool(Local::g_threadNum)); +} + + +bool +ThreadPool::Queue(std::function p_workItem) +{ + if (nullptr == Local::g_threadPool) + { + return false; + } + + boost::asio::post(*Local::g_threadPool, std::move(p_workItem)); + return true; +} + + +std::uint32_t +ThreadPool::CurrentThreadNum() +{ + return Local::g_threadNum; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/main.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/main.cpp new file mode 100644 index 0000000000..040703c3ca --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexBuilder/main.cpp @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/IndexBuilder/Options.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Core/Common.h" +#include "inc/Helper/SimpleIniReader.h" + +#include +#include + +using namespace SPTAG; + +int main(int argc, char* argv[]) +{ + std::shared_ptr options(new IndexBuilder::BuilderOptions); + if (!options->Parse(argc - 1, argv + 1)) + { + exit(1); + } + + auto indexBuilder = VectorIndex::CreateInstance(options->m_indexAlgoType, options->m_inputValueType); + + Helper::IniReader iniReader; + if (!options->m_builderConfigFile.empty()) + { + iniReader.LoadIniFile(options->m_builderConfigFile); + } + + for (int i = 1; i < argc; i++) + { + std::string param(argv[i]); + size_t idx = param.find("="); + if (idx == std::string::npos) continue; + + std::string paramName = param.substr(0, idx); + std::string paramVal = param.substr(idx + 1); + std::string sectionName; + idx = paramName.find("."); + if (idx != std::string::npos) { + sectionName = paramName.substr(0, idx); + paramName = paramName.substr(idx + 1); + } + iniReader.SetParameter(sectionName, paramName, paramVal); + std::cout << "Set [" << sectionName << "]" << paramName << " = " << paramVal << std::endl; + } + + if (!iniReader.DoesParameterExist("Index", "NumberOfThreads")) { + iniReader.SetParameter("Index", "NumberOfThreads", std::to_string(options->m_threadNum)); + } + for (const auto& iter : iniReader.GetParameters("Index")) + { + indexBuilder->SetParameter(iter.first.c_str(), iter.second.c_str()); + } + + ErrorCode code; + if (options->m_inputFiles.find("BIN:") == 0) { + std::vector files = SPTAG::Helper::StrUtils::SplitString(options->m_inputFiles.substr(4), ","); + std::ifstream inputStream(files[0], std::ifstream::binary); + if (!inputStream.is_open()) { + fprintf(stderr, "Failed to read input file.\n"); + exit(1); + } + SizeType row; + DimensionType col; + inputStream.read((char*)&row, sizeof(SizeType)); + inputStream.read((char*)&col, sizeof(DimensionType)); + std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(options->m_inputValueType)) * row * col; + ByteArray vectorSet = ByteArray::Alloc(totalRecordVectorBytes); + char* vecBuf = reinterpret_cast(vectorSet.Data()); + inputStream.read(vecBuf, totalRecordVectorBytes); + inputStream.close(); + std::shared_ptr p_vectorSet(new BasicVectorSet(vectorSet, options->m_inputValueType, col, row)); + + std::shared_ptr p_metaSet = nullptr; + if (files.size() >= 3) { + p_metaSet.reset(new FileMetadataSet(files[1], files[2])); + } + code = indexBuilder->BuildIndex(p_vectorSet, p_metaSet); + indexBuilder->SaveIndex(options->m_outputFolder); + } + else { + auto vectorReader = Helper::VectorSetReader::CreateInstance(options); + if (ErrorCode::Success != vectorReader->LoadFile(options->m_inputFiles)) + { + fprintf(stderr, "Failed to read input file.\n"); + exit(1); + } + code = indexBuilder->BuildIndex(vectorReader->GetVectorSet(), vectorReader->GetMetadataSet()); + indexBuilder->SaveIndex(options->m_outputFolder); + } + + if (ErrorCode::Success != code) + { + fprintf(stderr, "Failed to build index.\n"); + exit(1); + } + return 0; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/IndexSearcher/main.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexSearcher/main.cpp new file mode 100644 index 0000000000..0a8c84c2e3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/IndexSearcher/main.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Core/Common.h" +#include "inc/Core/MetadataSet.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SearchQuery.h" +#include "inc/Core/Common/WorkSpace.h" +#include "inc/Core/Common/DataUtils.h" +#include +#include + +using namespace SPTAG; + +template +float CalcRecall(std::vector &results, const std::vector> &truth, SizeType NumQuerys, int K, std::ofstream& log) +{ + float meanrecall = 0, minrecall = MaxDist, maxrecall = 0, stdrecall = 0; + std::vector thisrecall(NumQuerys, 0); + for (SizeType i = 0; i < NumQuerys; i++) + { + for (SizeType id : truth[i]) + { + for (int j = 0; j < K; j++) + { + if (results[i].GetResult(j)->VID == id) + { + thisrecall[i] += 1; + break; + } + } + } + thisrecall[i] /= K; + meanrecall += thisrecall[i]; + if (thisrecall[i] < minrecall) minrecall = thisrecall[i]; + if (thisrecall[i] > maxrecall) maxrecall = thisrecall[i]; + } + meanrecall /= NumQuerys; + for (SizeType i = 0; i < NumQuerys; i++) + { + stdrecall += (thisrecall[i] - meanrecall) * (thisrecall[i] - meanrecall); + } + stdrecall = std::sqrt(stdrecall / NumQuerys); + log << meanrecall << " " << stdrecall << " " << minrecall << " " << maxrecall << std::endl; + return meanrecall; +} + +void LoadTruth(std::ifstream& fp, std::vector>& truth, SizeType NumQuerys, int K) +{ + SizeType get; + std::string line; + for (SizeType i = 0; i < NumQuerys; ++i) + { + truth[i].clear(); + for (int j = 0; j < K; ++j) + { + fp >> get; + truth[i].insert(get); + } + std::getline(fp, line); + } +} + +template +int Process(Helper::IniReader& reader, VectorIndex& index) +{ + std::string queryFile = reader.GetParameter("Index", "QueryFile", std::string("querys.bin")); + std::string truthFile = reader.GetParameter("Index", "TruthFile", std::string("truth.txt")); + std::string outputFile = reader.GetParameter("Index", "ResultFile", std::string("")); + + SizeType numBatchQuerys = reader.GetParameter("Index", "NumBatchQuerys", (SizeType)10000); + SizeType numDebugQuerys = reader.GetParameter("Index", "NumDebugQuerys", (SizeType)-1); + int K = reader.GetParameter("Index", "K", 32); + + std::vector maxCheck = Helper::StrUtils::SplitString(reader.GetParameter("Index", "MaxCheck", std::string("2048")), "#"); + + std::ifstream inStream(queryFile); + std::ifstream ftruth(truthFile); + std::ofstream fp; + if (!inStream.is_open()) + { + std::cout << "ERROR: Cannot Load Query file " << queryFile << "!" << std::endl; + return -1; + } + if (outputFile != "") + { + fp.open(outputFile); + if (!fp.is_open()) + { + std::cout << "ERROR: Cannot open " << outputFile << " for write!" << std::endl; + } + } + + std::ofstream log(index.GetIndexName() + "_" + std::to_string(K) + ".txt"); + if (!log.is_open()) + { + std::cout << "ERROR: Cannot open logging file!" << std::endl; + return -1; + } + + SizeType numQuerys = (numDebugQuerys >= 0) ? numDebugQuerys : numBatchQuerys; + + std::vector> Query(numQuerys, std::vector(index.GetFeatureDim(), 0)); + std::vector> truth(numQuerys); + std::vector results(numQuerys, QueryResult(NULL, K, 0)); + + clock_t * latencies = new clock_t[numQuerys + 1]; + + int base = 1; + if (index.GetDistCalcMethod() == DistCalcMethod::Cosine) { + base = COMMON::Utils::GetBase(); + } + int basesquare = base * base; + + DimensionType dims = index.GetFeatureDim(); + std::vector QStrings; + while (!inStream.eof()) + { + QStrings.clear(); + COMMON::Utils::PrepareQuerys(inStream, QStrings, Query, numQuerys, dims, index.GetDistCalcMethod(), base); + if (numQuerys == 0) break; + + for (SizeType i = 0; i < numQuerys; i++) results[i].SetTarget(Query[i].data()); + if (ftruth.is_open()) LoadTruth(ftruth, truth, numQuerys, K); + + std::cout << " \t[avg] \t[99%] \t[95%] \t[recall] \t[mem]" << std::endl; + + SizeType subSize = (numQuerys - 1) / omp_get_num_threads() + 1; + for (std::string& mc : maxCheck) + { + index.SetParameter("MaxCheck", mc.c_str()); + for (SizeType i = 0; i < numQuerys; i++) results[i].Reset(); + +#pragma omp parallel for + for (int tid = 0; tid < omp_get_num_threads(); tid++) + { + SizeType start = tid * subSize; + SizeType end = min((tid + 1) * subSize, numQuerys); + for (SizeType i = start; i < end; i++) + { + latencies[i] = clock(); + index.SearchIndex(results[i]); + } + } + + latencies[numQuerys] = clock(); + + float timeMean = 0, timeMin = MaxDist, timeMax = 0, timeStd = 0; + for (SizeType i = 0; i < numQuerys; i++) + { + if (latencies[i + 1] >= latencies[i]) + latencies[i] = latencies[i + 1] - latencies[i]; + else + latencies[i] = latencies[numQuerys] - latencies[i]; + timeMean += latencies[i]; + if (latencies[i] > timeMax) timeMax = (float)latencies[i]; + if (latencies[i] < timeMin) timeMin = (float)latencies[i]; + } + timeMean /= numQuerys; + for (SizeType i = 0; i < numQuerys; i++) timeStd += ((float)latencies[i] - timeMean) * ((float)latencies[i] - timeMean); + timeStd = std::sqrt(timeStd / numQuerys); + log << timeMean << " " << timeStd << " " << timeMin << " " << timeMax << " "; + + std::sort(latencies, latencies + numQuerys, [](clock_t x, clock_t y) + { + return x < y; + }); + float l99 = float(latencies[SizeType(numQuerys * 0.99)]) / CLOCKS_PER_SEC; + float l95 = float(latencies[SizeType(numQuerys * 0.95)]) / CLOCKS_PER_SEC; + + float recall = 0; + if (ftruth.is_open()) + { + recall = CalcRecall(results, truth, numQuerys, K, log); + } + +#ifndef _MSC_VER + struct rusage rusage; + getrusage(RUSAGE_SELF, &rusage); + unsigned long long peakWSS = rusage.ru_maxrss * 1024 / 1000000000; +#else + PROCESS_MEMORY_COUNTERS pmc; + GetProcessMemoryInfo(GetCurrentProcess(), &pmc, sizeof(pmc)); + unsigned long long peakWSS = pmc.PeakWorkingSetSize / 1000000000; +#endif + std::cout << mc << "\t" << std::fixed << std::setprecision(6) << (timeMean / CLOCKS_PER_SEC) << "\t" << std::setprecision(4) << l99 << "\t" << l95 << "\t" << recall << "\t\t" << peakWSS << "GB" << std::endl; + + } + + if (fp.is_open()) + { + fp << std::setprecision(3) << std::fixed; + for (SizeType i = 0; i < numQuerys; i++) + { + fp << QStrings[i] << ":"; + for (int j = 0; j < K; j++) + { + if (results[i].GetResult(j)->VID < 0) { + fp << results[i].GetResult(j)->Dist << "@" << results[i].GetResult(j)->VID << std::endl; + } + else { + ByteArray vm = index.GetMetadata(results[i].GetResult(j)->VID); + fp << (results[i].GetResult(j)->Dist / basesquare) << "@"; + fp.write((const char*)vm.Data(), vm.Length()); + } + fp << "|"; + } + fp << std::endl; + } + } + + if (numQuerys < numBatchQuerys || numDebugQuerys >= 0) break; + } + std::cout << "Output results finish!" << std::endl; + + inStream.close(); + fp.close(); + log.close(); + ftruth.close(); + delete[] latencies; + + QStrings.clear(); + results.clear(); + + return 0; +} + +int main(int argc, char** argv) +{ + if (argc < 2) + { + std::cerr << "IndexSearcher.exe folder" << std::endl; + return -1; + } + + std::shared_ptr vecIndex; + auto ret = SPTAG::VectorIndex::LoadIndex(argv[1], vecIndex); + if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) + { + std::cerr << "Cannot open configure file!" << std::endl; + return -1; + } + + Helper::IniReader iniReader; + for (int i = 1; i < argc; i++) + { + std::string param(argv[i]); + size_t idx = param.find("="); + if (idx == std::string::npos) continue; + + std::string paramName = param.substr(0, idx); + std::string paramVal = param.substr(idx + 1); + std::string sectionName; + idx = paramName.find("."); + if (idx != std::string::npos) { + sectionName = paramName.substr(0, idx); + paramName = paramName.substr(idx + 1); + } + iniReader.SetParameter(sectionName, paramName, paramVal); + std::cout << "Set [" << sectionName << "]" << paramName << " = " << paramVal << std::endl; + } + + switch (vecIndex->GetVectorValueType()) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + Process(iniReader, *(vecIndex.get())); \ + break; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + return 0; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/QueryParser.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/QueryParser.cpp new file mode 100644 index 0000000000..0fb47e9390 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/QueryParser.cpp @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/QueryParser.h" + +#include +#include + +using namespace SPTAG; +using namespace SPTAG::Service; + + +const char* QueryParser::c_defaultVectorSeparator = "|"; + + +QueryParser::QueryParser() + : m_vectorBase64(nullptr), + m_vectorBase64Length(0) +{ +} + + +QueryParser::~QueryParser() +{ +} + + +ErrorCode +QueryParser::Parse(const std::string& p_query, const char* p_vectorSeparator) +{ + if (p_vectorSeparator == nullptr) + { + p_vectorSeparator = c_defaultVectorSeparator; + } + + m_vectorElements.clear(); + m_options.clear(); + + m_dataHolder = ByteArray::Alloc(p_query.size() + 1); + memcpy(m_dataHolder.Data(), p_query.c_str(), p_query.size() + 1); + + enum class State : uint8_t + { + OptionNameBegin, + OptionName, + OptionValueBegin, + OptionValue, + Vector, + VectorBase64, + None + }; + + State currState = State::None; + + char* optionName = nullptr; + char* vectorStrBegin = nullptr; + char* vectorStrEnd = nullptr; + SizeType estDimension = 0; + + char* iter = nullptr; + + for (iter = reinterpret_cast(m_dataHolder.Data()); *iter != '\0'; ++iter) + { + if (std::isspace(*iter)) + { + *iter = '\0'; + if (State::Vector == currState) + { + ++estDimension; + vectorStrEnd = iter; + } + else if (State::VectorBase64 == currState) + { + m_vectorBase64Length = iter - m_vectorBase64; + } + + currState = State::None; + continue; + } + + switch (currState) + { + case State::None: + if ('$' == *iter) + { + currState = State::OptionNameBegin; + } + else if ('#' == *iter) + { + currState = State::VectorBase64; + m_vectorBase64 = iter + 1; + } + else + { + currState = State::Vector; + vectorStrBegin = iter; + } + + break; + + case State::OptionNameBegin: + optionName = iter; + currState = State::OptionName; + break; + + case State::OptionName: + if (':' == *iter || '=' == *iter) + { + *iter = '\0'; + currState = State::OptionValueBegin; + } + else if (std::isupper(*iter)) + { + // Convert OptionName to lowercase. + *iter = (*iter) | 0x20; + } + + break; + + case State::OptionValueBegin: + currState = State::OptionValue; + m_options.emplace_back(optionName, iter); + break; + + case State::Vector: + if (std::strchr(p_vectorSeparator, *iter) != nullptr) + { + ++estDimension; + *iter = '\0'; + } + + break; + + default: + break; + } + } + + if (State::Vector == currState) + { + ++estDimension; + vectorStrEnd = iter; + } + else if (State::VectorBase64 == currState) + { + m_vectorBase64Length = iter - m_vectorBase64; + } + + if (vectorStrBegin == nullptr || 0 == estDimension) + { + return ErrorCode::Fail; + } + + m_vectorElements.reserve(estDimension); + while (vectorStrBegin < vectorStrEnd) + { + while (vectorStrBegin < vectorStrEnd && '\0' == *vectorStrBegin) + { + ++vectorStrBegin; + } + + if (vectorStrBegin >= vectorStrEnd) + { + break; + } + + m_vectorElements.push_back(vectorStrBegin); + + while (vectorStrBegin < vectorStrEnd && '\0' != *vectorStrBegin) + { + ++vectorStrBegin; + } + } + + if (m_vectorElements.empty()) + { + return ErrorCode::Fail; + } + + return ErrorCode::Success; +} + + +const std::vector& +QueryParser::GetVectorElements() const +{ + return m_vectorElements; +} + + +const std::vector& +QueryParser::GetOptions() const +{ + return m_options; +} + + +const char* +QueryParser::GetVectorBase64() const +{ + return m_vectorBase64; +} + + +SizeType +QueryParser::GetVectorBase64Length() const +{ + return m_vectorBase64Length; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutionContext.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutionContext.cpp new file mode 100644 index 0000000000..36ff082404 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutionContext.cpp @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/SearchExecutionContext.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/Base64Encode.h" + +using namespace SPTAG; +using namespace SPTAG::Service; + +namespace +{ +namespace Local +{ + +template +ErrorCode +ConvertVectorFromString(const std::vector& p_source, ByteArray& p_dest, SizeType& p_dimension) +{ + p_dimension = 0; + p_dest = ByteArray::Alloc(p_source.size() * sizeof(ValueType)); + ValueType* arr = reinterpret_cast(p_dest.Data()); + for (std::size_t i = 0; i < p_source.size(); ++i) + { + if (!Helper::Convert::ConvertStringTo(p_source[i], arr[i])) + { + p_dest.Clear(); + p_dimension = 0; + return ErrorCode::Fail; + } + + ++p_dimension; + } + + return ErrorCode::Success; +} + +} +} + + +SearchExecutionContext::SearchExecutionContext(const std::shared_ptr& p_serviceSettings) + : c_serviceSettings(p_serviceSettings), + m_vectorDimension(0), + m_inputValueType(VectorValueType::Undefined), + m_extractMetadata(false), + m_resultNum(p_serviceSettings->m_defaultMaxResultNumber) +{ +} + + +SearchExecutionContext::~SearchExecutionContext() +{ + m_results.clear(); +} + + +ErrorCode +SearchExecutionContext::ParseQuery(const std::string& p_query) +{ + return m_queryParser.Parse(p_query, c_serviceSettings->m_vectorSeparator.c_str()); +} + + +ErrorCode +SearchExecutionContext::ExtractOption() +{ + for (const auto& optionPair : m_queryParser.GetOptions()) + { + if (Helper::StrUtils::StrEqualIgnoreCase(optionPair.first, "indexname")) + { + const char* begin = optionPair.second; + const char* end = optionPair.second; + while (*end != '\0') + { + while (*end != '\0' && *end != ',') + { + ++end; + } + + if (end != begin) + { + m_indexNames.emplace_back(begin, end - begin); + } + + if (*end != '\0') + { + ++end; + begin = end; + } + } + } + else if (Helper::StrUtils::StrEqualIgnoreCase(optionPair.first, "datatype")) + { + Helper::Convert::ConvertStringTo(optionPair.second, m_inputValueType); + } + else if (Helper::StrUtils::StrEqualIgnoreCase(optionPair.first, "extractmetadata")) + { + Helper::Convert::ConvertStringTo(optionPair.second, m_extractMetadata); + } + else if (Helper::StrUtils::StrEqualIgnoreCase(optionPair.first, "resultnum")) + { + Helper::Convert::ConvertStringTo(optionPair.second, m_resultNum); + } + } + + return ErrorCode::Success; +} + + +ErrorCode +SearchExecutionContext::ExtractVector(VectorValueType p_targetType) +{ + if (!m_queryParser.GetVectorElements().empty()) + { + switch (p_targetType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return Local::ConvertVectorFromString( \ + m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); \ + break; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + } + else if (m_queryParser.GetVectorBase64() != nullptr + && m_queryParser.GetVectorBase64Length() != 0) + { + SizeType estLen = m_queryParser.GetVectorBase64Length(); + auto temp = ByteArray::Alloc(Helper::Base64::CapacityForDecode(estLen)); + std::size_t outLen = 0; + if (!Helper::Base64::Decode(m_queryParser.GetVectorBase64(), estLen, temp.Data(), outLen)) + { + return ErrorCode::Fail; + } + + if (outLen % GetValueTypeSize(p_targetType) != 0) + { + return ErrorCode::Fail; + } + + m_vectorDimension = outLen / GetValueTypeSize(p_targetType); + m_vector = ByteArray(temp.Data(), outLen, temp.DataHolder()); + + return ErrorCode::Success; + } + + return ErrorCode::Fail; +} + + +const std::vector& +SearchExecutionContext::GetSelectedIndexNames() const +{ + return m_indexNames; +} + + +void +SearchExecutionContext::AddResults(std::string p_indexName, QueryResult& p_results) +{ + m_results.emplace_back(); + m_results.back().m_indexName.swap(p_indexName); + m_results.back().m_results = p_results; +} + + +std::vector& +SearchExecutionContext::GetResults() +{ + return m_results; +} + + +const std::vector& +SearchExecutionContext::GetResults() const +{ + return m_results; +} + + +const ByteArray& +SearchExecutionContext::GetVector() const +{ + return m_vector; +} + + +const SizeType +SearchExecutionContext::GetVectorDimension() const +{ + return m_vectorDimension; +} + + +const std::vector& +SearchExecutionContext::GetOptions() const +{ + return m_queryParser.GetOptions(); +} + + +const SizeType +SearchExecutionContext::GetResultNum() const +{ + return m_resultNum; +} + + +const bool +SearchExecutionContext::GetExtractMetadata() const +{ + return m_extractMetadata; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutor.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutor.cpp new file mode 100644 index 0000000000..2bc3832d88 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchExecutor.cpp @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/SearchExecutor.h" + +using namespace SPTAG; +using namespace SPTAG::Service; + + +SearchExecutor::SearchExecutor(std::string p_queryString, + std::shared_ptr p_serviceContext, + const CallBack& p_callback) + : m_callback(p_callback), + c_serviceContext(std::move(p_serviceContext)), + m_queryString(std::move(p_queryString)) +{ +} + + +SearchExecutor::~SearchExecutor() +{ +} + + +void +SearchExecutor::Execute() +{ + ExecuteInternal(); + if (bool(m_callback)) + { + m_callback(std::move(m_executionContext)); + } +} + + +void +SearchExecutor::ExecuteInternal() +{ + m_executionContext.reset(new SearchExecutionContext(c_serviceContext->GetServiceSettings())); + + m_executionContext->ParseQuery(m_queryString); + m_executionContext->ExtractOption(); + + SelectIndex(); + + if (m_selectedIndex.empty()) + { + return; + } + + const auto& firstIndex = m_selectedIndex.front(); + + if (ErrorCode::Success != m_executionContext->ExtractVector(firstIndex->GetVectorValueType())) + { + return; + } + + if (m_executionContext->GetVectorDimension() != firstIndex->GetFeatureDim()) + { + return; + } + + QueryResult query(m_executionContext->GetVector().Data(), + m_executionContext->GetResultNum(), + m_executionContext->GetExtractMetadata()); + + for (const auto& vectorIndex : m_selectedIndex) + { + if (vectorIndex->GetVectorValueType() != firstIndex->GetVectorValueType() + || vectorIndex->GetFeatureDim() != firstIndex->GetFeatureDim()) + { + continue; + } + + query.Reset(); + if (ErrorCode::Success == vectorIndex->SearchIndex(query)) + { + m_executionContext->AddResults(vectorIndex->GetIndexName(), query); + } + } +} + + +void +SearchExecutor::SelectIndex() +{ + const auto& indexNames = m_executionContext->GetSelectedIndexNames(); + const auto& indexMap = c_serviceContext->GetIndexMap(); + if (indexMap.empty()) + { + return; + } + + if (indexNames.empty()) + { + if (indexMap.size() == 1) + { + m_selectedIndex.push_back(indexMap.begin()->second); + } + } + else + { + for (const auto& indexName : indexNames) + { + auto iter = indexMap.find(indexName); + if (iter != indexMap.cend()) + { + m_selectedIndex.push_back(iter->second); + } + } + } +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchService.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchService.cpp new file mode 100644 index 0000000000..83096fbcde --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/SearchService.cpp @@ -0,0 +1,262 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/SearchService.h" +#include "inc/Server/SearchExecutor.h" +#include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/ArgumentsParser.h" + +#include + +using namespace SPTAG; +using namespace SPTAG::Service; + + +namespace +{ +namespace Local +{ + +class SerivceCmdOptions : public Helper::ArgumentsParser +{ +public: + SerivceCmdOptions() + : m_serveMode("interactive"), + m_configFile("AnnService.ini") + { + AddOptionalOption(m_serveMode, "-m", "--mode", "Service mode, interactive or socket."); + AddOptionalOption(m_configFile, "-c", "--config", "Service config file path."); + } + + virtual ~SerivceCmdOptions() + { + } + + std::string m_serveMode; + + std::string m_configFile; +}; + +} + +} // namespace + + +SearchService::SearchService() + : m_initialized(false), + m_shutdownSignals(m_ioContext), + m_serveMode(ServeMode::Interactive) +{ +} + + +SearchService::~SearchService() +{ +} + + +bool +SearchService::Initialize(int p_argNum, char* p_args[]) +{ + Local::SerivceCmdOptions cmdOptions; + if (!cmdOptions.Parse(p_argNum - 1, p_args + 1)) + { + return false; + } + + if (Helper::StrUtils::StrEqualIgnoreCase(cmdOptions.m_serveMode.c_str(), "interactive")) + { + m_serveMode = ServeMode::Interactive; + } + else if (Helper::StrUtils::StrEqualIgnoreCase(cmdOptions.m_serveMode.c_str(), "socket")) + { + m_serveMode = ServeMode::Socket; + } + else + { + fprintf(stderr, "Failed parse Serve Mode!\n"); + return false; + } + + m_serviceContext.reset(new ServiceContext(cmdOptions.m_configFile)); + + m_initialized = m_serviceContext->IsInitialized(); + + return m_initialized; +} + + +void +SearchService::Run() +{ + if (!m_initialized) + { + return; + } + + switch (m_serveMode) + { + case ServeMode::Interactive: + RunInteractiveMode(); + break; + + case ServeMode::Socket: + RunSocketMode(); + break; + + default: + break; + } +} + + +void +SearchService::RunSocketMode() +{ + auto threadNum = max((SizeType)1, m_serviceContext->GetServiceSettings()->m_threadNum); + m_threadPool.reset(new boost::asio::thread_pool(threadNum)); + + Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); + handlerMap->emplace(Socket::PacketType::SearchRequest, + [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) + { + boost::asio::post(*m_threadPool, std::bind(&SearchService::SearchHanlder, this, p_srcID, std::move(p_packet))); + }); + + m_socketServer.reset(new Socket::Server(m_serviceContext->GetServiceSettings()->m_listenAddr, + m_serviceContext->GetServiceSettings()->m_listenPort, + handlerMap, + m_serviceContext->GetServiceSettings()->m_socketThreadNum)); + + fprintf(stderr, + "Start to listen %s:%s ...\n", + m_serviceContext->GetServiceSettings()->m_listenAddr.c_str(), + m_serviceContext->GetServiceSettings()->m_listenPort.c_str()); + + m_shutdownSignals.add(SIGINT); + m_shutdownSignals.add(SIGTERM); +#ifdef SIGQUIT + m_shutdownSignals.add(SIGQUIT); +#endif + + m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) + { + fprintf(stderr, "Received shutdown signals.\n"); + }); + + m_ioContext.run(); + fprintf(stderr, "Start shutdown procedure.\n"); + + m_socketServer.reset(); + m_threadPool->stop(); + m_threadPool->join(); +} + + +void +SearchService::RunInteractiveMode() +{ + const std::size_t bufferSize = 1 << 16; + std::unique_ptr inputBuffer(new char[bufferSize]); + while (true) + { + std::cout << "Query: "; + if (!fgets(inputBuffer.get(), bufferSize, stdin)) + { + break; + } + + auto callback = [](std::shared_ptr p_exeContext) + { + std::cout << "Result:" << std::endl; + if (nullptr == p_exeContext) + { + std::cout << "Not Executed." << std::endl; + return; + } + + const auto& results = p_exeContext->GetResults(); + for (const auto& result : results) + { + std::cout << "Index: " << result.m_indexName << std::endl; + int idx = 0; + for (const auto& res : result.m_results) + { + std::cout << "------------------" << std::endl; + std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; + if (result.m_results.WithMeta()) + { + const auto& metadata = result.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char*)metadata.Data(), metadata.Length()); + } + std::cout << std::endl; + ++idx; + } + } + }; + + SearchExecutor executor(inputBuffer.get(), m_serviceContext, callback); + executor.Execute(); + } +} + + +void +SearchService::SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +{ + if (p_packet.Header().m_bodyLength == 0) + { + return; + } + + if (Socket::c_invalidConnectionID == p_packet.Header().m_connectionID) + { + p_packet.Header().m_connectionID = p_localConnectionID; + } + + Socket::RemoteQuery remoteQuery; + remoteQuery.Read(p_packet.Body()); + + auto callback = std::bind(&SearchService::SearchHanlderCallback, + this, + std::placeholders::_1, + std::move(p_packet)); + + SearchExecutor executor(std::move(remoteQuery.m_queryString), + m_serviceContext, + callback); + executor.Execute(); +} + + +void +SearchService::SearchHanlderCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket) +{ + Socket::Packet ret; + ret.Header().m_packetType = Socket::PacketType::SearchResponse; + ret.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + ret.Header().m_connectionID = p_srcPacket.Header().m_connectionID; + ret.Header().m_resourceID = p_srcPacket.Header().m_resourceID; + + if (nullptr == p_exeContext) + { + ret.Header().m_processStatus = Socket::PacketProcessStatus::Failed; + ret.AllocateBuffer(0); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + else + { + Socket::RemoteSearchResult remoteResult; + remoteResult.m_status = Socket::RemoteSearchResult::ResultStatus::Success; + remoteResult.m_allIndexResults.swap(p_exeContext->GetResults()); + ret.AllocateBuffer(static_cast(remoteResult.EstimateBufferSize())); + auto bodyEnd = remoteResult.Write(ret.Body()); + + ret.Header().m_bodyLength = static_cast(bodyEnd - ret.Body()); + ret.Header().WriteBuffer(ret.HeaderBuffer()); + } + + m_socketServer->SendPacket(p_srcPacket.Header().m_connectionID, std::move(ret), nullptr); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceContext.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceContext.cpp new file mode 100644 index 0000000000..8d62b2c7af --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceContext.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/ServiceContext.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/StringConvert.h" + +using namespace SPTAG; +using namespace SPTAG::Service; + + +ServiceContext::ServiceContext(const std::string& p_configFilePath) + : m_initialized(false) +{ + Helper::IniReader iniReader; + if (ErrorCode::Success != iniReader.LoadIniFile(p_configFilePath)) + { + return; + } + + m_settings.reset(new ServiceSettings); + + m_settings->m_listenAddr = iniReader.GetParameter("Service", "ListenAddr", std::string("0.0.0.0")); + m_settings->m_listenPort = iniReader.GetParameter("Service", "ListenPort", std::string("8000")); + m_settings->m_threadNum = iniReader.GetParameter("Service", "ThreadNumber", static_cast(8)); + m_settings->m_socketThreadNum = iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); + + m_settings->m_defaultMaxResultNumber = iniReader.GetParameter("QueryConfig", "DefaultMaxResultNumber", static_cast(10)); + m_settings->m_vectorSeparator = iniReader.GetParameter("QueryConfig", "DefaultSeparator", std::string("|")); + + const std::string emptyStr; + + std::string indexListStr = iniReader.GetParameter("Index", "List", emptyStr); + const auto& indexList = Helper::StrUtils::SplitString(indexListStr, ","); + + for (const auto& indexName : indexList) + { + std::string sectionName("Index_"); + sectionName += indexName.c_str(); + if (!iniReader.DoesParameterExist(sectionName, "IndexFolder")) + { + continue; + } + + std::string indexFolder = iniReader.GetParameter(sectionName, "IndexFolder", emptyStr); + + std::shared_ptr vectorIndex; + if (ErrorCode::Success == VectorIndex::LoadIndex(indexFolder, vectorIndex)) + { + vectorIndex->SetIndexName(indexName); + m_fullIndexList.emplace(indexName, vectorIndex); + } + else + { + fprintf(stderr, "Failed loading index: %s\n", indexName.c_str()); + } + } + + m_initialized = true; +} + + +ServiceContext::~ServiceContext() +{ + +} + + +const std::map>& +ServiceContext::GetIndexMap() const +{ + return m_fullIndexList; +} + + +const std::shared_ptr& +ServiceContext::GetServiceSettings() const +{ + return m_settings; +} + + +bool +ServiceContext::IsInitialized() const +{ + return m_initialized; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceSettings.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceSettings.cpp new file mode 100644 index 0000000000..d51153195b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/ServiceSettings.cpp @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/ServiceSettings.h" + +using namespace SPTAG; +using namespace SPTAG::Service; + + +ServiceSettings::ServiceSettings() + : m_defaultMaxResultNumber(10), + m_threadNum(12) +{ +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Server/main.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/main.cpp new file mode 100644 index 0000000000..5aa5dc1e59 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Server/main.cpp @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Server/SearchService.h" + +SPTAG::Service::SearchService g_service; + +int main(int argc, char* argv[]) +{ + if (!g_service.Initialize(argc, argv)) + { + return 1; + } + + g_service.Run(); + + return 0; +} + diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Client.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Client.cpp new file mode 100644 index 0000000000..9c4101e4f4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Client.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/Client.h" + +#include + +using namespace SPTAG::Socket; + + +Client::Client(const PacketHandlerMapPtr& p_handlerMap, + std::size_t p_threadNum, + std::uint32_t p_heartbeatIntervalSeconds) + : c_requestHandlerMap(p_handlerMap), + m_connectionManager(new ConnectionManager), + m_deadlineTimer(m_ioContext), + m_heartbeatIntervalSeconds(p_heartbeatIntervalSeconds), + m_stopped(false) +{ + KeepIoContext(); + m_threadPool.reserve(p_threadNum); + for (std::size_t i = 0; i < p_threadNum; ++i) + { + m_threadPool.emplace_back(std::move(std::thread([this]() { m_ioContext.run(); }))); + } +} + + +Client::~Client() +{ + m_stopped = true; + + m_deadlineTimer.cancel(); + m_connectionManager->StopAll(); + while (!m_ioContext.stopped()) + { + m_ioContext.stop(); + } + + for (auto& t : m_threadPool) + { + t.join(); + } +} + + +ConnectionID +Client::ConnectToServer(const std::string& p_address, + const std::string& p_port, + SPTAG::ErrorCode& p_ec) +{ + boost::asio::ip::tcp::resolver resolver(m_ioContext); + + boost::system::error_code errCode; + auto endPoints = resolver.resolve(p_address, p_port, errCode); + if (errCode || endPoints.empty()) + { + p_ec = ErrorCode::Socket_FailedResolveEndPoint; + return c_invalidConnectionID; + } + + boost::asio::ip::tcp::socket socket(m_ioContext); + for (const auto ep : endPoints) + { + errCode.clear(); + socket.connect(ep, errCode); + if (!errCode) + { + break; + } + + socket.close(errCode); + } + + if (socket.is_open()) + { + p_ec = ErrorCode::Success; + return m_connectionManager->AddConnection(std::move(socket), + c_requestHandlerMap, + m_heartbeatIntervalSeconds); + } + + p_ec = ErrorCode::Socket_FailedConnectToEndPoint; + return c_invalidConnectionID; +} + + +void +Client::AsyncConnectToServer(const std::string& p_address, + const std::string& p_port, + ConnectCallback p_callback) +{ + boost::asio::post(m_ioContext, + [this, p_address, p_port, p_callback]() + { + SPTAG::ErrorCode errCode; + auto connID = ConnectToServer(p_address, p_port, errCode); + if (bool(p_callback)) + { + p_callback(connID, errCode); + } + }); +} + + +void +Client::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) +{ + auto connection = m_connectionManager->GetConnection(p_connection); + if (nullptr != connection) + { + connection->AsyncSend(std::move(p_packet), std::move(p_callback)); + } + else if (bool(p_callback)) + { + p_callback(false); + } +} + + +void +Client::SetEventOnConnectionClose(std::function p_event) +{ + m_connectionManager->SetEventOnRemoving(std::move(p_event)); +} + + +void +Client::KeepIoContext() +{ + if (m_stopped) + { + return; + } + + m_deadlineTimer.expires_from_now(boost::posix_time::hours(24)); + m_deadlineTimer.async_wait([this](boost::system::error_code p_ec) + { + this->KeepIoContext(); + }); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Common.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Common.cpp new file mode 100644 index 0000000000..2cfc1178ed --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Common.cpp @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/Common.h" + + +using namespace SPTAG::Socket; + +const ConnectionID SPTAG::Socket::c_invalidConnectionID = 0; + +const ResourceID SPTAG::Socket::c_invalidResourceID = 0; diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Connection.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Connection.cpp new file mode 100644 index 0000000000..6e536cbfcf --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Connection.cpp @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/Connection.h" +#include "inc/Socket/ConnectionManager.h" + +#include +#include +#include +#include +#include + +#include + +using namespace SPTAG::Socket; + +Connection::Connection(ConnectionID p_connectionID, + boost::asio::ip::tcp::socket&& p_socket, + const PacketHandlerMapPtr& p_handlerMap, + std::weak_ptr p_connectionManager) + : c_connectionID(p_connectionID), + c_handlerMap(p_handlerMap), + c_connectionManager(std::move(p_connectionManager)), + m_socket(std::move(p_socket)), + m_strand(p_socket.get_executor().context()), + m_heartbeatTimer(p_socket.get_executor().context()), + m_remoteConnectionID(c_invalidConnectionID), + m_stopped(true), + m_heartbeatStarted(false) +{ +} + + +void +Connection::Start() +{ +#ifdef _DEBUG + fprintf(stderr, "Connection Start, local: %u, remote: %s:%u\n", + static_cast(m_socket.local_endpoint().port()), + m_socket.remote_endpoint().address().to_string().c_str(), + static_cast(m_socket.remote_endpoint().port())); +#endif + + if (!m_stopped.exchange(false)) + { + return; + } + + SendRegister(); + AsyncReadHeader(); +} + + +void +Connection::Stop() +{ +#ifdef _DEBUG + fprintf(stderr, "Connection Stop, local: %u, remote: %s:%u\n", + static_cast(m_socket.local_endpoint().port()), + m_socket.remote_endpoint().address().to_string().c_str(), + static_cast(m_socket.remote_endpoint().port())); +#endif + + if (m_stopped.exchange(true)) + { + return; + } + + boost::system::error_code errCode; + if (m_heartbeatStarted.exchange(false)) + { + m_heartbeatTimer.cancel(errCode); + } + + m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, errCode); + m_socket.close(errCode); +} + + +void +Connection::StartHeartbeat(std::size_t p_intervalSeconds) +{ + if (m_stopped || m_heartbeatStarted.exchange(true)) + { + return; + } + + SendHeartbeat(p_intervalSeconds); +} + + +ConnectionID +Connection::GetConnectionID() const +{ + return c_connectionID; +} + + +ConnectionID +Connection::GetRemoteConnectionID() const +{ + return m_remoteConnectionID; +} + + +void +Connection::AsyncSend(Packet p_packet, std::function p_callback) +{ + if (m_stopped) + { + if (bool(p_callback)) + { + p_callback(false); + } + + return; + } + + auto sharedThis = shared_from_this(); + boost::asio::post(m_strand, + [sharedThis, p_packet, p_callback]() + { + auto handler = [p_callback, p_packet, sharedThis](boost::system::error_code p_ec, + std::size_t p_bytesTransferred) + { + if (p_ec && boost::asio::error::operation_aborted != p_ec) + { + sharedThis->OnConnectionFail(p_ec); + } + + if (bool(p_callback)) + { + p_callback(!p_ec); + } + }; + + boost::asio::async_write(sharedThis->m_socket, + boost::asio::buffer(p_packet.Buffer(), + p_packet.BufferLength()), + std::move(handler)); + }); +} + + +void +Connection::AsyncReadHeader() +{ + if (m_stopped) + { + return; + } + + auto sharedThis = shared_from_this(); + boost::asio::post(m_strand, + [sharedThis]() + { + auto handler = boost::bind(&Connection::HandleReadHeader, + sharedThis, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred); + + boost::asio::async_read(sharedThis->m_socket, + boost::asio::buffer(sharedThis->m_packetHeaderReadBuffer), + std::move(handler)); + }); +} + + +void +Connection::AsyncReadBody() +{ + if (m_stopped) + { + return; + } + + auto sharedThis = shared_from_this(); + boost::asio::post(m_strand, + [sharedThis]() + { + auto handler = boost::bind(&Connection::HandleReadBody, + sharedThis, + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred); + + boost::asio::async_read(sharedThis->m_socket, + boost::asio::buffer(sharedThis->m_packetRead.Body(), + sharedThis->m_packetRead.Header().m_bodyLength), + std::move(handler)); + }); +} + + +void +Connection::HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred) +{ + if (!p_ec) + { + m_packetRead.Header().ReadBuffer(m_packetHeaderReadBuffer); + if (m_packetRead.Header().m_bodyLength > 0) + { + m_packetRead.AllocateBuffer(m_packetRead.Header().m_bodyLength); + AsyncReadBody(); + } + else + { + HandleReadBody(p_ec, p_bytesTransferred); + } + + return; + } + else if (boost::asio::error::operation_aborted != p_ec) + { + OnConnectionFail(p_ec); + return; + } + + AsyncReadHeader(); +} + + +void +Connection::HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred) +{ + if (!p_ec) + { + bool foundHanlder = true; + switch (m_packetRead.Header().m_packetType) + { + case PacketType::HeartbeatRequest: + HandleHeartbeatRequest(); + break; + + case PacketType::HeartbeatResponse: + break; + + case PacketType::RegisterRequest: + HandleRegisterRequest(); + break; + + case PacketType::RegisterResponse: + HandleRegisterResponse(); + break; + + default: + foundHanlder = false; + break; + } + + if (nullptr != c_handlerMap) + { + auto iter = c_handlerMap->find(m_packetRead.Header().m_packetType); + if (c_handlerMap->cend() != iter && bool(iter->second)) + { + (iter->second)(c_connectionID, std::move(m_packetRead)); + foundHanlder = true; + } + } + + if (!foundHanlder) + { + HandleNoHandlerResponse(); + } + } + else if (boost::asio::error::operation_aborted != p_ec) + { + OnConnectionFail(p_ec); + return; + } + + AsyncReadHeader(); +} + + +void +Connection::SendHeartbeat(std::size_t p_intervalSeconds) +{ + if (m_stopped) + { + return; + } + + Packet msg; + msg.Header().m_packetType = PacketType::HeartbeatRequest; + msg.Header().m_processStatus = PacketProcessStatus::Ok; + msg.Header().m_connectionID = 0; + + msg.AllocateBuffer(0); + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + AsyncSend(std::move(msg), nullptr); + + m_heartbeatTimer.expires_from_now(boost::posix_time::seconds(p_intervalSeconds)); + m_heartbeatTimer.async_wait(boost::bind(&Connection::SendHeartbeat, + shared_from_this(), + p_intervalSeconds)); +} + + +void +Connection::SendRegister() +{ + Packet msg; + msg.Header().m_packetType = PacketType::RegisterRequest; + msg.Header().m_processStatus = PacketProcessStatus::Ok; + msg.Header().m_connectionID = 0; + + msg.AllocateBuffer(0); + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + AsyncSend(std::move(msg), nullptr); +} + + +void +Connection::HandleHeartbeatRequest() +{ + Packet msg; + msg.Header().m_packetType = PacketType::HeartbeatResponse; + msg.Header().m_processStatus = PacketProcessStatus::Ok; + + msg.AllocateBuffer(0); + + if (0 == m_packetRead.Header().m_connectionID + || c_connectionID == m_packetRead.Header().m_connectionID) + { + m_packetRead.Header().m_connectionID; + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + AsyncSend(std::move(msg), nullptr); + } + else + { + msg.Header().m_connectionID = m_packetRead.Header().m_connectionID; + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + auto mgr = c_connectionManager.lock(); + if (nullptr != mgr) + { + auto con = mgr->GetConnection(m_packetRead.Header().m_connectionID); + if (nullptr != con) + { + con->AsyncSend(std::move(msg), nullptr); + } + } + } +} + + +void +Connection::HandleRegisterRequest() +{ + Packet msg; + msg.Header().m_packetType = PacketType::RegisterResponse; + msg.Header().m_processStatus = PacketProcessStatus::Ok; + msg.Header().m_connectionID = c_connectionID; + msg.Header().m_resourceID = m_packetRead.Header().m_resourceID; + + msg.AllocateBuffer(0); + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + AsyncSend(std::move(msg), nullptr); +} + + +void +Connection::HandleRegisterResponse() +{ + m_remoteConnectionID = m_packetRead.Header().m_connectionID; +} + + +void +Connection::HandleNoHandlerResponse() +{ + auto packetType = m_packetRead.Header().m_packetType; + if (!PacketTypeHelper::IsRequestPacket(packetType)) + { + return; + } + + Packet msg; + msg.Header().m_packetType = PacketTypeHelper::GetCrosspondingResponseType(packetType); + msg.Header().m_processStatus = PacketProcessStatus::Dropped; + msg.Header().m_connectionID = c_connectionID; + msg.Header().m_resourceID = m_packetRead.Header().m_resourceID; + + msg.AllocateBuffer(0); + msg.Header().WriteBuffer(msg.HeaderBuffer()); + + AsyncSend(std::move(msg), nullptr); +} + + +void +Connection::OnConnectionFail(const boost::system::error_code& p_ec) +{ + auto mgr = c_connectionManager.lock(); + if (nullptr != mgr) + { + mgr->RemoveConnection(c_connectionID); + } +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/ConnectionManager.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/ConnectionManager.cpp new file mode 100644 index 0000000000..9d52dbc8b1 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/ConnectionManager.cpp @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/ConnectionManager.h" + +using namespace SPTAG::Socket; + + +ConnectionManager::ConnectionItem::ConnectionItem() + : m_isEmpty(true) +{ +} + + +ConnectionManager::ConnectionManager() + : m_nextConnectionID(1), + m_connectionCount(0) +{ +} + + +ConnectionID +ConnectionManager::AddConnection(boost::asio::ip::tcp::socket&& p_socket, + const PacketHandlerMapPtr& p_handler, + std::uint32_t p_heartbeatIntervalSeconds) +{ + ConnectionID currID = m_nextConnectionID.fetch_add(1); + while (c_invalidConnectionID == currID || !m_connections[GetPosition(currID)].m_isEmpty.exchange(false)) + { + if (m_connectionCount >= c_connectionPoolSize) + { + return c_invalidConnectionID; + } + + currID = m_nextConnectionID.fetch_add(1); + } + + ++m_connectionCount; + + auto connection = std::make_shared(currID, + std::move(p_socket), + p_handler, + std::weak_ptr(shared_from_this())); + + { + Helper::Concurrent::LockGuard guard(m_spinLock); + m_connections[GetPosition(currID)].m_connection = connection; + } + + connection->Start(); + if (p_heartbeatIntervalSeconds > 0) + { + connection->StartHeartbeat(p_heartbeatIntervalSeconds); + } + + return currID; +} + + +void +ConnectionManager::RemoveConnection(ConnectionID p_connectionID) +{ + auto position = GetPosition(p_connectionID); + if (m_connections[position].m_isEmpty.exchange(true)) + { + return; + } + + Connection::Ptr conn; + + { + Helper::Concurrent::LockGuard guard(m_spinLock); + conn = std::move(m_connections[position].m_connection); + } + + --m_connectionCount; + + conn->Stop(); + conn.reset(); + + if (bool(m_eventOnRemoving)) + { + m_eventOnRemoving(p_connectionID); + } +} + + +Connection::Ptr +ConnectionManager::GetConnection(ConnectionID p_connectionID) +{ + auto position = GetPosition(p_connectionID); + Connection::Ptr ret; + + { + Helper::Concurrent::LockGuard guard(m_spinLock); + ret = m_connections[position].m_connection; + } + + if (nullptr == ret || ret->GetConnectionID() != p_connectionID) + { + return nullptr; + } + + return ret; +} + + +void +ConnectionManager::SetEventOnRemoving(std::function p_event) +{ + m_eventOnRemoving = std::move(p_event); +} + + +void +ConnectionManager::StopAll() +{ + Helper::Concurrent::LockGuard guard(m_spinLock); + for (auto& connection : m_connections) + { + if (nullptr != connection.m_connection) + { + connection.m_connection->Stop(); + } + } +} + + +std::uint32_t +ConnectionManager::GetPosition(ConnectionID p_connectionID) +{ + return static_cast(p_connectionID) & c_connectionPoolMask; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Packet.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Packet.cpp new file mode 100644 index 0000000000..335400bbd8 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Packet.cpp @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/Packet.h" +#include "inc/Socket/SimpleSerialization.h" + +#include + +using namespace SPTAG::Socket; + +PacketHeader::PacketHeader() + : m_packetType(PacketType::Undefined), + m_processStatus(PacketProcessStatus::Ok), + m_bodyLength(0), + m_connectionID(c_invalidConnectionID), + m_resourceID(c_invalidResourceID) +{ +} + + +PacketHeader::PacketHeader(PacketHeader&& p_right) + : m_packetType(std::move(p_right.m_packetType)), + m_processStatus(std::move(p_right.m_processStatus)), + m_bodyLength(std::move(p_right.m_bodyLength)), + m_connectionID(std::move(p_right.m_connectionID)), + m_resourceID(std::move(p_right.m_resourceID)) +{ +} + + +PacketHeader::PacketHeader(const PacketHeader& p_right) + : m_packetType(p_right.m_packetType), + m_processStatus(p_right.m_processStatus), + m_bodyLength(p_right.m_bodyLength), + m_connectionID(p_right.m_connectionID), + m_resourceID(p_right.m_resourceID) +{ +} + + +std::size_t +PacketHeader::WriteBuffer(std::uint8_t* p_buffer) +{ + std::uint8_t* buff = p_buffer; + buff = SimpleSerialization::SimpleWriteBuffer(m_packetType, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_processStatus, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_bodyLength, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_connectionID, buff); + buff = SimpleSerialization::SimpleWriteBuffer(m_resourceID, buff); + + return p_buffer - buff; +} + + +void +PacketHeader::ReadBuffer(const std::uint8_t* p_buffer) +{ + const std::uint8_t* buff = p_buffer; + buff = SimpleSerialization::SimpleReadBuffer(buff, m_packetType); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_processStatus); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_bodyLength); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_connectionID); + buff = SimpleSerialization::SimpleReadBuffer(buff, m_resourceID); +} + + +Packet::Packet() +{ +} + + +Packet::Packet(Packet&& p_right) + : m_header(std::move(p_right.m_header)), + m_buffer(std::move(p_right.m_buffer)), + m_bufferCapacity(std::move(p_right.m_bufferCapacity)) +{ +} + + +Packet::Packet(const Packet& p_right) + : m_header(p_right.m_header), + m_buffer(p_right.m_buffer), + m_bufferCapacity(p_right.m_bufferCapacity) +{ +} + + +PacketHeader& +Packet::Header() +{ + return m_header; +} + + +std::uint8_t* +Packet::HeaderBuffer() const +{ + return m_buffer.get(); +} + + +std::uint8_t* +Packet::Body() const +{ + if (nullptr != m_buffer && PacketHeader::c_bufferSize < m_bufferCapacity) + { + return m_buffer.get() + PacketHeader::c_bufferSize; + } + + return nullptr; +} + + +std::uint8_t* +Packet::Buffer() const +{ + return m_buffer.get(); +} + + +std::uint32_t +Packet::BufferLength() const +{ + return PacketHeader::c_bufferSize + m_header.m_bodyLength; +} + + +std::uint32_t +Packet::BufferCapacity() const +{ + return m_bufferCapacity; +} + + +void +Packet::AllocateBuffer(std::uint32_t p_bodyCapacity) +{ + m_bufferCapacity = PacketHeader::c_bufferSize + p_bodyCapacity; + m_buffer.reset(new std::uint8_t[m_bufferCapacity], std::default_delete()); +} + + +bool +PacketTypeHelper::IsRequestPacket(PacketType p_type) +{ + if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) + { + return false; + } + + return (static_cast(p_type) & static_cast(PacketType::ResponseMask)) == 0; +} + + +bool +PacketTypeHelper::IsResponsePacket(PacketType p_type) +{ + if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) + { + return false; + } + + return (static_cast(p_type) & static_cast(PacketType::ResponseMask)) != 0; +} + + +PacketType +PacketTypeHelper::GetCrosspondingResponseType(PacketType p_type) +{ + if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) + { + return PacketType::Undefined; + } + + auto ret = static_cast(p_type) | static_cast(PacketType::ResponseMask); + return static_cast(ret); +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/RemoteSearchQuery.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/RemoteSearchQuery.cpp new file mode 100644 index 0000000000..2cb450328e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/RemoteSearchQuery.cpp @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/RemoteSearchQuery.h" +#include "inc/Socket/SimpleSerialization.h" + +using namespace SPTAG; +using namespace SPTAG::Socket; + + +RemoteQuery::RemoteQuery() + : m_type(QueryType::String) +{ +} + + +std::size_t +RemoteQuery::EstimateBufferSize() const +{ + std::size_t sum = 0; + sum += SimpleSerialization::EstimateBufferSize(MajorVersion()); + sum += SimpleSerialization::EstimateBufferSize(MirrorVersion()); + sum += SimpleSerialization::EstimateBufferSize(m_type); + sum += SimpleSerialization::EstimateBufferSize(m_queryString); + + return sum; +} + + +std::uint8_t* +RemoteQuery::Write(std::uint8_t* p_buffer) const +{ + p_buffer = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), p_buffer); + + p_buffer = SimpleSerialization::SimpleWriteBuffer(m_type, p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(m_queryString, p_buffer); + + return p_buffer; +} + + +const std::uint8_t* +RemoteQuery::Read(const std::uint8_t* p_buffer) +{ + decltype(MajorVersion()) majorVer = 0; + decltype(MirrorVersion()) mirrorVer = 0; + + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, majorVer); + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, mirrorVer); + if (majorVer != MajorVersion()) + { + return nullptr; + } + + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, m_type); + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, m_queryString); + + return p_buffer; +} + + +RemoteSearchResult::RemoteSearchResult() + : m_status(ResultStatus::Timeout) +{ +} + + +RemoteSearchResult::RemoteSearchResult(const RemoteSearchResult& p_right) + : m_status(p_right.m_status), + m_allIndexResults(p_right.m_allIndexResults) +{ +} + + +RemoteSearchResult::RemoteSearchResult(RemoteSearchResult&& p_right) + : m_status(std::move(p_right.m_status)), + m_allIndexResults(std::move(p_right.m_allIndexResults)) +{ +} + + +RemoteSearchResult& +RemoteSearchResult::operator=(RemoteSearchResult&& p_right) +{ + m_status = p_right.m_status; + m_allIndexResults = std::move(p_right.m_allIndexResults); + + return *this; +} + + +std::size_t +RemoteSearchResult::EstimateBufferSize() const +{ + std::size_t sum = 0; + sum += SimpleSerialization::EstimateBufferSize(MajorVersion()); + sum += SimpleSerialization::EstimateBufferSize(MirrorVersion()); + + sum += SimpleSerialization::EstimateBufferSize(m_status); + + sum += sizeof(std::uint32_t); + for (const auto& indexRes : m_allIndexResults) + { + sum += SimpleSerialization::EstimateBufferSize(indexRes.m_indexName); + sum += sizeof(std::uint32_t); + sum += sizeof(bool); + + for (const auto& res : indexRes.m_results) + { + sum += SimpleSerialization::EstimateBufferSize(res.VID); + sum += SimpleSerialization::EstimateBufferSize(res.Dist); + } + + if (indexRes.m_results.WithMeta()) + { + for (int i = 0; i < indexRes.m_results.GetResultNum(); ++i) + { + sum += SimpleSerialization::EstimateBufferSize(indexRes.m_results.GetMetadata(i)); + } + } + } + + return sum; +} + + +std::uint8_t* +RemoteSearchResult::Write(std::uint8_t* p_buffer) const +{ + p_buffer = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), p_buffer); + + p_buffer = SimpleSerialization::SimpleWriteBuffer(m_status, p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(static_cast(m_allIndexResults.size()), p_buffer); + for (const auto& indexRes : m_allIndexResults) + { + p_buffer = SimpleSerialization::SimpleWriteBuffer(indexRes.m_indexName, p_buffer); + + p_buffer = SimpleSerialization::SimpleWriteBuffer(static_cast(indexRes.m_results.GetResultNum()), p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(indexRes.m_results.WithMeta(), p_buffer); + + for (const auto& res : indexRes.m_results) + { + p_buffer = SimpleSerialization::SimpleWriteBuffer(res.VID, p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(res.Dist, p_buffer); + } + + if (indexRes.m_results.WithMeta()) + { + for (int i = 0; i < indexRes.m_results.GetResultNum(); ++i) + { + p_buffer = SimpleSerialization::SimpleWriteBuffer(indexRes.m_results.GetMetadata(i), p_buffer); + } + } + } + + return p_buffer; +} + + +const std::uint8_t* +RemoteSearchResult::Read(const std::uint8_t* p_buffer) +{ + decltype(MajorVersion()) majorVer = 0; + decltype(MirrorVersion()) mirrorVer = 0; + + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, majorVer); + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, mirrorVer); + if (majorVer != MajorVersion()) + { + return nullptr; + } + + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, m_status); + + std::uint32_t len = 0; + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, len); + m_allIndexResults.resize(len); + + for (auto& indexRes : m_allIndexResults) + { + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, indexRes.m_indexName); + + std::uint32_t resNum = 0; + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, resNum); + + bool withMeta = false; + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, withMeta); + + indexRes.m_results.Init(nullptr, resNum, withMeta); + for (auto& res : indexRes.m_results) + { + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.VID); + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.Dist); + } + + if (withMeta) + { + for (int i = 0; i < indexRes.m_results.GetResultNum(); ++i) + { + ByteArray meta; + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, meta); + indexRes.m_results.SetMetadata(i, std::move(meta)); + } + } + } + + return p_buffer; +} diff --git a/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Server.cpp b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Server.cpp new file mode 100644 index 0000000000..86d60040bf --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/AnnService/src/Socket/Server.cpp @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Socket/Server.h" + +#include + +using namespace SPTAG::Socket; + +Server::Server(const std::string& p_address, + const std::string& p_port, + const PacketHandlerMapPtr& p_handlerMap, + std::size_t p_threadNum) + : m_requestHandlerMap(p_handlerMap), + m_connectionManager(new ConnectionManager), + m_acceptor(m_ioContext) +{ + boost::asio::ip::tcp::resolver resolver(m_ioContext); + + boost::system::error_code errCode; + auto endPoints = resolver.resolve(p_address, p_port, errCode); + if (errCode) + { + fprintf(stderr, + "Failed to resolve %s %s, error: %s", + p_address.c_str(), + p_port.c_str(), + errCode.message().c_str()); + + throw std::runtime_error("Failed to resolve address."); + } + + boost::asio::ip::tcp::endpoint endpoint = *(endPoints.begin()); + m_acceptor.open(endpoint.protocol()); + m_acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(false)); + + m_acceptor.bind(endpoint, errCode); + if (errCode) + { + fprintf(stderr, + "Failed to bind %s %s, error: %s", + p_address.c_str(), + p_port.c_str(), + errCode.message().c_str()); + + throw std::runtime_error("Failed to bind port."); + } + + m_acceptor.listen(boost::asio::socket_base::max_listen_connections, errCode); + if (errCode) + { + fprintf(stderr, + "Failed to listen %s %s, error: %s", + p_address.c_str(), + p_port.c_str(), + errCode.message().c_str()); + + throw std::runtime_error("Failed to listen port."); + } + + StartAccept(); + + m_threadPool.reserve(p_threadNum); + for (std::size_t i = 0; i < p_threadNum; ++i) + { + m_threadPool.emplace_back(std::move(std::thread([this]() { StartListen(); }))); + } +} + + +Server::~Server() +{ + m_acceptor.close(); + m_connectionManager->StopAll(); + while (!m_ioContext.stopped()) + { + m_ioContext.stop(); + } + + for (auto& t : m_threadPool) + { + t.join(); + } +} + + +void +Server::SetEventOnConnectionClose(std::function p_event) +{ + m_connectionManager->SetEventOnRemoving(std::move(p_event)); +} + + +void +Server::StartAccept() +{ + m_acceptor.async_accept([this](boost::system::error_code p_ec, + boost::asio::ip::tcp::socket p_socket) + { + if (!m_acceptor.is_open()) + { + return; + } + + if (!p_ec) + { + m_connectionManager->AddConnection(std::move(p_socket), + m_requestHandlerMap, + 0); + } + + StartAccept(); + }); +} + + +void +Server::StartListen() +{ + m_ioContext.run(); +} + + +void +Server::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) +{ + auto connection = m_connectionManager->GetConnection(p_connection); + if (nullptr != connection) + { + connection->AsyncSend(std::move(p_packet), std::move(p_callback)); + } + else if (bool(p_callback)) + { + p_callback(false); + } +} diff --git a/core/src/index/thirdparty/SPTAG/CMakeLists.txt b/core/src/index/thirdparty/SPTAG/CMakeLists.txt new file mode 100644 index 0000000000..44544bf7e9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/CMakeLists.txt @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +cmake_minimum_required (VERSION 3.12) + +project (SPTAGLib) + +function(CXX_COMPILER_DUMPVERSION _OUTPUT_VERSION) + exec_program(${CMAKE_CXX_COMPILER} + ARGS ${CMAKE_CXX_COMPILER_ARG1} -dumpversion + OUTPUT_VARIABLE COMPILER_VERSION + ) + + set(${_OUTPUT_VERSION} ${COMPILER_VERSION} PARENT_SCOPE) +endfunction() + +if(NOT WIN32) + CXX_COMPILER_DUMPVERSION(CXX_COMPILER_VERSION) +endif() + +if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + # require at least gcc 5.0 + if (CXX_COMPILER_VERSION VERSION_LESS 5.0) + message(FATAL_ERROR "GCC version must be at least 5.0!") + endif() + set (CMAKE_CXX_FLAGS_RELEASE "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -lm -lrt -DNDEBUG -std=c++14 -fopenmp -march=native") + set (CMAKE_CXX_FLAGS_DEBUG "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -ggdb -lm -lrt -DNDEBUG -std=c++14 -fopenmp -march=native") +elseif(WIN32) + if(NOT MSVC14) + message(FATAL_ERROR "On Windows, only MSVC version 14 are supported!") + endif() +else () + message(FATAL_ERROR "Unrecognized compiler (use GCC or MSVC)!") +endif() + +if (NOT CMAKE_BUILD_TYPE) + set (CMAKE_BUILD_TYPE Release CACHE STRING "Build types: Release Debug" FORCE) +endif() +message (STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +if (${CMAKE_SIZEOF_VOID_P} EQUAL "8") + set (PROJECTNAME_ARCHITECTURE "x64") +else () + set (PROJECTNAME_ARCHITECTURE "x86") +endif () +message (STATUS "Platform type: ${PROJECTNAME_ARCHITECTURE}") + +set(Boost_USE_MULTITHREADED ON) + +if (WIN32) + set(Boost_USE_STATIC_LIBS ON) + + set(CMAKE_CONFIGURATION_TYPES ${CMAKE_BUILD_TYPE}) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + + set (LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR}/${CMAKE_CFG_INTDIR}) + set (EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/${CMAKE_CFG_INTDIR}) +else() + set (LIBRARY_OUTPUT_PATH "${PROJECT_SOURCE_DIR}/${CMAKE_BUILD_TYPE}/") + set (EXECUTABLE_OUTPUT_PATH "${PROJECT_SOURCE_DIR}/${CMAKE_BUILD_TYPE}/") +endif() + +set (CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}) + +find_package(OpenMP) +if (OpenMP_FOUND) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + message (STATUS "Found openmp.") +else() + message (FATAL_ERROR "Could no find openmp!") +endif() + +find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex) +if (Boost_FOUND) + include_directories (${Boost_INCLUDE_DIR}) + link_directories (${Boost_LIBRARY_DIR} "/usr/lib") + message (STATUS "Found Boost.") + message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}") + message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}") + message (STATUS "Library: ${Boost_LIBRARIES}") +else() + message (FATAL_ERROR "Could not find Boost 1.67!") +endif() + +add_subdirectory (AnnService) +add_subdirectory (Wrappers) +add_subdirectory (Test) diff --git a/core/src/index/thirdparty/SPTAG/Dockerfile b/core/src/index/thirdparty/SPTAG/Dockerfile new file mode 100644 index 0000000000..59c8c70166 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Dockerfile @@ -0,0 +1,35 @@ +FROM ubuntu:18.04 + +WORKDIR /app +COPY CMakeLists.txt ./ +COPY AnnService ./AnnService/ +COPY Test ./Test/ +COPY Wrappers ./Wrappers/ + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +RUN apt-get update && apt-get install -y --no-install-recommends wget build-essential \ + # remove the following if you don't want to build the wrappers + openjdk-8-jdk python3-pip swig && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +# cmake >= 3.12 is required +RUN wget "https://github.com/Kitware/CMake/releases/download/v3.14.4/cmake-3.14.4-Linux-x86_64.tar.gz" -q -O - \ + | tar -xz --strip-components=1 -C /usr/local + +# specific version of boost +RUN wget "https://dl.bintray.com/boostorg/release/1.67.0/source/boost_1_67_0.tar.gz" -q -O - \ + | tar -xz && \ + cd boost_1_67_0 && \ + ./bootstrap.sh && \ + ./b2 install && \ + # update ld cache so it finds boost in /usr/local/lib + ldconfig && \ + cd .. && rm -rf boost_1_67_0 + +# build +RUN mkdir build && cd build && cmake .. && make && cd .. + +# so python can find the SPTAG module +ENV PYTHONPATH=/app/Release \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/LICENSE b/core/src/index/thirdparty/SPTAG/LICENSE new file mode 100644 index 0000000000..d1ca00f20a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/README.md b/core/src/index/thirdparty/SPTAG/README.md new file mode 100644 index 0000000000..ae4f0aab9b --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/README.md @@ -0,0 +1,144 @@ +# SPTAG: A library for fast approximate nearest neighbor search + +[![MIT licensed](https://img.shields.io/badge/license-MIT-yellow.svg)](https://github.com/Microsoft/SPTAG/blob/master/LICENSE) +[![Build status](https://sysdnn.visualstudio.com/SPTAG/_apis/build/status/SPTAG-GITHUB)](https://sysdnn.visualstudio.com/SPTAG/_build/latest?definitionId=2) + +## **SPTAG** + SPTAG (Space Partition Tree And Graph) is a library for large scale vector approximate nearest neighbor search scenario released by [Microsoft Research (MSR)](https://www.msra.cn/) and [Microsoft Bing](http://bing.com). + +

+ architecture +

+ + + +## **Introduction** + +This library assumes that the samples are represented as vectors and that the vectors can be compared by L2 distances or cosine distances. +Vectors returned for a query vector are the vectors that have smallest L2 distance or cosine distances with the query vector. + +SPTAG provides two methods: kd-tree and relative neighborhood graph (SPTAG-KDT) +and balanced k-means tree and relative neighborhood graph (SPTAG-BKT). +SPTAG-KDT is advantageous in index building cost, and SPTAG-BKT is advantageous in search accuracy in very high-dimensional data. + + + +## **How it works** + +SPTAG is inspired by the NGS approach [[WangL12](#References)]. It contains two basic modules: index builder and searcher. +The RNG is built on the k-nearest neighborhood graph [[WangWZTG12](#References), [WangWJLZZH14](#References)] +for boosting the connectivity. Balanced k-means trees are used to replace kd-trees to avoid the inaccurate distance bound estimation in kd-trees for very high-dimensional vectors. +The search begins with the search in the space partition trees for +finding several seeds to start the search in the RNG. +The searches in the trees and the graph are iteratively conducted. + + ## **Highlights** + * Fresh update: Support online vector deletion and insertion + * Distributed serving: Search over multiple machines + + ## **Build** + +### **Requirements** + +* swig >= 3.0 +* cmake >= 3.12.0 +* boost >= 1.67.0 + +### **Install** + +> For Linux: +```bash +mkdir build +cd build && cmake .. && make +``` +It will generate a Release folder in the code directory which contains all the build targets. + +> For Windows: +```bash +mkdir build +cd build && cmake -A x64 .. +``` +It will generate a SPTAGLib.sln in the build directory. +Compiling the ALL_BUILD project in the Visual Studio (at least 2015) will generate a Release directory which contains all the build targets. + +> Using Docker: +```bash +docker build -t sptag . +``` +Will build a docker container with binaries in `/app/Release/`. + +### **Verify** + +Run the test (or Test.exe) in the Release folder to verify all the tests have passed. + +### **Usage** + +The detailed usage can be found in [Get started](docs/GettingStart.md). +The detailed parameters tunning can be found in [Parameters](docs/Parameters.md). + +## **References** +Please cite SPTAG in your publications if it helps your research: +``` +@manual{ChenW18, + author = {Qi Chen and + Haidong Wang and + Mingqin Li and + Gang Ren and + Scarlett Li and + Jeffery Zhu and + Jason Li and + Chuanjie Liu and + Lintao Zhang and + Jingdong Wang}, + title = {SPTAG: A library for fast approximate nearest neighbor search}, + url = {https://github.com/Microsoft/SPTAG}, + year = {2018} +} + +@inproceedings{WangL12, + author = {Jingdong Wang and + Shipeng Li}, + title = {Query-driven iterated neighborhood graph search for large scale indexing}, + booktitle = {ACM Multimedia 2012}, + pages = {179--188}, + year = {2012} +} + +@inproceedings{WangWZTGL12, + author = {Jing Wang and + Jingdong Wang and + Gang Zeng and + Zhuowen Tu and + Rui Gan and + Shipeng Li}, + title = {Scalable k-NN graph construction for visual descriptors}, + booktitle = {CVPR 2012}, + pages = {1106--1113}, + year = {2012} +} + +@article{WangWJLZZH14, + author = {Jingdong Wang and + Naiyan Wang and + You Jia and + Jian Li and + Gang Zeng and + Hongbin Zha and + Xian{-}Sheng Hua}, + title = {Trinary-Projection Trees for Approximate Nearest Neighbor Search}, + journal = {{IEEE} Trans. Pattern Anal. Mach. Intell.}, + volume = {36}, + number = {2}, + pages = {388--403}, + year = {2014 +} +``` + +## **Contribute** + +This project welcomes contributions and suggestions from all the users. + +We use [GitHub issues](https://github.com/Microsoft/SPTAG/issues) for tracking suggestions and bugs. + +## **License** +The entire codebase is under [MIT license](https://github.com/Microsoft/SPTAG/blob/master/LICENSE) diff --git a/core/src/index/thirdparty/SPTAG/SPTAG.sdf b/core/src/index/thirdparty/SPTAG/SPTAG.sdf new file mode 100644 index 0000000000..254a44ece6 Binary files /dev/null and b/core/src/index/thirdparty/SPTAG/SPTAG.sdf differ diff --git a/core/src/index/thirdparty/SPTAG/SPTAG.sln b/core/src/index/thirdparty/SPTAG/SPTAG.sln new file mode 100644 index 0000000000..5fdfd0297c --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/SPTAG.sln @@ -0,0 +1,211 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 14 +VisualStudioVersion = 14.0.25420.1 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CoreLibrary", "AnnService\CoreLibrary.vcxproj", "{C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Server", "AnnService\Server.vcxproj", "{E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonCore", "Wrappers\PythonCore.vcxproj", "{AF31947C-0495-42FE-A1AD-8F0DA2A679C7}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SocketLib", "AnnService\SocketLib.vcxproj", "{F9A72303-6381-4C80-86FF-606A2F6F7B96}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Client", "AnnService\Client.vcxproj", "{A89D70C3-C53B-42DE-A5CE-9A472540F5CB}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Aggregator", "AnnService\Aggregator.vcxproj", "{D7F09A63-BDCA-4F6C-A864-8551D1FE447A}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonClient", "Wrappers\PythonClient.vcxproj", "{9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexBuilder", "AnnService\IndexBuilder.vcxproj", "{F492F794-E78B-4B1F-A556-5E045B9163D5}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexSearcher", "AnnService\IndexSearcher.vcxproj", "{97615D3B-9FA0-469E-B229-95A91A5087E0}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Test", "Test\Test.vcxproj", "{29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaCore", "Wrappers\JavaCore.vcxproj", "{93FEB26B-965E-4157-8BE5-052F5CA112BB}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaClient", "Wrappers\JavaClient.vcxproj", "{8866BF98-AA2E-450F-9F33-083E007CCA74}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpCore", "Wrappers\CsharpCore.vcxproj", "{1896C009-AD46-4A70-B83C-4652A7F37503}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpClient", "Wrappers\CsharpClient.vcxproj", "{363BA3BB-75C4-4CC7-AECB-28C7534B3710}" + ProjectSection(ProjectDependencies) = postProject + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CLRCore", "Wrappers\CLRCore.vcxproj", "{38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|x64 = Release|x64 + Release|x86 = Release|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.Build.0 = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.Build.0 = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.ActiveCfg = Release|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.Build.0 = Release|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.ActiveCfg = Release|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.Build.0 = Release|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.ActiveCfg = Release|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.Build.0 = Release|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.ActiveCfg = Release|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.Build.0 = Release|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.ActiveCfg = Release|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.Build.0 = Release|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.ActiveCfg = Release|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.Build.0 = Release|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.Build.0 = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.ActiveCfg = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.Build.0 = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.ActiveCfg = Debug|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.Build.0 = Debug|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.ActiveCfg = Release|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.Build.0 = Release|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.ActiveCfg = Release|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.Build.0 = Release|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.ActiveCfg = Debug|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.Build.0 = Debug|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.ActiveCfg = Debug|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.Build.0 = Debug|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.ActiveCfg = Release|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.Build.0 = Release|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.ActiveCfg = Release|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.Build.0 = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.ActiveCfg = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.Build.0 = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.ActiveCfg = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.Build.0 = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.ActiveCfg = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.Build.0 = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.ActiveCfg = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.Build.0 = Release|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.ActiveCfg = Debug|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.Build.0 = Debug|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.ActiveCfg = Debug|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.Build.0 = Debug|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.ActiveCfg = Release|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.Build.0 = Release|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.ActiveCfg = Release|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.Build.0 = Release|Win32 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x64.ActiveCfg = Debug|x64 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x86.ActiveCfg = Debug|Win32 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x64.ActiveCfg = Release|x64 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x86.ActiveCfg = Release|Win32 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x64.ActiveCfg = Debug|x64 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x86.ActiveCfg = Debug|Win32 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x64.ActiveCfg = Release|x64 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x86.ActiveCfg = Release|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.ActiveCfg = Debug|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.Build.0 = Debug|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.ActiveCfg = Debug|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.Build.0 = Debug|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.ActiveCfg = Release|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.Build.0 = Release|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.ActiveCfg = Release|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.Build.0 = Release|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.ActiveCfg = Debug|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.Build.0 = Debug|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.ActiveCfg = Debug|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.Build.0 = Debug|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.ActiveCfg = Release|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.Build.0 = Release|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.ActiveCfg = Release|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.Build.0 = Release|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.ActiveCfg = Debug|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.Build.0 = Debug|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.ActiveCfg = Debug|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.Build.0 = Debug|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.ActiveCfg = Release|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.Build.0 = Release|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.ActiveCfg = Release|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.Build.0 = Release|Win32 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {38BDFF12-6FEC-4B67-A7BD-436D9E2544FD} + EndGlobalSection +EndGlobal diff --git a/core/src/index/thirdparty/SPTAG/Test/CMakeLists.txt b/core/src/index/thirdparty/SPTAG/Test/CMakeLists.txt new file mode 100644 index 0000000000..39166b32a9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +if(NOT WIN32) + ADD_DEFINITIONS(-DBOOST_TEST_DYN_LINK) + message (STATUS "BOOST_TEST_DYN_LINK") +endif() + +find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex filesystem unit_test_framework) +if (Boost_FOUND) + include_directories (${Boost_INCLUDE_DIR}) + link_directories (${Boost_LIBRARY_DIR}) + message (STATUS "Found Boost.") + message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}") + message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}") + message (STATUS "Library: ${Boost_LIBRARIES}") +else() + message (FATAL_ERROR "Could not find Boost 1.67!") +endif() + +include_directories(${PYTHON_INCLUDE_PATH} ${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/PythonWrapper ${PROJECT_SOURCE_DIR}/Test) + +file(GLOB TEST_HDR_FILES ${PROJECT_SOURCE_DIR}/Test/inc/Test.h) +file(GLOB TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/Test/src/*.cpp) +add_executable (test ${TEST_SRC_FILES} ${TEST_HDR_FILES}) +target_link_libraries(test SPTAGLib ${Boost_LIBRARIES}) + +install(TARGETS test + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) + diff --git a/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj new file mode 100644 index 0000000000..c479ae5be1 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj @@ -0,0 +1,183 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C} + Test + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutAppDir) + $(OutLibDir);$(LibraryPath) + + + false + + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + Console + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + Console + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + + + + + + + + + + Designer + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.filters new file mode 100644 index 0000000000..a814c3ec3f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.filters @@ -0,0 +1,48 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + + + Header Files + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.user b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.user new file mode 100644 index 0000000000..10f0fcf2d9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/Test.vcxproj.user @@ -0,0 +1,7 @@ + + + + $(OutLibDir) + WindowsLocalDebugger + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/inc/Test.h b/core/src/index/thirdparty/SPTAG/Test/inc/Test.h new file mode 100644 index 0000000000..da6c096ba2 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/inc/Test.h @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include diff --git a/core/src/index/thirdparty/SPTAG/Test/packages.config b/core/src/index/thirdparty/SPTAG/Test/packages.config new file mode 100644 index 0000000000..651c754779 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/AlgoTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/AlgoTest.cpp new file mode 100644 index 0000000000..a93cd38bed --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/AlgoTest.cpp @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Test.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Core/VectorIndex.h" + +#include + +template +void Build(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +{ + + std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +template +void BuildWithMetaMapping(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +{ + + std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +template +void Search(const std::string folder, T* vec, SPTAG::SizeType n, int k, std::string* truthmeta) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + for (SPTAG::SizeType i = 0; i < n; i++) + { + SPTAG::QueryResult res(vec, k, true); + vecIndex->SearchIndex(res); + std::unordered_set resmeta; + for (int j = 0; j < k; j++) + { + resmeta.insert(std::string((char*)res.GetMetadata(j).Data(), res.GetMetadata(j).Length())); + std::cout << res.GetResult(j)->Dist << "@(" << res.GetResult(j)->VID << "," << std::string((char*)res.GetMetadata(j).Data(), res.GetMetadata(j).Length()) << ") "; + } + std::cout << std::endl; + for (int j = 0; j < k; j++) + { + BOOST_CHECK(resmeta.find(truthmeta[i * k + j]) != resmeta.end()); + } + vec += vecIndex->GetFeatureDim(); + } + vecIndex.reset(); +} + +template +void Add(const std::string folder, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->AddIndex(vec, meta)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); + vecIndex.reset(); +} + +template +void Delete(const std::string folder, T* vec, SPTAG::SizeType n, const std::string out) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->DeleteIndex((const void*)vec, n)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); + vecIndex.reset(); +} + +template +void Test(SPTAG::IndexAlgoType algo, std::string distCalcMethod) +{ + SPTAG::SizeType n = 100, q = 3; + SPTAG::DimensionType m = 10; + int k = 3; + std::vector vec; + for (SPTAG::SizeType i = 0; i < n; i++) { + for (SPTAG::DimensionType j = 0; j < m; j++) { + vec.push_back((T)i); + } + } + + std::vector query; + for (SPTAG::SizeType i = 0; i < q; i++) { + for (SPTAG::DimensionType j = 0; j < m; j++) { + query.push_back((T)i*2); + } + } + + std::vector meta; + std::vector metaoffset; + for (SPTAG::SizeType i = 0; i < n; i++) { + metaoffset.push_back((std::uint64_t)meta.size()); + std::string a = std::to_string(i); + for (size_t j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back((std::uint64_t)meta.size()); + + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t*)vec.data(), sizeof(T) * n * m, false), + SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), + n)); + + Build(algo, distCalcMethod, vecset, metaset, "testindices"); + std::string truthmeta1[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + Search("testindices", query.data(), q, k, truthmeta1); + + Add("testindices", vecset, metaset, "testindices"); + std::string truthmeta2[] = { "0", "0", "1", "2", "2", "1", "4", "4", "3" }; + Search("testindices", query.data(), q, k, truthmeta2); + + Delete("testindices", query.data(), q, "testindices"); + std::string truthmeta3[] = { "1", "1", "3", "1", "3", "1", "3", "5", "3" }; + Search("testindices", query.data(), q, k, truthmeta3); + + BuildWithMetaMapping(algo, distCalcMethod, vecset, metaset, "testindices"); + std::string truthmeta4[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + Search("testindices", query.data(), q, k, truthmeta4); + + Add("testindices", vecset, metaset, "testindices"); + std::string truthmeta5[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + Search("testindices", query.data(), q, k, truthmeta5); +} + +BOOST_AUTO_TEST_SUITE (AlgoTest) + +BOOST_AUTO_TEST_CASE(KDTTest) +{ + Test(SPTAG::IndexAlgoType::KDT, "L2"); +} + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + Test(SPTAG::IndexAlgoType::BKT, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/core/src/index/thirdparty/SPTAG/Test/src/Base64HelperTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/Base64HelperTest.cpp new file mode 100644 index 0000000000..2ead4753e5 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/Base64HelperTest.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Test.h" +#include "inc/Helper/Base64Encode.h" + +#include + +BOOST_AUTO_TEST_SUITE(Base64Test) + +BOOST_AUTO_TEST_CASE(Base64EncDec) +{ + using namespace SPTAG::Helper::Base64; + + const size_t bufferSize = 1 << 10; + std::unique_ptr rawBuffer(new uint8_t[bufferSize]); + std::unique_ptr encBuffer(new char[bufferSize]); + std::unique_ptr rawBuffer2(new uint8_t[bufferSize]); + + for (size_t inputSize = 1; inputSize < 128; ++inputSize) + { + for (size_t i = 0; i < inputSize; ++i) + { + rawBuffer[i] = static_cast(i); + } + + size_t encBufLen = CapacityForEncode(inputSize); + BOOST_CHECK(encBufLen < bufferSize); + + size_t encOutLen = 0; + BOOST_CHECK(Encode(rawBuffer.get(), inputSize, encBuffer.get(), encOutLen)); + BOOST_CHECK(encBufLen >= encOutLen); + + size_t decBufLen = CapacityForDecode(encOutLen); + BOOST_CHECK(decBufLen < bufferSize); + + size_t decOutLen = 0; + BOOST_CHECK(Decode(encBuffer.get(), encOutLen, rawBuffer.get(), decOutLen)); + BOOST_CHECK(decBufLen >= decOutLen); + } +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/CommonHelperTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/CommonHelperTest.cpp new file mode 100644 index 0000000000..17015642cc --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/CommonHelperTest.cpp @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Test.h" +#include "inc/Helper/CommonHelper.h" + +#include + +BOOST_AUTO_TEST_SUITE(CommonHelperTest) + + +BOOST_AUTO_TEST_CASE(ToLowerInPlaceTest) +{ + auto runTestCase = [](std::string p_input, const std::string& p_expected) + { + SPTAG::Helper::StrUtils::ToLowerInPlace(p_input); + BOOST_CHECK(p_input == p_expected); + }; + + runTestCase("abc", "abc"); + runTestCase("ABC", "abc"); + runTestCase("abC", "abc"); + runTestCase("Upper-Case", "upper-case"); + runTestCase("123!-=aBc", "123!-=abc"); +} + + +BOOST_AUTO_TEST_CASE(SplitStringTest) +{ + std::string input("seg1 seg2 seg3 seg4"); + + const auto& segs = SPTAG::Helper::StrUtils::SplitString(input, " "); + BOOST_CHECK(segs.size() == 4); + BOOST_CHECK(segs[0] == "seg1"); + BOOST_CHECK(segs[1] == "seg2"); + BOOST_CHECK(segs[2] == "seg3"); + BOOST_CHECK(segs[3] == "seg4"); +} + + +BOOST_AUTO_TEST_CASE(FindTrimmedSegmentTest) +{ + using namespace SPTAG::Helper::StrUtils; + std::string input("\t Space End \r\n\t"); + + const auto& pos = FindTrimmedSegment(input.c_str(), + input.c_str() + input.size(), + [](char p_val)->bool + { + return std::isspace(p_val) > 0; + }); + + BOOST_CHECK(pos.first == input.c_str() + 2); + BOOST_CHECK(pos.second == input.c_str() + 13); +} + + +BOOST_AUTO_TEST_CASE(StartsWithTest) +{ + using namespace SPTAG::Helper::StrUtils; + + BOOST_CHECK(StartsWith("Abcd", "A")); + BOOST_CHECK(StartsWith("Abcd", "Ab")); + BOOST_CHECK(StartsWith("Abcd", "Abc")); + BOOST_CHECK(StartsWith("Abcd", "Abcd")); + + BOOST_CHECK(!StartsWith("Abcd", "a")); + BOOST_CHECK(!StartsWith("Abcd", "F")); + BOOST_CHECK(!StartsWith("Abcd", "AF")); + BOOST_CHECK(!StartsWith("Abcd", "AbF")); + BOOST_CHECK(!StartsWith("Abcd", "AbcF")); + BOOST_CHECK(!StartsWith("Abcd", "Abcde")); +} + + +BOOST_AUTO_TEST_CASE(StrEqualIgnoreCaseTest) +{ + using namespace SPTAG::Helper::StrUtils; + + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "Abcd")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "abcd")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "abCD")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd-123", "abcd-123")); + BOOST_CHECK(StrEqualIgnoreCase(" ZZZ", " zzz")); + + BOOST_CHECK(!StrEqualIgnoreCase("abcd", "abcd1")); + BOOST_CHECK(!StrEqualIgnoreCase("Abcd", " abcd")); + BOOST_CHECK(!StrEqualIgnoreCase("000", "OOO")); +} + + + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/DistanceTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/DistanceTest.cpp new file mode 100644 index 0000000000..97602a2a8d --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/DistanceTest.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "inc/Test.h" +#include "inc/Core/Common/DistanceUtils.h" + +template +static float ComputeCosineDistance(const T *pX, const T *pY, SPTAG::DimensionType length) { + float diff = 0; + const T* pEnd1 = pX + length; + while (pX < pEnd1) diff += (*pX++) * (*pY++); + return diff; +} + +template +static float ComputeL2Distance(const T *pX, const T *pY, SPTAG::DimensionType length) +{ + float diff = 0; + const T* pEnd1 = pX + length; + while (pX < pEnd1) { + float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + } + return diff; +} + +template +T random(int high = RAND_MAX, int low = 0) // Generates a random value. +{ + return (T)(low + float(high - low)*(std::rand()/static_cast(RAND_MAX + 1.0))); +} + +template +void test(int high) { + SPTAG::DimensionType dimension = random(256, 2); + T *X = new T[dimension], *Y = new T[dimension]; + BOOST_ASSERT(X != nullptr && Y != nullptr); + for (SPTAG::DimensionType i = 0; i < dimension; i++) { + X[i] = random(high, -high); + Y[i] = random(high, -high); + } + BOOST_CHECK_CLOSE_FRACTION(ComputeL2Distance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeL2Distance(X, Y, dimension), 1e-5); + BOOST_CHECK_CLOSE_FRACTION(high*high - ComputeCosineDistance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeCosineDistance(X, Y, dimension), 1e-5); + + delete[] X; + delete[] Y; +} + +BOOST_AUTO_TEST_SUITE(DistanceTest) + +BOOST_AUTO_TEST_CASE(TestDistanceComputation) +{ + test(1); + test(127); + test(32767); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/IniReaderTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/IniReaderTest.cpp new file mode 100644 index 0000000000..c5dd0baaf5 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/IniReaderTest.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Test.h" +#include "inc/Helper/SimpleIniReader.h" + +#include + +BOOST_AUTO_TEST_SUITE(IniReaderTest) + +BOOST_AUTO_TEST_CASE(IniReaderLoadTest) +{ + std::ofstream tmpIni("temp.ini"); + tmpIni << "[Common]" << std::endl; + tmpIni << "; Comment " << std::endl; + tmpIni << "Param1=1" << std::endl; + tmpIni << "Param2=Exp=2" << std::endl; + + tmpIni.close(); + + SPTAG::Helper::IniReader reader; + BOOST_CHECK(SPTAG::ErrorCode::Success == reader.LoadIniFile("temp.ini")); + + BOOST_CHECK(reader.DoesSectionExist("Common")); + BOOST_CHECK(reader.DoesParameterExist("Common", "Param1")); + BOOST_CHECK(reader.DoesParameterExist("Common", "Param2")); + + BOOST_CHECK(!reader.DoesSectionExist("NotExist")); + BOOST_CHECK(!reader.DoesParameterExist("NotExist", "Param1")); + BOOST_CHECK(!reader.DoesParameterExist("Common", "ParamNotExist")); + + BOOST_CHECK(1 == reader.GetParameter("Common", "Param1", 0)); + BOOST_CHECK(0 == reader.GetParameter("Common", "ParamNotExist", 0)); + + BOOST_CHECK(std::string("Exp=2") == reader.GetParameter("Common", "Param2", std::string())); + BOOST_CHECK(std::string("1") == reader.GetParameter("Common", "Param1", std::string())); + BOOST_CHECK(std::string() == reader.GetParameter("Common", "ParamNotExist", std::string())); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/StringConvertTest.cpp b/core/src/index/thirdparty/SPTAG/Test/src/StringConvertTest.cpp new file mode 100644 index 0000000000..fa457debe2 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/StringConvertTest.cpp @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Test.h" +#include "inc/Helper/StringConvert.h" + +namespace +{ + namespace Local + { + + template + void TestConvertSuccCase(ValueType p_val, const char* p_valStr) + { + using namespace SPTAG::Helper::Convert; + + std::string str = ConvertToString(p_val); + if (nullptr != p_valStr) + { + BOOST_CHECK(str == p_valStr); + } + + ValueType val; + BOOST_CHECK(ConvertStringTo(str.c_str(), val)); + BOOST_CHECK(val == p_val); + } + + } +} + +BOOST_AUTO_TEST_SUITE(StringConvertTest) + +BOOST_AUTO_TEST_CASE(ConvertInt8) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt16) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt32) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt64) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt8) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt16) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt32) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt64) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertFloat) +{ + Local::TestConvertSuccCase(static_cast(-1), nullptr); + Local::TestConvertSuccCase(static_cast(0), nullptr); + Local::TestConvertSuccCase(static_cast(3), nullptr); + Local::TestConvertSuccCase(static_cast(100), nullptr); +} + +BOOST_AUTO_TEST_CASE(ConvertDouble) +{ + Local::TestConvertSuccCase(static_cast(-1), nullptr); + Local::TestConvertSuccCase(static_cast(0), nullptr); + Local::TestConvertSuccCase(static_cast(3), nullptr); + Local::TestConvertSuccCase(static_cast(100), nullptr); +} + +BOOST_AUTO_TEST_CASE(ConvertIndexAlgoType) +{ + Local::TestConvertSuccCase(SPTAG::IndexAlgoType::BKT, "BKT"); + Local::TestConvertSuccCase(SPTAG::IndexAlgoType::KDT, "KDT"); +} + +BOOST_AUTO_TEST_CASE(ConvertVectorValueType) +{ + Local::TestConvertSuccCase(SPTAG::VectorValueType::Float, "Float"); + Local::TestConvertSuccCase(SPTAG::VectorValueType::Int8, "Int8"); + Local::TestConvertSuccCase(SPTAG::VectorValueType::Int16, "Int16"); +} + +BOOST_AUTO_TEST_CASE(ConvertDistCalcMethod) +{ + Local::TestConvertSuccCase(SPTAG::DistCalcMethod::Cosine, "Cosine"); + Local::TestConvertSuccCase(SPTAG::DistCalcMethod::L2, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Test/src/main.cpp b/core/src/index/thirdparty/SPTAG/Test/src/main.cpp new file mode 100644 index 0000000000..7bf61ea119 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Test/src/main.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define BOOST_TEST_MAIN +#define BOOST_TEST_MODULE Main +#include "inc/Test.h" + +#include +#include + +using namespace boost::unit_test; + +class SPTAGVisitor : public test_tree_visitor +{ +public: + void visit(test_case const& test) + { + std::string prefix(2, '\t'); + std::cout << prefix << "Case: " << test.p_name << std::endl; + } + + bool test_suite_start(test_suite const& suite) + { + std::string prefix(1, '\t'); + std::cout << prefix << "Suite: " << suite.p_name << std::endl; + return true; + } +}; + +struct GlobalFixture +{ + GlobalFixture() + { + SPTAGVisitor visitor; + traverse_test_tree(framework::master_test_suite(), visitor, false); + } + +}; + +BOOST_TEST_GLOBAL_FIXTURE(GlobalFixture); + diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj new file mode 100644 index 0000000000..efb4d0f259 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj @@ -0,0 +1,141 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F} + v4.5.2 + ManagedCProj + CLRCore + 8.1 + + + + + DynamicLibrary + true + v140 + true + Unicode + + + DynamicLibrary + false + v140 + true + Unicode + + + DynamicLibrary + true + v140 + true + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + Microsoft.ANN.SPTAGManaged + .dll + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + + true + + + true + + + false + + + false + + + + Level3 + Disabled + _DEBUG;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + NotUsing + true + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + Level3 + NDEBUG;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + NotUsing + true + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + + + + + + + + + + + + + + {8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942} + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj.filters new file mode 100644 index 0000000000..c0c35e9683 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CLRCore.vcxproj.filters @@ -0,0 +1,32 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {ba4289c4-f872-4dbc-a57f-7b415614afb3} + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CMakeLists.txt b/core/src/index/thirdparty/SPTAG/Wrappers/CMakeLists.txt new file mode 100644 index 0000000000..514367978e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CMakeLists.txt @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +find_package(Python2 COMPONENTS Development) +if (Python2_FOUND) + include_directories (${Python2_INCLUDE_DIRS}) + link_directories (${Python2_LIBRARY_DIRS}) + set (Python_INCLUDE_DIRS ${Python2_INCLUDE_DIRS}) + set (Python_LIBRARIES ${Python2_LIBRARIES}) + set (Python_FOUND true) +else() + find_package(Python3 COMPONENTS Development) + if (Python3_FOUND) + include_directories (${Python3_INCLUDE_DIRS}) + link_directories (${Python3_LIBRARY_DIRS}) + set (Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}) + set (Python_LIBRARIES ${Python3_LIBRARIES}) + set (Python_FOUND true) + endif() +endif() + +if (Python_FOUND) + message (STATUS "Found Python.") + message (STATUS "Include Path: ${Python_INCLUDE_DIRS}") + message (STATUS "Library Path: ${Python_LIBRARIES}") + + if (WIN32) + set(PY_SUFFIX .pyd) + else() + set(PY_SUFFIX .so) + endif() + + execute_process(COMMAND swig -python -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_pwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/PythonCore.i) + execute_process(COMMAND swig -python -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_pwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/PythonClient.i) + + include_directories(${PYTHON_INCLUDE_PATH} ${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Wrappers) + + file(GLOB CORE_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface.h) + file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_pwrap.cpp) + add_library (_SPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) + set_target_properties(_SPTAG PROPERTIES PREFIX "" SUFFIX ${PY_SUFFIX}) + target_link_libraries(_SPTAG SPTAGLib ${Python_LIBRARIES}) + add_custom_command(TARGET _SPTAG POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/Wrappers/inc/SPTAG.py ${EXECUTABLE_OUTPUT_PATH}) + + file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) + file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_pwrap.cpp) + add_library (_SPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) + set_target_properties(_SPTAGClient PROPERTIES PREFIX "" SUFFIX ${PY_SUFFIX}) + target_link_libraries(_SPTAGClient SPTAGLib ${Python_LIBRARIES} ${Boost_LIBRARIES}) + add_custom_command(TARGET _SPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/Wrappers/inc/SPTAGClient.py ${EXECUTABLE_OUTPUT_PATH}) + + install(TARGETS _SPTAG _SPTAGClient + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) + install(FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/SPTAG.py ${PROJECT_SOURCE_DIR}/Wrappers/inc/SPTAGClient.py DESTINATION bin) +else() + message (STATUS "Could not find Python.") +endif() + +find_package(JNI) +if (JNI_FOUND) + include_directories (${JNI_INCLUDE_DIRS}) + link_directories (${JNI_LIBRARY_DIRS}) + message (STATUS "Found JNI.") + message (STATUS "Include Path: ${JNI_INCLUDE_DIRS}") + message (STATUS "Library Path: ${JNI_LIBRARIES}") + + if (WIN32) + set (JAVA_SUFFIX .dll) + else() + set (JAVA_SUFFIX .so) + endif() + + execute_process(COMMAND swig -java -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_jwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/JavaCore.i) + execute_process(COMMAND swig -java -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_jwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/JavaClient.i) + + include_directories(${JNI_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Wrappers) + + file(GLOB CORE_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface.h) + file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_jwrap.cpp) + add_library (JAVASPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) + set_target_properties(JAVASPTAG PROPERTIES SUFFIX ${JAVA_SUFFIX}) + target_link_libraries(JAVASPTAG SPTAGLib ${JNI_LIBRARIES}) + + file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) + file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_jwrap.cpp) + add_library (JAVASPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) + set_target_properties(JAVASPTAGClient PROPERTIES SUFFIX ${JAVA_SUFFIX}) + target_link_libraries(JAVASPTAGClient SPTAGLib ${JNI_LIBRARIES} ${Boost_LIBRARIES}) + + file(GLOB JAVA_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.java) + foreach(JAVA_FILE ${JAVA_FILES}) + message (STATUS "Add copy post-command for file " ${JAVA_FILE}) + add_custom_command(TARGET JAVASPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${JAVA_FILE} ${EXECUTABLE_OUTPUT_PATH}) + endforeach(JAVA_FILE) + + install(TARGETS JAVASPTAG JAVASPTAGClient + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) + install(FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.java DESTINATION bin) +else() + message (STATUS "Could not find JNI.") +endif() + +if (WIN32) + if (${PROJECTNAME_ARCHITECTURE} MATCHES "x64") + set (csharp_dotnet_framework_hints "$ENV{windir}\\Microsoft.NET\\Framework64") + else() + set (csharp_dotnet_framework_hints "$ENV{windir}\\Microsoft.NET\\Framework") + endif() + + file(GLOB_RECURSE csharp_dotnet_executables ${csharp_dotnet_framework_hints}/csc.exe) + list(SORT csharp_dotnet_executables) + list(REVERSE csharp_dotnet_executables) + foreach (csharp_dotnet_executable ${csharp_dotnet_executables}) + if (NOT DEFINED DOTNET_FOUND) + string(REPLACE "${csharp_dotnet_framework_hints}/" "" csharp_dotnet_version_temp ${csharp_dotnet_executable}) + string(REPLACE "/csc.exe" "" csharp_dotnet_version_temp ${csharp_dotnet_version_temp}) + + set (DOTNET_EXECUTABLE_VERSION "${csharp_dotnet_version_temp}" CACHE STRING "C# .NET compiler version" FORCE) + set (DOTNET_FOUND ${csharp_dotnet_executable}) + endif() + endforeach(csharp_dotnet_executable) +else() + FIND_PROGRAM(DOTNET_FOUND dotnet) +endif() + +if (DOTNET_FOUND) + message (STATUS "Found dotnet.") + message (STATUS "DOTNET_EXECUTABLE: " ${DOTNET_FOUND}) + + if (WIN32) + set (CSHARP_SUFFIX .dll) + else() + set (CSHARP_SUFFIX .so) + endif() + + execute_process(COMMAND swig -csharp -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_cwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CsharpCore.i) + execute_process(COMMAND swig -csharp -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_cwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CsharpClient.i) + + include_directories(${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Wrappers) + + file(GLOB CORE_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface.h) + file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_cwrap.cpp) + add_library (CSHARPSPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) + set_target_properties(CSHARPSPTAG PROPERTIES SUFFIX ${CSHARP_SUFFIX}) + target_link_libraries(CSHARPSPTAG SPTAGLib) + + file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) + file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_cwrap.cpp) + add_library (CSHARPSPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) + set_target_properties(CSHARPSPTAGClient PROPERTIES SUFFIX ${CSHARP_SUFFIX}) + target_link_libraries(CSHARPSPTAGClient SPTAGLib ${Boost_LIBRARIES}) + + file(GLOB CSHARP_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.cs) + foreach(CSHARP_FILE ${CSHARP_FILES}) + message (STATUS "Add copy post-command for file " ${CSHARP_FILE}) + add_custom_command(TARGET CSHARPSPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CSHARP_FILE} ${EXECUTABLE_OUTPUT_PATH}) + endforeach(CSHARP_FILE) + + install(TARGETS CSHARPSPTAG CSHARPSPTAGClient + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) + install(FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.cs DESTINATION bin) +else() + message (STATUS "Could not find C#.") +endif() + diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj new file mode 100644 index 0000000000..d7d17102d8 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj @@ -0,0 +1,191 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710} + CsharpClient + 8.1 + + + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + CSHARPSPTAGClient + .dll + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + %(AdditionalIncludeDirectories) + + + + + Level3 + MaxSpeed + true + true + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + + + + + + false + false + false + false + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj.filters new file mode 100644 index 0000000000..589c50014d --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpClient.vcxproj.filters @@ -0,0 +1,41 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resource Files + + + Resource Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj new file mode 100644 index 0000000000..e809d8b901 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj @@ -0,0 +1,134 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {1896C009-AD46-4A70-B83C-4652A7F37503} + CsharpCore + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + CSHARPSPTAG + .dll + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + %(AdditionalIncludeDirectories) + Guard + ProgramDatabase + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + /guard:cf %(AdditionalOptions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj.filters new file mode 100644 index 0000000000..51b1ec0ce6 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/CsharpCore.vcxproj.filters @@ -0,0 +1,40 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {ba4289c4-f872-4dbc-a57f-7b415614afb3} + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resources + + + Resources + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj new file mode 100644 index 0000000000..2ee36ac620 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj @@ -0,0 +1,191 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {8866BF98-AA2E-450F-9F33-083E007CCA74} + JavaClient + 8.1 + + + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + JAVASPTAGClient + .dll + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + $(JavaLib);CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + + + + + $(JavaIncDir);%(AdditionalIncludeDirectories) + + + + + Level3 + MaxSpeed + true + true + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;SWIG_JAVA_INTERPRETER_NO_DEBUG;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + + + + + + false + false + false + false + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj.filters new file mode 100644 index 0000000000..0d047923aa --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/JavaClient.vcxproj.filters @@ -0,0 +1,41 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resource Files + + + Resource Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj new file mode 100644 index 0000000000..f15c0e005f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj @@ -0,0 +1,134 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {93FEB26B-965E-4157-8BE5-052F5CA112BB} + JavaCore + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + JAVASPTAG + .dll + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + $(JavaLib);CoreLibrary.lib;%(AdditionalDependencies) + + + + + _WINDLL;SWIG_JAVA_INTERPRETER_NO_DEBUG;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + $(JavaIncDir);%(AdditionalIncludeDirectories) + Guard + ProgramDatabase + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + /guard:cf %(AdditionalOptions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj.filters new file mode 100644 index 0000000000..851552684d --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/JavaCore.vcxproj.filters @@ -0,0 +1,40 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {ba4289c4-f872-4dbc-a57f-7b415614afb3} + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resources + + + Resources + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj new file mode 100644 index 0000000000..5cf2c2a9cf --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj @@ -0,0 +1,189 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA} + PythonClient + 8.1 + + + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + _SPTAGClient + .pyd + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;SocketLib.lib;$(SolutionDir)packages\python2.2.7.15\tools\libs\python27.lib;%(AdditionalDependencies) + + + + + $(SolutionDir)packages\python2.2.7.15\tools\include\;%(AdditionalIncludeDirectories) + + + + + Level3 + MaxSpeed + true + true + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _WINDLL;_SCL_SECURE_NO_WARNINGS;SWIG_PYTHON_INTERPRETER_NO_DEBUG;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + + + + + + false + false + false + false + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj.filters new file mode 100644 index 0000000000..84c71f0977 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/PythonClient.vcxproj.filters @@ -0,0 +1,41 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resource Files + + + Resource Files + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj new file mode 100644 index 0000000000..7555ba97f4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj @@ -0,0 +1,132 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7} + PythonCore + 8.1 + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + DynamicLibrary + true + v140 + MultiByte + + + DynamicLibrary + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + _SPTAG + .pyd + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;$(SolutionDir)packages\python2.2.7.15\tools\libs\python27.lib;%(AdditionalDependencies) + + + + + _WINDLL;SWIG_PYTHON_INTERPRETER_NO_DEBUG;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + $(SolutionDir)packages\python2.2.7.15\tools\include\;%(AdditionalIncludeDirectories) + Guard + ProgramDatabase + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + /guard:cf %(AdditionalOptions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.filters b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.filters new file mode 100644 index 0000000000..8d0ee1d7b9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.filters @@ -0,0 +1,40 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {ba4289c4-f872-4dbc-a57f-7b415614afb3} + + + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + Resources + + + Resources + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.user b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.user new file mode 100644 index 0000000000..abe8dd8961 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/PythonCore.vcxproj.user @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/CLRCoreInterface.h b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CLRCoreInterface.h new file mode 100644 index 0000000000..1a273ba8d7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CLRCoreInterface.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ManagedObject.h" +#include "inc/Core/VectorIndex.h" + +using namespace System; + +namespace Microsoft +{ + namespace ANN + { + namespace SPTAGManaged + { + + public ref class BasicResult : + public ManagedObject + { + public: + BasicResult(SPTAG::BasicResult* p_instance) : ManagedObject(p_instance) + { + } + + property int VID + { + public: + int get() + { + return m_Instance->VID; + } + private: + void set(int p_vid) + { + } + } + + property float Dist + { + public: + float get() + { + return m_Instance->Dist; + } + private: + void set(float p_dist) + { + } + } + + property array^ Meta + { + public: + array^ get() + { + array^ buf = gcnew array(m_Instance->Meta.Length()); + Marshal::Copy((IntPtr)m_Instance->Meta.Data(), buf, 0, (int)m_Instance->Meta.Length()); + return buf; + } + private: + void set(array^ p_meta) + { + } + } + }; + + public ref class AnnIndex : + public ManagedObject> + { + public: + AnnIndex(std::shared_ptr p_index); + + AnnIndex(String^ p_algoType, String^ p_valueType, int p_dimension); + + void SetBuildParam(String^ p_name, String^ p_value); + + void SetSearchParam(String^ p_name, String^ p_value); + + bool Build(array^ p_data, int p_num); + + bool BuildWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex); + + array^ Search(array^ p_data, int p_resultNum); + + array^ SearchWithMetaData(array^ p_data, int p_resultNum); + + bool Save(String^ p_saveFile); + + array^>^ Dump(); + + bool Add(array^ p_data, int p_num); + + bool AddWithMetaData(array^ p_data, array^ p_meta, int p_num); + + bool Delete(array^ p_data, int p_num); + + bool DeleteByMetaData(array^ p_meta); + + static AnnIndex^ Load(String^ p_loaderFile); + + static AnnIndex^ Load(array^>^ p_index); + + static bool Merge(String^ p_indexFilePath1, String^ p_indexFilePath2); + + private: + + int m_dimension; + + size_t m_inputVectorSize; + }; + } + } +} diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/ClientInterface.h b/core/src/index/thirdparty/SPTAG/Wrappers/inc/ClientInterface.h new file mode 100644 index 0000000000..94d46bcad3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/ClientInterface.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_PW_CLIENTINTERFACE_H_ +#define _SPTAG_PW_CLIENTINTERFACE_H_ + +#include "TransferDataType.h" +#include "inc/Socket/Client.h" +#include "inc/Socket/ResourceManager.h" + +#include +#include +#include + +class AnnClient +{ +public: + AnnClient(const char* p_serverAddr, const char* p_serverPort); + + ~AnnClient(); + + void SetTimeoutMilliseconds(int p_timeout); + + void SetSearchParam(const char* p_name, const char* p_value); + + void ClearSearchParam(); + + std::shared_ptr Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bool p_withMetaData); + + bool IsConnected() const; + +private: + std::string CreateSearchQuery(const ByteArray& p_data, + int p_resultNum, + bool p_extractMetadata, + SPTAG::VectorValueType p_valueType); + + SPTAG::Socket::PacketHandlerMapPtr GetHandlerMap(); + + void SearchResponseHanlder(SPTAG::Socket::ConnectionID p_localConnectionID, + SPTAG::Socket::Packet p_packet); + +private: + typedef std::function Callback; + + std::uint32_t m_timeoutInMilliseconds; + + std::string m_server; + + std::string m_port; + + std::unique_ptr m_socketClient; + + std::atomic m_connectionID; + + SPTAG::Socket::ResourceManager m_callbackManager; + + std::unordered_map m_params; + + std::mutex m_paramMutex; +}; + +#endif // _SPTAG_PW_CLIENTINTERFACE_H_ diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/CoreInterface.h b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CoreInterface.h new file mode 100644 index 0000000000..bc69874746 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CoreInterface.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_PW_COREINTERFACE_H_ +#define _SPTAG_PW_COREINTERFACE_H_ + +#include "TransferDataType.h" +#include "inc/Core/Common.h" +#include "inc/Core/VectorIndex.h" + +typedef int SizeType; +typedef int DimensionType; + +class AnnIndex +{ +public: + AnnIndex(DimensionType p_dimension); + + AnnIndex(const char* p_algoType, const char* p_valueType, DimensionType p_dimension); + + ~AnnIndex(); + + void SetBuildParam(const char* p_name, const char* p_value); + + void SetSearchParam(const char* p_name, const char* p_value); + + bool Build(ByteArray p_data, SizeType p_num); + + bool BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex); + + std::shared_ptr Search(ByteArray p_data, int p_resultNum); + + std::shared_ptr SearchWithMetaData(ByteArray p_data, int p_resultNum); + + bool ReadyToServe() const; + + bool Save(const char* p_saveFile) const; + + bool Add(ByteArray p_data, SizeType p_num); + + bool AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num); + + bool Delete(ByteArray p_data, SizeType p_num); + + bool DeleteByMetaData(ByteArray p_meta); + + static AnnIndex Load(const char* p_loaderFile); + + static bool Merge(const char* p_indexFilePath1, const char* p_indexFilePath2); + +private: + AnnIndex(const std::shared_ptr& p_index); + + std::shared_ptr m_index; + + size_t m_inputVectorSize; + + DimensionType m_dimension; + + SPTAG::IndexAlgoType m_algoType; + + SPTAG::VectorValueType m_inputValueType; +}; + +#endif // _SPTAG_PW_COREINTERFACE_H_ diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpClient.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpClient.i new file mode 100644 index 0000000000..481627a97f --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpClient.i @@ -0,0 +1,16 @@ +%module CSHARPSPTAGClient + +%{ +#include "inc/ClientInterface.h" +%} + +%include +%shared_ptr(AnnClient) +%shared_ptr(RemoteSearchResult) +%include "CsharpCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "ClientInterface.h" diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCommon.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCommon.i new file mode 100644 index 0000000000..6251d6f245 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCommon.i @@ -0,0 +1,125 @@ +#ifdef SWIGCSHARP + +%{ + struct WrapperArray + { + void * _data; + size_t _size; + }; + + void deleteArrayOfWrapperArray(void* ptr) { + delete[] (WrapperArray*)ptr; + } +%} + +%pragma(csharp) imclasscode=%{ + [System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Sequential)] + public struct WrapperArray + { + public System.IntPtr _data; + public ulong _size; + public WrapperArray(System.IntPtr in_data, ulong in_size) { _data = in_data; _size = in_size; } + } +%} + +%apply void *VOID_INT_PTR { void * } +void deleteArrayOfWrapperArray(void* ptr); + +%typemap(ctype) ByteArray "WrapperArray" +%typemap(imtype) ByteArray "WrapperArray" +%typemap(cstype) ByteArray "byte[]" +%typemap(in) ByteArray { + $1.Set((std::uint8_t*)$input._data, $input._size, false); +} +%typemap(out) ByteArray { + $result._data = $1.Data(); + $result._size = $1.Length(); +} +%typemap(csin, + pre="unsafe { fixed(byte* ptr$csinput = $csinput) { $modulePINVOKE.WrapperArray temp$csinput = new $modulePINVOKE.WrapperArray( (System.IntPtr)ptr$csinput, (ulong)$csinput.LongLength );", + terminator="} }" + ) ByteArray %{ temp$csinput %} + +%typemap(csvarin) ByteArray %{ + set { + unsafe { fixed(byte* ptr$csinput = $csinput) + { + $modulePINVOKE.WrapperArray temp$csinput = new $modulePINVOKE.WrapperArray( (System.IntPtr)ptr$csinput, (ulong)$csinput.LongLength ); + $imcall; + } + } + } +%} + +%typemap(csout, excode=SWIGEXCODE) ByteArray %{ + $modulePINVOKE.WrapperArray data = $imcall;$excode + byte[] ret = new byte[data._size]; + System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + return ret; +%} + +%typemap(csvarout) ByteArray %{ + get { + $modulePINVOKE.WrapperArray data = $imcall; + byte[] ret = new byte[data._size]; + System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + return ret; + } +%} + +%typemap(ctype) std::shared_ptr "WrapperArray" +%typemap(imtype) std::shared_ptr "WrapperArray" +%typemap(cstype) std::shared_ptr "BasicResult[]" +%typemap(out) std::shared_ptr { + $result._data = new WrapperArray[$1->GetResultNum()]; + $result._size = $1->GetResultNum(); + for (int i = 0; i < $1->GetResultNum(); i++) + (((WrapperArray*)$result._data) + i)->_data = new BasicResult(*($1->GetResult(i))); +} +%typemap(csout, excode=SWIGEXCODE) std::shared_ptr { + $modulePINVOKE.WrapperArray data = $imcall; + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + $modulePINVOKE.WrapperArray arr = ($modulePINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof($modulePINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof($modulePINVOKE.WrapperArray); + } + $modulePINVOKE.deleteArrayOfWrapperArray(data._data); + $excode + return ret; +} + +%typemap(ctype) std::shared_ptr "WrapperArray" +%typemap(imtype) std::shared_ptr "WrapperArray" +%typemap(cstype) std::shared_ptr "BasicResult[]" +%typemap(out) std::shared_ptr { + int combinelen = 0; + int nodelen = (int)(($1->m_allIndexResults).size()); + for (int i = 0; i < nodelen; i++) { + combinelen += $1->m_allIndexResults[i].m_results.GetResultNum(); + } + $result._data = new WrapperArray[combinelen]; + $result._size = combinelen; + size_t copyed = 0; + for (int i = 0; i < nodelen; i++) { + auto& queryResult = $1->m_allIndexResults[i].m_results; + for (int j = 0; j < queryResult.GetResultNum(); j++) + (((WrapperArray*)$result._data) + copyed + j)->_data = new BasicResult(*(queryResult.GetResult(j))); + copyed += queryResult.GetResultNum(); + } +} +%typemap(csout, excode=SWIGEXCODE) std::shared_ptr { + $modulePINVOKE.WrapperArray data = $imcall; + BasicResult[] ret = new BasicResult[data._size]; + System.IntPtr ptr = data._data; + for (ulong i = 0; i < data._size; i++) { + $modulePINVOKE.WrapperArray arr = ($modulePINVOKE.WrapperArray)System.Runtime.InteropServices.Marshal.PtrToStructure(ptr, typeof($modulePINVOKE.WrapperArray)); + ret[i] = new BasicResult(arr._data, true); + ptr += sizeof($modulePINVOKE.WrapperArray); + } + $modulePINVOKE.deleteArrayOfWrapperArray(data._data); + $excode + return ret; +} +#endif diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCore.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCore.i new file mode 100644 index 0000000000..6434239b90 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/CsharpCore.i @@ -0,0 +1,17 @@ +%module CSHARPSPTAG + +%{ +#include "inc/CoreInterface.h" +%} + +%include +%shared_ptr(AnnIndex) +%shared_ptr(QueryResult) +%include "CsharpCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "CoreInterface.h" +%include "../../AnnService/inc/Core/SearchResult.h" diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaClient.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaClient.i new file mode 100644 index 0000000000..62a274e51a --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaClient.i @@ -0,0 +1,16 @@ +%module JAVASPTAGClient + +%{ +#include "inc/ClientInterface.h" +%} + +%include +%shared_ptr(AnnClient) +%shared_ptr(RemoteSearchResult) +%include "JavaCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "ClientInterface.h" diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCommon.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCommon.i new file mode 100644 index 0000000000..366052d4f9 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCommon.i @@ -0,0 +1,60 @@ +#ifdef SWIGJAVA + +%typemap(jni) ByteArray "jbyteArray" +%typemap(jtype) ByteArray "byte[]" +%typemap(jstype) ByteArray "byte[]" +%typemap(in) ByteArray { + $1.Set((std::uint8_t*)JCALL2(GetByteArrayElements, jenv, $input, 0), + JCALL1(GetArrayLength, jenv, $input), false); +} +%typemap(out) ByteArray { + $result = JCALL1(NewByteArray, jenv, $1.Length()); + JCALL4(SetByteArrayRegion, jenv, $result, 0, $1.Length(), (jbyte *)$1.Data()); +} +%typemap(javain) ByteArray "$javainput" +%typemap(javaout) ByteArray { return $jnicall; } + +%typemap(jni) std::shared_ptr "jobjectArray" +%typemap(jtype) std::shared_ptr "BasicResult[]" +%typemap(jstype) std::shared_ptr "BasicResult[]" +%typemap(out) std::shared_ptr { + jclass retClass = jenv->FindClass("BasicResult"); + int len = $1->GetResultNum(); + $result = jenv->NewObjectArray(len, retClass, NULL); + for (int i = 0; i < len; i++) { + auto& meta = $1->GetMetadata(i); + jbyteArray bptr = jenv->NewByteArray(meta.Length()); + jenv->SetByteArrayRegion(bptr, 0, meta.Length(), (jbyte *)meta.Data()); + jenv->SetObjectArrayElement(jresult, i, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[B)V"), (jint)($1->GetResult(i)->VID), (jfloat)($1->GetResult(i)->Dist), bptr)); + } +} +%typemap(javaout) std::shared_ptr { return $jnicall; } + +%typemap(jni) std::shared_ptr "jobjectArray" +%typemap(jtype) std::shared_ptr "BasicResult[]" +%typemap(jstype) std::shared_ptr "BasicResult[]" +%typemap(out) std::shared_ptr { + int combinelen = 0; + int nodelen = (int)(($1->m_allIndexResults).size()); + for (int i = 0; i < nodelen; i++) { + combinelen += $1->m_allIndexResults[i].m_results.GetResultNum(); + } + jclass retClass = jenv->FindClass("BasicResult"); + $result = jenv->NewObjectArray(combinelen, retClass, NULL); + int id = 0; + for (int i = 0; i < nodelen; i++) { + for (int j = 0; j < $1->m_allIndexResults[i].m_results.GetResultNum(); j++) { + auto& ptr = $1->m_allIndexResults[i].m_results; + auto& meta = ptr.GetMetadata(j); + jbyteArray bptr = jenv->NewByteArray(meta.Length()); + jenv->SetByteArrayRegion(bptr, 0, meta.Length(), (jbyte *)meta.Data()); + jenv->SetObjectArrayElement(jresult, id, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[B)V"), (jint)(ptr.GetResult(j)->VID), (jfloat)(ptr.GetResult(j)->Dist), bptr)); + id++; + } + } +} +%typemap(javaout) std::shared_ptr { + return $jnicall; +} + +#endif diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCore.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCore.i new file mode 100644 index 0000000000..78d9dd72e3 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/JavaCore.i @@ -0,0 +1,17 @@ +%module JAVASPTAG + +%{ +#include "inc/CoreInterface.h" +%} + +%include +%shared_ptr(AnnIndex) +%shared_ptr(QueryResult) +%include "JavaCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "CoreInterface.h" +%include "../../AnnService/inc/Core/SearchResult.h" diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/ManagedObject.h b/core/src/index/thirdparty/SPTAG/Wrappers/inc/ManagedObject.h new file mode 100644 index 0000000000..266d84b440 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/ManagedObject.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "inc/Helper/StringConvert.h" + +using namespace System; +using namespace System::Runtime::InteropServices; + +namespace Microsoft +{ + namespace ANN + { + namespace SPTAGManaged + { + /// + /// hold a pointer to an umnanaged object from the core project + /// + template + public ref class ManagedObject + { + protected: + T* m_Instance; + + public: + ManagedObject(T* instance) + :m_Instance(instance) + { + } + + ManagedObject(T& instance) + { + m_Instance = new T(instance); + } + + /// + /// destructor, which is called whenever delete an object with delete keyword + /// + virtual ~ManagedObject() + { + if (m_Instance != nullptr) + { + delete m_Instance; + } + } + + /// + /// finalizer which is called by Garbage Collector whenever it destroys the wrapper object. + /// + !ManagedObject() + { + if (m_Instance != nullptr) + { + delete m_Instance; + } + } + + T* GetInstance() + { + return m_Instance; + } + + static const char* string_to_char_array(String^ string) + { + const char* str = (const char*)(Marshal::StringToHGlobalAnsi(string)).ToPointer(); + return str; + } + + template + static T string_to(String^ string) + { + T data; + SPTAG::Helper::Convert::ConvertStringTo(string_to_char_array(string), data); + return data; + } + }; + } + } +} + diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonClient.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonClient.i new file mode 100644 index 0000000000..a70e2fdeb7 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonClient.i @@ -0,0 +1,16 @@ +%module SPTAGClient + +%{ +#include "inc/ClientInterface.h" +%} + +%include +%shared_ptr(AnnClient) +%shared_ptr(RemoteSearchResult) +%include "PythonCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "ClientInterface.h" \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCommon.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCommon.i new file mode 100644 index 0000000000..7b10d50b99 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCommon.i @@ -0,0 +1,117 @@ +#ifdef SWIGPYTHON + +%typemap(out) std::shared_ptr +%{ + { + $result = PyTuple_New(3); + int resNum = $1->GetResultNum(); + auto dstVecIDs = PyList_New(resNum); + auto dstVecDists = PyList_New(resNum); + auto dstMetadata = PyList_New(resNum); + int i = 0; + for (const auto& res : *($1)) + { + PyList_SetItem(dstVecIDs, i, PyInt_FromLong(res.VID)); + PyList_SetItem(dstVecDists, i, PyFloat_FromDouble(res.Dist)); + i++; + } + + if ($1->WithMeta()) + { + for (i = 0; i < resNum; ++i) + { + const auto& metadata = $1->GetMetadata(i); + PyList_SetItem(dstMetadata, i, PyBytes_FromStringAndSize(reinterpret_cast(metadata.Data()), + metadata.Length())); + } + } + + PyTuple_SetItem($result, 0, dstVecIDs); + PyTuple_SetItem($result, 1, dstVecDists); + PyTuple_SetItem($result, 2, dstMetadata); + } +%} + +%typemap(out) std::shared_ptr +%{ + { + $result = PyTuple_New(3); + auto dstVecIDs = PyList_New(0); + auto dstVecDists = PyList_New(0); + auto dstMetadata = PyList_New(0); + for (const auto& indexRes : $1->m_allIndexResults) + { + for (const auto& res : indexRes.m_results) + { + PyList_Append(dstVecIDs, PyInt_FromLong(res.VID)); + PyList_Append(dstVecDists, PyFloat_FromDouble(res.Dist)); + } + + if (indexRes.m_results.WithMeta()) + { + for (int i = 0; i < indexRes.m_results.GetResultNum(); ++i) + { + const auto& metadata = indexRes.m_results.GetMetadata(i); + PyList_Append(dstMetadata, PyBytes_FromStringAndSize(reinterpret_cast(metadata.Data()), + metadata.Length())); + } + } + } + PyTuple_SetItem($result, 0, dstVecIDs); + PyTuple_SetItem($result, 1, dstVecDists); + PyTuple_SetItem($result, 2, dstMetadata); + } +%} + + +%{ +struct PyBufferHolder +{ + PyBufferHolder() : shouldRelease(false) { } + + ~PyBufferHolder() + { + if (shouldRelease) + { + PyBuffer_Release(&buff); + } + } + + Py_buffer buff; + + bool shouldRelease; +}; +%} + +%typemap(in) ByteArray (PyBufferHolder bufferHolder) +%{ + if (PyBytes_Check($input)) + { + $1 = SPTAG::ByteArray((std::uint8_t*)PyBytes_AsString($input), PyBytes_Size($input), false); + } + else if (PyObject_CheckBuffer($input)) + { + if (PyObject_GetBuffer($input, &bufferHolder.buff, PyBUF_SIMPLE | PyBUF_C_CONTIGUOUS) == -1) + { + PyErr_SetString(PyExc_ValueError, "Failed get buffer."); + return NULL; + } + + bufferHolder.shouldRelease = true; + $1 = SPTAG::ByteArray((std::uint8_t*)bufferHolder.buff.buf, bufferHolder.buff.len, false); + } +#if (PY_VERSION_HEX >= 0x03030000) + else if (PyUnicode_Check($input)) + { + $1 = SPTAG::ByteArray((std::uint8_t*)PyUnicode_DATA($input), PyUnicode_GET_LENGTH($input), false); + } +#endif + + if (nullptr == $1.Data()) + { + PyErr_SetString(PyExc_ValueError, "Expected Bytes, Data Structure with Buffer Protocol, or Unicode String after Python 3.3 ."); + return NULL; + } +%} + +#endif diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCore.i b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCore.i new file mode 100644 index 0000000000..d2f38ca856 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/PythonCore.i @@ -0,0 +1,16 @@ +%module SPTAG + +%{ +#include "inc/CoreInterface.h" +%} + +%include +%shared_ptr(AnnIndex) +%shared_ptr(QueryResult) +%include "PythonCommon.i" + +%{ +#define SWIG_FILE_WITH_INIT +%} + +%include "CoreInterface.h" \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/inc/TransferDataType.h b/core/src/index/thirdparty/SPTAG/Wrappers/inc/TransferDataType.h new file mode 100644 index 0000000000..51ef9614ab --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/inc/TransferDataType.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_PW_TRANSFERDATATYPE_H_ +#define _SPTAG_PW_TRANSFERDATATYPE_H_ + +#include "inc/Core/CommonDataStructure.h" +#include "inc/Core/SearchQuery.h" +#include "inc/Socket/RemoteSearchQuery.h" + +typedef SPTAG::ByteArray ByteArray; + +typedef SPTAG::QueryResult QueryResult; + +typedef SPTAG::BasicResult BasicResult; + +typedef SPTAG::Socket::RemoteSearchResult RemoteSearchResult; + +#endif // _SPTAG_PW_TRANSFERDATATYPE_H_ diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/packages.config b/core/src/index/thirdparty/SPTAG/Wrappers/packages.config new file mode 100644 index 0000000000..d780ec4a8e --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/packages.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/src/AssemblyInfo.cpp b/core/src/index/thirdparty/SPTAG/Wrappers/src/AssemblyInfo.cpp new file mode 100644 index 0000000000..43759a83ef --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/src/AssemblyInfo.cpp @@ -0,0 +1,36 @@ +using namespace System; +using namespace System::Reflection; +using namespace System::Runtime::CompilerServices; +using namespace System::Runtime::InteropServices; +using namespace System::Security::Permissions; + +// +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +// +[assembly:AssemblyTitleAttribute(L"CLRCore")]; +[assembly:AssemblyDescriptionAttribute(L"")]; +[assembly:AssemblyConfigurationAttribute(L"")]; +[assembly:AssemblyCompanyAttribute(L"")]; +[assembly:AssemblyProductAttribute(L"CLRCore")]; +[assembly:AssemblyCopyrightAttribute(L"Copyright (c) 2019")]; +[assembly:AssemblyTrademarkAttribute(L"")]; +[assembly:AssemblyCultureAttribute(L"")]; + +// +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the value or you can default the Revision and Build Numbers +// by using the '*' as shown below: + +[assembly:AssemblyVersionAttribute("1.0.*")]; + +[assembly:ComVisible(false)]; + +[assembly:CLSCompliantAttribute(true)]; \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/src/CLRCoreInterface.cpp b/core/src/index/thirdparty/SPTAG/Wrappers/src/CLRCoreInterface.cpp new file mode 100644 index 0000000000..39e62baf44 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/src/CLRCoreInterface.cpp @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/CLRCoreInterface.h" + + +namespace Microsoft +{ + namespace ANN + { + namespace SPTAGManaged + { + AnnIndex::AnnIndex(std::shared_ptr p_index) : + ManagedObject(p_index) + { + m_dimension = p_index->GetFeatureDim(); + m_inputVectorSize = SPTAG::GetValueTypeSize(p_index->GetVectorValueType()) * m_dimension; + } + + AnnIndex::AnnIndex(String^ p_algoType, String^ p_valueType, int p_dimension) : + ManagedObject(SPTAG::VectorIndex::CreateInstance(string_to(p_algoType), string_to(p_valueType))) + { + m_dimension = p_dimension; + m_inputVectorSize = SPTAG::GetValueTypeSize((*m_Instance)->GetVectorValueType()) * m_dimension; + } + + void AnnIndex::SetBuildParam(String^ p_name, String^ p_value) + { + if (m_Instance != nullptr) + (*m_Instance)->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value)); + } + + void AnnIndex::SetSearchParam(String^ p_name, String^ p_value) + { + if (m_Instance != nullptr) + (*m_Instance)->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value)); + } + + bool AnnIndex::Build(array^ p_data, int p_num) + { + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(ptr, p_num, m_dimension)); + } + + bool AnnIndex::BuildWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex) + { + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr dataptr = &p_data[0]; + std::shared_ptr vectors(new SPTAG::BasicVectorSet(SPTAG::ByteArray(dataptr, p_data->LongLength, false), (*m_Instance)->GetVectorValueType(), m_dimension, p_num)); + + pin_ptr metaptr = &p_meta[0]; + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + int current = 0; + for (long long i = 0; i < p_meta->LongLength; i++) { + if (((char)metaptr[i]) == '\n') + offsets[++current] = (std::uint64_t)(i + 1); + } + std::shared_ptr meta(new SPTAG::MemMetadataSet(SPTAG::ByteArray(metaptr, p_meta->LongLength, false), SPTAG::ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); + return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(vectors, meta, p_withMetaIndex)); + } + + array^ AnnIndex::Search(array^ p_data, int p_resultNum) + { + array^ res; + if (m_Instance == nullptr || m_dimension == 0 || p_data->LongLength != m_inputVectorSize) + return res; + + pin_ptr ptr = &p_data[0]; + SPTAG::QueryResult results(ptr, p_resultNum, false); + (*m_Instance)->SearchIndex(results); + + res = gcnew array(p_resultNum); + for (int i = 0; i < p_resultNum; i++) + res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); + + return res; + } + + array^ AnnIndex::SearchWithMetaData(array^ p_data, int p_resultNum) + { + array^ res; + if (m_Instance == nullptr || m_dimension == 0 || p_data->LongLength != m_inputVectorSize) + return res; + + pin_ptr ptr = &p_data[0]; + SPTAG::QueryResult results(ptr, p_resultNum, true); + (*m_Instance)->SearchIndex(results); + + res = gcnew array(p_resultNum); + for (int i = 0; i < p_resultNum; i++) + res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); + + return res; + } + + bool AnnIndex::Save(String^ p_saveFile) + { + return SPTAG::ErrorCode::Success == (*m_Instance)->SaveIndex(string_to_char_array(p_saveFile)); + } + + array^>^ AnnIndex::Dump() + { + std::shared_ptr> buffersize = (*m_Instance)->CalculateBufferSize(); + array^>^ res = gcnew array^>(buffersize->size() + 1); + std::vector indexBlobs; + for (int i = 1; i < res->Length; i++) + { + res[i] = gcnew array(buffersize->at(i-1)); + pin_ptr ptr = &res[i][0]; + indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t*)ptr, res[i]->LongLength, false)); + } + std::string config; + if (SPTAG::ErrorCode::Success != (*m_Instance)->SaveIndex(config, indexBlobs)) + { + array^>^ null; + return null; + } + res[0] = gcnew array(config.size()); + Marshal::Copy(IntPtr(&config[0]), res[0], 0, config.size()); + return res; + } + + bool AnnIndex::Add(array^ p_data, int p_num) + { + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->AddIndex(ptr, p_num, m_dimension)); + } + + bool AnnIndex::AddWithMetaData(array^ p_data, array^ p_meta, int p_num) + { + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr dataptr = &p_data[0]; + std::shared_ptr vectors(new SPTAG::BasicVectorSet(SPTAG::ByteArray(dataptr, p_data->LongLength, false), (*m_Instance)->GetVectorValueType(), m_dimension, p_num)); + + pin_ptr metaptr = &p_meta[0]; + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + int current = 0; + for (long long i = 0; i < p_meta->LongLength; i++) { + if (((char)metaptr[i]) == '\n') + offsets[++current] = (std::uint64_t)(i + 1); + } + std::shared_ptr meta(new SPTAG::MemMetadataSet(SPTAG::ByteArray(metaptr, p_meta->LongLength, false), SPTAG::ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); + return (SPTAG::ErrorCode::Success == (*m_Instance)->AddIndex(vectors, meta)); + } + + bool AnnIndex::Delete(array^ p_data, int p_num) + { + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->DeleteIndex(ptr, p_num)); + } + + bool AnnIndex::DeleteByMetaData(array^ p_meta) + { + if (m_Instance == nullptr) + return false; + + pin_ptr metaptr = &p_meta[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->DeleteIndex(SPTAG::ByteArray(metaptr, p_meta->LongLength, false))); + } + + AnnIndex^ AnnIndex::Load(String^ p_loaderFile) + { + std::shared_ptr vecIndex; + AnnIndex^ res; + if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(string_to_char_array(p_loaderFile), vecIndex) || nullptr == vecIndex) + { + res = gcnew AnnIndex(nullptr); + } + else { + res = gcnew AnnIndex(vecIndex); + } + return res; + } + + AnnIndex^ AnnIndex::Load(array^>^ p_index) + { + std::vector p_indexBlobs; + for (int i = 1; i < p_index->Length; i++) + { + pin_ptr ptr = &p_index[i][0]; + p_indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t*)ptr, p_index[i]->LongLength, false)); + } + pin_ptr configptr = &p_index[0][0]; + + std::shared_ptr vecIndex; + if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(std::string((char*)configptr, p_index[0]->LongLength), p_indexBlobs, vecIndex) || nullptr == vecIndex) + { + return gcnew AnnIndex(nullptr); + } + return gcnew AnnIndex(vecIndex); + } + + bool AnnIndex::Merge(String^ p_indexFilePath1, String^ p_indexFilePath2) + { + return (SPTAG::ErrorCode::Success == SPTAG::VectorIndex::MergeIndex(string_to_char_array(p_indexFilePath1), string_to_char_array(p_indexFilePath2))); + } + } + } +} \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/src/ClientInterface.cpp b/core/src/index/thirdparty/SPTAG/Wrappers/src/ClientInterface.cpp new file mode 100644 index 0000000000..65a1d4cf17 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/src/ClientInterface.cpp @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/ClientInterface.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/Concurrent.h" +#include "inc/Helper/Base64Encode.h" +#include "inc/Helper/StringConvert.h" + +#include + + +AnnClient::AnnClient(const char* p_serverAddr, const char* p_serverPort) + : m_connectionID(SPTAG::Socket::c_invalidConnectionID), + m_timeoutInMilliseconds(9000) +{ + using namespace SPTAG; + + m_socketClient.reset(new Socket::Client(GetHandlerMap(), 2, 30)); + + if (nullptr == p_serverAddr || nullptr == p_serverPort) + { + return; + } + + m_server = p_serverAddr; + m_port = p_serverPort; + + auto connectCallback = [this](Socket::ConnectionID p_cid, ErrorCode p_ec) + { + m_connectionID = p_cid; + + if (ErrorCode::Socket_FailedResolveEndPoint == p_ec) + { + return; + } + + while (Socket::c_invalidConnectionID == m_connectionID) + { + ErrorCode errCode; + std::this_thread::sleep_for(std::chrono::seconds(10)); + m_connectionID = m_socketClient->ConnectToServer(m_server, m_port, errCode); + } + }; + + m_socketClient->AsyncConnectToServer(m_server, m_port, connectCallback); + + m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) + { + ErrorCode errCode; + m_connectionID = Socket::c_invalidConnectionID; + while (Socket::c_invalidConnectionID == m_connectionID) + { + std::this_thread::sleep_for(std::chrono::seconds(10)); + m_connectionID = m_socketClient->ConnectToServer(m_server, m_port, errCode); + } + }); +} + + +AnnClient::~AnnClient() +{ +} + + +void +AnnClient::SetTimeoutMilliseconds(int p_timeout) +{ + m_timeoutInMilliseconds = p_timeout; +} + + +void +AnnClient::SetSearchParam(const char* p_name, const char* p_value) +{ + std::lock_guard guard(m_paramMutex); + + if (nullptr == p_name || '\0' == *p_name) + { + return; + } + + std::string name(p_name); + SPTAG::Helper::StrUtils::ToLowerInPlace(name); + + if (nullptr == p_value || '\0' == *p_value) + { + m_params.erase(name); + return; + } + + m_params[name] = p_value; +} + + +void +AnnClient::ClearSearchParam() +{ + std::lock_guard guard(m_paramMutex); + m_params.clear(); +} + + +std::shared_ptr +AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bool p_withMetaData) +{ + using namespace SPTAG; + + SPTAG::Socket::RemoteSearchResult ret; + if (Socket::c_invalidConnectionID != m_connectionID) + { + + auto signal = std::make_shared(1); + + auto callback = [&ret, signal](RemoteSearchResult p_result) + { + if (RemoteSearchResult::ResultStatus::Success == p_result.m_status) + { + ret = std::move(p_result); + } + + signal->FinishOne(); + }; + + auto timeoutCallback = [this](std::shared_ptr p_callback) + { + if (nullptr != p_callback) + { + RemoteSearchResult result; + result.m_status = RemoteSearchResult::ResultStatus::Timeout; + + (*p_callback)(std::move(result)); + } + }; + + auto connectCallback = [callback, this](bool p_connectSucc) + { + if (!p_connectSucc) + { + RemoteSearchResult result; + result.m_status = RemoteSearchResult::ResultStatus::FailedNetwork; + + callback(std::move(result)); + } + }; + + Socket::Packet packet; + packet.Header().m_connectionID = Socket::c_invalidConnectionID; + packet.Header().m_packetType = Socket::PacketType::SearchRequest; + packet.Header().m_processStatus = Socket::PacketProcessStatus::Ok; + packet.Header().m_resourceID = m_callbackManager.Add(std::make_shared(std::move(callback)), + m_timeoutInMilliseconds, + std::move(timeoutCallback)); + + Socket::RemoteQuery query; + SPTAG::VectorValueType valueType; + SPTAG::Helper::Convert::ConvertStringTo(p_valueType, valueType); + query.m_queryString = CreateSearchQuery(p_data, p_resultNum, p_withMetaData, valueType); + + packet.Header().m_bodyLength = static_cast(query.EstimateBufferSize()); + packet.AllocateBuffer(packet.Header().m_bodyLength); + query.Write(packet.Body()); + packet.Header().WriteBuffer(packet.HeaderBuffer()); + + m_socketClient->SendPacket(m_connectionID, std::move(packet), connectCallback); + + signal->Wait(); + } + return std::make_shared(ret); +} + + +bool +AnnClient::IsConnected() const +{ + return m_connectionID != SPTAG::Socket::c_invalidConnectionID; +} + + +SPTAG::Socket::PacketHandlerMapPtr +AnnClient::GetHandlerMap() +{ + using namespace SPTAG; + + Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); + handlerMap->emplace(Socket::PacketType::SearchResponse, + std::bind(&AnnClient::SearchResponseHanlder, + this, + std::placeholders::_1, + std::placeholders::_2)); + + return handlerMap; +} + + +void +AnnClient::SearchResponseHanlder(SPTAG::Socket::ConnectionID p_localConnectionID, + SPTAG::Socket::Packet p_packet) +{ + using namespace SPTAG; + + std::shared_ptr callback = m_callbackManager.GetAndRemove(p_packet.Header().m_resourceID); + if (nullptr == callback) + { + return; + } + + if (p_packet.Header().m_processStatus != Socket::PacketProcessStatus::Ok || 0 == p_packet.Header().m_bodyLength) + { + Socket::RemoteSearchResult result; + result.m_status = Socket::RemoteSearchResult::ResultStatus::FailedExecute; + + (*callback)(std::move(result)); + } + else + { + Socket::RemoteSearchResult result; + result.Read(p_packet.Body()); + (*callback)(std::move(result)); + } +} + + +std::string +AnnClient::CreateSearchQuery(const ByteArray& p_data, + int p_resultNum, + bool p_extractMetadata, + SPTAG::VectorValueType p_valueType) +{ + std::stringstream out; + + out << "#"; + std::size_t encLen; + SPTAG::Helper::Base64::Encode(p_data.Data(), p_data.Length(), out, encLen); + + out << " $datatype:" << SPTAG::Helper::Convert::ConvertToString(p_valueType); + out << " $resultnum:" << std::to_string(p_resultNum); + out << " $extractmetadata:" << (p_extractMetadata ? "true" : "false"); + + { + std::lock_guard guard(m_paramMutex); + for (const auto& param : m_params) + { + out << " $" << param.first << ":" << param.second; + } + } + + return out.str(); +} + diff --git a/core/src/index/thirdparty/SPTAG/Wrappers/src/CoreInterface.cpp b/core/src/index/thirdparty/SPTAG/Wrappers/src/CoreInterface.cpp new file mode 100644 index 0000000000..5a62fe0315 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/Wrappers/src/CoreInterface.cpp @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/CoreInterface.h" +#include "inc/Helper/StringConvert.h" + + +AnnIndex::AnnIndex(DimensionType p_dimension) + : m_algoType(SPTAG::IndexAlgoType::BKT), + m_inputValueType(SPTAG::VectorValueType::Float), + m_dimension(p_dimension) +{ + m_inputVectorSize = SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; +} + + +AnnIndex::AnnIndex(const char* p_algoType, const char* p_valueType, DimensionType p_dimension) + : m_algoType(SPTAG::IndexAlgoType::Undefined), + m_inputValueType(SPTAG::VectorValueType::Undefined), + m_dimension(p_dimension) +{ + SPTAG::Helper::Convert::ConvertStringTo(p_algoType, m_algoType); + SPTAG::Helper::Convert::ConvertStringTo(p_valueType, m_inputValueType); + m_inputVectorSize = SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; +} + + +AnnIndex::AnnIndex(const std::shared_ptr& p_index) + : m_algoType(p_index->GetIndexAlgoType()), + m_inputValueType(p_index->GetVectorValueType()), + m_dimension(p_index->GetFeatureDim()), + m_index(p_index) +{ + m_inputVectorSize = SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; +} + + +AnnIndex::~AnnIndex() +{ +} + + +bool +AnnIndex::Build(ByteArray p_data, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_data.Data(), (SPTAG::SizeType)p_num, (SPTAG::DimensionType)m_dimension)); +} + + +bool +AnnIndex::BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + + std::shared_ptr vectors(new SPTAG::BasicVectorSet(p_data, + m_inputValueType, + static_cast(m_dimension), + static_cast(p_num))); + + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + SizeType current = 1; + for (size_t i = 0; i < p_meta.Length(); i++) { + if (((char)p_meta.Data()[i]) == '\n') + offsets[current++] = (std::uint64_t)(i + 1); + } + std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num)); + return (SPTAG::ErrorCode::Success == m_index->BuildIndex(vectors, meta, p_withMetaIndex)); +} + + +void +AnnIndex::SetBuildParam(const char* p_name, const char* p_value) +{ + if (nullptr == m_index) + { + if (SPTAG::IndexAlgoType::Undefined == m_algoType || + SPTAG::VectorValueType::Undefined == m_inputValueType) + { + return; + } + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + + } + m_index->SetParameter(p_name, p_value); +} + + +void +AnnIndex::SetSearchParam(const char* p_name, const char* p_value) +{ + if (nullptr != m_index) m_index->SetParameter(p_name, p_value); +} + + +std::shared_ptr +AnnIndex::Search(ByteArray p_data, int p_resultNum) +{ + std::shared_ptr results = std::make_shared(p_data.Data(), p_resultNum, false); + + if (nullptr != m_index && p_data.Length() == m_inputVectorSize) + { + m_index->SearchIndex(*results); + } + return std::move(results); +} + +std::shared_ptr +AnnIndex::SearchWithMetaData(ByteArray p_data, int p_resultNum) +{ + std::shared_ptr results = std::make_shared(p_data.Data(), p_resultNum, true); + + if (nullptr != m_index && p_data.Length() == m_inputVectorSize) + { + m_index->SearchIndex(*results); + } + return std::move(results); +} + +bool +AnnIndex::ReadyToServe() const +{ + return m_index != nullptr; +} + + +bool +AnnIndex::Save(const char* p_savefile) const +{ + return SPTAG::ErrorCode::Success == m_index->SaveIndex(p_savefile); +} + + +AnnIndex +AnnIndex::Load(const char* p_loaderFile) +{ + std::shared_ptr vecIndex; + auto ret = SPTAG::VectorIndex::LoadIndex(p_loaderFile, vecIndex); + if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) + { + return AnnIndex(0); + } + + return AnnIndex(vecIndex); +} + + +bool +AnnIndex::Add(ByteArray p_data, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + return (SPTAG::ErrorCode::Success == m_index->AddIndex(p_data.Data(), (SPTAG::SizeType)p_num, (SPTAG::DimensionType)m_dimension)); +} + + +bool +AnnIndex::AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + + std::shared_ptr vectors(new SPTAG::BasicVectorSet(p_data, + m_inputValueType, + static_cast(m_dimension), + static_cast(p_num))); + + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + SizeType current = 1; + for (size_t i = 0; i < p_meta.Length(); i++) { + if (((char)p_meta.Data()[i]) == '\n') + offsets[current++] = (std::uint64_t)(i + 1); + } + std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num)); + return (SPTAG::ErrorCode::Success == m_index->AddIndex(vectors, meta)); +} + + +bool +AnnIndex::Delete(ByteArray p_data, SizeType p_num) +{ + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + + return (SPTAG::ErrorCode::Success == m_index->DeleteIndex(p_data.Data(), (SPTAG::SizeType)p_num)); +} + + +bool +AnnIndex::DeleteByMetaData(ByteArray p_meta) +{ + if (nullptr == m_index) return false; + + return (SPTAG::ErrorCode::Success == m_index->DeleteIndex(p_meta)); +} + + +bool +AnnIndex::Merge(const char* p_indexFilePath1, const char* p_indexFilePath2) +{ + return (SPTAG::ErrorCode::Success == SPTAG::VectorIndex::MergeIndex(p_indexFilePath1, p_indexFilePath2)); +} \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/azure-pipelines.yml b/core/src/index/thirdparty/SPTAG/azure-pipelines.yml new file mode 100644 index 0000000000..22f697bdd4 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/azure-pipelines.yml @@ -0,0 +1,230 @@ +resources: + +- repo: self + +phases: + +- phase: Phase_1 + + displayName: Agent job + + + + condition: succeeded() + + queue: + + name: SPTAGBuild + + steps: + + - script: | + mkdir build + cd build + cmake .. + make + cd ../Release + ./test + + displayName: 'Command Line Script' + + + + + +- phase: Phase_2 + + displayName: Agent job + + + + condition: succeeded() + + queue: + + name: Hosted + + demands: + + - msbuild + + - visualstudio + + + + steps: + + - task: NuGetToolInstaller@0 + + displayName: 'Use NuGet 4.3.0' + + + + + + - task: NuGetCommand@2 + + displayName: 'NuGet restore' + + inputs: + + restoreSolution: SPTAG.sln + + + + vstsFeed: 'bae5097f-8d64-4f8f-913e-24a4eb8302c3' + + + + + + - task: VSBuild@1 + + displayName: 'Build solution SPTAG.sln' + + inputs: + + solution: SPTAG.sln + + + + vsVersion: 14.0 + + + + platform: x64 + + + + configuration: debug + + + + msbuildArchitecture: x64 + + + + createLogFile: true + + + + + + - script: '.\x64\Debug\Test.exe' + + displayName: 'Command Line Script' + + + + - task: CopyFiles@2 + + displayName: 'Copy Files to: $(Build.ArtifactStagingDirectory)' + + inputs: + + SourceFolder: x64/Debug/ + + + + Contents: '*' + + + + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + + + + + - task: PublishBuildArtifacts@1 + + displayName: 'Publish Artifact: drop' + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-binskim.BinSkim@3 + + displayName: 'Run BinSkim ' + + inputs: + + InputType: Basic + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-credscan.CredScan@2 + + displayName: 'Run CredScan' + + inputs: + + scanFolder: AnnService + + + + suppressAsError: true + + + + verboseOutput: true + + + + debugMode: false + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-autoapplicability.AutoApplicability@1 + + displayName: 'Run AutoApplicability' + + inputs: + + IsSoftware: true + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-publishsecurityanalysislogs.PublishSecurityAnalysisLogs@2 + + displayName: 'Publish Security Analysis Logs' + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-report.SdtReport@1 + + displayName: 'Create Security Analysis Report' + + inputs: + + AllTools: true + + + + + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-postanalysis.PostAnalysis@1 + + displayName: 'Post Analysis' + + inputs: + + AllTools: true + + + + + + - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 + + displayName: 'Component Detection' + diff --git a/core/src/index/thirdparty/SPTAG/docs/GettingStart.md b/core/src/index/thirdparty/SPTAG/docs/GettingStart.md new file mode 100644 index 0000000000..9f82b680ac --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/docs/GettingStart.md @@ -0,0 +1,286 @@ +## **Quick start** + +### **Index Build** + ```bash + Usage: + ./IndexBuiler [options] + Options: + -d, --dimension Dimension of vector, required. + -v, --vectortype Input vector data type (e.g. Float, Int8, Int16), required. + -i, --input Input raw data, required. + -o, --outputfolder Output folder, required. + -a, --algo Index Algorithm type (e.g. BKT, KDT), required. + + -t, --thread Thread Number, default is 32. + --delimiter Vector delimiter, default is |. + Index.= Set the algorithm parameter ArgName with value ArgValue. + ``` + + ### **Index Search** + ```bash + Usage: + ./IndexSearcher [options] + Options + Index.QueryFile=XXX Input Query file + Index.ResultFile=XXX Output result file + Index.TruthFile=XXX Truth file that can help to calculate the recall + Index.K=XXX How many nearest neighbors return + Index.MaxCheck=XXX The maxcheck of the search + ``` + +### ** Input File format ** +> Input raw data for index build and input query file for index search (suppose vector dimension is 3): +``` +\t||| +\t||| +... +``` +where each line represents a vector with its metadata and its value separated by a tab space. Each dimension of a vector is separated by | or use --delimiter to define the separator. + +> Truth file to calculate recall (suppose K is 2): +``` + + +... +``` +where each line represents the K nearest neighbors of a query separated by a blank space. Each neighbor is given by its vector id. + +### **Server** +```bash +Usage: +./Server [options] +Options: + -m, --mode Service mode, interactive or socket. + -c, --config Configure file of the index + +Write a server configuration file service.ini as follows: + +[Service] +ListenAddr=0.0.0.0 +ListenPort=8000 +ThreadNumber=8 +SocketThreadNumber=8 + +[QueryConfig] +DefaultMaxResultNumber=6 +DefaultSeparator=| + +[Index] +List=BKT + +[Index_BKT] +IndexFolder=BKT_gist +``` + +### **Client** +```bash +Usage: +./Client [options] +Options: +-s, --server Server address +-p, --port Server port +-t, Search timeout +-cth, Client Thread Number +-sth Socket Thread Number +``` + +### **Aggregator** +```bash +Usage: +./Aggregator + +Write Aggregator.ini as follows: + +[Service] +ListenAddr=0.0.0.0 +ListenPort=8100 +ThreadNumber=8 +SocketThreadNumber=8 + +[Servers] +Number=2 + +[Server_0] +Address=127.0.0.1 +Port=8000 + +[Server_1] +Address=127.0.0.1 +Port=8010 +``` + +### **Python Support** +> Singlebox PythonWrapper + ```python + +import SPTAG +import numpy as np + +n = 100 +k = 3 +r = 3 + +def testBuild(algo, distmethod, x, out): + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + ret = i.Build(x, x.shape[0]) + i.Save(out) + +def testBuildWithMetaData(algo, distmethod, x, s, out): + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.BuildWithMetaData(x, s, x.shape[0]): + i.Save(out) + +def testSearch(index, q, k): + j = SPTAG.AnnIndex.Load(index) + for t in range(q.shape[0]): + result = j.Search(q[t], k) + print (result[0]) # ids + print (result[1]) # distances + +def testSearchWithMetaData(index, q, k): + j = SPTAG.AnnIndex.Load(index) + j.SetSearchParam("MaxCheck", '1024') + for t in range(q.shape[0]): + result = j.SearchWithMetaData(q[t], k) + print (result[0]) # ids + print (result[1]) # distances + print (result[2]) # metadata + +def testAdd(index, x, out, algo, distmethod): + if index != None: + i = SPTAG.AnnIndex.Load(index) + else: + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.Add(x, x.shape[0]): + i.Save(out) + +def testAddWithMetaData(index, x, s, out, algo, distmethod): + if index != None: + i = SPTAG.AnnIndex.Load(index) + else: + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.AddWithMetaData(x, s, x.shape[0]): + i.Save(out) + +def testDelete(index, x, out): + i = SPTAG.AnnIndex.Load(index) + ret = i.Delete(x, x.shape[0]) + print (ret) + i.Save(out) + +def Test(algo, distmethod): + x = np.ones((n, 10), dtype=np.float32) * np.reshape(np.arange(n, dtype=np.float32), (n, 1)) + q = np.ones((r, 10), dtype=np.float32) * np.reshape(np.arange(r, dtype=np.float32), (r, 1)) * 2 + m = '' + for i in range(n): + m += str(i) + '\n' + + m = m.encode() + + print ("Build.............................") + testBuild(algo, distmethod, x, 'testindices') + testSearch('testindices', q, k) + print ("Add.............................") + testAdd('testindices', x, 'testindices', algo, distmethod) + testSearch('testindices', q, k) + print ("Delete.............................") + testDelete('testindices', q, 'testindices') + testSearch('testindices', q, k) + + print ("AddWithMetaData.............................") + testAddWithMetaData(None, x, m, 'testindices', algo, distmethod) + print ("Delete.............................") + testSearchWithMetaData('testindices', q, k) + testDelete('testindices', q, 'testindices') + testSearchWithMetaData('testindices', q, k) + +if __name__ == '__main__': + Test('BKT', 'L2') + Test('KDT', 'L2') + + ``` + + > Python Client Wrapper, Suppose there is a sever run at 127.0.0.1:8000 serving ten-dimensional vector datasets: + ```python +import SPTAGClient +import numpy as np +import time + +def testSPTAGClient(): + index = SPTAGClient.AnnClient('127.0.0.1', '8100') + while not index.IsConnected(): + time.sleep(1) + index.SetTimeoutMilliseconds(18000) + + q = np.ones((10, 10), dtype=np.float32) + for t in range(q.shape[0]): + result = index.Search(q[t], 6, 'Float', False) + print (result[0]) + print (result[1]) + +if __name__ == '__main__': + testSPTAGClient() + + ``` + + ### **C# Support** +> Singlebox CsharpWrapper + ```C# +using System; +using System.Text; + +public class test +{ + static int dimension = 10; + static int n = 10; + static int k = 3; + + static byte[] createFloatArray(int n) + { + byte[] data = new byte[n * dimension * sizeof(float)]; + for (int i = 0; i < n; i++) + for (int j = 0; j < dimension; j++) + Array.Copy(BitConverter.GetBytes((float)i), 0, data, (i * dimension + j) * sizeof(float), 4); + return data; + } + + static byte[] createMetadata(int n) + { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < n; i++) + sb.Append(i.ToString() + '\n'); + return Encoding.ASCII.GetBytes(sb.ToString()); + } + + static void Main() + { + { + AnnIndex idx = new AnnIndex("BKT", "Float", dimension); + idx.SetBuildParam("DistCalcMethod", "L2"); + byte[] data = createFloatArray(n); + byte[] meta = createMetadata(n); + idx.BuildWithMetaData(data, meta, n); + idx.Save("testcsharp"); + } + + AnnIndex index = AnnIndex.Load("testcsharp"); + BasicResult[] res = index.SearchWithMetaData(createFloatArray(1), k); + for (int i = 0; i < res.Length; i++) + Console.WriteLine("result " + i.ToString() + ":" + res[i].Dist.ToString() + "@(" + res[i].VID.ToString() + "," + Encoding.ASCII.GetString(res[i].Meta) + ")"); + Console.WriteLine("test finish!"); + } +} + + ``` + + + \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/docs/Parameters.md b/core/src/index/thirdparty/SPTAG/docs/Parameters.md new file mode 100644 index 0000000000..9e2fa93715 --- /dev/null +++ b/core/src/index/thirdparty/SPTAG/docs/Parameters.md @@ -0,0 +1,159 @@ +## **Parameters** + +> Common Parameters + +| ParametersName | type | default | definition| +|---|---|---|---| +| Samples | int | 1000 | how many points will be sampled to do tree node split | +|TPTNumber | int | 32 | number of TPT trees to help with graph construction | +|TPTLeafSize | int | 2000 | TPT tree leaf size | +NeighborhoodSize | int | 32 | number of neighbors each node has in the neighborhood graph | +|GraphNeighborhoodScale | int | 2 | number of neighborhood size scale in the build stage | +|CEF | int | 1000 | number of results used to construct RNG | +|MaxCheckForRefineGraph| int | 10000 | how many nodes each node will visit during graph refine in the build stage | +|NumberOfThreads | int | 1 | number of threads to uses for speed up the build | +|DistCalcMethod | string | Cosine | choose from Cosine and L2 | +|MaxCheck | int | 8192 | how many nodes will be visited for a query in the search stage + +> BKT + +| ParametersName | type | default | definition| +|---|---|---|---| +| BKTNumber | int | 1 | number of BKT trees | +| BKTKMeansK | int | 32 | how many childs each tree node has | + +> KDT + +| ParametersName | type | default | definition| +|---|---|---|---| +| KDTNumber | int | 1 | number of KDT trees | + +> Parameters that will affect the index size +* NeighborhoodSize +* BKTNumber +* KDTNumber + +> Parameters that will affect the index build time +* NumberOfThreads +* TPTNumber +* TPTLeafSize +* GraphNeighborhoodScale +* CEF +* MaxCheckForRefineGraph + +> Parameters that will affect the index quality +* TPTNumber +* TPTLeafSize +* GraphNeighborhoodScale +* CEF +* MaxCheckForRefineGraph +* NeighborhoodSize +* KDTNumber + +> Parameters that will affect search latency and recall +* MaxCheck + +## **NNI for parameters tuning** + +Prepare vector data file **data.tsv**, query data file **query.tsv**, and truth file **truth.txt** following the format introduced in the [Get Started](GettingStart.md). + +Install [microsoft nni](https://github.com/microsoft/nni) and write the following python code (nni_sptag.py), parameter search space configuration (search_space.json) and nni environment configuration (config.yml). + +> nni_sptag.py + +```Python +import nni +import os + +vector_dimension = 10 +vector_type = 'Float' +index_algo = 'BKT' +threads = 32 +k = 3 + +def main(): + para = nni.get_next_parameter() + cmd_build = "./indexbuilder -d %d -v %s -i data.tsv -o index -a %s -t %d " % (vector_dimension, vector_type, index_algo, threads) + for p, v in para.items(): + cmd_build += "Index." + p + "=" + str(v) + cmd_test = "./indexsearcher index Index.QueryFile=query.tsv Index.TruthFile=truth.txt Index.K=%d" % (k) + os.system(cmd_build) + os.system(cmd_test + " > out.txt") + with open("out.txt", "r") as fd: + lines = fd.readlines() + res = lines[-2] + segs = res.split() + recall = float(segs[-2]) + avg_latency = float(segs[-5]) + score = recall + nni.report_final_result(score) + +if __name__ == '__main__': + main() +``` +> search_space.json + +```json +{ + "BKTKmeansK": {"_type": "choice", "_value": [2, 4, 8, 16, 32]}, + "GraphNeighborhoodScale": {"_type": "choice", "_value": [2, 4, 8, 16, 32]} +} + +``` + +> config.yml + +```yaml +authorName: default + +experimentName: example_sptag + +trialConcurrency: 1 + +maxExecDuration: 1h + +maxTrialNum: 10 + +#choice: local, remote, pai + +trainingServicePlatform: local + +searchSpacePath: search_space.json + +#choice: true, false + +useAnnotation: false + +tuner: + + #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner + + #SMAC (SMAC should be installed through nnictl) + + builtinTunerName: TPE + + classArgs: + + #choice: maximize, minimize + + optimize_mode: maximize + +trial: + + command: python3 nni_sptag.py + + codeDir: . + + gpuNum: 0 + +``` + +Then start the tuning (tunning results can be found in the Web UI urls in the command output): +```bash +nnictl create --config config.yml +``` + +stop the tunning: +```bash +nnictl stop +``` \ No newline at end of file diff --git a/core/src/index/thirdparty/SPTAG/docs/img/sptag.png b/core/src/index/thirdparty/SPTAG/docs/img/sptag.png new file mode 100644 index 0000000000..dc21bdf3dc Binary files /dev/null and b/core/src/index/thirdparty/SPTAG/docs/img/sptag.png differ diff --git a/core/src/index/thirdparty/annoy/LICENSE b/core/src/index/thirdparty/annoy/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/core/src/index/thirdparty/annoy/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/core/src/index/thirdparty/annoy/RELEASE.md b/core/src/index/thirdparty/annoy/RELEASE.md new file mode 100644 index 0000000000..c3a1147ce9 --- /dev/null +++ b/core/src/index/thirdparty/annoy/RELEASE.md @@ -0,0 +1,15 @@ +How to release +-------------- + +1. Make sure you're on master. `git checkout master && git fetch && git reset --hard origin/master` +1. Update `setup.py` to the newest version, `git add setup.py && git commit -m "version 1.2.3"` +1. `python setup.py sdist bdist_wheel` +1. `git tag -a v1.2.3 -m "version 1.2.3"` +1. `git push --tags origin master` to push the last version to Github +1. Go to https://github.com/spotify/annoy/releases and click "Draft a new release" +1. `twine upload dist/annoy-1.2.3*` + +TODO +---- + +* Wheel diff --git a/core/src/index/thirdparty/annoy/examples/mmap_test.py b/core/src/index/thirdparty/annoy/examples/mmap_test.py new file mode 100644 index 0000000000..4f86e86713 --- /dev/null +++ b/core/src/index/thirdparty/annoy/examples/mmap_test.py @@ -0,0 +1,14 @@ +from annoy import AnnoyIndex + +a = AnnoyIndex(3, 'angular') +a.add_item(0, [1, 0, 0]) +a.add_item(1, [0, 1, 0]) +a.add_item(2, [0, 0, 1]) +a.build(-1) +a.save('test.tree') + +b = AnnoyIndex(3) +b.load('test.tree') + +print(b.get_nns_by_item(0, 100)) +print(b.get_nns_by_vector([1.0, 0.5, 0.5], 100)) diff --git a/core/src/index/thirdparty/annoy/examples/precision_test.cpp b/core/src/index/thirdparty/annoy/examples/precision_test.cpp new file mode 100644 index 0000000000..2c006487c9 --- /dev/null +++ b/core/src/index/thirdparty/annoy/examples/precision_test.cpp @@ -0,0 +1,176 @@ +/* + * precision_test.cpp + + * + * Created on: Jul 13, 2016 + * Author: Claudio Sanhueza + * Contact: csanhuezalobos@gmail.com + */ + +#include +#include +#include "../src/kissrandom.h" +#include "../src/annoylib.h" +#include +#include +#include +#include + + +int precision(int f=40, int n=1000000){ + std::chrono::high_resolution_clock::time_point t_start, t_end; + + std::default_random_engine generator; + std::normal_distribution distribution(0.0, 1.0); + + //****************************************************** + //Building the tree + AnnoyIndex t = AnnoyIndex(f); + + std::cout << "Building index ... be patient !!" << std::endl; + std::cout << "\"Trees that are slow to grow bear the best fruit\" (Moliere)" << std::endl; + + + + for(int i=0; i( t_end - t_start ).count(); + std::cout << " Done in "<< duration << " secs." << std::endl; + + + std::cout << "Saving index ..."; + t.save("precision.tree"); + std::cout << " Done" << std::endl; + + + + //****************************************************** + std::vector limits = {10, 100, 1000, 10000}; + int K=10; + int prec_n = 1000; + + std::map prec_sum; + std::map time_sum; + std::vector closest; + + //init precision and timers map + for(std::vector::iterator it = limits.begin(); it!=limits.end(); ++it){ + prec_sum[(*it)] = 0.0; + time_sum[(*it)] = 0.0; + } + + // doing the work + for(int i=0; i toplist; + std::vector intersection; + + for(std::vector::iterator limit = limits.begin(); limit!=limits.end(); ++limit){ + + t_start = std::chrono::high_resolution_clock::now(); + t.get_nns_by_item(j, (*limit), (size_t) -1, &toplist, nullptr); //search_k defaults to "n_trees * n" if not provided. + t_end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( t_end - t_start ).count(); + + //intersecting results + std::sort(closest.begin(), closest.end(), std::less()); + std::sort(toplist.begin(), toplist.end(), std::less()); + intersection.resize(std::max(closest.size(), toplist.size())); + std::vector::iterator it_set = std::set_intersection(closest.begin(), closest.end(), toplist.begin(), toplist.end(), intersection.begin()); + intersection.resize(it_set-intersection.begin()); + + // storing metrics + int found = intersection.size(); + double hitrate = found / (double) K; + prec_sum[(*limit)] += hitrate; + + time_sum[(*limit)] += duration; + + + //deallocate memory + vector().swap(intersection); + vector().swap(toplist); + } + + //print resulting metrics + for(std::vector::iterator limit = limits.begin(); limit!=limits.end(); ++limit){ + std::cout << "limit: " << (*limit) << "\tprecision: "<< std::fixed << std::setprecision(2) << (100.0 * prec_sum[(*limit)] / (i + 1)) << "% \tavg. time: "<< std::fixed<< std::setprecision(6) << (time_sum[(*limit)] / (i + 1)) * 1e-04 << "s" << std::endl; + } + + closest.clear(); vector().swap(closest); + + } + + std::cout << "\nDone" << std::endl; + return 0; +} + + +void help(){ + std::cout << "Annoy Precision C++ example" << std::endl; + std::cout << "Usage:" << std::endl; + std::cout << "(default) ./precision" << std::endl; + std::cout << "(using parameters) ./precision num_features num_nodes" << std::endl; + std::cout << std::endl; +} + +void feedback(int f, int n){ + std::cout<<"Runing precision example with:" << std::endl; + std::cout<<"num. features: "<< f << std::endl; + std::cout<<"num. nodes: "<< n << std::endl; + std::cout << std::endl; +} + + +int main(int argc, char **argv) { + int f, n; + + + if(argc == 1){ + f = 40; + n = 1000000; + + feedback(f,n); + + precision(40, 1000000); + } + else if(argc == 3){ + + f = atoi(argv[1]); + n = atoi(argv[2]); + + feedback(f,n); + + precision(f, n); + } + else { + help(); + return EXIT_FAILURE; + } + + + return EXIT_SUCCESS; +} diff --git a/core/src/index/thirdparty/annoy/examples/precision_test.py b/core/src/index/thirdparty/annoy/examples/precision_test.py new file mode 100644 index 0000000000..d179e6b9ba --- /dev/null +++ b/core/src/index/thirdparty/annoy/examples/precision_test.py @@ -0,0 +1,46 @@ +from __future__ import print_function +import random, time +from annoy import AnnoyIndex + +try: + xrange +except NameError: + # Python 3 compat + xrange = range + +n, f = 100000, 40 + +t = AnnoyIndex(f, 'angular') +for i in xrange(n): + v = [] + for z in xrange(f): + v.append(random.gauss(0, 1)) + t.add_item(i, v) + +t.build(2 * f) +t.save('test.tree') + +limits = [10, 100, 1000, 10000] +k = 10 +prec_sum = {} +prec_n = 1000 +time_sum = {} + +for i in xrange(prec_n): + j = random.randrange(0, n) + + closest = set(t.get_nns_by_item(j, k, n)) + for limit in limits: + t0 = time.time() + toplist = t.get_nns_by_item(j, k, limit) + T = time.time() - t0 + + found = len(closest.intersection(toplist)) + hitrate = 1.0 * found / k + prec_sum[limit] = prec_sum.get(limit, 0.0) + hitrate + time_sum[limit] = time_sum.get(limit, 0.0) + T + +for limit in limits: + print('limit: %-9d precision: %6.2f%% avg time: %.6fs' + % (limit, 100.0 * prec_sum[limit] / (i + 1), + time_sum[limit] / (i + 1))) diff --git a/core/src/index/thirdparty/annoy/examples/s_compile_cpp.sh b/core/src/index/thirdparty/annoy/examples/s_compile_cpp.sh new file mode 100755 index 0000000000..687a6082b2 --- /dev/null +++ b/core/src/index/thirdparty/annoy/examples/s_compile_cpp.sh @@ -0,0 +1,7 @@ +#!/bin/bash + + +echo "compiling precision example..." +cmd="g++ precision_test.cpp -o precision_test -std=c++11" +eval $cmd +echo "Done" diff --git a/core/src/index/thirdparty/annoy/examples/simple_test.py b/core/src/index/thirdparty/annoy/examples/simple_test.py new file mode 100644 index 0000000000..27e0343a26 --- /dev/null +++ b/core/src/index/thirdparty/annoy/examples/simple_test.py @@ -0,0 +1,10 @@ +from annoy import AnnoyIndex + +a = AnnoyIndex(3, 'angular') +a.add_item(0, [1, 0, 0]) +a.add_item(1, [0, 1, 0]) +a.add_item(2, [0, 0, 1]) +a.build(-1) + +print(a.get_nns_by_item(0, 100)) +print(a.get_nns_by_vector([1.0, 0.5, 0.5], 100)) diff --git a/core/src/index/thirdparty/annoy/src/annoygomodule.h b/core/src/index/thirdparty/annoy/src/annoygomodule.h new file mode 100644 index 0000000000..c5fb408419 --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/annoygomodule.h @@ -0,0 +1,92 @@ +#include "annoylib.h" +#include "kissrandom.h" + +namespace GoAnnoy { + +class AnnoyIndex { + protected: + ::AnnoyIndexInterface *ptr; + + int f; + + public: + ~AnnoyIndex() { + delete ptr; + }; + void addItem(int item, const float* w) { + ptr->add_item(item, w); + }; + void build(int q) { + ptr->build(q); + }; + bool save(const char* filename, bool prefault) { + return ptr->save(filename, prefault); + }; + bool save(const char* filename) { + return ptr->save(filename, true); + }; + void unload() { + ptr->unload(); + }; + bool load(const char* filename, bool prefault) { + return ptr->load(filename, prefault); + }; + bool load(const char* filename) { + return ptr->load(filename, true); + }; + float getDistance(int i, int j) { + return ptr->get_distance(i, j); + }; + void getNnsByItem(int item, int n, int search_k, vector* result, vector* distances) { + ptr->get_nns_by_item(item, n, search_k, result, distances); + }; + void getNnsByVector(const float* w, int n, int search_k, vector* result, vector* distances) { + ptr->get_nns_by_vector(w, n, search_k, result, distances); + }; + void getNnsByItem(int item, int n, int search_k, vector* result) { + ptr->get_nns_by_item(item, n, search_k, result, nullptr); + }; + void getNnsByVector(const float* w, int n, int search_k, vector* result) { + ptr->get_nns_by_vector(w, n, search_k, result, nullptr); + }; + + int getNItems() { + return (int)ptr->get_n_items(); + }; + void verbose(bool v) { + ptr->verbose(v); + }; + void getItem(int item, vector *v) { + v->resize(this->f); + ptr->get_item(item, &v->front()); + }; + bool onDiskBuild(const char* filename) { + return ptr->on_disk_build(filename); + }; +}; + +class AnnoyIndexAngular : public AnnoyIndex +{ + public: + AnnoyIndexAngular(int f) { + ptr = new ::AnnoyIndex(f); + this->f = f; + } +}; + +class AnnoyIndexEuclidean : public AnnoyIndex { + public: + AnnoyIndexEuclidean(int f) { + ptr = new ::AnnoyIndex(f); + this->f = f; + } +}; + +class AnnoyIndexManhattan : public AnnoyIndex { + public: + AnnoyIndexManhattan(int f) { + ptr = new ::AnnoyIndex(f); + this->f = f; + } +}; +} diff --git a/core/src/index/thirdparty/annoy/src/annoygomodule.i b/core/src/index/thirdparty/annoy/src/annoygomodule.i new file mode 100644 index 0000000000..9882cbeb2c --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/annoygomodule.i @@ -0,0 +1,96 @@ +%module annoyindex + +%{ +#include "annoygomodule.h" +%} + + +// const float * +%typemap(gotype) (const float *) "[]float32" + +%typemap(in) (const float *) +%{ + float *v; + vector w; + v = (float *)$input.array; + for (int i = 0; i < $input.len; i++) { + w.push_back(v[i]); + } + $1 = &w[0]; +%} + +// vector * +%typemap(gotype) (vector *) "*[]int" + +%typemap(in) (vector *) +%{ + $1 = new vector(); +%} + +%typemap(freearg) (vector *) +%{ + delete $1; +%} + +%typemap(argout) (vector *) +%{ + { + $input->len = $1->size(); + $input->cap = $1->size(); + $input->array = malloc($input->len * sizeof(intgo)); + for (int i = 0; i < $1->size(); i++) { + ((intgo *)$input->array)[i] = (intgo)(*$1)[i]; + } + } +%} + + +// vector * +%typemap(gotype) (vector *) "*[]float32" + +%typemap(in) (vector *) +%{ + $1 = new vector(); +%} + +%typemap(freearg) (vector *) +%{ + delete $1; +%} + +%typemap(argout) (vector *) +%{ + { + $input->len = $1->size(); + $input->cap = $1->size(); + $input->array = malloc($input->len * sizeof(float)); + for (int i = 0; i < $1->size(); i++) { + ((float *)$input->array)[i] = (float)(*$1)[i]; + } + } +%} + + +%typemap(gotype) (const char *) "string" + +%typemap(in) (const char *) +%{ + $1 = (char *)calloc((((_gostring_)$input).n + 1), sizeof(char)); + strncpy($1, (((_gostring_)$input).p), ((_gostring_)$input).n); +%} + +%typemap(freearg) (const char *) +%{ + free($1); +%} + + +/* Let's just grab the original header file here */ +%include "annoygomodule.h" + +%feature("notabstract") GoAnnoyIndexAngular; +%feature("notabstract") GoAnnoyIndexEuclidean; +%feature("notabstract") GoAnnoyIndexManhattan; + + + diff --git a/core/src/index/thirdparty/annoy/src/annoylib.h b/core/src/index/thirdparty/annoy/src/annoylib.h new file mode 100644 index 0000000000..3af171664b --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/annoylib.h @@ -0,0 +1,1411 @@ +// Copyright (c) 2013 Spotify AB +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + + +#ifndef ANNOYLIB_H +#define ANNOYLIB_H + +#include +#include +#ifndef _MSC_VER +#include +#endif +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) && _MSC_VER == 1500 +typedef unsigned char uint8_t; +typedef signed __int32 int32_t; +typedef unsigned __int64 uint64_t; +typedef signed __int64 int64_t; +#else +#include +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) + // a bit hacky, but override some definitions to support 64 bit + #define off_t int64_t + #define lseek_getsize(fd) _lseeki64(fd, 0, SEEK_END) + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include "mman.h" + #include +#else + #include + #define lseek_getsize(fd) lseek(fd, 0, SEEK_END) +#endif + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +// Needed for Visual Studio to disable runtime checks for mempcy +#pragma runtime_checks("s", off) +#endif + +// This allows others to supply their own logger / error printer without +// requiring Annoy to import their headers. See RcppAnnoy for a use case. +#ifndef __ERROR_PRINTER_OVERRIDE__ + #define showUpdate(...) { fprintf(stderr, __VA_ARGS__ ); } +#else + #define showUpdate(...) { __ERROR_PRINTER_OVERRIDE__( __VA_ARGS__ ); } +#endif + +// Portable alloc definition, cf Writing R Extensions, Section 1.6.4 +#ifdef __GNUC__ + // Includes GCC, clang and Intel compilers + # undef alloca + # define alloca(x) __builtin_alloca((x)) +#elif defined(__sun) || defined(_AIX) + // this is necessary (and sufficient) for Solaris 10 and AIX 6: + # include +#endif + +inline void set_error_from_errno(char **error, const char* msg) { + showUpdate("%s: %s (%d)\n", msg, strerror(errno), errno); + if (error) { + *error = (char *)malloc(256); // TODO: win doesn't support snprintf + sprintf(*error, "%s: %s (%d)", msg, strerror(errno), errno); + } +} + +inline void set_error_from_string(char **error, const char* msg) { + showUpdate("%s\n", msg); + if (error) { + *error = (char *)malloc(strlen(msg) + 1); + strcpy(*error, msg); + } +} + +// We let the v array in the Node struct take whatever space is needed, so this is a mostly insignificant number. +// Compilers need *some* size defined for the v array, and some memory checking tools will flag for buffer overruns if this is set too low. +#define V_ARRAY_SIZE 65536 + +#ifndef _MSC_VER +#define popcount __builtin_popcountll +#else // See #293, #358 +#define isnan(x) _isnan(x) +#define popcount cole_popcount +#endif + +#if !defined(NO_MANUAL_VECTORIZATION) && defined(__GNUC__) && (__GNUC__ >6) && defined(__AVX512F__) // See #402 +#define USE_AVX512 +#elif !defined(NO_MANUAL_VECTORIZATION) && defined(__AVX__) && defined (__SSE__) && defined(__SSE2__) && defined(__SSE3__) +#define USE_AVX +#else +#endif + +#if defined(USE_AVX) || defined(USE_AVX512) +#if defined(_MSC_VER) +#include +#elif defined(__GNUC__) +#include +#include + +#endif +#endif + +#include +#include + +using std::vector; +using std::pair; +using std::numeric_limits; +using std::make_pair; + +inline void* remap_memory(void* _ptr, int _fd, size_t old_size, size_t new_size) { +#ifdef __linux__ + _ptr = mremap(_ptr, old_size, new_size, MREMAP_MAYMOVE); +#else + munmap(_ptr, old_size); +#ifdef MAP_POPULATE + _ptr = mmap(_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0); +#else + _ptr = mmap(_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0); +#endif +#endif + return _ptr; +} + +namespace { + +template +inline Node* get_node_ptr(const void* _nodes, const size_t _s, const S i) { + return (Node*)((uint8_t *)_nodes + (_s * i)); +} + +template +inline T dot(const T* x, const T* y, int f) { + T s = 0; + for (int z = 0; z < f; z++) { + s += (*x) * (*y); + x++; + y++; + } + return s; +} + +template +inline T manhattan_distance(const T* x, const T* y, int f) { + T d = 0.0; + for (int i = 0; i < f; i++) + d += fabs(x[i] - y[i]); + return d; +} + +template +inline T euclidean_distance(const T* x, const T* y, int f) { + // Don't use dot-product: avoid catastrophic cancellation in #314. + T d = 0.0; + for (int i = 0; i < f; ++i) { + const T tmp=*x - *y; + d += tmp * tmp; + ++x; + ++y; + } + return d; +} + +//#ifdef USE_AVX +// Horizontal single sum of 256bit vector. +#if 0 /* use FAISS distance calculation algorithm instead */ +inline float hsum256_ps_avx(__m256 v) { + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} +#endif + +template<> +inline float dot(const float* x, const float *y, int f) { +#if 0 /* use FAISS distance calculation algorithm instead */ + float result = 0; + if (f > 7) { + __m256 d = _mm256_setzero_ps(); + for (; f > 7; f -= 8) { + d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y))); + x += 8; + y += 8; + } + // Sum all floats in dot register. + result += hsum256_ps_avx(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + result += *x * *y; + x++; + y++; + } + return result; +#else + return faiss::fvec_inner_product(x, y, (size_t)f); +#endif +} + +template<> +inline float manhattan_distance(const float* x, const float* y, int f) { +#if 0 /* use FAISS distance calculation algorithm instead */ + float result = 0; + int i = f; + if (f > 7) { + __m256 manhattan = _mm256_setzero_ps(); + __m256 minus_zero = _mm256_set1_ps(-0.0f); + for (; i > 7; i -= 8) { + const __m256 x_minus_y = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)); + const __m256 distance = _mm256_andnot_ps(minus_zero, x_minus_y); // Absolute value of x_minus_y (forces sign bit to zero) + manhattan = _mm256_add_ps(manhattan, distance); + x += 8; + y += 8; + } + // Sum all floats in manhattan register. + result = hsum256_ps_avx(manhattan); + } + // Don't forget the remaining values. + for (; i > 0; i--) { + result += fabsf(*x - *y); + x++; + y++; + } + return result; +#else + return faiss::fvec_L1(x, y, (size_t)f); +#endif +} + +template<> +inline float euclidean_distance(const float* x, const float* y, int f) { +#if 0 /* use FAISS distance calculation algorithm instead */ + float result=0; + if (f > 7) { + __m256 d = _mm256_setzero_ps(); + for (; f > 7; f -= 8) { + const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)); + d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX... + x += 8; + y += 8; + } + // Sum all floats in dot register. + result = hsum256_ps_avx(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + float tmp = *x - *y; + result += tmp * tmp; + x++; + y++; + } + return result; +#else + return faiss::fvec_L2sqr(x, y, (size_t)f); +#endif +} + +//#endif + +#if 0 /* use FAISS distance calculation algorithm instead */ +#ifdef USE_AVX512 +template<> +inline float dot(const float* x, const float *y, int f) { + float result = 0; + if (f > 15) { + __m512 d = _mm512_setzero_ps(); + for (; f > 15; f -= 16) { + //AVX512F includes FMA + d = _mm512_fmadd_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y), d); + x += 16; + y += 16; + } + // Sum all floats in dot register. + result += _mm512_reduce_add_ps(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + result += *x * *y; + x++; + y++; + } + return result; +} + +template<> +inline float manhattan_distance(const float* x, const float* y, int f) { + float result = 0; + int i = f; + if (f > 15) { + __m512 manhattan = _mm512_setzero_ps(); + for (; i > 15; i -= 16) { + const __m512 x_minus_y = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y)); + manhattan = _mm512_add_ps(manhattan, _mm512_abs_ps(x_minus_y)); + x += 16; + y += 16; + } + // Sum all floats in manhattan register. + result = _mm512_reduce_add_ps(manhattan); + } + // Don't forget the remaining values. + for (; i > 0; i--) { + result += fabsf(*x - *y); + x++; + y++; + } + return result; +} + +template<> +inline float euclidean_distance(const float* x, const float* y, int f) { + float result=0; + if (f > 15) { + __m512 d = _mm512_setzero_ps(); + for (; f > 15; f -= 16) { + const __m512 diff = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y)); + d = _mm512_fmadd_ps(diff, diff, d); + x += 16; + y += 16; + } + // Sum all floats in dot register. + result = _mm512_reduce_add_ps(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + float tmp = *x - *y; + result += tmp * tmp; + x++; + y++; + } + return result; +} + +#endif +#endif + + +template +inline T get_norm(T* v, int f) { + return sqrt(dot(v, v, f)); +} + +template +inline void two_means(const vector& nodes, int f, Random& random, bool cosine, Node* p, Node* q) { + /* + This algorithm is a huge heuristic. Empirically it works really well, but I + can't motivate it well. The basic idea is to keep two centroids and assign + points to either one of them. We weight each centroid by the number of points + assigned to it, so to balance it. + */ + static int iteration_steps = 200; + size_t count = nodes.size(); + + size_t i = random.index(count); + size_t j = random.index(count-1); + j += (j >= i); // ensure that i != j + + Distance::template copy_node(p, nodes[i], f); + Distance::template copy_node(q, nodes[j], f); + + if (cosine) { Distance::template normalize(p, f); Distance::template normalize(q, f); } + Distance::init_node(p, f); + Distance::init_node(q, f); + + int ic = 1, jc = 1; + for (int l = 0; l < iteration_steps; l++) { + size_t k = random.index(count); + T di = ic * Distance::distance(p, nodes[k], f), + dj = jc * Distance::distance(q, nodes[k], f); + T norm = cosine ? get_norm(nodes[k]->v, f) : 1; + if (!(norm > T(0))) { + continue; + } + if (di < dj) { + for (int z = 0; z < f; z++) + p->v[z] = (p->v[z] * ic + nodes[k]->v[z] / norm) / (ic + 1); + Distance::init_node(p, f); + ic++; + } else if (dj < di) { + for (int z = 0; z < f; z++) + q->v[z] = (q->v[z] * jc + nodes[k]->v[z] / norm) / (jc + 1); + Distance::init_node(q, f); + jc++; + } + } +} +} // namespace + +struct Base { + template + static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) { + // Override this in specific metric structs below if you need to do any pre-processing + // on the entire set of nodes passed into this index. + } + + template + static inline void zero_value(Node* dest) { + // Initialize any fields that require sane defaults within this node. + } + + template + static inline void copy_node(Node* dest, const Node* source, const int f) { + memcpy(dest->v, source->v, f * sizeof(T)); + } + + template + static inline void normalize(Node* node, int f) { + T norm = get_norm(node->v, f); + if (norm > 0) { + for (int z = 0; z < f; z++) + node->v[z] /= norm; + } + } +}; + +struct Angular : Base { + template + struct Node { + /* + * We store a binary tree where each node has two things + * - A vector associated with it + * - Two children + * All nodes occupy the same amount of memory + * All nodes with n_descendants == 1 are leaf nodes. + * A memory optimization is that for nodes with 2 <= n_descendants <= K, + * we skip the vector. Instead we store a list of all descendants. K is + * determined by the number of items that fits in the space of the vector. + * For nodes with n_descendants == 1 the vector is a data point. + * For nodes with n_descendants > K the vector is the normal of the split plane. + * Note that we can't really do sizeof(node) because we cheat and allocate + * more memory to be able to fit the vector outside + */ + S n_descendants; + union { + S children[2]; // Will possibly store more than 2 + T norm; + }; + T v[V_ARRAY_SIZE]; + }; + template + static inline T distance(const Node* x, const Node* y, int f) { + // want to calculate (a/|a| - b/|b|)^2 + // = a^2 / a^2 + b^2 / b^2 - 2ab/|a||b| + // = 2 - 2cos + T pp = x->norm ? x->norm : dot(x->v, x->v, f); // For backwards compatibility reasons, we need to fall back and compute the norm here + T qq = y->norm ? y->norm : dot(y->v, y->v, f); + T pq = dot(x->v, y->v, f); + T ppqq = pp * qq; + if (ppqq > 0) return 2.0 - 2.0 * pq / sqrt(ppqq); + else return 2.0; // cos is 0 + } + template + static inline T margin(const Node* n, const T* y, int f) { + return dot(n->v, y, f); + } + template + static inline bool side(const Node* n, const T* y, int f, Random& random) { + T dot = margin(n, y, f); + if (dot != 0) + return (dot > 0); + else + return (bool)random.flip(); + } + template + static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { + Node* p = (Node*)alloca(s); + Node* q = (Node*)alloca(s); + two_means >(nodes, f, random, true, p, q); + for (int z = 0; z < f; z++) + n->v[z] = p->v[z] - q->v[z]; + Base::normalize >(n, f); + } + template + static inline T normalized_distance(T distance) { + // Used when requesting distances from Python layer + // Turns out sometimes the squared distance is -0.0 + // so we have to make sure it's a positive number. + return sqrt(std::max(distance, T(0))); + } + template + static inline T pq_distance(T distance, T margin, int child_nr) { + if (child_nr == 0) + margin = -margin; + return std::min(distance, margin); + } + template + static inline T pq_initial_value() { + return numeric_limits::infinity(); + } + template + static inline void init_node(Node* n, int f) { + n->norm = dot(n->v, n->v, f); + } + static const char* name() { + return "angular"; + } +}; + + +struct DotProduct : Angular { + template + struct Node { + /* + * This is an extension of the Angular node with an extra attribute for the scaled norm. + */ + S n_descendants; + S children[2]; // Will possibly store more than 2 + T dot_factor; + T v[V_ARRAY_SIZE]; + }; + + static const char* name() { + return "dot"; + } + template + static inline T distance(const Node* x, const Node* y, int f) { + return -dot(x->v, y->v, f); + } + + template + static inline void zero_value(Node* dest) { + dest->dot_factor = 0; + } + + template + static inline void init_node(Node* n, int f) { + } + + template + static inline void copy_node(Node* dest, const Node* source, const int f) { + memcpy(dest->v, source->v, f * sizeof(T)); + dest->dot_factor = source->dot_factor; + } + + template + static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { + Node* p = (Node*)alloca(s); + Node* q = (Node*)alloca(s); + DotProduct::zero_value(p); + DotProduct::zero_value(q); + two_means >(nodes, f, random, true, p, q); + for (int z = 0; z < f; z++) + n->v[z] = p->v[z] - q->v[z]; + n->dot_factor = p->dot_factor - q->dot_factor; + DotProduct::normalize >(n, f); + } + + template + static inline void normalize(Node* node, int f) { + T norm = sqrt(dot(node->v, node->v, f) + pow(node->dot_factor, 2)); + if (norm > 0) { + for (int z = 0; z < f; z++) + node->v[z] /= norm; + node->dot_factor /= norm; + } + } + + template + static inline T margin(const Node* n, const T* y, int f) { + return dot(n->v, y, f) + (n->dot_factor * n->dot_factor); + } + + template + static inline bool side(const Node* n, const T* y, int f, Random& random) { + T dot = margin(n, y, f); + if (dot != 0) + return (dot > 0); + else + return (bool)random.flip(); + } + + template + static inline T normalized_distance(T distance) { + return -distance; + } + + template + static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) { + // This uses a method from Microsoft Research for transforming inner product spaces to cosine/angular-compatible spaces. + // (Bachrach et al., 2014, see https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf) + + // Step one: compute the norm of each vector and store that in its extra dimension (f-1) + for (S i = 0; i < node_count; i++) { + Node* node = get_node_ptr(nodes, _s, i); + T norm = sqrt(dot(node->v, node->v, f)); + if (isnan(norm)) norm = 0; + node->dot_factor = norm; + } + + // Step two: find the maximum norm + T max_norm = 0; + for (S i = 0; i < node_count; i++) { + Node* node = get_node_ptr(nodes, _s, i); + if (node->dot_factor > max_norm) { + max_norm = node->dot_factor; + } + } + + // Step three: set each vector's extra dimension to sqrt(max_norm^2 - norm^2) + for (S i = 0; i < node_count; i++) { + Node* node = get_node_ptr(nodes, _s, i); + T node_norm = node->dot_factor; + + T dot_factor = sqrt(pow(max_norm, static_cast(2.0)) - pow(node_norm, static_cast(2.0))); + if (isnan(dot_factor)) dot_factor = 0; + + node->dot_factor = dot_factor; + } + } +}; + +struct Hamming : Base { + template + struct Node { + S n_descendants; + S children[2]; + T v[V_ARRAY_SIZE]; + }; + + static const size_t max_iterations = 20; + + template + static inline T pq_distance(T distance, T margin, int child_nr) { + return distance - (margin != (unsigned int) child_nr); + } + + template + static inline T pq_initial_value() { + return numeric_limits::max(); + } + template + static inline int cole_popcount(T v) { + // Note: Only used with MSVC 9, which lacks intrinsics and fails to + // calculate std::bitset::count for v > 32bit. Uses the generalized + // approach by Eric Cole. + // See https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64 + v = v - ((v >> 1) & (T)~(T)0/3); + v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3); + v = (v + (v >> 4)) & (T)~(T)0/255*15; + return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * 8; + } + template + static inline T distance(const Node* x, const Node* y, int f) { + size_t dist = 0; + for (int i = 0; i < f; i++) { + dist += popcount(x->v[i] ^ y->v[i]); + } + return dist; + } + template + static inline bool margin(const Node* n, const T* y, int f) { + static const size_t n_bits = sizeof(T) * 8; + T chunk = n->v[0] / n_bits; + return (y[chunk] & (static_cast(1) << (n_bits - 1 - (n->v[0] % n_bits)))) != 0; + } + template + static inline bool side(const Node* n, const T* y, int f, Random& random) { + return margin(n, y, f); + } + template + static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { + size_t cur_size = 0; + size_t i = 0; + int dim = f * 8 * sizeof(T); + for (; i < max_iterations; i++) { + // choose random position to split at + n->v[0] = random.index(dim); + cur_size = 0; + for (typename vector*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) { + if (margin(n, (*it)->v, f)) { + cur_size++; + } + } + if (cur_size > 0 && cur_size < nodes.size()) { + break; + } + } + // brute-force search for splitting coordinate + if (i == max_iterations) { + int j = 0; + for (; j < dim; j++) { + n->v[0] = j; + cur_size = 0; + for (typename vector*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) { + if (margin(n, (*it)->v, f)) { + cur_size++; + } + } + if (cur_size > 0 && cur_size < nodes.size()) { + break; + } + } + } + } + template + static inline T normalized_distance(T distance) { + return distance; + } + template + static inline void init_node(Node* n, int f) { + } + static const char* name() { + return "hamming"; + } +}; + + +struct Minkowski : Base { + template + struct Node { + S n_descendants; + T a; // need an extra constant term to determine the offset of the plane + S children[2]; + T v[V_ARRAY_SIZE]; + }; + template + static inline T margin(const Node* n, const T* y, int f) { + return n->a + dot(n->v, y, f); + } + template + static inline bool side(const Node* n, const T* y, int f, Random& random) { + T dot = margin(n, y, f); + if (dot != 0) + return (dot > 0); + else + return (bool)random.flip(); + } + template + static inline T pq_distance(T distance, T margin, int child_nr) { + if (child_nr == 0) + margin = -margin; + return std::min(distance, margin); + } + template + static inline T pq_initial_value() { + return numeric_limits::infinity(); + } +}; + + +struct Euclidean : Minkowski { + template + static inline T distance(const Node* x, const Node* y, int f) { + return euclidean_distance(x->v, y->v, f); + } + template + static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { + Node* p = (Node*)alloca(s); + Node* q = (Node*)alloca(s); + two_means >(nodes, f, random, false, p, q); + + for (int z = 0; z < f; z++) + n->v[z] = p->v[z] - q->v[z]; + Base::normalize >(n, f); + n->a = 0.0; + for (int z = 0; z < f; z++) + n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2; + } + template + static inline T normalized_distance(T distance) { + return distance; + } + template + static inline void init_node(Node* n, int f) { + } + static const char* name() { + return "euclidean"; + } + +}; + +struct Manhattan : Minkowski { + template + static inline T distance(const Node* x, const Node* y, int f) { + return manhattan_distance(x->v, y->v, f); + } + template + static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { + Node* p = (Node*)alloca(s); + Node* q = (Node*)alloca(s); + two_means >(nodes, f, random, false, p, q); + + for (int z = 0; z < f; z++) + n->v[z] = p->v[z] - q->v[z]; + Base::normalize >(n, f); + n->a = 0.0; + for (int z = 0; z < f; z++) + n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2; + } + template + static inline T normalized_distance(T distance) { + return std::max(distance, T(0)); + } + template + static inline void init_node(Node* n, int f) { + } + static const char* name() { + return "manhattan"; + } +}; + +template +class AnnoyIndexInterface { + public: + // Note that the methods with an **error argument will allocate memory and write the pointer to that string if error is non-nullptr + virtual ~AnnoyIndexInterface() {}; + virtual bool add_item(S item, const T* w, char** error=nullptr) = 0; + virtual bool build(int q, char** error=nullptr) = 0; + virtual bool unbuild(char** error=nullptr) = 0; + virtual bool save(const char* filename, bool prefault=false, char** error=nullptr) = 0; + virtual void unload() = 0; + virtual bool load(const char* filename, bool prefault=false, char** error=nullptr) = 0; + virtual bool load_index(void* index_data, const int64_t& index_size, char** error = nullptr) = 0; + virtual T get_distance(S i, S j) const = 0; + virtual void get_nns_by_item(S item, size_t n, int64_t search_k, vector* result, vector* distances, + const faiss::ConcurrentBitsetPtr& bitset = nullptr) const = 0; + virtual void get_nns_by_vector(const T* w, size_t n, int64_t search_k, vector* result, vector* distances, + const faiss::ConcurrentBitsetPtr& bitset = nullptr) const = 0; + virtual S get_n_items() const = 0; + virtual S get_dim() const = 0; + virtual S get_n_trees() const = 0; + virtual int64_t get_index_length() const = 0; + virtual void* get_index() const = 0; + virtual void verbose(bool v) = 0; + virtual void get_item(S item, T* v) const = 0; + virtual void set_seed(int q) = 0; + virtual bool on_disk_build(const char* filename, char** error=nullptr) = 0; + virtual int64_t cal_size() = 0; +}; + +template + class AnnoyIndex : public AnnoyIndexInterface { + /* + * We use random projection to build a forest of binary trees of all items. + * Basically just split the hyperspace into two sides by a hyperplane, + * then recursively split each of those subtrees etc. + * We create a tree like this q times. The default q is determined automatically + * in such a way that we at most use 2x as much memory as the vectors take. + */ +public: + typedef Distance D; + typedef typename D::template Node Node; + +protected: + const int _f; + size_t _s; + S _n_items; + Random _random; + void* _nodes; // Could either be mmapped, or point to a memory buffer that we reallocate + S _n_nodes; + S _nodes_size; + vector _roots; + S _K; + bool _loaded; + bool _verbose; + int _fd; + bool _on_disk; + bool _built; +public: + + AnnoyIndex(int f) : _f(f), _random() { + _s = offsetof(Node, v) + _f * sizeof(T); // Size of each node + _verbose = false; + _built = false; + _K = (S) (((size_t) (_s - offsetof(Node, children))) / sizeof(S)); // Max number of descendants to fit into node + reinitialize(); // Reset everything + } + ~AnnoyIndex() { + unload(); + } + + int get_f() const { + return _f; + } + + bool add_item(S item, const T* w, char** error=nullptr) { + return add_item_impl(item, w, error); + } + + template + bool add_item_impl(S item, const W& w, char** error=nullptr) { + if (_loaded) { + set_error_from_string(error, "You can't add an item to a loaded index"); + return false; + } + _allocate_size(item + 1); + Node* n = _get(item); + + D::zero_value(n); + + n->children[0] = 0; + n->children[1] = 0; + n->n_descendants = 1; + + for (int z = 0; z < _f; z++) + n->v[z] = w[z]; + + D::init_node(n, _f); + + if (item >= _n_items) + _n_items = item + 1; + + return true; + } + + bool on_disk_build(const char* file, char** error=nullptr) { + _on_disk = true; + _fd = open(file, O_RDWR | O_CREAT | O_TRUNC, (int) 0600); + if (_fd == -1) { + set_error_from_errno(error, "Unable to open"); + _fd = 0; + return false; + } + _nodes_size = 1; + if (ftruncate(_fd, _s * _nodes_size) == -1) { + set_error_from_errno(error, "Unable to truncate"); + return false; + } +#ifdef MAP_POPULATE + _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0); +#else + _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0); +#endif + return true; + } + + bool build(int q, char** error=nullptr) { + if (_loaded) { + set_error_from_string(error, "You can't build a loaded index"); + return false; + } + + if (_built) { + set_error_from_string(error, "You can't build a built index"); + return false; + } + + D::template preprocess(_nodes, _s, _n_items, _f); + + _n_nodes = _n_items; + while (1) { + if (q == -1 && _n_nodes >= _n_items * 2) + break; + if (q != -1 && _roots.size() >= (size_t)q) + break; + if (_verbose) showUpdate("pass %zd...\n", _roots.size()); + + vector indices; + for (S i = 0; i < _n_items; i++) { + if (_get(i)->n_descendants >= 1) // Issue #223 + indices.push_back(i); + } + + _roots.push_back(_make_tree(indices, true)); + } + + // Also, copy the roots into the last segment of the array + // This way we can load them faster without reading the whole file + _allocate_size(_n_nodes + (S)_roots.size()); + for (size_t i = 0; i < _roots.size(); i++) + memcpy(_get(_n_nodes + (S)i), _get(_roots[i]), _s); + _n_nodes += _roots.size(); + + if (_verbose) showUpdate("has %ld nodes\n", _n_nodes); + + if (_on_disk) { + _nodes = remap_memory(_nodes, _fd, _s * _nodes_size, _s * _n_nodes); + if (ftruncate(_fd, _s * _n_nodes)) { + // TODO: this probably creates an index in a corrupt state... not sure what to do + set_error_from_errno(error, "Unable to truncate"); + return false; + } + _nodes_size = _n_nodes; + } + _built = true; + return true; + } + + bool unbuild(char** error=nullptr) { + if (_loaded) { + set_error_from_string(error, "You can't unbuild a loaded index"); + return false; + } + + _roots.clear(); + _n_nodes = _n_items; + _built = false; + + return true; + } + + bool save(const char* filename, bool prefault=false, char** error=nullptr) { + if (!_built) { + set_error_from_string(error, "You can't save an index that hasn't been built"); + return false; + } + if (_on_disk) { + return true; + } else { + // Delete file if it already exists (See issue #335) + unlink(filename); + + FILE *f = fopen(filename, "wb"); + if (f == nullptr) { + set_error_from_errno(error, "Unable to open"); + return false; + } + + if (fwrite(_nodes, _s, _n_nodes, f) != (size_t) _n_nodes) { + set_error_from_errno(error, "Unable to write"); + return false; + } + + if (fclose(f) == EOF) { + set_error_from_errno(error, "Unable to close"); + return false; + } + + unload(); + return load(filename, prefault, error); + } + } + + void reinitialize() { + _fd = 0; + _nodes = nullptr; + _loaded = false; + _n_items = 0; + _n_nodes = 0; + _nodes_size = 0; + _on_disk = false; + _roots.clear(); + } + + void unload() { + if (_on_disk && _fd) { + close(_fd); + munmap(_nodes, _s * _nodes_size); + } else { + if (_fd) { + // we have mmapped data + close(_fd); + munmap(_nodes, _n_nodes * _s); + } else if (_nodes) { + // We have heap allocated data + free(_nodes); + } + } + reinitialize(); + if (_verbose) showUpdate("unloaded\n"); + } + + bool load(const char* filename, bool prefault=false, char** error=nullptr) { + _fd = open(filename, O_RDONLY, (int)0400); + if (_fd == -1) { + set_error_from_errno(error, "Unable to open"); + _fd = 0; + return false; + } + off_t size = lseek_getsize(_fd); + if (size == -1) { + set_error_from_errno(error, "Unable to get size"); + return false; + } else if (size == 0) { + set_error_from_errno(error, "Size of file is zero"); + return false; + } else if (size % _s) { + // Something is fishy with this index! + set_error_from_errno(error, "Index size is not a multiple of vector size"); + return false; + } + + int flags = MAP_SHARED; + if (prefault) { +#ifdef MAP_POPULATE + flags |= MAP_POPULATE; +#else + showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform"); +#endif + } + _nodes = (Node*)mmap(0, size, PROT_READ, flags, _fd, 0); + _n_nodes = (S)(size / _s); + + // Find the roots by scanning the end of the file and taking the nodes with most descendants + _roots.clear(); + S m = -1; + for (S i = _n_nodes - 1; i >= 0; i--) { + S k = _get(i)->n_descendants; + if (m == -1 || k == m) { + _roots.push_back(i); + m = k; + } else { + break; + } + } + // hacky fix: since the last root precedes the copy of all roots, delete it + if (_roots.size() > 1 && _get(_roots.front())->children[0] == _get(_roots.back())->children[0]) + _roots.pop_back(); + _loaded = true; + _built = true; + _n_items = m; + if (_verbose) showUpdate("found %lu roots with degree %ld\n", _roots.size(), m); + return true; + } + + bool load_index(void* index_data, const int64_t& index_size, char** error) { + if (index_size == -1) { + set_error_from_errno(error, "Unable to get size"); + return false; + } else if (index_size == 0) { + set_error_from_errno(error, "Size of file is zero"); + return false; + } else if (index_size % _s) { + // Something is fishy with this index! + set_error_from_errno(error, "Index size is not a multiple of vector size"); + return false; + } + + _n_nodes = (S)(index_size / _s); +// _nodes = (Node*)malloc(_s * _n_nodes); + _nodes = (Node*)malloc((size_t)index_size); + if (_nodes == nullptr) { + set_error_from_errno(error, "alloc failed when load_index 4 annoy"); + return false; + } + memcpy(_nodes, index_data, (size_t)index_size); + + // Find the roots by scanning the end of the file and taking the nodes with most descendants + _roots.clear(); + S m = -1; + for (S i = _n_nodes - 1; i >= 0; i--) { + S k = _get(i)->n_descendants; + if (m == -1 || k == m) { + _roots.push_back(i); + m = k; + } else { + break; + } + } + // hacky fix: since the last root precedes the copy of all roots, delete it + if (_roots.size() > 1 && _get(_roots.front())->children[0] == _get(_roots.back())->children[0]) + _roots.pop_back(); + _loaded = true; + _built = true; + _n_items = m; + if (_verbose) showUpdate("found %lu roots with degree %ld\n", _roots.size(), m); + return true; + } + + T get_distance(S i, S j) const { + return D::normalized_distance(D::distance(_get(i), _get(j), _f)); + } + + void get_nns_by_item(S item, size_t n, int64_t search_k, vector* result, vector* distances, + const faiss::ConcurrentBitsetPtr& bitset) const { + // TODO: handle OOB + const Node* m = _get(item); + _get_all_nns(m->v, n, search_k, result, distances, bitset); + } + + void get_nns_by_vector(const T* w, size_t n, int64_t search_k, vector* result, vector* distances, + const faiss::ConcurrentBitsetPtr& bitset) const { + _get_all_nns(w, n, search_k, result, distances, bitset); + } + + S get_n_items() const { + return _n_items; + } + + S get_dim() const { + return _f; + } + + S get_n_trees() const { + return (S)_roots.size(); + } + + int64_t get_index_length() const { + return (int64_t)_s * _n_nodes; + } + + void* get_index() const { + return _nodes; + } + + void verbose(bool v) { + _verbose = v; + } + + void get_item(S item, T* v) const { + // TODO: handle OOB + Node* m = _get(item); + memcpy(v, m->v, (_f) * sizeof(T)); + } + + void set_seed(int seed) { + _random.set_seed(seed); + } + +protected: + void _allocate_size(S n) { + if (n > _nodes_size) { + const double reallocation_factor = 1.3; + S new_nodes_size = std::max(n, (S) ((_nodes_size + 1) * reallocation_factor)); + void *old = _nodes; + + if (_on_disk) { + int rc = ftruncate(_fd, _s * new_nodes_size); + if (_verbose && rc) showUpdate("File truncation error\n"); + _nodes = remap_memory(_nodes, _fd, _s * _nodes_size, _s * new_nodes_size); + } else { + _nodes = realloc(_nodes, _s * new_nodes_size); + memset((char *) _nodes + (_nodes_size * _s) / sizeof(char), 0, (new_nodes_size - _nodes_size) * _s); + } + + _nodes_size = new_nodes_size; + if (_verbose) showUpdate("Reallocating to %ld nodes: old_address=%p, new_address=%p\n", new_nodes_size, old, _nodes); + } + } + + inline Node* _get(const S i) const { + return get_node_ptr(_nodes, _s, i); + } + + S _make_tree(const vector& indices, bool is_root) { + // The basic rule is that if we have <= _K items, then it's a leaf node, otherwise it's a split node. + // There's some regrettable complications caused by the problem that root nodes have to be "special": + // 1. We identify root nodes by the arguable logic that _n_items == n->n_descendants, regardless of how many descendants they actually have + // 2. Root nodes with only 1 child need to be a "dummy" parent + // 3. Due to the _n_items "hack", we need to be careful with the cases where _n_items <= _K or _n_items > _K + if (indices.size() == 1 && !is_root) + return indices[0]; + + if (indices.size() <= (size_t)_K && (!is_root || (size_t)_n_items <= (size_t)_K || indices.size() == 1)) { + _allocate_size(_n_nodes + 1); + S item = _n_nodes++; + Node* m = _get(item); + m->n_descendants = is_root ? _n_items : (S)indices.size(); + + // Using std::copy instead of a loop seems to resolve issues #3 and #13, + // probably because gcc 4.8 goes overboard with optimizations. + // Using memcpy instead of std::copy for MSVC compatibility. #235 + // Only copy when necessary to avoid crash in MSVC 9. #293 + if (!indices.empty()) + memcpy(m->children, &indices[0], indices.size() * sizeof(S)); + return item; + } + + vector children; + for (size_t i = 0; i < indices.size(); i++) { + S j = indices[i]; + Node* n = _get(j); + if (n) + children.push_back(n); + } + + vector children_indices[2]; + Node* m = (Node*)alloca(_s); + D::create_split(children, _f, _s, _random, m); + faiss::BuilderSuspend::check_wait(); + + for (size_t i = 0; i < indices.size(); i++) { + S j = indices[i]; + Node* n = _get(j); + if (n) { + bool side = D::side(m, n->v, _f, _random); + children_indices[side].push_back(j); + } else { + showUpdate("No node for index %ld?\n", j); + } + } + + // If we didn't find a hyperplane, just randomize sides as a last option + while (children_indices[0].size() == 0 || children_indices[1].size() == 0) { + if (_verbose) + showUpdate("\tNo hyperplane found (left has %ld children, right has %ld children)\n", + children_indices[0].size(), children_indices[1].size()); + if (_verbose && indices.size() > 100000) + showUpdate("Failed splitting %lu items\n", indices.size()); + + children_indices[0].clear(); + children_indices[1].clear(); + + // Set the vector to 0.0 + for (int z = 0; z < _f; z++) + m->v[z] = 0; + + for (size_t i = 0; i < indices.size(); i++) { + S j = indices[i]; + // Just randomize... + children_indices[_random.flip()].push_back(j); + } + } + + int flip = (children_indices[0].size() > children_indices[1].size()); + + m->n_descendants = is_root ? _n_items : (S)indices.size(); + for (int side = 0; side < 2; side++) { + // run _make_tree for the smallest child first (for cache locality) + faiss::BuilderSuspend::check_wait(); + m->children[side^flip] = _make_tree(children_indices[side^flip], false); + } + + _allocate_size(_n_nodes + 1); + S item = _n_nodes++; + memcpy(_get(item), m, _s); + + return item; + } + + void _get_all_nns(const T* v, size_t n, int64_t search_k, vector* result, vector* distances, + const faiss::ConcurrentBitsetPtr& bitset) const { + Node* v_node = (Node *)alloca(_s); + D::template zero_value(v_node); + memcpy(v_node->v, v, sizeof(T) * _f); + D::init_node(v_node, _f); + + std::priority_queue > q; + + if (search_k <= 0) { + search_k = std::max(int64_t(n * _roots.size()), int64_t(_n_items * 5 / 100)); + } + + for (size_t i = 0; i < _roots.size(); i++) { + q.push(make_pair(Distance::template pq_initial_value(), _roots[i])); + } + + std::vector nns; + while (nns.size() < (size_t)search_k && !q.empty()) { + const pair& top = q.top(); + T d = top.first; + S i = top.second; + Node* nd = _get(i); + q.pop(); + if (nd->n_descendants == 1 && i < _n_items) { // raw data + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)i)) + nns.push_back(i); + } else if (nd->n_descendants <= _K) { + const S* dst = nd->children; + for (auto ii = 0; ii < nd->n_descendants; ++ ii) { + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)dst[ii])) + nns.push_back(dst[ii]); +// nns.insert(nns.end(), dst, &dst[nd->n_descendants]); + } + } else { + T margin = D::margin(nd, v, _f); + q.push(make_pair(D::pq_distance(d, margin, 1), static_cast(nd->children[1]))); + q.push(make_pair(D::pq_distance(d, margin, 0), static_cast(nd->children[0]))); + } + } + + // Get distances for all items + // To avoid calculating distance multiple times for any items, sort by id + std::sort(nns.begin(), nns.end()); + vector > nns_dist; + S last = -1; + for (size_t i = 0; i < nns.size(); i++) { + S j = nns[i]; + if (j == last) + continue; + last = j; + if (_get(j)->n_descendants == 1) // This is only to guard a really obscure case, #284 + nns_dist.push_back(make_pair(D::distance(v_node, _get(j), _f), j)); + } + + size_t m = nns_dist.size(); + size_t p = n < m ? n : m; // Return this many items + std::partial_sort(nns_dist.begin(), nns_dist.begin() + p, nns_dist.end()); + for (size_t i = 0; i < p; i++) { + if (distances) + distances->push_back(D::normalized_distance(nns_dist[i].first)); + result->push_back(nns_dist[i].second); + } + } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += _roots.size() * sizeof(S); + ret += std::max(_n_nodes, _nodes_size) * _s; + return ret; + } +}; + +#endif +// vim: tabstop=2 shiftwidth=2 diff --git a/core/src/index/thirdparty/annoy/src/annoyluamodule.cc b/core/src/index/thirdparty/annoy/src/annoyluamodule.cc new file mode 100644 index 0000000000..4f483d2d3d --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/annoyluamodule.cc @@ -0,0 +1,318 @@ +// Copyright (c) 2016 Boris Nagaev +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include +#include + +#include + +#include "annoylib.h" +#include "kissrandom.h" + +#if LUA_VERSION_NUM == 501 +#define compat_setfuncs(L, funcs) luaL_register(L, nullptr, funcs) +#define compat_rawlen lua_objlen +#else +#define compat_setfuncs(L, funcs) luaL_setfuncs(L, funcs, 0) +#define compat_rawlen lua_rawlen +#endif + +template +class LuaAnnoy { +public: + typedef int32_t AnnoyS; + typedef float AnnoyT; + typedef AnnoyIndex Impl; + typedef LuaAnnoy ThisClass; + + class LuaArrayProxy { + public: + LuaArrayProxy(lua_State* L, int object, int f) + : L_(L) + , object_(object) + { + luaL_checktype(L, object, LUA_TTABLE); + int v_len = compat_rawlen(L, object); + luaL_argcheck(L, v_len == f, object, "Length of v != f"); + } + + double operator[](int index) const { + lua_rawgeti(L_, object_, index + 1); + double result = lua_tonumber(L_, -1); + lua_pop(L_, 1); + return result; + } + + private: + lua_State* L_; + int object_; + }; + + static void toVector(lua_State* L, int object, int f, AnnoyT* dst) { + LuaArrayProxy proxy(L, object, f); + for (int i = 0; i < f; i++) { + dst[i] = proxy[i]; + } + } + + template + static void pushVector(lua_State* L, const Vector& v) { + lua_createtable(L, v.size(), 0); + for (int j = 0; j < v.size(); j++) { + lua_pushnumber(L, v[j]); + lua_rawseti(L, -2, j + 1); + } + } + + static const char* typeAsString() { + return typeid(Impl).name(); + } + + static Impl* getAnnoy(lua_State* L, int object) { + return reinterpret_cast( + luaL_checkudata(L, object, typeAsString()) + ); + } + + static int getItemIndex(lua_State* L, int object, int size = -1) { + int item = luaL_checkinteger(L, object); + luaL_argcheck(L, item >= 0, object, "Index must be >= 0"); + if (size != -1) { + luaL_argcheck(L, item < size, object, "Index must be < size"); + } + return item; + } + + static int gc(lua_State* L) { + Impl* self = getAnnoy(L, 1); + self->~Impl(); + return 0; + } + + static int tostring(lua_State* L) { + Impl* self = getAnnoy(L, 1); + lua_pushfstring( + L, + "annoy.AnnoyIndex object (%dx%d, %s distance)", + self->get_n_items(), self->get_f(), Distance::name() + ); + return 1; + } + + static int add_item(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int item = getItemIndex(L, 2); + self->add_item_impl(item, LuaArrayProxy(L, 3, self->get_f())); + return 0; + } + + static int build(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int n_trees = luaL_checkinteger(L, 2); + self->build(n_trees); + lua_pushboolean(L, true); + return 1; + } + + static int on_disk_build(lua_State* L) { + Impl* self = getAnnoy(L, 1); + const char* filename = luaL_checkstring(L, 2); + self->on_disk_build(filename); + lua_pushboolean(L, true); + return 1; + } + + static int save(lua_State* L) { + int nargs = lua_gettop(L); + Impl* self = getAnnoy(L, 1); + const char* filename = luaL_checkstring(L, 2); + bool prefault = true; + if (nargs >= 3) { + prefault = lua_toboolean(L, 3); + } + self->save(filename, prefault); + lua_pushboolean(L, true); + return 1; + } + + static int load(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int nargs = lua_gettop(L); + const char* filename = luaL_checkstring(L, 2); + bool prefault = true; + if (nargs >= 3) { + prefault = lua_toboolean(L, 3); + } + if (!self->load(filename, prefault)) { + return luaL_error(L, "Can't load file: %s", filename); + } + lua_pushboolean(L, true); + return 1; + } + + static int unload(lua_State* L) { + Impl* self = getAnnoy(L, 1); + self->unload(); + lua_pushboolean(L, true); + return 1; + } + + struct Searcher { + std::vector result; + std::vector distances; + Impl* self; + int n; + int search_k; + bool include_distances; + + Searcher(lua_State* L) { + int nargs = lua_gettop(L); + self = getAnnoy(L, 1); + n = luaL_checkinteger(L, 3); + search_k = -1; + if (nargs >= 4) { + search_k = luaL_checkinteger(L, 4); + } + include_distances = false; + if (nargs >= 5) { + include_distances = lua_toboolean(L, 5); + } + } + + int pushResults(lua_State* L) { + pushVector(L, result); + if (include_distances) { + pushVector(L, distances); + } + return include_distances ? 2 : 1; + } + }; + + static int get_nns_by_item(lua_State* L) { + Searcher s(L); + int item = getItemIndex(L, 2, s.self->get_n_items()); + s.self->get_nns_by_item(item, s.n, s.search_k, &s.result, + s.include_distances ? &s.distances : nullptr); + return s.pushResults(L); + } + + static int get_nns_by_vector(lua_State* L) { + Searcher s(L); + std::vector _vec(s.self->get_f()); + AnnoyT* vec = &(_vec[0]); + toVector(L, 2, s.self->get_f(), vec); + s.self->get_nns_by_vector(vec, s.n, s.search_k, &s.result, + s.include_distances ? &s.distances : nullptr); + return s.pushResults(L); + } + + static int get_item_vector(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int item = getItemIndex(L, 2, self->get_n_items()); + std::vector _vec(self->get_f()); + AnnoyT* vec = &(_vec[0]); + self->get_item(item, vec); + pushVector(L, _vec); + return 1; + } + + static int get_distance(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int i = getItemIndex(L, 2, self->get_n_items()); + int j = getItemIndex(L, 3, self->get_n_items()); + AnnoyT distance = self->get_distance(i, j); + lua_pushnumber(L, distance); + return 1; + } + + static int get_n_items(lua_State* L) { + Impl* self = getAnnoy(L, 1); + lua_pushnumber(L, self->get_n_items()); + return 1; + } + + static const luaL_Reg* getMetatable() { + static const luaL_Reg funcs[] = { + {"__gc", &ThisClass::gc}, + {"__tostring", &ThisClass::tostring}, + {nullptr, nullptr}, + }; + return funcs; + } + + static const luaL_Reg* getMethods() { + static const luaL_Reg funcs[] = { + {"add_item", &ThisClass::add_item}, + {"build", &ThisClass::build}, + {"save", &ThisClass::save}, + {"load", &ThisClass::load}, + {"unload", &ThisClass::unload}, + {"get_nns_by_item", &ThisClass::get_nns_by_item}, + {"get_nns_by_vector", &ThisClass::get_nns_by_vector}, + {"get_item_vector", &ThisClass::get_item_vector}, + {"get_distance", &ThisClass::get_distance}, + {"get_n_items", &ThisClass::get_n_items}, + {"on_disk_build", &ThisClass::on_disk_build}, + {nullptr, nullptr}, + }; + return funcs; + } + + static void createNew(lua_State* L, int f) { + void* self = lua_newuserdata(L, sizeof(Impl)); + if (luaL_newmetatable(L, typeAsString())) { + compat_setfuncs(L, getMetatable()); + lua_newtable(L); + compat_setfuncs(L, getMethods()); + lua_setfield(L, -2, "__index"); + } + new (self) Impl(f); + lua_setmetatable(L, -2); + } +}; + +static int lua_an_make(lua_State* L) { + int f = luaL_checkinteger(L, 1); + const char* metric = "angular"; + if (lua_gettop(L) >= 2) { + metric = luaL_checkstring(L, 2); + } + if (strcmp(metric, "angular") == 0) { + LuaAnnoy::createNew(L, f); + return 1; + } else if (strcmp(metric, "euclidean") == 0) { + LuaAnnoy::createNew(L, f); + return 1; + } else if (strcmp(metric, "manhattan") == 0) { + LuaAnnoy::createNew(L, f); + return 1; + } else { + return luaL_error(L, "Unknown metric: %s", metric); + } +} + +static const luaL_Reg LUA_ANNOY_FUNCS[] = { + {"AnnoyIndex", lua_an_make}, + {nullptr, nullptr}, +}; + +extern "C" { +int luaopen_annoy(lua_State* L) { + lua_newtable(L); + compat_setfuncs(L, LUA_ANNOY_FUNCS); + return 1; +} +} + +// vim: tabstop=2 shiftwidth=2 diff --git a/core/src/index/thirdparty/annoy/src/annoymodule.cc b/core/src/index/thirdparty/annoy/src/annoymodule.cc new file mode 100644 index 0000000000..6121a2bc41 --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/annoymodule.cc @@ -0,0 +1,632 @@ +// Copyright (c) 2013 Spotify AB +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "annoylib.h" +#include "kissrandom.h" +#include "Python.h" +#include "structmember.h" +#include +#if defined(_MSC_VER) && _MSC_VER == 1500 +typedef signed __int32 int32_t; +#else +#include +#endif + + +#if defined(USE_AVX512) +#define AVX_INFO "Using 512-bit AVX instructions" +#elif defined(USE_AVX128) +#define AVX_INFO "Using 128-bit AVX instructions" +#else +#define AVX_INFO "Not using AVX instructions" +#endif + +#if defined(_MSC_VER) +#define COMPILER_INFO "Compiled using MSC" +#elif defined(__GNUC__) +#define COMPILER_INFO "Compiled on GCC" +#else +#define COMPILER_INFO "Compiled on unknown platform" +#endif + +#define ANNOY_DOC (COMPILER_INFO ". " AVX_INFO ".") + +#if PY_MAJOR_VERSION >= 3 +#define IS_PY3K +#endif + +#ifndef Py_TYPE + #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif + +#ifdef IS_PY3K + #define PyInt_FromLong PyLong_FromLong +#endif + + +template class AnnoyIndexInterface; + +class HammingWrapper : public AnnoyIndexInterface { + // Wrapper class for Hamming distance, using composition. + // This translates binary (float) vectors into packed uint64_t vectors. + // This is questionable from a performance point of view. Should reconsider this solution. +private: + int32_t _f_external, _f_internal; + AnnoyIndex _index; + void _pack(const float* src, uint64_t* dst) const { + for (int32_t i = 0; i < _f_internal; i++) { + dst[i] = 0; + for (int32_t j = 0; j < 64 && i*64+j < _f_external; j++) { + dst[i] |= (uint64_t)(src[i * 64 + j] > 0.5) << j; + } + } + }; + void _unpack(const uint64_t* src, float* dst) const { + for (int32_t i = 0; i < _f_external; i++) { + dst[i] = (src[i / 64] >> (i % 64)) & 1; + } + }; +public: + HammingWrapper(int f) : _f_external(f), _f_internal((f + 63) / 64), _index((f + 63) / 64) {}; + bool add_item(int32_t item, const float* w, char**error) { + vector w_internal(_f_internal, 0); + _pack(w, &w_internal[0]); + return _index.add_item(item, &w_internal[0], error); + }; + bool build(int q, char** error) { return _index.build(q, error); }; + bool unbuild(char** error) { return _index.unbuild(error); }; + bool save(const char* filename, bool prefault, char** error) { return _index.save(filename, prefault, error); }; + void unload() { _index.unload(); }; + bool load(const char* filename, bool prefault, char** error) { return _index.load(filename, prefault, error); }; + float get_distance(int32_t i, int32_t j) const { return _index.get_distance(i, j); }; + void get_nns_by_item(int32_t item, size_t n, int search_k, vector* result, vector* distances) const { + if (distances) { + vector distances_internal; + _index.get_nns_by_item(item, n, search_k, result, &distances_internal); + distances->insert(distances->begin(), distances_internal.begin(), distances_internal.end()); + } else { + _index.get_nns_by_item(item, n, search_k, result, nullptr); + } + }; + void get_nns_by_vector(const float* w, size_t n, int search_k, vector* result, vector* distances) const { + vector w_internal(_f_internal, 0); + _pack(w, &w_internal[0]); + if (distances) { + vector distances_internal; + _index.get_nns_by_vector(&w_internal[0], n, search_k, result, &distances_internal); + distances->insert(distances->begin(), distances_internal.begin(), distances_internal.end()); + } else { + _index.get_nns_by_vector(&w_internal[0], n, search_k, result, nullptr); + } + }; + int32_t get_n_items() const { return _index.get_n_items(); }; + int32_t get_n_trees() const { return _index.get_n_trees(); }; + void verbose(bool v) { _index.verbose(v); }; + void get_item(int32_t item, float* v) const { + vector v_internal(_f_internal, 0); + _index.get_item(item, &v_internal[0]); + _unpack(&v_internal[0], v); + }; + void set_seed(int q) { _index.set_seed(q); }; + bool on_disk_build(const char* filename, char** error) { return _index.on_disk_build(filename, error); }; +}; + +// annoy python object +typedef struct { + PyObject_HEAD + int f; + AnnoyIndexInterface* ptr; +} py_annoy; + + +static PyObject * +py_an_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { + py_annoy *self = (py_annoy *)type->tp_alloc(type, 0); + if (self == nullptr) { + return nullptr; + } + const char *metric = nullptr; + + static char const * kwlist[] = {"f", "metric", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i|s", (char**)kwlist, &self->f, &metric)) + return nullptr; + if (!metric) { + // This keeps coming up, see #368 etc + PyErr_WarnEx(PyExc_FutureWarning, "The default argument for metric will be removed " + "in future version of Annoy. Please pass metric='angular' explicitly.", 1); + self->ptr = new AnnoyIndex(self->f); + } else if (!strcmp(metric, "angular")) { + self->ptr = new AnnoyIndex(self->f); + } else if (!strcmp(metric, "euclidean")) { + self->ptr = new AnnoyIndex(self->f); + } else if (!strcmp(metric, "manhattan")) { + self->ptr = new AnnoyIndex(self->f); + } else if (!strcmp(metric, "hamming")) { + self->ptr = new HammingWrapper(self->f); + } else if (!strcmp(metric, "dot")) { + self->ptr = new AnnoyIndex(self->f); + } else { + PyErr_SetString(PyExc_ValueError, "No such metric"); + return nullptr; + } + + return (PyObject *)self; +} + + +static int +py_an_init(py_annoy *self, PyObject *args, PyObject *kwargs) { + // Seems to be needed for Python 3 + const char *metric = nullptr; + int f; + static char const * kwlist[] = {"f", "metric", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i|s", (char**)kwlist, &f, &metric)) + return (int) nullptr; + return 0; +} + + +static void +py_an_dealloc(py_annoy* self) { + delete self->ptr; + Py_TYPE(self)->tp_free((PyObject*)self); +} + + +static PyMemberDef py_annoy_members[] = { + {(char*)"f", T_INT, offsetof(py_annoy, f), 0, + (char*)""}, + {nullptr} /* Sentinel */ +}; + + +static PyObject * +py_an_load(py_annoy *self, PyObject *args, PyObject *kwargs) { + char *filename, *error; + bool prefault = false; + if (!self->ptr) + return nullptr; + static char const * kwlist[] = {"fn", "prefault", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|b", (char**)kwlist, &filename, &prefault)) + return nullptr; + + if (!self->ptr->load(filename, prefault, &error)) { + PyErr_SetString(PyExc_IOError, error); + free(error); + return nullptr; + } + Py_RETURN_TRUE; +} + + +static PyObject * +py_an_save(py_annoy *self, PyObject *args, PyObject *kwargs) { + char *filename, *error; + bool prefault = false; + if (!self->ptr) + return nullptr; + static char const * kwlist[] = {"fn", "prefault", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|b", (char**)kwlist, &filename, &prefault)) + return nullptr; + + if (!self->ptr->save(filename, prefault, &error)) { + PyErr_SetString(PyExc_IOError, error); + free(error); + return nullptr; + } + Py_RETURN_TRUE; +} + + +PyObject* +get_nns_to_python(const vector& result, const vector& distances, int include_distances) { + PyObject* l = PyList_New(result.size()); + for (size_t i = 0; i < result.size(); i++) + PyList_SetItem(l, i, PyInt_FromLong(result[i])); + if (!include_distances) + return l; + + PyObject* d = PyList_New(distances.size()); + for (size_t i = 0; i < distances.size(); i++) + PyList_SetItem(d, i, PyFloat_FromDouble(distances[i])); + + PyObject* t = PyTuple_New(2); + PyTuple_SetItem(t, 0, l); + PyTuple_SetItem(t, 1, d); + + return t; +} + + +bool check_constraints(py_annoy *self, int32_t item, bool building) { + if (item < 0) { + PyErr_SetString(PyExc_IndexError, "Item index can not be negative"); + return false; + } else if (!building && item >= self->ptr->get_n_items()) { + PyErr_SetString(PyExc_IndexError, "Item index larger than the largest item index"); + return false; + } else { + return true; + } +} + +static PyObject* +py_an_get_nns_by_item(py_annoy *self, PyObject *args, PyObject *kwargs) { + int32_t item, n, search_k=-1, include_distances=0; + if (!self->ptr) + return nullptr; + + static char const * kwlist[] = {"i", "n", "search_k", "include_distances", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ii|ii", (char**)kwlist, &item, &n, &search_k, &include_distances)) + return nullptr; + + if (!check_constraints(self, item, false)) { + return nullptr; + } + + vector result; + vector distances; + + Py_BEGIN_ALLOW_THREADS; + self->ptr->get_nns_by_item(item, n, search_k, &result, include_distances ? &distances : nullptr); + Py_END_ALLOW_THREADS; + + return get_nns_to_python(result, distances, include_distances); +} + + +bool +convert_list_to_vector(PyObject* v, int f, vector* w) { + if (PyObject_Size(v) == -1) { + char buf[256]; + snprintf(buf, 256, "Expected an iterable, got an object of type \"%s\"", v->ob_type->tp_name); + PyErr_SetString(PyExc_ValueError, buf); + return false; + } + if (PyObject_Size(v) != f) { + char buf[128]; + snprintf(buf, 128, "Vector has wrong length (expected %d, got %ld)", f, PyObject_Size(v)); + PyErr_SetString(PyExc_IndexError, buf); + return false; + } + for (int z = 0; z < f; z++) { + PyObject *key = PyInt_FromLong(z); + PyObject *pf = PyObject_GetItem(v, key); + (*w)[z] = PyFloat_AsDouble(pf); + Py_DECREF(key); + Py_DECREF(pf); + } + return true; +} + +static PyObject* +py_an_get_nns_by_vector(py_annoy *self, PyObject *args, PyObject *kwargs) { + PyObject* v; + int32_t n, search_k=-1, include_distances=0; + if (!self->ptr) + return nullptr; + + static char const * kwlist[] = {"vector", "n", "search_k", "include_distances", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi|ii", (char**)kwlist, &v, &n, &search_k, &include_distances)) + return nullptr; + + vector w(self->f); + if (!convert_list_to_vector(v, self->f, &w)) { + return nullptr; + } + + vector result; + vector distances; + + Py_BEGIN_ALLOW_THREADS; + self->ptr->get_nns_by_vector(&w[0], n, search_k, &result, include_distances ? &distances : nullptr); + Py_END_ALLOW_THREADS; + + return get_nns_to_python(result, distances, include_distances); +} + + +static PyObject* +py_an_get_item_vector(py_annoy *self, PyObject *args) { + int32_t item; + if (!self->ptr) + return nullptr; + if (!PyArg_ParseTuple(args, "i", &item)) + return nullptr; + + if (!check_constraints(self, item, false)) { + return nullptr; + } + + vector v(self->f); + self->ptr->get_item(item, &v[0]); + PyObject* l = PyList_New(self->f); + for (int z = 0; z < self->f; z++) { + PyList_SetItem(l, z, PyFloat_FromDouble(v[z])); + } + + return l; +} + + +static PyObject* +py_an_add_item(py_annoy *self, PyObject *args, PyObject* kwargs) { + PyObject* v; + int32_t item; + if (!self->ptr) + return nullptr; + static char const * kwlist[] = {"i", "vector", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "iO", (char**)kwlist, &item, &v)) + return nullptr; + + if (!check_constraints(self, item, true)) { + return nullptr; + } + + vector w(self->f); + if (!convert_list_to_vector(v, self->f, &w)) { + return nullptr; + } + char* error; + if (!self->ptr->add_item(item, &w[0], &error)) { + PyErr_SetString(PyExc_Exception, error); + free(error); + return nullptr; + } + + Py_RETURN_NONE; +} + +static PyObject * +py_an_on_disk_build(py_annoy *self, PyObject *args, PyObject *kwargs) { + char *filename, *error; + if (!self->ptr) + return nullptr; + static char const * kwlist[] = {"fn", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s", (char**)kwlist, &filename)) + return nullptr; + + if (!self->ptr->on_disk_build(filename, &error)) { + PyErr_SetString(PyExc_IOError, error); + free(error); + return nullptr; + } + Py_RETURN_TRUE; +} + +static PyObject * +py_an_build(py_annoy *self, PyObject *args, PyObject *kwargs) { + int q; + if (!self->ptr) + return nullptr; + static char const * kwlist[] = {"n_trees", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "i", (char**)kwlist, &q)) + return nullptr; + + bool res; + char* error; + Py_BEGIN_ALLOW_THREADS; + res = self->ptr->build(q, &error); + Py_END_ALLOW_THREADS; + if (!res) { + PyErr_SetString(PyExc_Exception, error); + free(error); + return nullptr; + } + + Py_RETURN_TRUE; +} + + +static PyObject * +py_an_unbuild(py_annoy *self) { + if (!self->ptr) + return nullptr; + + char* error; + if (!self->ptr->unbuild(&error)) { + PyErr_SetString(PyExc_Exception, error); + free(error); + return nullptr; + } + + Py_RETURN_TRUE; +} + + +static PyObject * +py_an_unload(py_annoy *self) { + if (!self->ptr) + return nullptr; + + self->ptr->unload(); + + Py_RETURN_TRUE; +} + + +static PyObject * +py_an_get_distance(py_annoy *self, PyObject *args) { + int32_t i, j; + if (!self->ptr) + return nullptr; + if (!PyArg_ParseTuple(args, "ii", &i, &j)) + return nullptr; + + if (!check_constraints(self, i, false) || !check_constraints(self, j, false)) { + return nullptr; + } + + double d = self->ptr->get_distance(i,j); + return PyFloat_FromDouble(d); +} + + +static PyObject * +py_an_get_n_items(py_annoy *self) { + if (!self->ptr) + return nullptr; + + int32_t n = self->ptr->get_n_items(); + return PyInt_FromLong(n); +} + +static PyObject * +py_an_get_n_trees(py_annoy *self) { + if (!self->ptr) + return nullptr; + + int32_t n = self->ptr->get_n_trees(); + return PyInt_FromLong(n); +} + +static PyObject * +py_an_verbose(py_annoy *self, PyObject *args) { + int verbose; + if (!self->ptr) + return nullptr; + if (!PyArg_ParseTuple(args, "i", &verbose)) + return nullptr; + + self->ptr->verbose((bool)verbose); + + Py_RETURN_TRUE; +} + + +static PyObject * +py_an_set_seed(py_annoy *self, PyObject *args) { + int q; + if (!self->ptr) + return nullptr; + if (!PyArg_ParseTuple(args, "i", &q)) + return nullptr; + + self->ptr->set_seed(q); + + Py_RETURN_NONE; +} + + +static PyMethodDef AnnoyMethods[] = { + {"load", (PyCFunction)py_an_load, METH_VARARGS | METH_KEYWORDS, "Loads (mmaps) an index from disk."}, + {"save", (PyCFunction)py_an_save, METH_VARARGS | METH_KEYWORDS, "Saves the index to disk."}, + {"get_nns_by_item",(PyCFunction)py_an_get_nns_by_item, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to item `i`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."}, + {"get_nns_by_vector",(PyCFunction)py_an_get_nns_by_vector, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to vector `vector`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."}, + {"get_item_vector",(PyCFunction)py_an_get_item_vector, METH_VARARGS, "Returns the vector for item `i` that was previously added."}, + {"add_item",(PyCFunction)py_an_add_item, METH_VARARGS | METH_KEYWORDS, "Adds item `i` (any nonnegative integer) with vector `v`.\n\nNote that it will allocate memory for `max(i)+1` items."}, + {"on_disk_build",(PyCFunction)py_an_on_disk_build, METH_VARARGS | METH_KEYWORDS, "Build will be performed with storage on disk instead of RAM."}, + {"build",(PyCFunction)py_an_build, METH_VARARGS | METH_KEYWORDS, "Builds a forest of `n_trees` trees.\n\nMore trees give higher precision when querying. After calling `build`,\nno more items can be added."}, + {"unbuild",(PyCFunction)py_an_unbuild, METH_NOARGS, "Unbuilds the tree in order to allows adding new items.\n\nbuild() has to be called again afterwards in order to\nrun queries."}, + {"unload",(PyCFunction)py_an_unload, METH_NOARGS, "Unloads an index from disk."}, + {"get_distance",(PyCFunction)py_an_get_distance, METH_VARARGS, "Returns the distance between items `i` and `j`."}, + {"get_n_items",(PyCFunction)py_an_get_n_items, METH_NOARGS, "Returns the number of items in the index."}, + {"get_n_trees",(PyCFunction)py_an_get_n_trees, METH_NOARGS, "Returns the number of trees in the index."}, + {"verbose",(PyCFunction)py_an_verbose, METH_VARARGS, ""}, + {"set_seed",(PyCFunction)py_an_set_seed, METH_VARARGS, "Sets the seed of Annoy's random number generator."}, + {nullptr, nullptr, 0, nullptr} /* Sentinel */ +}; + + +static PyTypeObject PyAnnoyType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "annoy.Annoy", /*tp_name*/ + sizeof(py_annoy), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)py_an_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + ANNOY_DOC, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + AnnoyMethods, /* tp_methods */ + py_annoy_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)py_an_init, /* tp_init */ + 0, /* tp_alloc */ + py_an_new, /* tp_new */ +}; + +static PyMethodDef module_methods[] = { + {nullptr} /* Sentinel */ +}; + +#if PY_MAJOR_VERSION >= 3 + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "annoylib", /* m_name */ + ANNOY_DOC, /* m_doc */ + -1, /* m_size */ + module_methods, /* m_methods */ + nullptr, /* m_reload */ + nullptr, /* m_traverse */ + nullptr, /* m_clear */ + nullptr, /* m_free */ + }; +#endif + +PyObject *create_module(void) { + PyObject *m; + + if (PyType_Ready(&PyAnnoyType) < 0) + return nullptr; + +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&moduledef); +#else + m = Py_InitModule("annoylib", module_methods); +#endif + + if (m == nullptr) + return nullptr; + + Py_INCREF(&PyAnnoyType); + PyModule_AddObject(m, "Annoy", (PyObject *)&PyAnnoyType); + return m; +} + +#if PY_MAJOR_VERSION >= 3 + PyMODINIT_FUNC PyInit_annoylib(void) { + return create_module(); // it should return moudule object in py3 + } +#else + PyMODINIT_FUNC initannoylib(void) { + create_module(); + } +#endif + + +// vim: tabstop=2 shiftwidth=2 diff --git a/core/src/index/thirdparty/annoy/src/kissrandom.h b/core/src/index/thirdparty/annoy/src/kissrandom.h new file mode 100644 index 0000000000..9e40110f3e --- /dev/null +++ b/core/src/index/thirdparty/annoy/src/kissrandom.h @@ -0,0 +1,106 @@ +#ifndef KISSRANDOM_H +#define KISSRANDOM_H + +#if defined(_MSC_VER) && _MSC_VER == 1500 +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +#else +#include +#endif + +// KISS = "keep it simple, stupid", but high quality random number generator +// http://www0.cs.ucl.ac.uk/staff/d.jones/GoodPracticeRNG.pdf -> "Use a good RNG and build it into your code" +// http://mathforum.org/kb/message.jspa?messageID=6627731 +// https://de.wikipedia.org/wiki/KISS_(Zufallszahlengenerator) + +// 32 bit KISS +struct Kiss32Random { + uint32_t x; + uint32_t y; + uint32_t z; + uint32_t c; + + // seed must be != 0 + Kiss32Random(uint32_t seed = 123456789) { + x = seed; + y = 362436000; + z = 521288629; + c = 7654321; + } + + uint32_t kiss() { + // Linear congruence generator + x = 69069 * x + 12345; + + // Xor shift + y ^= y << 13; + y ^= y >> 17; + y ^= y << 5; + + // Multiply-with-carry + uint64_t t = 698769069ULL * z + c; + c = t >> 32; + z = (uint32_t) t; + + return x + y + z; + } + inline int flip() { + // Draw random 0 or 1 + return kiss() & 1; + } + inline size_t index(size_t n) { + // Draw random integer between 0 and n-1 where n is at most the number of data points you have + return kiss() % n; + } + inline void set_seed(uint32_t seed) { + x = seed; + } +}; + +// 64 bit KISS. Use this if you have more than about 2^24 data points ("big data" ;) ) +struct Kiss64Random { + uint64_t x; + uint64_t y; + uint64_t z; + uint64_t c; + + // seed must be != 0 + Kiss64Random(uint64_t seed = 1234567890987654321ULL) { + x = seed; + y = 362436362436362436ULL; + z = 1066149217761810ULL; + c = 123456123456123456ULL; + } + + uint64_t kiss() { + // Linear congruence generator + z = 6906969069LL*z+1234567; + + // Xor shift + y ^= (y<<13); + y ^= (y>>17); + y ^= (y<<43); + + // Multiply-with-carry (uint128_t t = (2^58 + 1) * x + c; c = t >> 64; x = (uint64_t) t) + uint64_t t = (x<<58)+c; + c = (x>>6); + x += t; + c += (x +#include +#include +#include + +#define PROT_NONE 0 +#define PROT_READ 1 +#define PROT_WRITE 2 +#define PROT_EXEC 4 + +#define MAP_FILE 0 +#define MAP_SHARED 1 +#define MAP_PRIVATE 2 +#define MAP_TYPE 0xf +#define MAP_FIXED 0x10 +#define MAP_ANONYMOUS 0x20 +#define MAP_ANON MAP_ANONYMOUS + +#define MAP_FAILED ((void *)-1) + +/* Flags for msync. */ +#define MS_ASYNC 1 +#define MS_SYNC 2 +#define MS_INVALIDATE 4 + +#ifndef FILE_MAP_EXECUTE +#define FILE_MAP_EXECUTE 0x0020 +#endif + +static int __map_mman_error(const DWORD err, const int deferr) +{ + if (err == 0) + return 0; + //TODO: implement + return err; +} + +static DWORD __map_mmap_prot_page(const int prot) +{ + DWORD protect = 0; + + if (prot == PROT_NONE) + return protect; + + if ((prot & PROT_EXEC) != 0) + { + protect = ((prot & PROT_WRITE) != 0) ? + PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ; + } + else + { + protect = ((prot & PROT_WRITE) != 0) ? + PAGE_READWRITE : PAGE_READONLY; + } + + return protect; +} + +static DWORD __map_mmap_prot_file(const int prot) +{ + DWORD desiredAccess = 0; + + if (prot == PROT_NONE) + return desiredAccess; + + if ((prot & PROT_READ) != 0) + desiredAccess |= FILE_MAP_READ; + if ((prot & PROT_WRITE) != 0) + desiredAccess |= FILE_MAP_WRITE; + if ((prot & PROT_EXEC) != 0) + desiredAccess |= FILE_MAP_EXECUTE; + + return desiredAccess; +} + +inline void* mmap(void *addr, size_t len, int prot, int flags, int fildes, off_t off) +{ + HANDLE fm, h; + + void * map = MAP_FAILED; + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4293) +#endif + + const DWORD dwFileOffsetLow = (sizeof(off_t) <= sizeof(DWORD)) ? + (DWORD)off : (DWORD)(off & 0xFFFFFFFFL); + const DWORD dwFileOffsetHigh = (sizeof(off_t) <= sizeof(DWORD)) ? + (DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL); + const DWORD protect = __map_mmap_prot_page(prot); + const DWORD desiredAccess = __map_mmap_prot_file(prot); + + const off_t maxSize = off + (off_t)len; + + const DWORD dwMaxSizeLow = (sizeof(off_t) <= sizeof(DWORD)) ? + (DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL); + const DWORD dwMaxSizeHigh = (sizeof(off_t) <= sizeof(DWORD)) ? + (DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + errno = 0; + + if (len == 0 + /* Unsupported flag combinations */ + || (flags & MAP_FIXED) != 0 + /* Usupported protection combinations */ + || prot == PROT_EXEC) + { + errno = EINVAL; + return MAP_FAILED; + } + + h = ((flags & MAP_ANONYMOUS) == 0) ? + (HANDLE)_get_osfhandle(fildes) : INVALID_HANDLE_VALUE; + + if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) + { + errno = EBADF; + return MAP_FAILED; + } + + fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); + + if (fm == NULL) + { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } + + map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len); + + CloseHandle(fm); + + if (map == NULL) + { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } + + return map; +} + +inline int munmap(void *addr, size_t len) +{ + if (UnmapViewOfFile(addr)) + return 0; + + errno = __map_mman_error(GetLastError(), EPERM); + + return -1; +} + +inline int mprotect(void *addr, size_t len, int prot) +{ + DWORD newProtect = __map_mmap_prot_page(prot); + DWORD oldProtect = 0; + + if (VirtualProtect(addr, len, newProtect, &oldProtect)) + return 0; + + errno = __map_mman_error(GetLastError(), EPERM); + + return -1; +} + +inline int msync(void *addr, size_t len, int flags) +{ + if (FlushViewOfFile(addr, len)) + return 0; + + errno = __map_mman_error(GetLastError(), EPERM); + + return -1; +} + +inline int mlock(const void *addr, size_t len) +{ + if (VirtualLock((LPVOID)addr, len)) + return 0; + + errno = __map_mman_error(GetLastError(), EPERM); + + return -1; +} + +inline int munlock(const void *addr, size_t len) +{ + if (VirtualUnlock((LPVOID)addr, len)) + return 0; + + errno = __map_mman_error(GetLastError(), EPERM); + + return -1; +} + +#if !defined(__MINGW32__) +inline int ftruncate(int fd, unsigned int size) { + if (fd < 0) { + errno = EBADF; + return -1; + } + + HANDLE h = (HANDLE)_get_osfhandle(fd); + unsigned int cur = SetFilePointer(h, 0, NULL, FILE_CURRENT); + if (cur == ~0 || SetFilePointer(h, size, NULL, FILE_BEGIN) == ~0 || !SetEndOfFile(h)) { + int error = GetLastError(); + switch (GetLastError()) { + case ERROR_INVALID_HANDLE: + errno = EBADF; + break; + default: + errno = EIO; + break; + } + return -1; + } + + return 0; +} +#endif + +#endif diff --git a/core/src/index/thirdparty/build.sh b/core/src/index/thirdparty/build.sh new file mode 100755 index 0000000000..a0af3349d5 --- /dev/null +++ b/core/src/index/thirdparty/build.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +BUILD_TYPE="Release" + +while getopts "p:t:d:h" arg +do + case $arg in + t) + BUILD_TYPE=$OPTARG # BUILD_TYPE + ;; + h) # help + echo " + +parameter: +-p: postgresql install path. +-t: build type + +usage: +./build.sh -t \${BUILD_TYPE} + " + exit 0 + ;; + ?) + echo "unknown argument" + exit 1 + ;; + esac +done + +if [[ -d build ]]; then + rm ./build -r +fi + +while IFS='' read -r line || [[ -n "$line" ]]; do + cd $line + ./build.sh -t ${BUILD_TYPE} + if [ $? -ne 0 ];then + exit 1 + fi + cd ../ +done < project.conf + diff --git a/core/src/index/thirdparty/faiss/.dockerignore b/core/src/index/thirdparty/faiss/.dockerignore new file mode 100644 index 0000000000..7763a51dc3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/.dockerignore @@ -0,0 +1 @@ +sift1M \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/.gitignore b/core/src/index/thirdparty/faiss/.gitignore new file mode 100644 index 0000000000..a25bc7f112 --- /dev/null +++ b/core/src/index/thirdparty/faiss/.gitignore @@ -0,0 +1,21 @@ +*.swp +*.swo +*.o +*.a +*.dSYM +*.so +*.dylib +*.pyc +*~ +.DS_Store +depend +/config.* +/aclocal.m4 +/autom4te.cache/ +/makefile.inc +/bin/ +/c_api/bin/ +/c_api/gpu/bin/ +/tests/test +/tests/gtest/ +include/ diff --git a/core/src/index/thirdparty/faiss/AutoTune.cpp b/core/src/index/thirdparty/faiss/AutoTune.cpp new file mode 100644 index 0000000000..a90a6f53ea --- /dev/null +++ b/core/src/index/thirdparty/faiss/AutoTune.cpp @@ -0,0 +1,719 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * implementation of Hyper-parameter auto-tuning + */ + +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace faiss { + + +AutoTuneCriterion::AutoTuneCriterion (idx_t nq, idx_t nnn): + nq (nq), nnn (nnn), gt_nnn (0) +{} + + +void AutoTuneCriterion::set_groundtruth ( + int gt_nnn, const float *gt_D_in, const idx_t *gt_I_in) +{ + this->gt_nnn = gt_nnn; + if (gt_D_in) { // allow null for this, as it is often not used + gt_D.resize (nq * gt_nnn); + memcpy (gt_D.data(), gt_D_in, sizeof (gt_D[0]) * nq * gt_nnn); + } + gt_I.resize (nq * gt_nnn); + memcpy (gt_I.data(), gt_I_in, sizeof (gt_I[0]) * nq * gt_nnn); +} + + + +OneRecallAtRCriterion::OneRecallAtRCriterion (idx_t nq, idx_t R): + AutoTuneCriterion(nq, R), R(R) +{} + +double OneRecallAtRCriterion::evaluate(const float* /*D*/, const idx_t* I) + const { + FAISS_THROW_IF_NOT_MSG( + (gt_I.size() == gt_nnn * nq && gt_nnn >= 1 && nnn >= R), + "ground truth not initialized"); + idx_t n_ok = 0; + for (idx_t q = 0; q < nq; q++) { + idx_t gt_nn = gt_I[q * gt_nnn]; + const idx_t* I_line = I + q * nnn; + for (int i = 0; i < R; i++) { + if (I_line[i] == gt_nn) { + n_ok++; + break; + } + } + } + return n_ok / double(nq); +} + + +IntersectionCriterion::IntersectionCriterion (idx_t nq, idx_t R): + AutoTuneCriterion(nq, R), R(R) +{} + +double IntersectionCriterion::evaluate(const float* /*D*/, const idx_t* I) + const { + FAISS_THROW_IF_NOT_MSG( + (gt_I.size() == gt_nnn * nq && gt_nnn >= R && nnn >= R), + "ground truth not initialized"); + int64_t n_ok = 0; +#pragma omp parallel for reduction(+: n_ok) + for (idx_t q = 0; q < nq; q++) { + n_ok += ranklist_intersection_size ( + R, >_I [q * gt_nnn], + R, I + q * nnn); + } + return n_ok / double (nq * R); +} + +/*************************************************************** + * OperatingPoints + ***************************************************************/ + +OperatingPoints::OperatingPoints () +{ + clear(); +} + +void OperatingPoints::clear () +{ + all_pts.clear(); + optimal_pts.clear(); + /// default point: doing nothing gives 0 performance and takes 0 time + OperatingPoint op = {0, 0, "", -1}; + optimal_pts.push_back(op); +} + +/// add a performance measure +bool OperatingPoints::add (double perf, double t, const std::string & key, + size_t cno) +{ + OperatingPoint op = {perf, t, key, int64_t(cno)}; + all_pts.push_back (op); + if (perf == 0) { + return false; // no method for 0 accuracy is faster than doing nothing + } + std::vector & a = optimal_pts; + if (perf > a.back().perf) { + // keep unconditionally + a.push_back (op); + } else if (perf == a.back().perf) { + if (t < a.back ().t) { + a.back() = op; + } else { + return false; + } + } else { + int i; + // stricto sensu this should be a bissection + for (i = 0; i < a.size(); i++) { + if (a[i].perf >= perf) break; + } + assert (i < a.size()); + if (t < a[i].t) { + if (a[i].perf == perf) { + a[i] = op; + } else { + a.insert (a.begin() + i, op); + } + } else { + return false; + } + } + { // remove non-optimal points from array + int i = a.size() - 1; + while (i > 0) { + if (a[i].t < a[i - 1].t) + a.erase (a.begin() + (i - 1)); + i--; + } + } + return true; +} + + +int OperatingPoints::merge_with (const OperatingPoints &other, + const std::string & prefix) +{ + int n_add = 0; + for (int i = 0; i < other.all_pts.size(); i++) { + const OperatingPoint & op = other.all_pts[i]; + if (add (op.perf, op.t, prefix + op.key, op.cno)) + n_add++; + } + return n_add; +} + + + +/// get time required to obtain a given performance measure +double OperatingPoints::t_for_perf (double perf) const +{ + const std::vector & a = optimal_pts; + if (perf > a.back().perf) return 1e50; + int i0 = -1, i1 = a.size() - 1; + while (i0 + 1 < i1) { + int imed = (i0 + i1 + 1) / 2; + if (a[imed].perf < perf) i0 = imed; + else i1 = imed; + } + return a[i1].t; +} + + +void OperatingPoints::all_to_gnuplot (const char *fname) const +{ + FILE *f = fopen(fname, "w"); + if (!f) { + fprintf (stderr, "cannot open %s", fname); + perror(""); + abort(); + } + for (int i = 0; i < all_pts.size(); i++) { + const OperatingPoint & op = all_pts[i]; + fprintf (f, "%g %g %s\n", op.perf, op.t, op.key.c_str()); + } + fclose(f); +} + +void OperatingPoints::optimal_to_gnuplot (const char *fname) const +{ + FILE *f = fopen(fname, "w"); + if (!f) { + fprintf (stderr, "cannot open %s", fname); + perror(""); + abort(); + } + double prev_perf = 0.0; + for (int i = 0; i < optimal_pts.size(); i++) { + const OperatingPoint & op = optimal_pts[i]; + fprintf (f, "%g %g\n", prev_perf, op.t); + fprintf (f, "%g %g %s\n", op.perf, op.t, op.key.c_str()); + prev_perf = op.perf; + } + fclose(f); +} + +void OperatingPoints::display (bool only_optimal) const +{ + const std::vector &pts = + only_optimal ? optimal_pts : all_pts; + printf("Tested %ld operating points, %ld ones are optimal:\n", + all_pts.size(), optimal_pts.size()); + + for (int i = 0; i < pts.size(); i++) { + const OperatingPoint & op = pts[i]; + const char *star = ""; + if (!only_optimal) { + for (int j = 0; j < optimal_pts.size(); j++) { + if (op.cno == optimal_pts[j].cno) { + star = "*"; + break; + } + } + } + printf ("cno=%ld key=%s perf=%.4f t=%.3f %s\n", + op.cno, op.key.c_str(), op.perf, op.t, star); + } + +} + +/*************************************************************** + * ParameterSpace + ***************************************************************/ + +ParameterSpace::ParameterSpace (): + verbose (1), n_experiments (500), + batchsize (1<<30), thread_over_batches (false), + min_test_duration (0) +{ +} + +/* not keeping this constructor as inheritors will call the parent + initialize() + */ + +#if 0 +ParameterSpace::ParameterSpace (Index *index): + verbose (1), n_experiments (500), + batchsize (1<<30), thread_over_batches (false) + +{ + initialize(index); +} +#endif + +size_t ParameterSpace::n_combinations () const +{ + size_t n = 1; + for (int i = 0; i < parameter_ranges.size(); i++) + n *= parameter_ranges[i].values.size(); + return n; +} + +/// get string representation of the combination +std::string ParameterSpace::combination_name (size_t cno) const { + char buf[1000], *wp = buf; + *wp = 0; + for (int i = 0; i < parameter_ranges.size(); i++) { + const ParameterRange & pr = parameter_ranges[i]; + size_t j = cno % pr.values.size(); + cno /= pr.values.size(); + wp += snprintf ( + wp, buf + 1000 - wp, "%s%s=%g", i == 0 ? "" : ",", + pr.name.c_str(), pr.values[j]); + } + return std::string (buf); +} + + +bool ParameterSpace::combination_ge (size_t c1, size_t c2) const +{ + for (int i = 0; i < parameter_ranges.size(); i++) { + int nval = parameter_ranges[i].values.size(); + size_t j1 = c1 % nval; + size_t j2 = c2 % nval; + if (!(j1 >= j2)) return false; + c1 /= nval; + c2 /= nval; + } + return true; +} + + + +#define DC(classname) \ + const classname *ix = dynamic_cast(index) + +static void init_pq_ParameterRange (const ProductQuantizer & pq, + ParameterRange & pr) +{ + if (pq.code_size % 4 == 0) { + // Polysemous not supported for code sizes that are not a + // multiple of 4 + for (int i = 2; i <= pq.code_size * 8 / 2; i+= 2) + pr.values.push_back(i); + } + pr.values.push_back (pq.code_size * 8); +} + +ParameterRange &ParameterSpace::add_range(const char * name) +{ + for (auto & pr : parameter_ranges) { + if (pr.name == name) { + return pr; + } + } + parameter_ranges.push_back (ParameterRange ()); + parameter_ranges.back ().name = name; + return parameter_ranges.back (); +} + + +/// initialize with reasonable parameters for the index +void ParameterSpace::initialize (const Index * index) +{ + if (DC (IndexPreTransform)) { + index = ix->index; + } + if (DC (IndexRefineFlat)) { + ParameterRange & pr = add_range("k_factor_rf"); + for (int i = 0; i <= 6; i++) { + pr.values.push_back (1 << i); + } + index = ix->base_index; + } + if (DC (IndexPreTransform)) { + index = ix->index; + } + + if (DC (IndexIVF)) { + { + ParameterRange & pr = add_range("nprobe"); + for (int i = 0; i < 13; i++) { + size_t nprobe = 1 << i; + if (nprobe >= ix->nlist) break; + pr.values.push_back (nprobe); + } + } + if (dynamic_cast(ix->quantizer)) { + ParameterRange & pr = add_range("efSearch"); + for (int i = 2; i <= 9; i++) { + pr.values.push_back (1 << i); + } + } + } + if (DC (IndexPQ)) { + ParameterRange & pr = add_range("ht"); + init_pq_ParameterRange (ix->pq, pr); + } + if (DC (IndexIVFPQ)) { + ParameterRange & pr = add_range("ht"); + init_pq_ParameterRange (ix->pq, pr); + } + + if (DC (IndexIVF)) { + const MultiIndexQuantizer *miq = + dynamic_cast (ix->quantizer); + if (miq) { + ParameterRange & pr_max_codes = add_range("max_codes"); + for (int i = 8; i < 20; i++) { + pr_max_codes.values.push_back (1 << i); + } + pr_max_codes.values.push_back ( + std::numeric_limits::infinity() + ); + } + } + if (DC (IndexIVFPQR)) { + ParameterRange & pr = add_range("k_factor"); + for (int i = 0; i <= 6; i++) { + pr.values.push_back (1 << i); + } + } + if (dynamic_cast(index)) { + ParameterRange & pr = add_range("efSearch"); + for (int i = 2; i <= 9; i++) { + pr.values.push_back (1 << i); + } + } +} + +#undef DC + +// non-const version +#define DC(classname) classname *ix = dynamic_cast(index) + + +/// set a combination of parameters on an index +void ParameterSpace::set_index_parameters (Index *index, size_t cno) const +{ + + for (int i = 0; i < parameter_ranges.size(); i++) { + const ParameterRange & pr = parameter_ranges[i]; + size_t j = cno % pr.values.size(); + cno /= pr.values.size(); + double val = pr.values [j]; + set_index_parameter (index, pr.name, val); + } +} + +/// set a combination of parameters on an index +void ParameterSpace::set_index_parameters ( + Index *index, const char *description_in) const +{ + char description[strlen(description_in) + 1]; + char *ptr; + memcpy (description, description_in, strlen(description_in) + 1); + + for (char *tok = strtok_r (description, " ,", &ptr); + tok; + tok = strtok_r (nullptr, " ,", &ptr)) { + char name[100]; + double val; + int ret = sscanf (tok, "%100[^=]=%lf", name, &val); + FAISS_THROW_IF_NOT_FMT ( + ret == 2, "could not interpret parameters %s", tok); + set_index_parameter (index, name, val); + } + +} + +void ParameterSpace::set_index_parameter ( + Index * index, const std::string & name, double val) const +{ + if (verbose > 1) + printf(" set %s=%g\n", name.c_str(), val); + + if (name == "verbose") { + index->verbose = int(val); + // and fall through to also enable it on sub-indexes + } + if (DC (IndexPreTransform)) { + set_index_parameter (ix->index, name, val); + return; + } + if (DC (IndexShards)) { + // call on all sub-indexes + auto fn = + [this, name, val](int, Index* subIndex) { + set_index_parameter(subIndex, name, val); + }; + + ix->runOnIndex(fn); + return; + } + if (DC (IndexReplicas)) { + // call on all sub-indexes + auto fn = + [this, name, val](int, Index* subIndex) { + set_index_parameter(subIndex, name, val); + }; + + ix->runOnIndex(fn); + return; + } + if (DC (IndexRefineFlat)) { + if (name == "k_factor_rf") { + ix->k_factor = int(val); + return; + } + // otherwise it is for the sub-index + set_index_parameter (&ix->refine_index, name, val); + return; + } + + if (name == "verbose") { + index->verbose = int(val); + return; // last verbose that we could find + } + + if (name == "nprobe") { + if (DC (IndexIDMap)) { + set_index_parameter (ix->index, name, val); + return; + } else if (DC (IndexIVF)) { + ix->nprobe = int(val); + return; + } + } + + if (name == "ht") { + if (DC (IndexPQ)) { + if (val >= ix->pq.code_size * 8) { + ix->search_type = IndexPQ::ST_PQ; + } else { + ix->search_type = IndexPQ::ST_polysemous; + ix->polysemous_ht = int(val); + } + return; + } else if (DC (IndexIVFPQ)) { + if (val >= ix->pq.code_size * 8) { + ix->polysemous_ht = 0; + } else { + ix->polysemous_ht = int(val); + } + return; + } + } + + if (name == "k_factor") { + if (DC (IndexIVFPQR)) { + ix->k_factor = val; + return; + } + } + if (name == "max_codes") { + if (DC (IndexIVF)) { + ix->max_codes = std::isfinite(val) ? size_t(val) : 0; + return; + } + } + + if (name == "efSearch") { + if (DC (IndexHNSW)) { + ix->hnsw.efSearch = int(val); + return; + } + if (DC (IndexIVF)) { + if (IndexHNSW *cq = + dynamic_cast(ix->quantizer)) { + cq->hnsw.efSearch = int(val); + return; + } + } + } + + FAISS_THROW_FMT ("ParameterSpace::set_index_parameter:" + "could not set parameter %s", + name.c_str()); +} + +void ParameterSpace::display () const +{ + printf ("ParameterSpace, %ld parameters, %ld combinations:\n", + parameter_ranges.size (), n_combinations ()); + for (int i = 0; i < parameter_ranges.size(); i++) { + const ParameterRange & pr = parameter_ranges[i]; + printf (" %s: ", pr.name.c_str ()); + char sep = '['; + for (int j = 0; j < pr.values.size(); j++) { + printf ("%c %g", sep, pr.values [j]); + sep = ','; + } + printf ("]\n"); + } +} + + + +void ParameterSpace::update_bounds (size_t cno, const OperatingPoint & op, + double *upper_bound_perf, + double *lower_bound_t) const +{ + if (combination_ge (cno, op.cno)) { + if (op.t > *lower_bound_t) *lower_bound_t = op.t; + } + if (combination_ge (op.cno, cno)) { + if (op.perf < *upper_bound_perf) *upper_bound_perf = op.perf; + } +} + + + +void ParameterSpace::explore (Index *index, + size_t nq, const float *xq, + const AutoTuneCriterion & crit, + OperatingPoints * ops) const +{ + FAISS_THROW_IF_NOT_MSG (nq == crit.nq, + "criterion does not have the same nb of queries"); + + size_t n_comb = n_combinations (); + + if (n_experiments == 0) { + + for (size_t cno = 0; cno < n_comb; cno++) { + set_index_parameters (index, cno); + std::vector I(nq * crit.nnn); + std::vector D(nq * crit.nnn); + + double t0 = getmillisecs (); + index->search (nq, xq, crit.nnn, D.data(), I.data()); + double t_search = (getmillisecs() - t0) / 1e3; + + double perf = crit.evaluate (D.data(), I.data()); + + bool keep = ops->add (perf, t_search, combination_name (cno), cno); + + if (verbose) + printf(" %ld/%ld: %s perf=%.3f t=%.3f s %s\n", cno, n_comb, + combination_name (cno).c_str(), perf, t_search, + keep ? "*" : ""); + } + return; + } + + int n_exp = n_experiments; + + if (n_exp > n_comb) n_exp = n_comb; + FAISS_THROW_IF_NOT (n_comb == 1 || n_exp > 2); + std::vector perm (n_comb); + // make sure the slowest and fastest experiment are run + perm[0] = 0; + if (n_comb > 1) { + perm[1] = n_comb - 1; + rand_perm (&perm[2], n_comb - 2, 1234); + for (int i = 2; i < perm.size(); i++) perm[i] ++; + } + + for (size_t xp = 0; xp < n_exp; xp++) { + size_t cno = perm[xp]; + + if (verbose) + printf(" %ld/%d: cno=%ld %s ", xp, n_exp, cno, + combination_name (cno).c_str()); + + { + double lower_bound_t = 0.0; + double upper_bound_perf = 1.0; + for (int i = 0; i < ops->all_pts.size(); i++) { + update_bounds (cno, ops->all_pts[i], + &upper_bound_perf, &lower_bound_t); + } + double best_t = ops->t_for_perf (upper_bound_perf); + if (verbose) + printf ("bounds [perf<=%.3f t>=%.3f] %s", + upper_bound_perf, lower_bound_t, + best_t <= lower_bound_t ? "skip\n" : ""); + if (best_t <= lower_bound_t) continue; + } + + set_index_parameters (index, cno); + std::vector I(nq * crit.nnn); + std::vector D(nq * crit.nnn); + + double t0 = getmillisecs (); + + int nrun = 0; + double t_search; + + do { + + if (thread_over_batches) { +#pragma omp parallel for + for (size_t q0 = 0; q0 < nq; q0 += batchsize) { + size_t q1 = q0 + batchsize; + if (q1 > nq) q1 = nq; + index->search (q1 - q0, xq + q0 * index->d, + crit.nnn, + D.data() + q0 * crit.nnn, + I.data() + q0 * crit.nnn); + } + } else { + for (size_t q0 = 0; q0 < nq; q0 += batchsize) { + size_t q1 = q0 + batchsize; + if (q1 > nq) q1 = nq; + index->search (q1 - q0, xq + q0 * index->d, + crit.nnn, + D.data() + q0 * crit.nnn, + I.data() + q0 * crit.nnn); + } + } + nrun ++; + t_search = (getmillisecs() - t0) / 1e3; + + } while (t_search < min_test_duration); + + t_search /= nrun; + + double perf = crit.evaluate (D.data(), I.data()); + + bool keep = ops->add (perf, t_search, combination_name (cno), cno); + + if (verbose) + printf(" perf %.3f t %.3f (%d runs) %s\n", + perf, t_search, nrun, + keep ? "*" : ""); + } +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/AutoTune.h b/core/src/index/thirdparty/faiss/AutoTune.h new file mode 100644 index 0000000000..d755844d6d --- /dev/null +++ b/core/src/index/thirdparty/faiss/AutoTune.h @@ -0,0 +1,212 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_AUTO_TUNE_H +#define FAISS_AUTO_TUNE_H + +#include +#include +#include + +#include +#include + +namespace faiss { + + +/** + * Evaluation criterion. Returns a performance measure in [0,1], + * higher is better. + */ +struct AutoTuneCriterion { + typedef Index::idx_t idx_t; + idx_t nq; ///< nb of queries this criterion is evaluated on + idx_t nnn; ///< nb of NNs that the query should request + idx_t gt_nnn; ///< nb of GT NNs required to evaluate criterion + + std::vector gt_D; ///< Ground-truth distances (size nq * gt_nnn) + std::vector gt_I; ///< Ground-truth indexes (size nq * gt_nnn) + + AutoTuneCriterion (idx_t nq, idx_t nnn); + + /** Intitializes the gt_D and gt_I vectors. Must be called before evaluating + * + * @param gt_D_in size nq * gt_nnn + * @param gt_I_in size nq * gt_nnn + */ + void set_groundtruth (int gt_nnn, const float *gt_D_in, + const idx_t *gt_I_in); + + /** Evaluate the criterion. + * + * @param D size nq * nnn + * @param I size nq * nnn + * @return the criterion, between 0 and 1. Larger is better. + */ + virtual double evaluate (const float *D, const idx_t *I) const = 0; + + virtual ~AutoTuneCriterion () {} + +}; + +struct OneRecallAtRCriterion: AutoTuneCriterion { + + idx_t R; + + OneRecallAtRCriterion (idx_t nq, idx_t R); + + double evaluate(const float* D, const idx_t* I) const override; + + ~OneRecallAtRCriterion() override {} +}; + + +struct IntersectionCriterion: AutoTuneCriterion { + + idx_t R; + + IntersectionCriterion (idx_t nq, idx_t R); + + double evaluate(const float* D, const idx_t* I) const override; + + ~IntersectionCriterion() override {} +}; + +/** + * Maintains a list of experimental results. Each operating point is a + * (perf, t, key) triplet, where higher perf and lower t is + * better. The key field is an arbitrary identifier for the operating point + */ + +struct OperatingPoint { + double perf; ///< performance measure (output of a Criterion) + double t; ///< corresponding execution time (ms) + std::string key; ///< key that identifies this op pt + int64_t cno; ///< integer identifer +}; + +struct OperatingPoints { + /// all operating points + std::vector all_pts; + + /// optimal operating points, sorted by perf + std::vector optimal_pts; + + // begins with a single operating point: t=0, perf=0 + OperatingPoints (); + + /// add operating points from other to this, with a prefix to the keys + int merge_with (const OperatingPoints &other, + const std::string & prefix = ""); + + void clear (); + + /// add a performance measure. Return whether it is an optimal point + bool add (double perf, double t, const std::string & key, size_t cno = 0); + + /// get time required to obtain a given performance measure + double t_for_perf (double perf) const; + + /// easy-to-read output + void display (bool only_optimal = true) const; + + /// output to a format easy to digest by gnuplot + void all_to_gnuplot (const char *fname) const; + void optimal_to_gnuplot (const char *fname) const; + +}; + +/// possible values of a parameter, sorted from least to most expensive/accurate +struct ParameterRange { + std::string name; + std::vector values; +}; + +/** Uses a-priori knowledge on the Faiss indexes to extract tunable parameters. + */ +struct ParameterSpace { + /// all tunable parameters + std::vector parameter_ranges; + + // exploration parameters + + /// verbosity during exploration + int verbose; + + /// nb of experiments during optimization (0 = try all combinations) + int n_experiments; + + /// maximum number of queries to submit at a time. + size_t batchsize; + + /// use multithreading over batches (useful to benchmark + /// independent single-searches) + bool thread_over_batches; + + /// run tests several times until they reach at least this + /// duration (to avoid jittering in MT mode) + double min_test_duration; + + ParameterSpace (); + + /// nb of combinations, = product of values sizes + size_t n_combinations () const; + + /// returns whether combinations c1 >= c2 in the tuple sense + bool combination_ge (size_t c1, size_t c2) const; + + /// get string representation of the combination + std::string combination_name (size_t cno) const; + + /// print a description on stdout + void display () const; + + /// add a new parameter (or return it if it exists) + ParameterRange &add_range(const char * name); + + /// initialize with reasonable parameters for the index + virtual void initialize (const Index * index); + + /// set a combination of parameters on an index + void set_index_parameters (Index *index, size_t cno) const; + + /// set a combination of parameters described by a string + void set_index_parameters (Index *index, const char *param_string) const; + + /// set one of the parameters + virtual void set_index_parameter ( + Index * index, const std::string & name, double val) const; + + /** find an upper bound on the performance and a lower bound on t + * for configuration cno given another operating point op */ + void update_bounds (size_t cno, const OperatingPoint & op, + double *upper_bound_perf, + double *lower_bound_t) const; + + /** explore operating points + * @param index index to run on + * @param xq query vectors (size nq * index.d) + * @param crit selection criterion + * @param ops resulting operating points + */ + void explore (Index *index, + size_t nq, const float *xq, + const AutoTuneCriterion & crit, + OperatingPoints * ops) const; + + virtual ~ParameterSpace () {} +}; + + + +} // namespace faiss + + + +#endif diff --git a/core/src/index/thirdparty/faiss/BuilderSuspend.cpp b/core/src/index/thirdparty/faiss/BuilderSuspend.cpp new file mode 100644 index 0000000000..dd32b630f7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/BuilderSuspend.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "BuilderSuspend.h" + +namespace faiss { + +std::atomic BuilderSuspend::suspend_flag_(false); +std::mutex BuilderSuspend::mutex_; +std::condition_variable BuilderSuspend::cv_; + +void BuilderSuspend::suspend() { + suspend_flag_ = true; +} + +void BuilderSuspend::resume() { + suspend_flag_ = false; +} + +void BuilderSuspend::check_wait() { + while (suspend_flag_) { + std::unique_lock lck(mutex_); + cv_.wait_for(lck, std::chrono::seconds(5)); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/BuilderSuspend.h b/core/src/index/thirdparty/faiss/BuilderSuspend.h new file mode 100644 index 0000000000..d5291a9628 --- /dev/null +++ b/core/src/index/thirdparty/faiss/BuilderSuspend.h @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +namespace faiss { + +class BuilderSuspend { +public: + static void suspend(); + static void resume(); + static void check_wait(); + +private: + static std::atomic suspend_flag_; + static std::mutex mutex_; + static std::condition_variable cv_; + +}; + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/CODE_OF_CONDUCT.md b/core/src/index/thirdparty/faiss/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..ac27d8a51b --- /dev/null +++ b/core/src/index/thirdparty/faiss/CODE_OF_CONDUCT.md @@ -0,0 +1,2 @@ +# Code of Conduct +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/CONTRIBUTING.md b/core/src/index/thirdparty/faiss/CONTRIBUTING.md new file mode 100644 index 0000000000..a93141be47 --- /dev/null +++ b/core/src/index/thirdparty/faiss/CONTRIBUTING.md @@ -0,0 +1,53 @@ +# Contributing to Faiss + +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process + +We mainly develop Faiss within Facebook. Sometimes, we will sync the +github version of Faiss with the internal state. + +## Pull Requests + +We welcome pull requests that add significant value to Faiss. If you plan to do +a major development and contribute it back to Faiss, please contact us first before +putting too much effort into it. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +There is a Facebook internal test suite for Faiss, and we need to run +all changes to Faiss through it. + +## Contributor License Agreement ("CLA") + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style + +* 4 or 2 spaces for indentation in C++ (no tabs) +* 80 character line length (both for C++ and Python) +* C++ language level: C++11 + +## License + +By contributing to Faiss, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. + diff --git a/core/src/index/thirdparty/faiss/Clustering.cpp b/core/src/index/thirdparty/faiss/Clustering.cpp new file mode 100644 index 0000000000..eba243d17d --- /dev/null +++ b/core/src/index/thirdparty/faiss/Clustering.cpp @@ -0,0 +1,526 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace faiss { + +ClusteringParameters::ClusteringParameters (): + niter(25), + nredo(1), + verbose(false), + spherical(false), + int_centroids(false), + update_index(false), + frozen_centroids(false), + min_points_per_centroid(39), + max_points_per_centroid(256), + seed(1234), + decode_block_size(32768) +{} +// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k + + +Clustering::Clustering (int d, int k): + d(d), k(k) {} + +Clustering::Clustering (int d, int k, const ClusteringParameters &cp): + ClusteringParameters (cp), d(d), k(k) {} + + + +static double imbalance_factor (int n, int k, int64_t *assign) { + std::vector hist(k, 0); + for (int i = 0; i < n; i++) + hist[assign[i]]++; + + double tot = 0, uf = 0; + + for (int i = 0 ; i < k ; i++) { + tot += hist[i]; + uf += hist[i] * (double) hist[i]; + } + uf = uf * k / (tot * tot); + + return uf; +} + +void Clustering::post_process_centroids () +{ + + if (spherical) { + fvec_renorm_L2 (d, k, centroids.data()); + } + + if (int_centroids) { + for (size_t i = 0; i < centroids.size(); i++) + centroids[i] = roundf (centroids[i]); + } +} + + +void Clustering::train (idx_t nx, const float *x_in, Index & index, + const float *weights) { + train_encoded (nx, reinterpret_cast(x_in), nullptr, + index, weights); +} + + +namespace { + +using idx_t = Clustering::idx_t; + +idx_t subsample_training_set( + const Clustering &clus, idx_t nx, const uint8_t *x, + size_t line_size, const float * weights, + uint8_t **x_out, + float **weights_out +) +{ + if (clus.verbose) { + printf("Sampling a subset of %ld / %ld for training\n", + clus.k * clus.max_points_per_centroid, nx); + } + std::vector perm (nx); + rand_perm (perm.data (), nx, clus.seed); + nx = clus.k * clus.max_points_per_centroid; + uint8_t * x_new = new uint8_t [nx * line_size]; + *x_out = x_new; + for (idx_t i = 0; i < nx; i++) { + memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size); + } + if (weights) { + float *weights_new = new float[nx]; + for (idx_t i = 0; i < nx; i++) { + weights_new[i] = weights[perm[i]]; + } + *weights_out = weights_new; + } else { + *weights_out = nullptr; + } + return nx; +} + +/** compute centroids as (weighted) sum of training points + * + * @param x training vectors, size n * code_size (from codec) + * @param codec how to decode the vectors (if NULL then cast to float*) + * @param weights per-training vector weight, size n (or NULL) + * @param assign nearest centroid for each training vector, size n + * @param k_frozen do not update the k_frozen first centroids + * @param centroids centroid vectors (output only), size k * d + * @param hassign histogram of assignments per centroid (size k), + * should be 0 on input + * + */ + +void compute_centroids (size_t d, size_t k, size_t n, + size_t k_frozen, + const uint8_t * x, const Index *codec, + const int64_t * assign, + const float * weights, + float * hassign, + float * centroids) +{ + k -= k_frozen; + centroids += k_frozen * d; + + memset (centroids, 0, sizeof(*centroids) * d * k); + + size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float); + +#pragma omp parallel + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // this thread is taking care of centroids c0:c1 + size_t c0 = (k * rank) / nt; + size_t c1 = (k * (rank + 1)) / nt; + std::vector decode_buffer (d); + + for (size_t i = 0; i < n; i++) { + int64_t ci = assign[i]; + assert (ci >= 0 && ci < k + k_frozen); + ci -= k_frozen; + if (ci >= c0 && ci < c1) { + float * c = centroids + ci * d; + const float * xi; + if (!codec) { + xi = reinterpret_cast(x + i * line_size); + } else { + float *xif = decode_buffer.data(); + codec->sa_decode (1, x + i * line_size, xif); + xi = xif; + } + if (weights) { + float w = weights[i]; + hassign[ci] += w; + for (size_t j = 0; j < d; j++) { + c[j] += xi[j] * w; + } + } else { + hassign[ci] += 1.0; + for (size_t j = 0; j < d; j++) { + c[j] += xi[j]; + } + } + } + } + + } + +#pragma omp parallel for + for (size_t ci = 0; ci < k; ci++) { + if (hassign[ci] == 0) { + continue; + } + float norm = 1 / hassign[ci]; + float * c = centroids + ci * d; + for (size_t j = 0; j < d; j++) { + c[j] *= norm; + } + } + +} + +// a bit above machine epsilon for float16 +#define EPS (1 / 1024.) + +/** Handle empty clusters by splitting larger ones. + * + * It works by slightly changing the centroids to make 2 clusters from + * a single one. Takes the same arguements as compute_centroids. + * + * @return nb of spliting operations (larger is worse) + */ +int split_clusters (size_t d, size_t k, size_t n, + size_t k_frozen, + float * hassign, + float * centroids) +{ + k -= k_frozen; + centroids += k_frozen * d; + + /* Take care of void clusters */ + size_t nsplit = 0; + RandomGenerator rng (1234); + for (size_t ci = 0; ci < k; ci++) { + if (hassign[ci] == 0) { /* need to redefine a centroid */ + size_t cj; + for (cj = 0; 1; cj = (cj + 1) % k) { + /* probability to pick this cluster for split */ + float p = (hassign[cj] - 1.0) / (float) (n - k); + float r = rng.rand_float (); + if (r < p) { + break; /* found our cluster to be split */ + } + } + memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d); + + /* small symmetric pertubation */ + for (size_t j = 0; j < d; j++) { + if (j % 2 == 0) { + centroids[ci * d + j] *= 1 + EPS; + centroids[cj * d + j] *= 1 - EPS; + } else { + centroids[ci * d + j] *= 1 - EPS; + centroids[cj * d + j] *= 1 + EPS; + } + } + + /* assume even split of the cluster */ + hassign[ci] = hassign[cj] / 2; + hassign[cj] -= hassign[ci]; + nsplit++; + } + } + + return nsplit; + +} + + + +}; + + +void Clustering::train_encoded (idx_t nx, const uint8_t *x_in, + const Index * codec, Index & index, + const float *weights) { + + FAISS_THROW_IF_NOT_FMT (nx >= k, + "Number of training points (%ld) should be at least " + "as large as number of clusters (%ld)", nx, k); + + FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d), + "Codec dimension %d not the same as data dimension %d", + int(codec->d), int(d)); + + FAISS_THROW_IF_NOT_FMT (index.d == d, + "Index dimension %d not the same as data dimension %d", + int(index.d), int(d)); + + double t0 = getmillisecs(); + + if (!codec) { + // Check for NaNs in input data. Normally it is the user's + // responsibility, but it may spare us some hard-to-debug + // reports. + const float *x = reinterpret_cast(x_in); + for (size_t i = 0; i < nx * d; i++) { + FAISS_THROW_IF_NOT_MSG (finite (x[i]), + "input contains NaN's or Inf's"); + } + } + + const uint8_t *x = x_in; + std::unique_ptr del1; + std::unique_ptr del3; + size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d; + + if (nx > k * max_points_per_centroid) { + uint8_t *x_new; + float *weights_new; + nx = subsample_training_set (*this, nx, x, line_size, weights, + &x_new, &weights_new); + del1.reset (x_new); x = x_new; + del3.reset (weights_new); weights = weights_new; + } else if (nx < k * min_points_per_centroid) { + fprintf (stderr, + "WARNING clustering %ld points to %ld centroids: " + "please provide at least %ld training points\n", + nx, k, idx_t(k) * min_points_per_centroid); + } + + if (nx == k) { + // this is a corner case, just copy training set to clusters + if (verbose) { + printf("Number of training points (%ld) same as number of " + "clusters, just copying\n", nx); + } + centroids.resize (d * k); + if (!codec) { + memcpy (centroids.data(), x_in, sizeof (float) * d * k); + } else { + codec->sa_decode (nx, x_in, centroids.data()); + } + + // one fake iteration... + ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 }; + iteration_stats.push_back (stats); + + index.reset(); + index.add(k, centroids.data()); + return; + } + + + if (verbose) { + printf("Clustering %d points in %ldD to %ld clusters, " + "redo %d times, %d iterations\n", + int(nx), d, k, nredo, niter); + if (codec) { + printf("Input data encoded in %ld bytes per vector\n", + codec->sa_code_size ()); + } + } + + std::unique_ptr assign(new idx_t[nx]); + std::unique_ptr dis(new float[nx]); + + // remember best iteration for redo + float best_err = HUGE_VALF; + std::vector best_obj; + std::vector best_centroids; + + // support input centroids + + FAISS_THROW_IF_NOT_MSG ( + centroids.size() % d == 0, + "size of provided input centroids not a multiple of dimension" + ); + + size_t n_input_centroids = centroids.size() / d; + + if (verbose && n_input_centroids > 0) { + printf (" Using %zd centroids provided as input (%sfrozen)\n", + n_input_centroids, frozen_centroids ? "" : "not "); + } + + double t_search_tot = 0; + if (verbose) { + printf(" Preprocessing in %.2f s\n", + (getmillisecs() - t0) / 1000.); + } + t0 = getmillisecs(); + + // temporary buffer to decode vectors during the optimization + std::vector decode_buffer + (codec ? d * decode_block_size : 0); + + for (int redo = 0; redo < nredo; redo++) { + + if (verbose && nredo > 1) { + printf("Outer iteration %d / %d\n", redo, nredo); + } + + // initialize (remaining) centroids with random points from the dataset + centroids.resize (d * k); + std::vector perm (nx); + + rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L); + + if (!codec) { + for (int i = n_input_centroids; i < k ; i++) { + memcpy (¢roids[i * d], x + perm[i] * line_size, line_size); + } + } else { + for (int i = n_input_centroids; i < k ; i++) { + codec->sa_decode (1, x + perm[i] * line_size, ¢roids[i * d]); + } + } + + post_process_centroids (); + + // prepare the index + + if (index.ntotal != 0) { + index.reset(); + } + + if (!index.is_trained) { + index.train (k, centroids.data()); + } + + index.add (k, centroids.data()); + + // k-means iterations + + float err = 0; + for (int i = 0; i < niter; i++) { + double t0s = getmillisecs(); + + if (!codec) { + index.assign (nx, reinterpret_cast(x), + assign.get(), dis.get()); + } else { + // search by blocks of decode_block_size vectors + size_t code_size = codec->sa_code_size (); + for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) { + size_t i1 = i0 + decode_block_size; + if (i1 > nx) { i1 = nx; } + codec->sa_decode (i1 - i0, x + code_size * i0, + decode_buffer.data ()); + index.search (i1 - i0, decode_buffer.data (), 1, + dis.get() + i0, assign.get() + i0); + } + } + + InterruptCallback::check(); + t_search_tot += getmillisecs() - t0s; + + // accumulate error + err = 0; + for (int j = 0; j < nx; j++) { + err += dis[j]; + } + + // update the centroids + std::vector hassign (k); + + size_t k_frozen = frozen_centroids ? n_input_centroids : 0; + compute_centroids ( + d, k, nx, k_frozen, + x, codec, assign.get(), weights, + hassign.data(), centroids.data() + ); + + int nsplit = split_clusters ( + d, k, nx, k_frozen, + hassign.data(), centroids.data() + ); + + // collect statistics + ClusteringIterationStats stats = + { err, (getmillisecs() - t0) / 1000.0, + t_search_tot / 1000, imbalance_factor (nx, k, assign.get()), + nsplit }; + iteration_stats.push_back(stats); + + if (verbose) { + printf (" Iteration %d (%.2f s, search %.2f s): " + "objective=%g imbalance=%.3f nsplit=%d \r", + i, stats.time, stats.time_search, stats.obj, + stats.imbalance_factor, nsplit); + fflush (stdout); + } + + post_process_centroids (); + + // add centroids to index for the next iteration (or for output) + + index.reset (); + if (update_index) { + index.train (k, centroids.data()); + } + + index.add (k, centroids.data()); + InterruptCallback::check (); + } + + if (verbose) printf("\n"); + if (nredo > 1) { + if (err < best_err) { + if (verbose) { + printf ("Objective improved: keep new clusters\n"); + } + best_centroids = centroids; + best_obj = iteration_stats; + best_err = err; + } + index.reset (); + } + } + if (nredo > 1) { + centroids = best_centroids; + iteration_stats = best_obj; + index.reset(); + index.add(k, best_centroids.data()); + } + +} + +float kmeans_clustering (size_t d, size_t n, size_t k, + const float *x, + float *centroids) +{ + Clustering clus (d, k); + clus.verbose = d * n * k > (1L << 30); + // display logs if > 1Gflop per iteration + IndexFlatL2 index (d); + clus.train (n, x, index); + memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k); + return clus.iteration_stats.back().obj; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/Clustering.h b/core/src/index/thirdparty/faiss/Clustering.h new file mode 100644 index 0000000000..46410af79f --- /dev/null +++ b/core/src/index/thirdparty/faiss/Clustering.h @@ -0,0 +1,129 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_CLUSTERING_H +#define FAISS_CLUSTERING_H +#include + +#include + +namespace faiss { + + +/** Class for the clustering parameters. Can be passed to the + * constructor of the Clustering object. + */ +struct ClusteringParameters { + int niter; ///< clustering iterations + int nredo; ///< redo clustering this many times and keep best + + bool verbose; + bool spherical; ///< do we want normalized centroids? + bool int_centroids; ///< round centroids coordinates to integer + bool update_index; ///< re-train index after each iteration? + bool frozen_centroids; ///< use the centroids provided as input and do not change them during iterations + + int min_points_per_centroid; ///< otherwise you get a warning + int max_points_per_centroid; ///< to limit size of dataset + + int seed; ///< seed for the random number generator + + size_t decode_block_size; ///< how many vectors at a time to decode + + /// sets reasonable defaults + ClusteringParameters (); +}; + + +struct ClusteringIterationStats { + float obj; ///< objective values (sum of distances reported by index) + double time; ///< seconds for iteration + double time_search; ///< seconds for just search + double imbalance_factor; ///< imbalance factor of iteration + int nsplit; ///< number of cluster splits +}; + + +/** K-means clustering based on assignment - centroid update iterations + * + * The clustering is based on an Index object that assigns training + * points to the centroids. Therefore, at each iteration the centroids + * are added to the index. + * + * On output, the centoids table is set to the latest version + * of the centroids and they are also added to the index. If the + * centroids table it is not empty on input, it is also used for + * initialization. + * + */ +struct Clustering: ClusteringParameters { + typedef Index::idx_t idx_t; + size_t d; ///< dimension of the vectors + size_t k; ///< nb of centroids + + /** centroids (k * d) + * if centroids are set on input to train, they will be used as initialization + */ + std::vector centroids; + + /// stats at every iteration of clustering + std::vector iteration_stats; + + Clustering (int d, int k); + Clustering (int d, int k, const ClusteringParameters &cp); + + /** run k-means training + * + * @param x training vectors, size n * d + * @param index index used for assignment + * @param x_weights weight associated to each vector: NULL or size n + */ + virtual void train (idx_t n, const float * x, faiss::Index & index, + const float *x_weights = nullptr); + + + /** run with encoded vectors + * + * win addition to train()'s parameters takes a codec as parameter + * to decode the input vectors. + * + * @param codec codec used to decode the vectors (nullptr = + * vectors are in fact floats) * + */ + void train_encoded (idx_t nx, const uint8_t *x_in, + const Index * codec, Index & index, + const float *weights = nullptr); + + /// Post-process the centroids after each centroid update. + /// includes optional L2 normalization and nearest integer rounding + void post_process_centroids (); + + virtual ~Clustering() {} +}; + + +/** simplified interface + * + * @param d dimension of the data + * @param n nb of training vectors + * @param k nb of output centroids + * @param x training set (size n * d) + * @param centroids output centroids (size k * d) + * @return final quantization error + */ +float kmeans_clustering (size_t d, size_t n, size_t k, + const float *x, + float *centroids); + + + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/DirectMap.cpp b/core/src/index/thirdparty/faiss/DirectMap.cpp new file mode 100644 index 0000000000..bd3cf5460f --- /dev/null +++ b/core/src/index/thirdparty/faiss/DirectMap.cpp @@ -0,0 +1,267 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include + +namespace faiss { + +DirectMap::DirectMap(): type(NoMap) +{} + +void DirectMap::set_type (Type new_type, const InvertedLists *invlists, size_t ntotal) { + + FAISS_THROW_IF_NOT (new_type == NoMap || new_type == Array || + new_type == Hashtable); + + if (new_type == type) { + // nothing to do + return; + } + + array.clear (); + hashtable.clear (); + type = new_type; + + if (new_type == NoMap) { + return; + } else if (new_type == Array) { + array.resize (ntotal, -1); + } else if (new_type == Hashtable) { + hashtable.reserve (ntotal); + } + + for (size_t key = 0; key < invlists->nlist; key++) { + size_t list_size = invlists->list_size (key); + InvertedLists::ScopedIds idlist (invlists, key); + + if (new_type == Array) { + for (long ofs = 0; ofs < list_size; ofs++) { + FAISS_THROW_IF_NOT_MSG ( + 0 <= idlist [ofs] && idlist[ofs] < ntotal, + "direct map supported only for seuquential ids"); + array [idlist [ofs]] = lo_build(key, ofs); + } + } else if (new_type == Hashtable) { + for (long ofs = 0; ofs < list_size; ofs++) { + hashtable [idlist [ofs]] = lo_build(key, ofs); + } + } + } +} + +void DirectMap::clear() +{ + array.clear (); + hashtable.clear (); +} + + +DirectMap::idx_t DirectMap::get (idx_t key) const +{ + if (type == Array) { + FAISS_THROW_IF_NOT_MSG ( + key >= 0 && key < array.size(), "invalid key" + ); + idx_t lo = array[key]; + FAISS_THROW_IF_NOT_MSG(lo >= 0, "-1 entry in direct_map"); + return lo; + } else if (type == Hashtable) { + auto res = hashtable.find (key); + FAISS_THROW_IF_NOT_MSG (res != hashtable.end(), "key not found"); + return res->second; + } else { + FAISS_THROW_MSG ("direct map not initialized"); + } +} + + + +void DirectMap::add_single_id (idx_t id, idx_t list_no, size_t offset) +{ + if (type == NoMap) return; + + if (type == Array) { + assert (id == array.size()); + if (list_no >= 0) { + array.push_back (lo_build (list_no, offset)); + } else { + array.push_back (-1); + } + } else if (type == Hashtable) { + if (list_no >= 0) { + hashtable[id] = lo_build (list_no, offset); + } + } + +} + +void DirectMap::check_can_add (const idx_t *ids) { + if (type == Array && ids) { + FAISS_THROW_MSG ("cannot have array direct map and add with ids"); + } +} + +/********************* DirectMapAdd implementation */ + + +DirectMapAdd::DirectMapAdd (DirectMap &direct_map, size_t n, const idx_t *xids): + direct_map(direct_map), type(direct_map.type), n(n), xids(xids) +{ + if (type == DirectMap::Array) { + FAISS_THROW_IF_NOT (xids == nullptr); + ntotal = direct_map.array.size(); + direct_map.array.resize (ntotal + n, -1); + } else if (type == DirectMap::Hashtable) { + // can't parallel update hashtable so use temp array + all_ofs.resize (n, -1); + } +} + + +void DirectMapAdd::add (size_t i, idx_t list_no, size_t ofs) +{ + if (type == DirectMap::Array) { + direct_map.array [ntotal + i] = lo_build (list_no, ofs); + } else if (type == DirectMap::Hashtable) { + all_ofs [i] = lo_build (list_no, ofs); + } +} + +DirectMapAdd::~DirectMapAdd () +{ + if (type == DirectMap::Hashtable) { + for (int i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + direct_map.hashtable [id] = all_ofs [i]; + } + } +} + +/********************************************************/ + +using ScopedCodes = InvertedLists::ScopedCodes; +using ScopedIds = InvertedLists::ScopedIds; + + +size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists *invlists) +{ + size_t nlist = invlists->nlist; + std::vector toremove(nlist); + + size_t nremove = 0; + + if (type == NoMap) { + // exhaustive scan of IVF +#pragma omp parallel for + for (idx_t i = 0; i < nlist; i++) { + idx_t l0 = invlists->list_size (i), l = l0, j = 0; + ScopedIds idsi (invlists, i); + while (j < l) { + if (sel.is_member (idsi[j])) { + l--; + invlists->update_entry ( + i, j, + invlists->get_single_id (i, l), + ScopedCodes (invlists, i, l).get() + ); + } else { + j++; + } + } + toremove[i] = l0 - l; + } + // this will not run well in parallel on ondisk because of + // possible shrinks + for (idx_t i = 0; i < nlist; i++) { + if (toremove[i] > 0) { + nremove += toremove[i]; + invlists->resize(i, invlists->list_size(i) - toremove[i]); + } + } + } else if (type == Hashtable) { + const IDSelectorArray *sela = + dynamic_cast(&sel); + FAISS_THROW_IF_NOT_MSG ( + sela, + "remove with hashtable works only with IDSelectorArray" + ); + + for (idx_t i = 0; i < sela->n; i++) { + idx_t id = sela->ids[i]; + auto res = hashtable.find (id); + if (res != hashtable.end()) { + size_t list_no = lo_listno (res->second); + size_t offset = lo_offset (res->second); + idx_t last = invlists->list_size (list_no) - 1; + hashtable.erase (res); + if (offset < last) { + idx_t last_id = invlists->get_single_id (list_no, last); + invlists->update_entry ( + list_no, offset, + last_id, + ScopedCodes (invlists, list_no, last).get() + ); + // update hash entry for last element + hashtable [last_id] = list_no << 32 | offset; + } + invlists->resize(list_no, last); + nremove++; + } + } + + } else { + FAISS_THROW_MSG("remove not supported with this direct_map format"); + } + return nremove; +} + +void DirectMap::update_codes (InvertedLists *invlists, + int n, const idx_t *ids, + const idx_t *assign, + const uint8_t *codes) +{ + FAISS_THROW_IF_NOT (type == Array); + + size_t code_size = invlists->code_size; + + for (size_t i = 0; i < n; i++) { + idx_t id = ids[i]; + FAISS_THROW_IF_NOT_MSG (0 <= id && id < array.size(), + "id to update out of range"); + { // remove old one + idx_t dm = array [id]; + int64_t ofs = lo_offset (dm); + int64_t il = lo_listno (dm); + size_t l = invlists->list_size (il); + if (ofs != l - 1) { // move l - 1 to ofs + int64_t id2 = invlists->get_single_id (il, l - 1); + array[id2] = lo_build (il, ofs); + invlists->update_entry (il, ofs, id2, + invlists->get_single_code (il, l - 1)); + } + invlists->resize (il, l - 1); + } + { // insert new one + int64_t il = assign[i]; + size_t l = invlists->list_size (il); + idx_t dm = lo_build (il, l); + array [id] = dm; + invlists->add_entry (il, id, codes + i * code_size); + } + } +} + + +} diff --git a/core/src/index/thirdparty/faiss/DirectMap.h b/core/src/index/thirdparty/faiss/DirectMap.h new file mode 100644 index 0000000000..27ea1c7260 --- /dev/null +++ b/core/src/index/thirdparty/faiss/DirectMap.h @@ -0,0 +1,120 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_DIRECT_MAP_H +#define FAISS_DIRECT_MAP_H + +#include +#include + + +namespace faiss { + +// When offsets list id + offset are encoded in an uint64 +// we call this LO = list-offset + +inline uint64_t lo_build (uint64_t list_id, uint64_t offset) { + return list_id << 32 | offset; +} + +inline uint64_t lo_listno (uint64_t lo) { + return lo >> 32; +} + +inline uint64_t lo_offset (uint64_t lo) { + return lo & 0xffffffff; +} + +/** + * Direct map: a way to map back from ids to inverted lists + */ +struct DirectMap { + typedef Index::idx_t idx_t; + + enum Type { + NoMap = 0, // default + Array = 1, // sequential ids (only for add, no add_with_ids) + Hashtable = 2 // arbitrary ids + }; + Type type; + + /// map for direct access to the elements. Map ids to LO-encoded entries. + std::vector array; + std::unordered_map hashtable; + + DirectMap(); + + /// set type and initialize + void set_type (Type new_type, const InvertedLists *invlists, size_t ntotal); + + /// get an entry + idx_t get (idx_t id) const; + + /// for quick checks + bool no () const {return type == NoMap; } + + /** + * update the direct_map + */ + + /// throw if Array and ids is not NULL + void check_can_add (const idx_t *ids); + + /// non thread-safe version + void add_single_id (idx_t id, idx_t list_no, size_t offset); + + /// remove all entries + void clear(); + + /** + * operations on inverted lists that require translation with a DirectMap + */ + + /// remove ids from the InvertedLists, possibly using the direct map + size_t remove_ids(const IDSelector& sel, InvertedLists *invlists); + + /// update entries, using the direct map + void update_codes (InvertedLists *invlists, + int n, const idx_t *ids, + const idx_t *list_nos, + const uint8_t *codes); + + + +}; + +/// Thread-safe way of updating the direct_map +struct DirectMapAdd { + + typedef Index::idx_t idx_t; + + using Type = DirectMap::Type; + + DirectMap &direct_map; + DirectMap::Type type; + size_t ntotal; + size_t n; + const idx_t *xids; + + std::vector all_ofs; + + DirectMapAdd (DirectMap &direct_map, size_t n, const idx_t *xids); + + /// add vector i (with id xids[i]) at list_no and offset + void add (size_t i, idx_t list_no, size_t offset); + + ~DirectMapAdd (); +}; + + + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/Dockerfile b/core/src/index/thirdparty/faiss/Dockerfile new file mode 100644 index 0000000000..9da42ef70f --- /dev/null +++ b/core/src/index/thirdparty/faiss/Dockerfile @@ -0,0 +1,29 @@ +FROM nvidia/cuda:8.0-devel-centos7 + +# Install MKL +RUN yum-config-manager --add-repo https://yum.repos.intel.com/mkl/setup/intel-mkl.repo +RUN rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB +RUN yum install -y intel-mkl-2019.3-062 +ENV LD_LIBRARY_PATH /opt/intel/mkl/lib/intel64:$LD_LIBRARY_PATH +ENV LIBRARY_PATH /opt/intel/mkl/lib/intel64:$LIBRARY_PATH +ENV LD_PRELOAD /usr/lib64/libgomp.so.1:/opt/intel/mkl/lib/intel64/libmkl_def.so:\ +/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:\ +/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_gnu_thread.so + +# Install necessary build tools +RUN yum install -y gcc-c++ make swig3 + +# Install necesary headers/libs +RUN yum install -y python-devel numpy + +COPY . /opt/faiss + +WORKDIR /opt/faiss + +# --with-cuda=/usr/local/cuda-8.0 +RUN ./configure --prefix=/usr --libdir=/usr/lib64 --without-cuda +RUN make -j "$(nproc)" +RUN make -C python +RUN make test +RUN make install +RUN make -C demos demo_ivfpq_indexing && ./demos/demo_ivfpq_indexing diff --git a/core/src/index/thirdparty/faiss/FaissHook.cpp b/core/src/index/thirdparty/faiss/FaissHook.cpp new file mode 100644 index 0000000000..e20ab37e55 --- /dev/null +++ b/core/src/index/thirdparty/faiss/FaissHook.cpp @@ -0,0 +1,109 @@ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +bool faiss_use_avx512 = true; +bool faiss_use_avx2 = true; +bool faiss_use_sse = true; + +/* set default to AVX */ +fvec_func_ptr fvec_inner_product = fvec_inner_product_avx; +fvec_func_ptr fvec_L2sqr = fvec_L2sqr_avx; +fvec_func_ptr fvec_L1 = fvec_L1_avx; +fvec_func_ptr fvec_Linf = fvec_Linf_avx; + +sq_get_distance_computer_func_ptr sq_get_distance_computer = sq_get_distance_computer_avx; +sq_sel_quantizer_func_ptr sq_sel_quantizer = sq_select_quantizer_avx; +sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx; + +/*****************************************************************************/ + +bool support_avx512() { + if (!faiss_use_avx512) return false; + + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX512F() && + instruction_set_inst.AVX512DQ() && + instruction_set_inst.AVX512BW()); +} + +bool support_avx2() { + if (!faiss_use_avx2) return false; + + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.AVX2()); +} + +bool support_sse() { + if (!faiss_use_sse) return false; + + InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); + return (instruction_set_inst.SSE42()); +} + +bool hook_init(std::string& cpu_flag) { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + + if (support_avx512()) { + /* for IVFFLAT */ + fvec_inner_product = fvec_inner_product_avx512; + fvec_L2sqr = fvec_L2sqr_avx512; + fvec_L1 = fvec_L1_avx512; + fvec_Linf = fvec_Linf_avx512; + + /* for IVFSQ */ + sq_get_distance_computer = sq_get_distance_computer_avx512; + sq_sel_quantizer = sq_select_quantizer_avx512; + sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx512; + + cpu_flag = "AVX512"; + } else if (support_avx2()) { + /* for IVFFLAT */ + fvec_inner_product = fvec_inner_product_avx; + fvec_L2sqr = fvec_L2sqr_avx; + fvec_L1 = fvec_L1_avx; + fvec_Linf = fvec_Linf_avx; + + /* for IVFSQ */ + sq_get_distance_computer = sq_get_distance_computer_avx; + sq_sel_quantizer = sq_select_quantizer_avx; + sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx; + + cpu_flag = "AVX2"; + } else if (support_sse()) { + /* for IVFFLAT */ + fvec_inner_product = fvec_inner_product_sse; + fvec_L2sqr = fvec_L2sqr_sse; + fvec_L1 = fvec_L1_sse; + fvec_Linf = fvec_Linf_sse; + + /* for IVFSQ */ + sq_get_distance_computer = sq_get_distance_computer_ref; + sq_sel_quantizer = sq_select_quantizer_ref; + sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_ref; + + cpu_flag = "SSE42"; + } else { + cpu_flag = "UNSUPPORTED"; + return false; + } + + return true; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/FaissHook.h b/core/src/index/thirdparty/faiss/FaissHook.h new file mode 100644 index 0000000000..f1aa98f606 --- /dev/null +++ b/core/src/index/thirdparty/faiss/FaissHook.h @@ -0,0 +1,40 @@ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace faiss { + +typedef float (*fvec_func_ptr)(const float*, const float*, size_t); + +typedef SQDistanceComputer* (*sq_get_distance_computer_func_ptr)(MetricType, QuantizerType, size_t, const std::vector&); +typedef Quantizer* (*sq_sel_quantizer_func_ptr)(QuantizerType, size_t, const std::vector&); +typedef InvertedListScanner* (*sq_sel_inv_list_scanner_func_ptr)(MetricType, const ScalarQuantizer*, const Index*, size_t, bool, bool); + +extern bool faiss_use_avx512; +extern bool faiss_use_avx2; +extern bool faiss_use_sse; + +extern fvec_func_ptr fvec_inner_product; +extern fvec_func_ptr fvec_L2sqr; +extern fvec_func_ptr fvec_L1; +extern fvec_func_ptr fvec_Linf; + +extern sq_get_distance_computer_func_ptr sq_get_distance_computer; +extern sq_sel_quantizer_func_ptr sq_sel_quantizer; +extern sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner; + +extern bool support_avx512(); +extern bool support_avx2(); +extern bool support_sse(); + +extern bool hook_init(std::string& cpu_flag); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/INSTALL.md b/core/src/index/thirdparty/faiss/INSTALL.md new file mode 100644 index 0000000000..01f29e46e1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/INSTALL.md @@ -0,0 +1,353 @@ +[//]: # "**********************************************************" +[//]: # "** INSTALL file for Faiss (Fair AI Similarity Search **" +[//]: # "**********************************************************" + +INSTALL file for Faiss (Fair AI Similarity Search) +================================================== + +Install via Conda +----------------- + +The easiest way to install FAISS is from Anaconda. We regularly push stable releases to the pytorch conda channel. + +Currently we support faiss-cpu both on Linux and OSX. We also provide faiss-gpu compiled with CUDA8/CUDA9/CUDA10 on Linux systems. + +You can easily install it by + +``` +# CPU version only +conda install faiss-cpu -c pytorch + +# GPU version +conda install faiss-gpu cudatoolkit=8.0 -c pytorch # For CUDA8 +conda install faiss-gpu cudatoolkit=9.0 -c pytorch # For CUDA9 +conda install faiss-gpu cudatoolkit=10.0 -c pytorch # For CUDA10 +``` + +Compile from source +------------------- + +The Faiss compilation works in 2 steps: + +1. compile the C++ core and examples + +2. compile the Python interface + +Steps 2 depends on 1. + +It is also possible to build a pure C interface. This optional process is +described separately (please see the [C interface installation file](c_api/INSTALL.md)) + +General compilation instructions +================================ + +TL;DR: `./configure && make (&& make install)` for the C++ library, and then `cd python; make && make install` for the python interface. + +1. `./configure` + +This generates the system-dependent configuration for the `Makefile`, stored in +a file called `makefile.inc`. + +A few useful options: +- `./configure --without-cuda` in order to build the CPU part only. +- `./configure --with-cuda=/path/to/cuda-10.1` in order to hint to the path of +the cudatoolkit. +- `./configure --with-cuda-arch="-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_72,code=sm_72"` for specifying which GPU architectures to build against. +- `./configure --with-python=/path/to/python3.7` in order to build a python +interface for a different python than the default one. +- `LDFLAGS=-L/path_to_mkl/lib/ ./configure` so that configure detects the MKL BLAS imeplementation. Note that this may require to set the LD_LIBRARY_PATH at runtime. + +2. `make` + +This builds the C++ library (the whole library if a suitable cuda toolkit was +found, or the CPU part only otherwise). + +3. `make install` (optional) + +This installs the headers and libraries. + +4. `make -C python` (or `make py`) + +This builds the python interface. + +5. `make -C python install` + +This installs the python library. + + +Faiss has been tested only on x86_64 machines on Linux and Mac OS. + +Faiss requires a C++ compiler that understands: +- the Intel intrinsics for SSE instructions, +- the GCC intrinsic for the popcount instruction, +- basic OpenMP. + +There are a few examples for makefile.inc in the example_makefiles/ +subdirectory. There are also indications for specific configurations in the +troubleshooting section of the wiki. + +https://github.com/facebookresearch/faiss/wiki/Troubleshooting + +Faiss comes as a .a archive, that can be linked with executables or +dynamic libraries (useful for the Python wrapper). + + +BLAS/Lapack +----------- + +The only variables that need to be configured for the C++ Faiss are +the BLAS/Lapack flags (a linear aglebra software package). It needs a +flag telling whether BLAS/Lapack uses 32 or 64 bit integers and the +linking flags. Faiss uses the Fortran 77 interface of BLAS/Lapack and +thus does not need an include path. + +There are several BLAS implementations, depending on the OS and +machine. To have reasonable performance, the BLAS library should be +multithreaded. See the example makefile.inc's for hints and examples +on how to set the flags, or simply run the configure script: + + `./configure` + +To check that the link flags are correct, and verify whether the +implementation uses 32 or 64 bit integers, you can + + `make misc/test_blas` + +and run + + `./misc/test_blas` + + +Testing Faiss +------------- + +A basic usage example is in + + `demos/demo_ivfpq_indexing` + +which you can build by calling + `make -C demos demo_ivfpq_indexing` + +It makes a small index, stores it and performs some searches. A normal +runtime is around 20s. With a fast machine and Intel MKL's BLAS it +runs in 2.5s. + +To run the whole test suite: + + `make test` (for the CPU part) + + `make test_gpu` (for the GPU part) + + +A real-life benchmark +--------------------- + +A bit longer example runs and evaluates Faiss on the SIFT1M +dataset. To run it, please download the ANN_SIFT1M dataset from + +http://corpus-texmex.irisa.fr/ + +and unzip it to the subdirectory `sift1M` at the root of the source +directory for this repository. + +Then compile and run the following (after ensuring you have installed faiss): + +``` +make demos +./demos/demo_sift1M +``` + +This is a demonstration of the high-level auto-tuning API. You can try +setting a different index_key to find the indexing structure that +gives the best performance. + + +The Python interface +====================================== + +The Python interface is compiled with + + `make -C python` (or `make py`) + +How it works +------------ + +The Python interface is provided via SWIG (Simple Wrapper and +Interface Generator) and an additional level of manual wrappers (in python/faiss.py). + +SWIG generates two wrapper files: a Python file (`python/swigfaiss.py`) and a +C++ file that must be compiled to a dynamic library (`python/_swigfaiss.so`). + +Testing the Python wrapper +-------------------------- + +Often, a successful compile does not mean that the library works, +because missing symbols are detected only at runtime. You should be +able to load the Faiss dynamic library: + + `python -c "import faiss"` + +In case of failure, it reports the first missing symbol. To see all +missing symbols (on Linux), use + + `ldd -r _swigfaiss.so` + +Sometimes, problems (eg with BLAS libraries) appear only when actually +calling a BLAS function. A simple way to check this + +```python +python -c "import faiss, numpy +faiss.Kmeans(10, 20).train(numpy.random.rand(1000, 10).astype('float32')) +``` + + +Real-life test +-------------- + +The following script extends the demo_sift1M test to several types of +indexes. This must be run from the root of the source directory for this +repository: + +``` +mkdir tmp # graphs of the output will be written here +PYTHONPATH=. python demos/demo_auto_tune.py +``` + +It will cycle through a few types of indexes and find optimal +operating points. You can play around with the types of indexes. + + +Step 3: Compiling the GPU implementation +======================================== + +The GPU version is a superset of the CPU version. In addition it +requires the cuda compiler and related libraries (Cublas) + +The nvcc-specific flags to pass to the compiler, based on your desired +compute capability can be customized by providing the `--with-cuda-arch` to +`./configure`. Only compute capability 3.5+ is supported. For example, we enable +by default: + +``` +-gencode=arch=compute_35,code=compute_35 +-gencode=arch=compute_52,code=compute_52 +-gencode=arch=compute_60,code=compute_60 +-gencode=arch=compute_61,code=compute_61 +-gencode=arch=compute_70,code=compute_70 +-gencode=arch=compute_75,code=compute_75 +``` + +However, look at https://developer.nvidia.com/cuda-gpus to determine +what compute capability you need to use, and replace our gencode +specifications with the one(s) you need. + +Most other flags are related to the C++11 compiler used by nvcc to +complile the actual C++ code. They are normally just transmitted by +nvcc, except some of them that are not recognized and that should be +escaped by prefixing them with -Xcompiler. Also link flags that are +prefixed with -Wl, should be passed with -Xlinker. + +You may want to add `-j 10` to use 10 threads during compile. + +Testing the GPU implementation +------------------------------ + +Compile the example with + + `make -C gpu/test demo_ivfpq_indexing_gpu` + +This produce the GPU code equivalent to the CPU +demo_ivfpq_indexing. It also shows how to translate indexed from/to +the GPU. + + +Python example with GPU support +------------------------------- + +The auto-tuning example above also runs on the GPU. Edit +`demos/demo_auto_tune.py` at line 100 with the values + +```python +keys_to_test = keys_gpu +use_gpu = True +``` + +and you can run + +``` +export PYTHONPATH=. +python demos/demo_auto_tune.py +``` + +to test the GPU code. + + +Docker instructions +=================== + +For using GPU capabilities of Faiss, you'll need to run "nvidia-docker" +rather than "docker". Make sure that docker +(https://docs.docker.com/engine/installation/) and nvidia-docker +(https://github.com/NVIDIA/nvidia-docker) are installed on your system + +To build the "faiss" image, run + + `nvidia-docker build -t faiss .` + +or if you don't want/need to clone the sources, just run + + `nvidia-docker build -t faiss github.com/facebookresearch/faiss` + +If you want to run the tests during the docker build, uncomment the +last 3 "RUN" steps in the Dockerfile. But you might want to run the +tests by yourself, so just run + + `nvidia-docker run -ti --name faiss faiss bash` + +and run what you want. If you need a dataset (like sift1M), download it +inside the created container, or better, mount a directory from the host + + nvidia-docker run -ti --name faiss -v /my/host/data/folder/ann_dataset/sift/:/opt/faiss/sift1M faiss bash + + +How to use Faiss in your own projects +===================================== + +C++ +--- + +The makefile generates a static and a dynamic library + +``` +libfaiss.a +libfaiss.so (or libfaiss.dylib) +``` + +the executable should be linked to one of these. If you use +the static version (.a), add the LDFLAGS used in the Makefile. + +For binary-only distributions, the headers should be under +a `faiss/` directory, so that they can be included as + +```c++ +#include +#include +``` + +Python +------ + +To import Faiss in your own Python project, you need the files + +``` +__init__.py +swigfaiss.py +_swigfaiss.so +``` +to be present in a `faiss/` directory visible in the PYTHONPATH or in the +current directory. +Then Faiss can be used in python with + +```python +import faiss +``` diff --git a/core/src/index/thirdparty/faiss/IVFlib.cpp b/core/src/index/thirdparty/faiss/IVFlib.cpp new file mode 100644 index 0000000000..098b729357 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IVFlib.cpp @@ -0,0 +1,364 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include +#include +#include +#include + + +namespace faiss { namespace ivflib { + + +void check_compatible_for_merge (const Index * index0, + const Index * index1) +{ + + const faiss::IndexPreTransform *pt0 = + dynamic_cast(index0); + + if (pt0) { + const faiss::IndexPreTransform *pt1 = + dynamic_cast(index1); + FAISS_THROW_IF_NOT_MSG (pt1, "both indexes should be pretransforms"); + + FAISS_THROW_IF_NOT (pt0->chain.size() == pt1->chain.size()); + for (int i = 0; i < pt0->chain.size(); i++) { + FAISS_THROW_IF_NOT (typeid(pt0->chain[i]) == typeid(pt1->chain[i])); + } + + index0 = pt0->index; + index1 = pt1->index; + } + FAISS_THROW_IF_NOT (typeid(index0) == typeid(index1)); + FAISS_THROW_IF_NOT (index0->d == index1->d && + index0->metric_type == index1->metric_type); + + const faiss::IndexIVF *ivf0 = dynamic_cast(index0); + if (ivf0) { + const faiss::IndexIVF *ivf1 = + dynamic_cast(index1); + FAISS_THROW_IF_NOT (ivf1); + + ivf0->check_compatible_for_merge (*ivf1); + } + + // TODO: check as thoroughfully for other index types + +} + +const IndexIVF * try_extract_index_ivf (const Index * index) +{ + if (auto *pt = + dynamic_cast(index)) { + index = pt->index; + } + + if (auto *idmap = + dynamic_cast(index)) { + index = idmap->index; + } + if (auto *idmap = + dynamic_cast(index)) { + index = idmap->index; + } + + auto *ivf = dynamic_cast(index); + + return ivf; +} + +IndexIVF * try_extract_index_ivf (Index * index) { + return const_cast (try_extract_index_ivf ((const Index*)(index))); +} + +const IndexIVF * extract_index_ivf (const Index * index) +{ + const IndexIVF *ivf = try_extract_index_ivf (index); + FAISS_THROW_IF_NOT (ivf); + return ivf; +} + +IndexIVF * extract_index_ivf (Index * index) { + return const_cast (extract_index_ivf ((const Index*)(index))); +} + + +void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) { + + check_compatible_for_merge (index0, index1); + IndexIVF * ivf0 = extract_index_ivf (index0); + IndexIVF * ivf1 = extract_index_ivf (index1); + + ivf0->merge_from (*ivf1, shift_ids ? ivf0->ntotal : 0); + + // useful for IndexPreTransform + index0->ntotal = ivf0->ntotal; + index1->ntotal = ivf1->ntotal; +} + + + +void search_centroid(faiss::Index *index, + const float* x, int n, + idx_t* centroid_ids) +{ + std::unique_ptr del; + if (auto index_pre = dynamic_cast(index)) { + x = index_pre->apply_chain(n, x); + del.reset((float*)x); + index = index_pre->index; + } + faiss::IndexIVF* index_ivf = dynamic_cast(index); + assert(index_ivf); + index_ivf->quantizer->assign(n, x, centroid_ids); +} + + + +void search_and_return_centroids(faiss::Index *index, + size_t n, + const float* xin, + long k, + float *distances, + idx_t* labels, + idx_t* query_centroid_ids, + idx_t* result_centroid_ids) +{ + const float *x = xin; + std::unique_ptr del; + if (auto index_pre = dynamic_cast(index)) { + x = index_pre->apply_chain(n, x); + del.reset((float*)x); + index = index_pre->index; + } + faiss::IndexIVF* index_ivf = dynamic_cast(index); + assert(index_ivf); + + size_t nprobe = index_ivf->nprobe; + std::vector cent_nos (n * nprobe); + std::vector cent_dis (n * nprobe); + index_ivf->quantizer->search( + n, x, nprobe, cent_dis.data(), cent_nos.data()); + + if (query_centroid_ids) { + for (size_t i = 0; i < n; i++) + query_centroid_ids[i] = cent_nos[i * nprobe]; + } + + index_ivf->search_preassigned (n, x, k, + cent_nos.data(), cent_dis.data(), + distances, labels, true); + + for (size_t i = 0; i < n * k; i++) { + idx_t label = labels[i]; + if (label < 0) { + if (result_centroid_ids) + result_centroid_ids[i] = -1; + } else { + long list_no = lo_listno (label); + long list_index = lo_offset (label); + if (result_centroid_ids) + result_centroid_ids[i] = list_no; + labels[i] = index_ivf->invlists->get_single_id(list_no, list_index); + } + } +} + + +SlidingIndexWindow::SlidingIndexWindow (Index *index): index (index) { + n_slice = 0; + IndexIVF* index_ivf = const_cast(extract_index_ivf (index)); + ils = dynamic_cast (index_ivf->invlists); + nlist = ils->nlist; + FAISS_THROW_IF_NOT_MSG (ils, + "only supports indexes with ArrayInvertedLists"); + sizes.resize(nlist); +} + +template +static void shift_and_add (std::vector & dst, + size_t remove, + const std::vector & src) +{ + if (remove > 0) + memmove (dst.data(), dst.data() + remove, + (dst.size() - remove) * sizeof (T)); + size_t insert_point = dst.size() - remove; + dst.resize (insert_point + src.size()); + memcpy (dst.data() + insert_point, src.data (), src.size() * sizeof(T)); +} + +template +static void remove_from_begin (std::vector & v, + size_t remove) +{ + if (remove > 0) + v.erase (v.begin(), v.begin() + remove); +} + +void SlidingIndexWindow::step(const Index *sub_index, bool remove_oldest) { + + FAISS_THROW_IF_NOT_MSG (!remove_oldest || n_slice > 0, + "cannot remove slice: there is none"); + + const ArrayInvertedLists *ils2 = nullptr; + if(sub_index) { + check_compatible_for_merge (index, sub_index); + ils2 = dynamic_cast( + extract_index_ivf (sub_index)->invlists); + FAISS_THROW_IF_NOT_MSG (ils2, "supports only ArrayInvertedLists"); + } + IndexIVF *index_ivf = extract_index_ivf (index); + + if (remove_oldest && ils2) { + for (int i = 0; i < nlist; i++) { + std::vector & sizesi = sizes[i]; + size_t amount_to_remove = sizesi[0]; + index_ivf->ntotal += ils2->ids[i].size() - amount_to_remove; + + shift_and_add (ils->ids[i], amount_to_remove, ils2->ids[i]); + shift_and_add (ils->codes[i], amount_to_remove * ils->code_size, + ils2->codes[i]); + for (int j = 0; j + 1 < n_slice; j++) { + sizesi[j] = sizesi[j + 1] - amount_to_remove; + } + sizesi[n_slice - 1] = ils->ids[i].size(); + } + } else if (ils2) { + for (int i = 0; i < nlist; i++) { + index_ivf->ntotal += ils2->ids[i].size(); + shift_and_add (ils->ids[i], 0, ils2->ids[i]); + shift_and_add (ils->codes[i], 0, ils2->codes[i]); + sizes[i].push_back(ils->ids[i].size()); + } + n_slice++; + } else if (remove_oldest) { + for (int i = 0; i < nlist; i++) { + size_t amount_to_remove = sizes[i][0]; + index_ivf->ntotal -= amount_to_remove; + remove_from_begin (ils->ids[i], amount_to_remove); + remove_from_begin (ils->codes[i], + amount_to_remove * ils->code_size); + for (int j = 0; j + 1 < n_slice; j++) { + sizes[i][j] = sizes[i][j + 1] - amount_to_remove; + } + sizes[i].pop_back (); + } + n_slice--; + } else { + FAISS_THROW_MSG ("nothing to do???"); + } + index->ntotal = index_ivf->ntotal; +} + + + +// Get a subset of inverted lists [i0, i1). Works on IndexIVF's and +// IndexIVF's embedded in a IndexPreTransform + +ArrayInvertedLists * +get_invlist_range (const Index *index, long i0, long i1) +{ + const IndexIVF *ivf = extract_index_ivf (index); + + FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist); + + const InvertedLists *src = ivf->invlists; + + ArrayInvertedLists * il = new ArrayInvertedLists(i1 - i0, src->code_size); + + for (long i = i0; i < i1; i++) { + il->add_entries(i - i0, src->list_size(i), + InvertedLists::ScopedIds (src, i).get(), + InvertedLists::ScopedCodes (src, i).get()); + } + return il; +} + + + +void set_invlist_range (Index *index, long i0, long i1, + ArrayInvertedLists * src) +{ + IndexIVF *ivf = extract_index_ivf (index); + + FAISS_THROW_IF_NOT (0 <= i0 && i0 <= i1 && i1 <= ivf->nlist); + + ArrayInvertedLists *dst = dynamic_cast(ivf->invlists); + FAISS_THROW_IF_NOT_MSG (dst, "only ArrayInvertedLists supported"); + FAISS_THROW_IF_NOT (src->nlist == i1 - i0 && + dst->code_size == src->code_size); + + size_t ntotal = index->ntotal; + for (long i = i0 ; i < i1; i++) { + ntotal -= dst->list_size (i); + ntotal += src->list_size (i - i0); + std::swap (src->codes[i - i0], dst->codes[i]); + std::swap (src->ids[i - i0], dst->ids[i]); + } + ivf->ntotal = index->ntotal = ntotal; +} + + +void search_with_parameters (const Index *index, + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + IVFSearchParameters *params, + size_t *nb_dis_ptr) +{ + FAISS_THROW_IF_NOT (params); + const float *prev_x = x; + ScopeDeleter del; + + if (auto ip = dynamic_cast (index)) { + x = ip->apply_chain (n, x); + if (x != prev_x) { + del.set(x); + } + index = ip->index; + } + + std::vector Iq(params->nprobe * n); + std::vector Dq(params->nprobe * n); + + const IndexIVF *index_ivf = dynamic_cast(index); + FAISS_THROW_IF_NOT (index_ivf); + + double t0 = getmillisecs(); + index_ivf->quantizer->search(n, x, params->nprobe, + Dq.data(), Iq.data()); + double t1 = getmillisecs(); + indexIVF_stats.quantization_time += t1 - t0; + + if (nb_dis_ptr) { + size_t nb_dis = 0; + const InvertedLists *il = index_ivf->invlists; + for (idx_t i = 0; i < n * params->nprobe; i++) { + if (Iq[i] >= 0) { + nb_dis += il->list_size(Iq[i]); + } + } + *nb_dis_ptr = nb_dis; + } + + index_ivf->search_preassigned(n, x, k, Iq.data(), Dq.data(), + distances, labels, + false, params); + double t2 = getmillisecs(); + indexIVF_stats.search_time += t2 - t1; +} + + + +} } // namespace faiss::ivflib diff --git a/core/src/index/thirdparty/faiss/IVFlib.h b/core/src/index/thirdparty/faiss/IVFlib.h new file mode 100644 index 0000000000..879fd19086 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IVFlib.h @@ -0,0 +1,136 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_IVFLIB_H +#define FAISS_IVFLIB_H + +/** Since IVF (inverted file) indexes are of so much use for + * large-scale use cases, we group a few functions related to them in + * this small library. Most functions work both on IndexIVFs and + * IndexIVFs embedded within an IndexPreTransform. + */ + +#include +#include + +namespace faiss { namespace ivflib { + + +/** check if two indexes have the same parameters and are trained in + * the same way, otherwise throw. */ +void check_compatible_for_merge (const Index * index1, + const Index * index2); + +/** get an IndexIVF from an index. The index may be an IndexIVF or + * some wrapper class that encloses an IndexIVF + * + * throws an exception if this is not the case. + */ +const IndexIVF * extract_index_ivf (const Index * index); +IndexIVF * extract_index_ivf (Index * index); + +/// same as above but returns nullptr instead of throwing on failure +const IndexIVF * try_extract_index_ivf (const Index * index); +IndexIVF * try_extract_index_ivf (Index * index); + +/** Merge index1 into index0. Works on IndexIVF's and IndexIVF's + * embedded in a IndexPreTransform. On output, the index1 is empty. + * + * @param shift_ids: translate the ids from index1 to index0->prev_ntotal + */ +void merge_into(Index *index0, Index *index1, bool shift_ids); + +typedef Index::idx_t idx_t; + +/* Returns the cluster the embeddings belong to. + * + * @param index Index, which should be an IVF index + * (otherwise there are no clusters) + * @param embeddings object descriptors for which the centroids should be found, + * size num_objects * d + * @param centroid_ids + * cluster id each object belongs to, size num_objects + */ +void search_centroid(Index *index, + const float* x, int n, + idx_t* centroid_ids); + +/* Returns the cluster the embeddings belong to. + * + * @param index Index, which should be an IVF index + * (otherwise there are no clusters) + * @param query_centroid_ids + * centroid ids corresponding to the query vectors (size n) + * @param result_centroid_ids + * centroid ids corresponding to the results (size n * k) + * other arguments are the same as the standard search function + */ +void search_and_return_centroids(Index *index, + size_t n, + const float* xin, + long k, + float *distances, + idx_t* labels, + idx_t* query_centroid_ids, + idx_t* result_centroid_ids); + + +/** A set of IndexIVFs concatenated together in a FIFO fashion. + * at each "step", the oldest index slice is removed and a new index is added. + */ +struct SlidingIndexWindow { + /// common index that contains the sliding window + Index * index; + + /// InvertedLists of index + ArrayInvertedLists *ils; + + /// number of slices currently in index + int n_slice; + + /// same as index->nlist + size_t nlist; + + /// cumulative list sizes at each slice + std::vector > sizes; + + /// index should be initially empty and trained + SlidingIndexWindow (Index *index); + + /** Add one index to the current index and remove the oldest one. + * + * @param sub_index slice to swap in (can be NULL) + * @param remove_oldest if true, remove the oldest slices */ + void step(const Index *sub_index, bool remove_oldest); + +}; + + +/// Get a subset of inverted lists [i0, i1) +ArrayInvertedLists * get_invlist_range (const Index *index, + long i0, long i1); + +/// Set a subset of inverted lists +void set_invlist_range (Index *index, long i0, long i1, + ArrayInvertedLists * src); + +// search an IndexIVF, possibly embedded in an IndexPreTransform with +// given parameters. Optionally returns the number of distances +// computed +void search_with_parameters (const Index *index, + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + IVFSearchParameters *params, + size_t *nb_dis = nullptr); + + + +} } // namespace faiss::ivflib + +#endif diff --git a/core/src/index/thirdparty/faiss/Index.cpp b/core/src/index/thirdparty/faiss/Index.cpp new file mode 100644 index 0000000000..b11cfb2683 --- /dev/null +++ b/core/src/index/thirdparty/faiss/Index.cpp @@ -0,0 +1,195 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include + + +namespace faiss { + +Index::~Index () +{ +} + + +void Index::train(idx_t /*n*/, const float* /*x*/) { + // does nothing by default +} + + +void Index::range_search (idx_t , const float *, float, + RangeSearchResult *, + ConcurrentBitsetPtr) const +{ + FAISS_THROW_MSG ("range search not implemented"); +} + +void Index::assign (idx_t n, const float* x, idx_t* labels, float* distance) +{ + float *dis_inner = (distance == nullptr) ? new float[n] : distance; + search (n, x, 1, dis_inner, labels); + if (distance == nullptr) { + delete[] dis_inner; + } +} + +void Index::add_with_ids( + idx_t /*n*/, + const float* /*x*/, + const idx_t* /*xids*/) { + FAISS_THROW_MSG ("add_with_ids not implemented for this type of index"); +} + + +void Index::add_without_codes(idx_t n, const float* x) { + FAISS_THROW_MSG ("add_without_codes not implemented for this type of index"); +} + +void Index::add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) { + FAISS_THROW_MSG ("add_with_ids_without_codes not implemented for this type of index"); +} + +#if 0 +void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { + FAISS_THROW_MSG ("get_vector_by_id not implemented for this type of index"); +} + +void Index::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) { + FAISS_THROW_MSG ("search_by_id not implemented for this type of index"); +} +#endif + +size_t Index::remove_ids(const IDSelector& /*sel*/) { + FAISS_THROW_MSG ("remove_ids not implemented for this type of index"); + return -1; +} + + +void Index::reconstruct (idx_t, float * ) const { + FAISS_THROW_MSG ("reconstruct not implemented for this type of index"); +} + + +void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const { + for (idx_t i = 0; i < ni; i++) { + reconstruct (i0 + i, recons + i * d); + } +} + + +void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const { + search (n, x, k, distances, labels); + for (idx_t i = 0; i < n; ++i) { + for (idx_t j = 0; j < k; ++j) { + idx_t ij = i * k + j; + idx_t key = labels[ij]; + float* reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + reconstruct (key, reconstructed); + } + } + } +} + +void Index::compute_residual (const float * x, + float * residual, idx_t key) const { + reconstruct (key, residual); + for (size_t i = 0; i < d; i++) { + residual[i] = x[i] - residual[i]; + } +} + +void Index::compute_residual_n (idx_t n, const float* xs, + float* residuals, + const idx_t* keys) const { +#pragma omp parallel for + for (idx_t i = 0; i < n; ++i) { + compute_residual(&xs[i * d], &residuals[i * d], keys[i]); + } +} + + + +size_t Index::sa_code_size () const +{ + FAISS_THROW_MSG ("standalone codec not implemented for this type of index"); +} + +void Index::sa_encode (idx_t, const float *, + uint8_t *) const +{ + FAISS_THROW_MSG ("standalone codec not implemented for this type of index"); +} + +void Index::sa_decode (idx_t, const uint8_t *, + float *) const +{ + FAISS_THROW_MSG ("standalone codec not implemented for this type of index"); +} + + +namespace { + + +// storage that explicitly reconstructs vectors before computing distances +struct GenericDistanceComputer : DistanceComputer { + size_t d; + const Index& storage; + std::vector buf; + const float *q; + + explicit GenericDistanceComputer(const Index& storage) + : storage(storage) { + d = storage.d; + buf.resize(d * 2); + } + + float operator () (idx_t i) override { + storage.reconstruct(i, buf.data()); + return fvec_L2sqr(q, buf.data(), d); + } + + float symmetric_dis(idx_t i, idx_t j) override { + storage.reconstruct(i, buf.data()); + storage.reconstruct(j, buf.data() + d); + return fvec_L2sqr(buf.data() + d, buf.data(), d); + } + + void set_query(const float *x) override { + q = x; + } + +}; + + +} // namespace + + +DistanceComputer * Index::get_distance_computer() const { + if (metric_type == METRIC_L2) { + return new GenericDistanceComputer(*this); + } else { + FAISS_THROW_MSG ("get_distance_computer() not implemented"); + } +} + + +} diff --git a/core/src/index/thirdparty/faiss/Index.h b/core/src/index/thirdparty/faiss/Index.h new file mode 100644 index 0000000000..9e8d22dba4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/Index.h @@ -0,0 +1,285 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_H +#define FAISS_INDEX_H + +#include +#include +#include +#include +#include +#include + +#define FAISS_VERSION_MAJOR 1 +#define FAISS_VERSION_MINOR 6 +#define FAISS_VERSION_PATCH 3 + +/** + * @namespace faiss + * + * Throughout the library, vectors are provided as float * pointers. + * Most algorithms can be optimized when several vectors are processed + * (added/searched) together in a batch. In this case, they are passed + * in as a matrix. When n vectors of size d are provided as float * x, + * component j of vector i is + * + * x[ i * d + j ] + * + * where 0 <= i < n and 0 <= j < d. In other words, matrices are + * always compact. When specifying the size of the matrix, we call it + * an n*d matrix, which implies a row-major storage. + */ + + +namespace faiss { + +/// Forward declarations see AuxIndexStructures.h +struct IDSelector; +struct RangeSearchResult; +struct DistanceComputer; + +/** Abstract structure for an index, supports adding vectors and searching them. + * + * All vectors provided at add or search time are 32-bit float arrays, + * although the internal representation may vary. + */ +struct Index { + using idx_t = int64_t; ///< all indices are this type + using component_t = float; + using distance_t = float; + + int d; ///< vector dimension + idx_t ntotal; ///< total nb of indexed vectors + bool verbose; ///< verbosity level + + /// set if the Index does not require training, or if training is + /// done already + bool is_trained; + + /// type of metric this index uses for search + MetricType metric_type; + float metric_arg; ///< argument of the metric type + + explicit Index (idx_t d = 0, MetricType metric = METRIC_L2): + d(d), + ntotal(0), + verbose(false), + is_trained(true), + metric_type (metric), + metric_arg(0) {} + + virtual ~Index (); + + + /** Perform training on a representative set of vectors + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void train(idx_t n, const float* x); + + /** Add n vectors of dimension d to the index. + * + * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 + * This function slices the input vectors in chuncks smaller than + * blocksize_add and calls add_core. + * @param x input matrix, size n * d + */ + virtual void add (idx_t n, const float *x) = 0; + + /** Same as add, but only add ids, not codes + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void add_without_codes(idx_t n, const float* x); + + /** Same as add, but stores xids instead of sequential ids. + * + * The default implementation fails with an assertion, as it is + * not supported by all indexes. + * + * @param xids if non-null, ids to store for the vectors (size n) + */ + virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids); + + /** Same as add_with_ids, but only add ids, not codes + * + * @param xids if non-null, ids to store for the vectors (size n) + */ + virtual void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids); + + /** query n vectors of dimension d to the index. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + * @param bitset flags to check the validity of vectors + */ + virtual void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const = 0; + +#if 0 + /** query n raw vectors from the index by ids. + * + * return n raw vectors. + * + * @param n input num of xid + * @param xid input labels of the NNs, size n + * @param x output raw vectors, size n * d + * @param bitset flags to check the validity of vectors + */ + virtual void get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset = nullptr); + + /** query n vectors of dimension d to the index by ids. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param xid input ids to search, size n + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + * @param bitset flags to check the validity of vectors + */ + virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr); +#endif + + /** query n vectors of dimension d to the index. + * + * return all vectors with distance < radius. Note that many + * indexes do not implement the range_search (only the k-NN search + * is mandatory). + * + * @param x input vectors to search, size n * d + * @param radius search radius + * @param result result table + */ + virtual void range_search (idx_t n, const float *x, float radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const; + + /** return the indexes of the k vectors closest to the query x. + * + * This function is identical as search but only return labels of neighbors. + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n + */ + virtual void assign (idx_t n, const float* x, idx_t* labels, float* distance = nullptr); + + /// removes all elements from the database. + virtual void reset() = 0; + + /** removes IDs from the index. Not supported by all + * indexes. Returns the number of elements removed. + */ + virtual size_t remove_ids (const IDSelector & sel); + + /** Reconstruct a stored vector (or an approximation if lossy coding) + * + * this function may not be defined for some indexes + * @param key id of the vector to reconstruct + * @param recons reconstucted vector (size d) + */ + virtual void reconstruct (idx_t key, float * recons) const; + + /** Reconstruct vectors i0 to i0 + ni - 1 + * + * this function may not be defined for some indexes + * @param recons reconstucted vector (size ni * d) + */ + virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const; + + /** Similar to search, but also reconstructs the stored vectors (or an + * approximation in the case of lossy coding) for the search results. + * + * If there are not enough results for a query, the resulting arrays + * is padded with -1s. + * + * @param recons reconstructed vectors size (n, k, d) + **/ + virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const; + + /** Computes a residual vector after indexing encoding. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param x input vector, size d + * @param residual output residual vector, size d + * @param key encoded index, as returned by search and assign + */ + virtual void compute_residual (const float * x, + float * residual, idx_t key) const; + + /** Computes a residual vector after indexing encoding (batch form). + * Equivalent to calling compute_residual for each vector. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param n number of vectors + * @param xs input vectors, size (n x d) + * @param residuals output residual vectors, size (n x d) + * @param keys encoded index, as returned by search and assign + */ + virtual void compute_residual_n (idx_t n, const float* xs, + float* residuals, + const idx_t* keys) const; + + /** Get a DistanceComputer (defined in AuxIndexStructures) object + * for this kind of index. + * + * DistanceComputer is implemented for indexes that support random + * access of their vectors. + */ + virtual DistanceComputer * get_distance_computer() const; + + + /* The standalone codec interface */ + + /** size of the produced codes in bytes */ + virtual size_t sa_code_size () const; + + /** encode a set of vectors + * + * @param n number of vectors + * @param x input vectors, size n * d + * @param bytes output encoded vectors, size n * sa_code_size() + */ + virtual void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const; + + /** encode a set of vectors + * + * @param n number of vectors + * @param bytes input encoded vectors, size n * sa_code_size() + * @param x output vectors, size n * d + */ + virtual void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const; + + +}; + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/Index2Layer.cpp b/core/src/index/thirdparty/faiss/Index2Layer.cpp new file mode 100644 index 0000000000..cbdfd75426 --- /dev/null +++ b/core/src/index/thirdparty/faiss/Index2Layer.cpp @@ -0,0 +1,437 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#ifdef __SSE__ +#include +#endif + +#include + +#include + +#include +#include +#include +#include +#include +#include + +/* +#include + +#include + +#include + + +*/ + + +namespace faiss { + + +/************************************* + * Index2Layer implementation + *************************************/ + + +Index2Layer::Index2Layer (Index * quantizer, size_t nlist, + int M, int nbit, + MetricType metric): + Index (quantizer->d, metric), + q1 (quantizer, nlist), + pq (quantizer->d, M, nbit) +{ + is_trained = false; + for (int nbyte = 0; nbyte < 7; nbyte++) { + if ((1L << (8 * nbyte)) >= nlist) { + code_size_1 = nbyte; + break; + } + } + code_size_2 = pq.code_size; + code_size = code_size_1 + code_size_2; +} + +Index2Layer::Index2Layer () +{ + code_size = code_size_1 = code_size_2 = 0; +} + +Index2Layer::~Index2Layer () +{} + +void Index2Layer::train(idx_t n, const float* x) +{ + if (verbose) { + printf ("training level-1 quantizer %ld vectors in %dD\n", + n, d); + } + + q1.train_q1 (n, x, verbose, metric_type); + + if (verbose) { + printf("computing residuals\n"); + } + + const float * x_in = x; + + x = fvecs_maybe_subsample ( + d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub, + x, verbose, pq.cp.seed); + + ScopeDeleter del_x (x_in == x ? nullptr : x); + + std::vector assign(n); // assignement to coarse centroids + q1.quantizer->assign (n, x, assign.data()); + std::vector residuals(n * d); + for (idx_t i = 0; i < n; i++) { + q1.quantizer->compute_residual ( + x + i * d, residuals.data() + i * d, assign[i]); + } + + if (verbose) + printf ("training %zdx%zd product quantizer on %ld vectors in %dD\n", + pq.M, pq.ksub, n, d); + pq.verbose = verbose; + pq.train (n, residuals.data()); + + is_trained = true; +} + +void Index2Layer::add(idx_t n, const float* x) +{ + idx_t bs = 32768; + if (n > bs) { + for (idx_t i0 = 0; i0 < n; i0 += bs) { + idx_t i1 = std::min(i0 + bs, n); + if (verbose) { + printf("Index2Layer::add: adding %ld:%ld / %ld\n", + i0, i1, n); + } + add (i1 - i0, x + i0 * d); + } + return; + } + + std::vector codes1 (n); + q1.quantizer->assign (n, x, codes1.data()); + std::vector residuals(n * d); + for (idx_t i = 0; i < n; i++) { + q1.quantizer->compute_residual ( + x + i * d, residuals.data() + i * d, codes1[i]); + } + std::vector codes2 (n * code_size_2); + + pq.compute_codes (residuals.data(), codes2.data(), n); + + codes.resize ((ntotal + n) * code_size); + uint8_t *wp = &codes[ntotal * code_size]; + + { + int i = 0x11223344; + const char *ip = (char*)&i; + FAISS_THROW_IF_NOT_MSG (ip[0] == 0x44, + "works only on a little-endian CPU"); + } + + // copy to output table + for (idx_t i = 0; i < n; i++) { + memcpy (wp, &codes1[i], code_size_1); + wp += code_size_1; + memcpy (wp, &codes2[i * code_size_2], code_size_2); + wp += code_size_2; + } + + ntotal += n; + +} + +void Index2Layer::search( + idx_t /*n*/, + const float* /*x*/, + idx_t /*k*/, + float* /*distances*/, + idx_t* /*labels*/, + ConcurrentBitsetPtr) const { + FAISS_THROW_MSG("not implemented"); +} + + +void Index2Layer::reconstruct_n(idx_t i0, idx_t ni, float* recons) const +{ + float recons1[d]; + FAISS_THROW_IF_NOT (i0 >= 0 && i0 + ni <= ntotal); + const uint8_t *rp = &codes[i0 * code_size]; + + for (idx_t i = 0; i < ni; i++) { + idx_t key = 0; + memcpy (&key, rp, code_size_1); + q1.quantizer->reconstruct (key, recons1); + rp += code_size_1; + pq.decode (rp, recons); + for (idx_t j = 0; j < d; j++) { + recons[j] += recons1[j]; + } + rp += code_size_2; + recons += d; + } +} + +void Index2Layer::transfer_to_IVFPQ (IndexIVFPQ & other) const +{ + FAISS_THROW_IF_NOT (other.nlist == q1.nlist); + FAISS_THROW_IF_NOT (other.code_size == code_size_2); + FAISS_THROW_IF_NOT (other.ntotal == 0); + + const uint8_t *rp = codes.data(); + + for (idx_t i = 0; i < ntotal; i++) { + idx_t key = 0; + memcpy (&key, rp, code_size_1); + rp += code_size_1; + other.invlists->add_entry (key, i, rp); + rp += code_size_2; + } + + other.ntotal = ntotal; + +} + + + +void Index2Layer::reconstruct(idx_t key, float* recons) const +{ + reconstruct_n (key, 1, recons); +} + +void Index2Layer::reset() +{ + ntotal = 0; + codes.clear (); +} + + +namespace { + + +struct Distance2Level : DistanceComputer { + size_t d; + const Index2Layer& storage; + std::vector buf; + const float *q; + + const float *pq_l1_tab, *pq_l2_tab; + + explicit Distance2Level(const Index2Layer& storage) + : storage(storage) { + d = storage.d; + FAISS_ASSERT(storage.pq.dsub == 4); + pq_l2_tab = storage.pq.centroids.data(); + buf.resize(2 * d); + } + + float symmetric_dis(idx_t i, idx_t j) override { + storage.reconstruct(i, buf.data()); + storage.reconstruct(j, buf.data() + d); + return fvec_L2sqr(buf.data() + d, buf.data(), d); + } + + void set_query(const float *x) override { + q = x; + } +}; + +// well optimized for xNN+PQNN +struct DistanceXPQ4 : Distance2Level { + + int M, k; + + explicit DistanceXPQ4(const Index2Layer& storage) + : Distance2Level (storage) { + const IndexFlat *quantizer = + dynamic_cast (storage.q1.quantizer); + + FAISS_ASSERT(quantizer); + M = storage.pq.M; + pq_l1_tab = quantizer->xb.data(); + } + + float operator () (idx_t i) override { +#ifdef __SSE__ + const uint8_t *code = storage.codes.data() + i * storage.code_size; + long key = 0; + memcpy (&key, code, storage.code_size_1); + code += storage.code_size_1; + + // walking pointers + const float *qa = q; + const __m128 *l1_t = (const __m128 *)(pq_l1_tab + d * key); + const __m128 *pq_l2_t = (const __m128 *)pq_l2_tab; + __m128 accu = _mm_setzero_ps(); + + for (int m = 0; m < M; m++) { + __m128 qi = _mm_loadu_ps(qa); + __m128 recons = l1_t[m] + pq_l2_t[*code++]; + __m128 diff = qi - recons; + accu += diff * diff; + pq_l2_t += 256; + qa += 4; + } + + accu = _mm_hadd_ps (accu, accu); + accu = _mm_hadd_ps (accu, accu); + return _mm_cvtss_f32 (accu); +#else + FAISS_THROW_MSG("not implemented for non-x64 platforms"); +#endif + } + +}; + +// well optimized for 2xNN+PQNN +struct Distance2xXPQ4 : Distance2Level { + + int M_2, mi_nbits; + + explicit Distance2xXPQ4(const Index2Layer& storage) + : Distance2Level(storage) { + const MultiIndexQuantizer *mi = + dynamic_cast (storage.q1.quantizer); + + FAISS_ASSERT(mi); + FAISS_ASSERT(storage.pq.M % 2 == 0); + M_2 = storage.pq.M / 2; + mi_nbits = mi->pq.nbits; + pq_l1_tab = mi->pq.centroids.data(); + } + + float operator () (idx_t i) override { + const uint8_t *code = storage.codes.data() + i * storage.code_size; + long key01 = 0; + memcpy (&key01, code, storage.code_size_1); + code += storage.code_size_1; +#ifdef __SSE__ + + // walking pointers + const float *qa = q; + const __m128 *pq_l1_t = (const __m128 *)pq_l1_tab; + const __m128 *pq_l2_t = (const __m128 *)pq_l2_tab; + __m128 accu = _mm_setzero_ps(); + + for (int mi_m = 0; mi_m < 2; mi_m++) { + long l1_idx = key01 & ((1L << mi_nbits) - 1); + const __m128 * pq_l1 = pq_l1_t + M_2 * l1_idx; + + for (int m = 0; m < M_2; m++) { + __m128 qi = _mm_loadu_ps(qa); + __m128 recons = pq_l1[m] + pq_l2_t[*code++]; + __m128 diff = qi - recons; + accu += diff * diff; + pq_l2_t += 256; + qa += 4; + } + pq_l1_t += M_2 << mi_nbits; + key01 >>= mi_nbits; + } + accu = _mm_hadd_ps (accu, accu); + accu = _mm_hadd_ps (accu, accu); + return _mm_cvtss_f32 (accu); +#else + FAISS_THROW_MSG("not implemented for non-x64 platforms"); +#endif + } + +}; + + +} // namespace + + +DistanceComputer * Index2Layer::get_distance_computer() const { +#ifdef __SSE__ + const MultiIndexQuantizer *mi = + dynamic_cast (q1.quantizer); + + if (mi && pq.M % 2 == 0 && pq.dsub == 4) { + return new Distance2xXPQ4(*this); + } + + const IndexFlat *fl = + dynamic_cast (q1.quantizer); + + if (fl && pq.dsub == 4) { + return new DistanceXPQ4(*this); + } +#endif + + return Index::get_distance_computer(); +} + + +/* The standalone codec interface */ +size_t Index2Layer::sa_code_size () const +{ + return code_size; +} + +void Index2Layer::sa_encode (idx_t n, const float *x, uint8_t *bytes) const +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr list_nos (new int64_t [n]); + q1.quantizer->assign (n, x, list_nos.get()); + std::vector residuals(n * d); + for (idx_t i = 0; i < n; i++) { + q1.quantizer->compute_residual ( + x + i * d, residuals.data() + i * d, list_nos[i]); + } + pq.compute_codes (residuals.data(), bytes, n); + + for (idx_t i = n - 1; i >= 0; i--) { + uint8_t * code = bytes + i * code_size; + memmove (code + code_size_1, + bytes + i * code_size_2, code_size_2); + q1.encode_listno (list_nos[i], code); + } + +} + +void Index2Layer::sa_decode (idx_t n, const uint8_t *bytes, float *x) const +{ + +#pragma omp parallel + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *code = bytes + i * code_size; + int64_t list_no = q1.decode_listno (code); + float *xi = x + i * d; + pq.decode (code + code_size_1, xi); + q1.quantizer->reconstruct (list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/Index2Layer.h b/core/src/index/thirdparty/faiss/Index2Layer.h new file mode 100644 index 0000000000..b7d8ccd1fa --- /dev/null +++ b/core/src/index/thirdparty/faiss/Index2Layer.h @@ -0,0 +1,87 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include +#include + +namespace faiss { + +struct IndexIVFPQ; + + +/** Same as an IndexIVFPQ without the inverted lists: codes are stored sequentially + * + * The class is mainly inteded to store encoded vectors that can be + * accessed randomly, the search function is not implemented. + */ +struct Index2Layer: Index { + /// first level quantizer + Level1Quantizer q1; + + /// second level quantizer is always a PQ + ProductQuantizer pq; + + /// Codes. Size ntotal * code_size. + std::vector codes; + + /// size of the code for the first level (ceil(log8(q1.nlist))) + size_t code_size_1; + + /// size of the code for the second level + size_t code_size_2; + + /// code_size_1 + code_size_2 + size_t code_size; + + Index2Layer (Index * quantizer, size_t nlist, + int M, int nbit = 8, + MetricType metric = METRIC_L2); + + Index2Layer (); + ~Index2Layer (); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + /// not implemented + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override; + + void reconstruct(idx_t key, float* recons) const override; + + void reset() override; + + DistanceComputer * get_distance_computer() const override; + + /// transfer the flat codes to an IVFPQ index + void transfer_to_IVFPQ(IndexIVFPQ & other) const; + + + /* The standalone codec interface */ + size_t sa_code_size () const override; + void sa_encode (idx_t n, const float *x, uint8_t *bytes) const override; + void sa_decode (idx_t n, const uint8_t *bytes, float *x) const override; + + size_t cal_size() { return sizeof(*this) + codes.size() * sizeof(uint8_t) + pq.cal_size(); } +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinary.cpp b/core/src/index/thirdparty/faiss/IndexBinary.cpp new file mode 100644 index 0000000000..fc41fe481c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinary.cpp @@ -0,0 +1,89 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include + +namespace faiss { + +IndexBinary::~IndexBinary() {} + +void IndexBinary::train(idx_t, const uint8_t *) { + // Does nothing by default. +} + +void IndexBinary::range_search(idx_t, const uint8_t *, int, + RangeSearchResult *, + ConcurrentBitsetPtr) const { + FAISS_THROW_MSG("range search not implemented"); +} + +void IndexBinary::assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k) { + int *distances = new int[n * k]; + ScopeDeleter del(distances); + search(n, x, k, distances, labels); +} + +void IndexBinary::add_with_ids(idx_t, const uint8_t *, const idx_t *) { + FAISS_THROW_MSG("add_with_ids not implemented for this type of index"); +} + +#if 0 +void IndexBinary::get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) { + FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index"); +} + +void IndexBinary::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) { + FAISS_THROW_MSG("search_by_id not implemented for this type of index"); +} +#endif + +size_t IndexBinary::remove_ids(const IDSelector&) { + FAISS_THROW_MSG("remove_ids not implemented for this type of index"); + return 0; +} + +void IndexBinary::reconstruct(idx_t, uint8_t *) const { + FAISS_THROW_MSG("reconstruct not implemented for this type of index"); +} + +void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const { + for (idx_t i = 0; i < ni; i++) { + reconstruct(i0 + i, recons + i * d); + } +} + +void IndexBinary::search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + uint8_t *recons) const { + search(n, x, k, distances, labels); + for (idx_t i = 0; i < n; ++i) { + for (idx_t j = 0; j < k; ++j) { + idx_t ij = i * k + j; + idx_t key = labels[ij]; + uint8_t *reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + reconstruct(key, reconstructed); + } + } + } +} + +void IndexBinary::display() const { + printf("Index: %s -> %ld elements\n", typeid (*this).name(), ntotal); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinary.h b/core/src/index/thirdparty/faiss/IndexBinary.h new file mode 100644 index 0000000000..4141a7d63c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinary.h @@ -0,0 +1,196 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_BINARY_H +#define FAISS_INDEX_BINARY_H + +#include +#include +#include +#include + +#include +#include + + +namespace faiss { + + +/// Forward declarations see AuxIndexStructures.h +struct IDSelector; +struct RangeSearchResult; + +/** Abstract structure for a binary index. + * + * Supports adding vertices and searching them. + * + * All queries are symmetric because there is no distinction between codes and + * vectors. + */ +struct IndexBinary { + using idx_t = Index::idx_t; ///< all indices are this type + using component_t = uint8_t; + using distance_t = int32_t; + + int d; ///< vector dimension + int code_size; ///< number of bytes per vector ( = d / 8 ) + idx_t ntotal; ///< total nb of indexed vectors + bool verbose; ///< verbosity level + + /// set if the Index does not require training, or if training is done already + bool is_trained; + + /// type of metric this index uses for search + MetricType metric_type; + + explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2) + : d(d), + code_size(d / 8), + ntotal(0), + verbose(false), + is_trained(true), + metric_type(metric) { + FAISS_THROW_IF_NOT(d % 8 == 0); + } + + virtual ~IndexBinary(); + + + /** Perform training on a representative set of vectors. + * + * @param n nb of training vectors + * @param x training vecors, size n * d / 8 + */ + virtual void train(idx_t n, const uint8_t *x); + + /** Add n vectors of dimension d to the index. + * + * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 + * @param x input matrix, size n * d / 8 + */ + virtual void add(idx_t n, const uint8_t *x) = 0; + + /** Same as add, but stores xids instead of sequential ids. + * + * The default implementation fails with an assertion, as it is + * not supported by all indexes. + * + * @param xids if non-null, ids to store for the vectors (size n) + */ + virtual void add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids); + + /** Query n vectors of dimension d to the index. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param x input vectors to search, size n * d / 8 + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + * @param bitset flags to check the validity of vectors + */ + virtual void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const = 0; + +#if 0 + /** Query n raw vectors from the index by ids. + * + * return n raw vectors. + * + * @param n input num of xid + * @param xid input labels of the NNs, size n + * @param x output raw vectors, size n * d + * @param bitset flags to check the validity of vectors + */ + virtual void get_vector_by_id (idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset = nullptr); + + /** query n vectors of dimension d to the index by ids. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param xid input ids to search, size n + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + * @param bitset flags to check the validity of vectors + */ + virtual void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr); +#endif + + /** Query n vectors of dimension d to the index. + * + * return all vectors with distance < radius. Note that many indexes + * do not implement the range_search (only the k-NN search is + * mandatory). The distances are converted to float to reuse the + * RangeSearchResult structure, but they are integer. By convention, + * only distances < radius (strict comparison) are returned, + * ie. radius = 0 does not return any result and 1 returns only + * exact same vectors. + * + * @param x input vectors to search, size n * d / 8 + * @param radius search radius + * @param result result table + */ + virtual void range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const; + + /** Return the indexes of the k vectors closest to the query x. + * + * This function is identical to search but only returns labels of neighbors. + * @param x input vectors to search, size n * d / 8 + * @param labels output labels of the NNs, size n*k + */ + void assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k = 1); + + /// Removes all elements from the database. + virtual void reset() = 0; + + /** Removes IDs from the index. Not supported by all indexes. + */ + virtual size_t remove_ids(const IDSelector& sel); + + /** Reconstruct a stored vector. + * + * This function may not be defined for some indexes. + * @param key id of the vector to reconstruct + * @param recons reconstucted vector (size d / 8) + */ + virtual void reconstruct(idx_t key, uint8_t *recons) const; + + + /** Reconstruct vectors i0 to i0 + ni - 1. + * + * This function may not be defined for some indexes. + * @param recons reconstucted vectors (size ni * d / 8) + */ + virtual void reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const; + + /** Similar to search, but also reconstructs the stored vectors (or an + * approximation in the case of lossy coding) for the search results. + * + * If there are not enough results for a query, the resulting array + * is padded with -1s. + * + * @param recons reconstructed vectors size (n, k, d) + **/ + virtual void search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + uint8_t *recons) const; + + /** Display the actual class name and some more info. */ + void display() const; +}; + + +} // namespace faiss + +#endif // FAISS_INDEX_BINARY_H diff --git a/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp b/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp new file mode 100644 index 0000000000..f301376cb7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryFlat.cpp @@ -0,0 +1,132 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +IndexBinaryFlat::IndexBinaryFlat(idx_t d) + : IndexBinary(d) {} + +IndexBinaryFlat::IndexBinaryFlat(idx_t d, MetricType metric) + : IndexBinary(d, metric) {} + +void IndexBinaryFlat::add(idx_t n, const uint8_t *x) { + xb.insert(xb.end(), x, x + n * code_size); + ntotal += n; +} + +void IndexBinaryFlat::reset() { + xb.clear(); + ntotal = 0; +} + +void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const { + const idx_t block_size = query_batch_size; + if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + float *D = reinterpret_cast(distances); + for (idx_t s = 0; s < n; s += block_size) { + idx_t nn = block_size; + if (s + block_size > n) { + nn = n - s; + } + + // We see the distances and labels as heaps. + float_maxheap_array_t res = { + size_t(nn), size_t(k), labels + s * k, D + s * k + }; + + binary_distence_knn_hc(metric_type, &res, x + s * code_size, xb.data(), ntotal, code_size, + /* ordered = */ true, bitset); + + } + if (metric_type == METRIC_Tanimoto) { + for (int i = 0; i < k * n; i++) { + D[i] = -log2(1-D[i]); + } + } + } else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) { + float *D = reinterpret_cast(distances); + for (idx_t s = 0; s < n; s += block_size) { + idx_t nn = block_size; + if (s + block_size > n) { + nn = n - s; + } + + // only match ids will be chosed, not to use heap + binary_distence_knn_mc(metric_type, x + s * code_size, xb.data(), nn, ntotal, k, code_size, + D + s * k, labels + s * k, bitset); + } + } else { + for (idx_t s = 0; s < n; s += block_size) { + idx_t nn = block_size; + if (s + block_size > n) { + nn = n - s; + } + if (use_heap) { + // We see the distances and labels as heaps. + int_maxheap_array_t res = { + size_t(nn), size_t(k), labels + s * k, distances + s * k + }; + + hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size, + /* ordered = */ true, bitset); + } else { + hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size, + distances + s * k, labels + s * k, bitset); + } + } + } +} + +size_t IndexBinaryFlat::remove_ids(const IDSelector& sel) { + idx_t j = 0; + for (idx_t i = 0; i < ntotal; i++) { + if (sel.is_member(i)) { + // should be removed + } else { + if (i > j) { + memmove(&xb[code_size * j], &xb[code_size * i], sizeof(xb[0]) * code_size); + } + j++; + } + } + long nremove = ntotal - j; + if (nremove > 0) { + ntotal = j; + xb.resize(ntotal * code_size); + } + return nremove; +} + +void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const { + memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size); +} + +void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result); +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinaryFlat.h b/core/src/index/thirdparty/faiss/IndexBinaryFlat.h new file mode 100644 index 0000000000..012b9b43f4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryFlat.h @@ -0,0 +1,61 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef INDEX_BINARY_FLAT_H +#define INDEX_BINARY_FLAT_H + +#include + +#include + +namespace faiss { + + +/** Index that stores the full vectors and performs exhaustive search. */ +struct IndexBinaryFlat : IndexBinary { + /// database vectors, size ntotal * d / 8 + std::vector xb; + + /** Select between using a heap or counting to select the k smallest values + * when scanning inverted lists. + */ + bool use_heap = true; + + size_t query_batch_size = 32; + + explicit IndexBinaryFlat(idx_t d); + + IndexBinaryFlat(idx_t d, MetricType metric); + + void add(idx_t n, const uint8_t *x) override; + + void reset() override; + + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, uint8_t *recons) const override; + + /** Remove some ids. Note that because of the indexing structure, + * the semantics of this operation are different from the usual ones: + * the new ids are shifted. */ + size_t remove_ids(const IDSelector& sel) override; + + IndexBinaryFlat() {} +}; + + +} // namespace faiss + +#endif // INDEX_BINARY_FLAT_H diff --git a/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.cpp b/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.cpp new file mode 100644 index 0000000000..67bd9a28dc --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.cpp @@ -0,0 +1,79 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +namespace faiss { + + +IndexBinaryFromFloat::IndexBinaryFromFloat() {} + +IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index) + : IndexBinary(index->d), + index(index), + own_fields(false) { + is_trained = index->is_trained; + ntotal = index->ntotal; +} + +IndexBinaryFromFloat::~IndexBinaryFromFloat() { + if (own_fields) { + delete index; + } +} + +void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) { + constexpr idx_t bs = 32768; + std::unique_ptr xf(new float[bs * d]); + + for (idx_t b = 0; b < n; b += bs) { + idx_t bn = std::min(bs, n - b); + binary_to_real(bn * d, x + b * code_size, xf.get()); + + index->add(bn, xf.get()); + } + ntotal = index->ntotal; +} + +void IndexBinaryFromFloat::reset() { + index->reset(); + ntotal = index->ntotal; +} + +void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const { + constexpr idx_t bs = 32768; + std::unique_ptr xf(new float[bs * d]); + std::unique_ptr df(new float[bs * k]); + + for (idx_t b = 0; b < n; b += bs) { + idx_t bn = std::min(bs, n - b); + binary_to_real(bn * d, x + b * code_size, xf.get()); + + index->search(bn, xf.get(), k, df.get(), labels + b * k); + for (int i = 0; i < bn * k; ++i) { + distances[b * k + i] = int32_t(std::round(df[i] / 4.0)); + } + } +} + +void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) { + std::unique_ptr xf(new float[n * d]); + binary_to_real(n * d, x, xf.get()); + + index->train(n, xf.get()); + is_trained = true; + ntotal = index->ntotal; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.h b/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.h new file mode 100644 index 0000000000..b630c832e4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryFromFloat.h @@ -0,0 +1,53 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_BINARY_FROM_FLOAT_H +#define FAISS_INDEX_BINARY_FROM_FLOAT_H + +#include + + +namespace faiss { + + +struct Index; + +/** IndexBinary backed by a float Index. + * + * Supports adding vertices and searching them. + * + * All queries are symmetric because there is no distinction between codes and + * vectors. + */ +struct IndexBinaryFromFloat : IndexBinary { + Index *index = nullptr; + + bool own_fields = false; ///< Whether object owns the index pointer. + + IndexBinaryFromFloat(); + + explicit IndexBinaryFromFloat(Index *index); + + ~IndexBinaryFromFloat(); + + void add(idx_t n, const uint8_t *x) override; + + void reset() override; + + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void train(idx_t n, const uint8_t *x) override; +}; + + +} // namespace faiss + +#endif // FAISS_INDEX_BINARY_FROM_FLOAT_H diff --git a/core/src/index/thirdparty/faiss/IndexBinaryHNSW.cpp b/core/src/index/thirdparty/faiss/IndexBinaryHNSW.cpp new file mode 100644 index 0000000000..87234e4aac --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryHNSW.cpp @@ -0,0 +1,326 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace faiss { + + +/************************************************************** + * add / search blocks of descriptors + **************************************************************/ + +namespace { + + +void hnsw_add_vertices(IndexBinaryHNSW& index_hnsw, + size_t n0, + size_t n, const uint8_t *x, + bool verbose, + bool preset_levels = false) { + HNSW& hnsw = index_hnsw.hnsw; + size_t ntotal = n0 + n; + double t0 = getmillisecs(); + if (verbose) { + printf("hnsw_add_vertices: adding %ld elements on top of %ld " + "(preset_levels=%d)\n", + n, n0, int(preset_levels)); + } + + int max_level = hnsw.prepare_level_tab(n, preset_levels); + + if (verbose) { + printf(" max_level = %d\n", max_level); + } + + std::vector locks(ntotal); + for(int i = 0; i < ntotal; i++) { + omp_init_lock(&locks[i]); + } + + // add vectors from highest to lowest level + std::vector hist; + std::vector order(n); + + { // make buckets with vectors of the same level + + // build histogram + for (int i = 0; i < n; i++) { + HNSW::storage_idx_t pt_id = i + n0; + int pt_level = hnsw.levels[pt_id] - 1; + while (pt_level >= hist.size()) { + hist.push_back(0); + } + hist[pt_level] ++; + } + + // accumulate + std::vector offsets(hist.size() + 1, 0); + for (int i = 0; i < hist.size() - 1; i++) { + offsets[i + 1] = offsets[i] + hist[i]; + } + + // bucket sort + for (int i = 0; i < n; i++) { + HNSW::storage_idx_t pt_id = i + n0; + int pt_level = hnsw.levels[pt_id] - 1; + order[offsets[pt_level]++] = pt_id; + } + } + + { // perform add + RandomGenerator rng2(789); + + int i1 = n; + + for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) { + int i0 = i1 - hist[pt_level]; + + if (verbose) { + printf("Adding %d elements at level %d\n", + i1 - i0, pt_level); + } + + // random permutation to get rid of dataset order bias + for (int j = i0; j < i1; j++) { + std::swap(order[j], order[j + rng2.rand_int(i1 - j)]); + } + +#pragma omp parallel + { + VisitedTable vt (ntotal); + + std::unique_ptr dis( + index_hnsw.get_distance_computer() + ); + int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1; + +#pragma omp for schedule(dynamic) + for (int i = i0; i < i1; i++) { + HNSW::storage_idx_t pt_id = order[i]; + dis->set_query((float *)(x + (pt_id - n0) * index_hnsw.code_size)); + + hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt); + + if (prev_display >= 0 && i - i0 > prev_display + 10000) { + prev_display = i - i0; + printf(" %d / %d\r", i - i0, i1 - i0); + fflush(stdout); + } + } + } + i1 = i0; + } + FAISS_ASSERT(i1 == 0); + } + if (verbose) { + printf("Done in %.3f ms\n", getmillisecs() - t0); + } + + for(int i = 0; i < ntotal; i++) + omp_destroy_lock(&locks[i]); +} + + +} // anonymous namespace + + +/************************************************************** + * IndexBinaryHNSW implementation + **************************************************************/ + +IndexBinaryHNSW::IndexBinaryHNSW() +{ + is_trained = true; +} + +IndexBinaryHNSW::IndexBinaryHNSW(int d, int M) + : IndexBinary(d), + hnsw(M), + own_fields(true), + storage(new IndexBinaryFlat(d)) +{ + is_trained = true; +} + +IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary *storage, int M) + : IndexBinary(storage->d), + hnsw(M), + own_fields(false), + storage(storage) +{ + is_trained = true; +} + +IndexBinaryHNSW::~IndexBinaryHNSW() { + if (own_fields) { + delete storage; + } +} + +void IndexBinaryHNSW::train(idx_t n, const uint8_t *x) +{ + // hnsw structure does not require training + storage->train(n, x); + is_trained = true; +} + +void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ +#pragma omp parallel + { + VisitedTable vt(ntotal); + std::unique_ptr dis(get_distance_computer()); + +#pragma omp for + for(idx_t i = 0; i < n; i++) { + idx_t *idxi = labels + i * k; + float *simi = (float *)(distances + i * k); + + dis->set_query((float *)(x + i * code_size)); + + maxheap_heapify(k, simi, idxi); + hnsw.search(*dis, k, idxi, simi, vt); + maxheap_reorder(k, simi, idxi); + } + } + +#pragma omp parallel for + for (int i = 0; i < n * k; ++i) { + distances[i] = std::round(((float *)distances)[i]); + } +} + + +void IndexBinaryHNSW::add(idx_t n, const uint8_t *x) +{ + FAISS_THROW_IF_NOT(is_trained); + int n0 = ntotal; + storage->add(n, x); + ntotal = storage->ntotal; + + hnsw_add_vertices(*this, n0, n, x, verbose, + hnsw.levels.size() == ntotal); +} + +void IndexBinaryHNSW::reset() +{ + hnsw.reset(); + storage->reset(); + ntotal = 0; +} + +void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t *recons) const +{ + storage->reconstruct(key, recons); +} + + +namespace { + + +template +struct FlatHammingDis : DistanceComputer { + const int code_size; + const uint8_t *b; + size_t ndis; + HammingComputer hc; + + float operator () (idx_t i) override { + ndis++; + return hc.hamming(b + i * code_size); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return HammingComputerDefault(b + j * code_size, code_size) + .hamming(b + i * code_size); + } + + + explicit FlatHammingDis(const IndexBinaryFlat& storage) + : code_size(storage.code_size), + b(storage.xb.data()), + ndis(0), + hc() {} + + // NOTE: Pointers are cast from float in order to reuse the floating-point + // DistanceComputer. + void set_query(const float *x) override { + hc.set((uint8_t *)x, code_size); + } + + ~FlatHammingDis() override { +#pragma omp critical + { + hnsw_stats.ndis += ndis; + } + } +}; + + +} // namespace + + +DistanceComputer *IndexBinaryHNSW::get_distance_computer() const { + IndexBinaryFlat *flat_storage = dynamic_cast(storage); + + FAISS_ASSERT(flat_storage != nullptr); + + switch(code_size) { + case 4: + return new FlatHammingDis(*flat_storage); + case 8: + return new FlatHammingDis(*flat_storage); + case 16: + return new FlatHammingDis(*flat_storage); + case 20: + return new FlatHammingDis(*flat_storage); + case 32: + return new FlatHammingDis(*flat_storage); + case 64: + return new FlatHammingDis(*flat_storage); + default: + if (code_size % 8 == 0) { + return new FlatHammingDis(*flat_storage); + } else if (code_size % 4 == 0) { + return new FlatHammingDis(*flat_storage); + } + } + + return new FlatHammingDis(*flat_storage); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinaryHNSW.h b/core/src/index/thirdparty/faiss/IndexBinaryHNSW.h new file mode 100644 index 0000000000..be10fee692 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryHNSW.h @@ -0,0 +1,57 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include + + +namespace faiss { + + +/** The HNSW index is a normal random-access index with a HNSW + * link structure built on top */ + +struct IndexBinaryHNSW : IndexBinary { + typedef HNSW::storage_idx_t storage_idx_t; + + // the link strcuture + HNSW hnsw; + + // the sequential storage + bool own_fields; + IndexBinary *storage; + + explicit IndexBinaryHNSW(); + explicit IndexBinaryHNSW(int d, int M = 32); + explicit IndexBinaryHNSW(IndexBinary *storage, int M = 32); + + ~IndexBinaryHNSW() override; + + DistanceComputer *get_distance_computer() const; + + void add(idx_t n, const uint8_t *x) override; + + /// Trains the storage if needed + void train(idx_t n, const uint8_t* x) override; + + /// entry point for search + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, uint8_t* recons) const override; + + void reset() override; +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinaryHash.cpp b/core/src/index/thirdparty/faiss/IndexBinaryHash.cpp new file mode 100644 index 0000000000..008da09455 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryHash.cpp @@ -0,0 +1,496 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- + +#include + +#include +#include + +#include +#include + +#include +#include + + +namespace faiss { + +void IndexBinaryHash::InvertedList::add ( + idx_t id, size_t code_size, const uint8_t *code) +{ + ids.push_back(id); + vecs.insert(vecs.end(), code, code + code_size); +} + +IndexBinaryHash::IndexBinaryHash(int d, int b): + IndexBinary(d), b(b), nflip(0) +{ + is_trained = true; +} + +IndexBinaryHash::IndexBinaryHash(): b(0), nflip(0) +{ + is_trained = true; +} + +void IndexBinaryHash::reset() +{ + invlists.clear(); + ntotal = 0; +} + + +void IndexBinaryHash::add(idx_t n, const uint8_t *x) +{ + add_with_ids(n, x, nullptr); +} + +void IndexBinaryHash::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) +{ + uint64_t mask = ((uint64_t)1 << b) - 1; + // simplistic add function. Cannot really be parallelized. + + for (idx_t i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + const uint8_t * xi = x + i * code_size; + idx_t hash = *((uint64_t*)xi) & mask; + invlists[hash].add(id, code_size, xi); + } + ntotal += n; +} + +namespace { + + +/** Enumerate all bit vectors of size nbit with up to maxflip 1s + * test in P127257851 P127258235 + */ +struct FlipEnumerator { + int nbit, nflip, maxflip; + uint64_t mask, x; + + FlipEnumerator (int nbit, int maxflip): nbit(nbit), maxflip(maxflip) { + nflip = 0; + mask = 0; + x = 0; + } + + bool next() { + if (x == mask) { + if (nflip == maxflip) { + return false; + } + // increase Hamming radius + nflip++; + mask = (((uint64_t)1 << nflip) - 1); + x = mask << (nbit - nflip); + return true; + } + + int i = __builtin_ctzll(x); + + if (i > 0) { + x ^= (uint64_t)3 << (i - 1); + } else { + // nb of LSB 1s + int n1 = __builtin_ctzll(~x); + // clear them + x &= ((uint64_t)(-1) << n1); + int n2 = __builtin_ctzll(x); + x ^= (((uint64_t)1 << (n1 + 2)) - 1) << (n2 - n1 - 1); + } + return true; + } + +}; + +using idx_t = Index::idx_t; + + +struct RangeSearchResults { + int radius; + RangeQueryResult &qres; + + inline void add (float dis, idx_t id) { + if (dis < radius) { + qres.add (dis, id); + } + } + +}; + +struct KnnSearchResults { + // heap params + idx_t k; + int32_t * heap_sim; + idx_t * heap_ids; + + using C = CMax; + + inline void add (float dis, idx_t id) { + if (dis < heap_sim[0]) { + heap_pop (k, heap_sim, heap_ids); + heap_push (k, heap_sim, heap_ids, dis, id); + } + } + +}; + +template +void +search_single_query_template(const IndexBinaryHash & index, const uint8_t *q, + SearchResults &res, + size_t &n0, size_t &nlist, size_t &ndis) +{ + size_t code_size = index.code_size; + uint64_t mask = ((uint64_t)1 << index.b) - 1; + uint64_t qhash = *((uint64_t*)q) & mask; + HammingComputer hc (q, code_size); + FlipEnumerator fe(index.b, index.nflip); + + // loop over neighbors that are at most at nflip bits + do { + uint64_t hash = qhash ^ fe.x; + auto it = index.invlists.find (hash); + + if (it == index.invlists.end()) { + continue; + } + + const IndexBinaryHash::InvertedList &il = it->second; + + size_t nv = il.ids.size(); + + if (nv == 0) { + n0++; + } else { + const uint8_t *codes = il.vecs.data(); + for (size_t i = 0; i < nv; i++) { + int dis = hc.hamming (codes); + res.add(dis, il.ids[i]); + codes += code_size; + } + ndis += nv; + nlist++; + } + } while(fe.next()); +} + +template +void +search_single_query(const IndexBinaryHash & index, const uint8_t *q, + SearchResults &res, + size_t &n0, size_t &nlist, size_t &ndis) +{ +#define HC(name) search_single_query_template(index, q, res, n0, nlist, ndis); + switch(index.code_size) { + case 4: HC(HammingComputer4); break; + case 8: HC(HammingComputer8); break; + case 16: HC(HammingComputer16); break; + case 20: HC(HammingComputer20); break; + case 32: HC(HammingComputer32); break; + default: + if (index.code_size % 8 == 0) { + HC(HammingComputerM8); + } else { + HC(HammingComputerDefault); + } + } +#undef HC +} + + +} // anonymous namespace + + + +void IndexBinaryHash::range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + + size_t nlist = 0, ndis = 0, n0 = 0; + +#pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist) + { + RangeSearchPartialResult pres (result); + +#pragma omp for + for (size_t i = 0; i < n; i++) { // loop queries + RangeQueryResult & qres = pres.new_result (i); + RangeSearchResults res = {radius, qres}; + const uint8_t *q = x + i * code_size; + + search_single_query (*this, q, res, n0, nlist, ndis); + + } + pres.finalize (); + } + indexBinaryHash_stats.nq += n; + indexBinaryHash_stats.n0 += n0; + indexBinaryHash_stats.nlist += nlist; + indexBinaryHash_stats.ndis += ndis; +} + +void IndexBinaryHash::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + + using HeapForL2 = CMax; + size_t nlist = 0, ndis = 0, n0 = 0; + +#pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0) + for (size_t i = 0; i < n; i++) { + int32_t * simi = distances + k * i; + idx_t * idxi = labels + k * i; + + heap_heapify (k, simi, idxi); + KnnSearchResults res = {k, simi, idxi}; + const uint8_t *q = x + i * code_size; + + search_single_query (*this, q, res, n0, nlist, ndis); + + } + indexBinaryHash_stats.nq += n; + indexBinaryHash_stats.n0 += n0; + indexBinaryHash_stats.nlist += nlist; + indexBinaryHash_stats.ndis += ndis; +} + +size_t IndexBinaryHash::hashtable_size() const +{ + return invlists.size(); +} + + +void IndexBinaryHash::display() const +{ + for (auto it = invlists.begin(); it != invlists.end(); ++it) { + printf("%ld: [", it->first); + const std::vector & v = it->second.ids; + for (auto x: v) { + printf("%ld ", 0 + x); + } + printf("]\n"); + + } +} + + +void IndexBinaryHashStats::reset() +{ + memset ((void*)this, 0, sizeof (*this)); +} + +IndexBinaryHashStats indexBinaryHash_stats; + +/******************************************************* + * IndexBinaryMultiHash implementation + ******************************************************/ + + +IndexBinaryMultiHash::IndexBinaryMultiHash(int d, int nhash, int b): + IndexBinary(d), + storage(new IndexBinaryFlat(d)), own_fields(true), + maps(nhash), nhash(nhash), b(b), nflip(0) +{ + FAISS_THROW_IF_NOT(nhash * b <= d); +} + +IndexBinaryMultiHash::IndexBinaryMultiHash(): + storage(nullptr), own_fields(true), + nhash(0), b(0), nflip(0) +{} + +IndexBinaryMultiHash::~IndexBinaryMultiHash() +{ + if (own_fields) { + delete storage; + } +} + + +void IndexBinaryMultiHash::reset() +{ + storage->reset(); + ntotal = 0; + for(auto map: maps) { + map.clear(); + } +} + +void IndexBinaryMultiHash::add(idx_t n, const uint8_t *x) +{ + storage->add(n, x); + // populate maps + uint64_t mask = ((uint64_t)1 << b) - 1; + + for(idx_t i = 0; i < n; i++) { + const uint8_t *xi = x + i * code_size; + int ho = 0; + for(int h = 0; h < nhash; h++) { + uint64_t hash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7); + hash &= mask; + maps[h][hash].push_back(i + ntotal); + ho += b; + } + } + ntotal += n; +} + + +namespace { + +template +static +void verify_shortlist( + const IndexBinaryFlat & index, + const uint8_t * q, + const std::unordered_set & shortlist, + SearchResults &res) +{ + size_t code_size = index.code_size; + size_t nlist = 0, ndis = 0, n0 = 0; + + HammingComputer hc (q, code_size); + const uint8_t *codes = index.xb.data(); + + for (auto i: shortlist) { + int dis = hc.hamming (codes + i * code_size); + res.add(dis, i); + } +} + +template +void +search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi, + SearchResults &res, + size_t &n0, size_t &nlist, size_t &ndis) +{ + + std::unordered_set shortlist; + int b = index.b; + uint64_t mask = ((uint64_t)1 << b) - 1; + + int ho = 0; + for(int h = 0; h < index.nhash; h++) { + uint64_t qhash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7); + qhash &= mask; + const IndexBinaryMultiHash::Map & map = index.maps[h]; + + FlipEnumerator fe(index.b, index.nflip); + // loop over neighbors that are at most at nflip bits + do { + uint64_t hash = qhash ^ fe.x; + auto it = map.find (hash); + + if (it != map.end()) { + const std::vector & v = it->second; + for (auto i: v) { + shortlist.insert(i); + } + nlist++; + } else { + n0++; + } + } while(fe.next()); + + ho += b; + } + ndis += shortlist.size(); + + // verify shortlist + +#define HC(name) verify_shortlist (*index.storage, xi, shortlist, res) + switch(index.code_size) { + case 4: HC(HammingComputer4); break; + case 8: HC(HammingComputer8); break; + case 16: HC(HammingComputer16); break; + case 20: HC(HammingComputer20); break; + case 32: HC(HammingComputer32); break; + default: + if (index.code_size % 8 == 0) { + HC(HammingComputerM8); + } else { + HC(HammingComputerDefault); + } + } +#undef HC +} + +} // anonymous namespace + +void IndexBinaryMultiHash::range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + + size_t nlist = 0, ndis = 0, n0 = 0; + +#pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist) + { + RangeSearchPartialResult pres (result); + +#pragma omp for + for (size_t i = 0; i < n; i++) { // loop queries + RangeQueryResult & qres = pres.new_result (i); + RangeSearchResults res = {radius, qres}; + const uint8_t *q = x + i * code_size; + + search_1_query_multihash (*this, q, res, n0, nlist, ndis); + + } + pres.finalize (); + } + indexBinaryHash_stats.nq += n; + indexBinaryHash_stats.n0 += n0; + indexBinaryHash_stats.nlist += nlist; + indexBinaryHash_stats.ndis += ndis; +} + +void IndexBinaryMultiHash::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + + using HeapForL2 = CMax; + size_t nlist = 0, ndis = 0, n0 = 0; + +#pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0) + for (size_t i = 0; i < n; i++) { + int32_t * simi = distances + k * i; + idx_t * idxi = labels + k * i; + + heap_heapify (k, simi, idxi); + KnnSearchResults res = {k, simi, idxi}; + const uint8_t *q = x + i * code_size; + + search_1_query_multihash (*this, q, res, n0, nlist, ndis); + + } + indexBinaryHash_stats.nq += n; + indexBinaryHash_stats.n0 += n0; + indexBinaryHash_stats.nlist += nlist; + indexBinaryHash_stats.ndis += ndis; +} + +size_t IndexBinaryMultiHash::hashtable_size() const +{ + size_t tot = 0; + for (auto map: maps) { + tot += map.size(); + } + + return tot; +} + + +} diff --git a/core/src/index/thirdparty/faiss/IndexBinaryHash.h b/core/src/index/thirdparty/faiss/IndexBinaryHash.h new file mode 100644 index 0000000000..5dbcad626d --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryHash.h @@ -0,0 +1,120 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_BINARY_HASH_H +#define FAISS_BINARY_HASH_H + + + +#include +#include + +#include +#include +#include + + +namespace faiss { + +struct RangeSearchResult; + + +/** just uses the b first bits as a hash value */ +struct IndexBinaryHash : IndexBinary { + + struct InvertedList { + std::vector ids; + std::vector vecs; + + void add (idx_t id, size_t code_size, const uint8_t *code); + }; + + using InvertedListMap = std::unordered_map; + InvertedListMap invlists; + + int b, nflip; + + IndexBinaryHash(int d, int b); + + IndexBinaryHash(); + + void reset() override; + + void add(idx_t n, const uint8_t *x) override; + + void add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) override; + + void range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void display() const; + size_t hashtable_size() const; + +}; + +struct IndexBinaryHashStats { + size_t nq; // nb of queries run + size_t n0; // nb of empty lists + size_t nlist; // nb of non-empty inverted lists scanned + size_t ndis; // nb of distancs computed + + IndexBinaryHashStats () {reset (); } + void reset (); +}; + +extern IndexBinaryHashStats indexBinaryHash_stats; + + +/** just uses the b first bits as a hash value */ +struct IndexBinaryMultiHash: IndexBinary { + + // where the vectors are actually stored + IndexBinaryFlat *storage; + bool own_fields; + + // maps hash values to the ids that hash to them + using Map = std::unordered_map >; + + // the different hashes, size nhash + std::vector maps; + + int nhash; ///< nb of hash maps + int b; ///< nb bits per hash map + int nflip; ///< nb bit flips to use at search time + + IndexBinaryMultiHash(int d, int nhash, int b); + + IndexBinaryMultiHash(); + + ~IndexBinaryMultiHash(); + + void reset() override; + + void add(idx_t n, const uint8_t *x) override; + + void range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + size_t hashtable_size() const; + +}; + +} + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp b/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp new file mode 100644 index 0000000000..775cf9d447 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryIVF.cpp @@ -0,0 +1,976 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- + +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace faiss { + +IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist) + : IndexBinary(d), + invlists(new ArrayInvertedLists(nlist, code_size)), + own_invlists(true), + nprobe(1), + max_codes(0), + quantizer(quantizer), + nlist(nlist), + own_fields(false), + clustering_index(nullptr) +{ + FAISS_THROW_IF_NOT (d == quantizer->d); + is_trained = quantizer->is_trained && (quantizer->ntotal == nlist); + + cp.niter = 10; +} + +IndexBinaryIVF::IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric) + : IndexBinary(d, metric), + invlists(new ArrayInvertedLists(nlist, code_size)), + own_invlists(true), + nprobe(1), + max_codes(0), + quantizer(quantizer), + nlist(nlist), + own_fields(false), + clustering_index(nullptr) +{ + FAISS_THROW_IF_NOT (d == quantizer->d); + is_trained = quantizer->is_trained && (quantizer->ntotal == nlist); + + cp.niter = 10; +} + +IndexBinaryIVF::IndexBinaryIVF() + : invlists(nullptr), + own_invlists(false), + nprobe(1), + max_codes(0), + quantizer(nullptr), + nlist(0), + own_fields(false), + clustering_index(nullptr) +{} + +void IndexBinaryIVF::add(idx_t n, const uint8_t *x) { + add_with_ids(n, x, nullptr); +} + +void IndexBinaryIVF::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) { + add_core(n, x, xids, nullptr); +} + +void IndexBinaryIVF::add_core(idx_t n, const uint8_t *x, const idx_t *xids, + const idx_t *precomputed_idx) { + FAISS_THROW_IF_NOT(is_trained); + assert(invlists); + direct_map.check_can_add (xids); + + const idx_t * idx; + + std::unique_ptr scoped_idx; + + if (precomputed_idx) { + idx = precomputed_idx; + } else { + scoped_idx.reset(new idx_t[n]); + quantizer->assign(n, x, scoped_idx.get()); + idx = scoped_idx.get(); + } + + long n_add = 0; + for (size_t i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + idx_t list_no = idx[i]; + + if (list_no < 0) { + direct_map.add_single_id (id, -1, 0); + } else { + const uint8_t *xi = x + i * code_size; + size_t offset = invlists->add_entry(list_no, id, xi); + + direct_map.add_single_id (id, list_no, offset); + } + + n_add++; + } + if (verbose) { + printf("IndexBinaryIVF::add_with_ids: added %ld / %ld vectors\n", + n_add, n); + } + ntotal += n_add; +} + +void IndexBinaryIVF::make_direct_map (bool b) +{ + if (b) { + direct_map.set_type (DirectMap::Array, invlists, ntotal); + } else { + direct_map.set_type (DirectMap::NoMap, invlists, ntotal); + } +} + +void IndexBinaryIVF::set_direct_map_type (DirectMap::Type type) +{ + direct_map.set_type (type, invlists, ntotal); +} + + +void IndexBinaryIVF::search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const { + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new int32_t[n * nprobe]); + + double t0 = getmillisecs(); + quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists(idx.get(), n * nprobe); + + search_preassigned(n, x, k, idx.get(), coarse_dis.get(), + distances, labels, false, nullptr, bitset); + indexIVF_stats.search_time += getmillisecs() - t0; +} + +#if 0 +void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset) { + make_direct_map(true); + + /* only get vector by 1 id */ + FAISS_ASSERT(n == 1); + if (!bitset || !bitset->test(xid[0])) { + reconstruct(xid[0], x + 0 * d); + } else { + memset(x, UINT8_MAX, d * sizeof(uint8_t)); + } +} + +void IndexBinaryIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) { + make_direct_map(true); + + auto x = new uint8_t[n * d]; + for (idx_t i = 0; i < n; ++i) { + reconstruct(xid[i], x + i * d); + } + + search(n, x, k, distances, labels, bitset); + delete []x; +} +#endif + +void IndexBinaryIVF::reconstruct(idx_t key, uint8_t *recons) const { + idx_t lo = direct_map.get (key); + reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons); +} + +void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const { + FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal)); + + for (idx_t list_no = 0; list_no < nlist; list_no++) { + size_t list_size = invlists->list_size(list_no); + const Index::idx_t *idlist = invlists->get_ids(list_no); + + for (idx_t offset = 0; offset < list_size; offset++) { + idx_t id = idlist[offset]; + if (!(id >= i0 && id < i0 + ni)) { + continue; + } + + uint8_t *reconstructed = recons + (id - i0) * d; + reconstruct_from_offset(list_no, offset, reconstructed); + } + } +} + +void IndexBinaryIVF::search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + uint8_t *recons) const { + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new int32_t[n * nprobe]); + + quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); + + invlists->prefetch_lists(idx.get(), n * nprobe); + + // search_preassigned() with `store_pairs` enabled to obtain the list_no + // and offset into `codes` for reconstruction + search_preassigned(n, x, k, idx.get(), coarse_dis.get(), + distances, labels, /* store_pairs */true); + for (idx_t i = 0; i < n; ++i) { + for (idx_t j = 0; j < k; ++j) { + idx_t ij = i * k + j; + idx_t key = labels[ij]; + uint8_t *reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + int list_no = key >> 32; + int offset = key & 0xffffffff; + + // Update label to the actual id + labels[ij] = invlists->get_single_id(list_no, offset); + + reconstruct_from_offset(list_no, offset, reconstructed); + } + } + } +} + +void IndexBinaryIVF::reconstruct_from_offset(idx_t list_no, idx_t offset, + uint8_t *recons) const { + memcpy(recons, invlists->get_single_code(list_no, offset), code_size); +} + +void IndexBinaryIVF::reset() { + direct_map.clear(); + invlists->reset(); + ntotal = 0; +} + +size_t IndexBinaryIVF::remove_ids(const IDSelector& sel) { + size_t nremove = direct_map.remove_ids (sel, invlists); + ntotal -= nremove; + return nremove; +} + +void IndexBinaryIVF::train(idx_t n, const uint8_t *x) { + if (verbose) { + printf("Training quantizer\n"); + } + + if (quantizer->is_trained && (quantizer->ntotal == nlist)) { + if (verbose) { + printf("IVF quantizer does not need training.\n"); + } + } else { + if (verbose) { + printf("Training quantizer on %ld vectors in %dD\n", n, d); + } + + Clustering clus(d, nlist, cp); + quantizer->reset(); + + IndexFlat index_tmp; + + if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + index_tmp = IndexFlat(d, METRIC_Jaccard); + } else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) { + // unsupported + FAISS_THROW_MSG("IVF not to support Substructure and Superstructure."); + } else { + index_tmp = IndexFlat(d, METRIC_L2); + } + + if (clustering_index && verbose) { + printf("using clustering_index of dimension %d to do the clustering\n", + clustering_index->d); + } + + // LSH codec that is able to convert the binary vectors to floats. + IndexLSH codec(d, d, false, false); + + clus.train_encoded (n, x, &codec, clustering_index ? *clustering_index : index_tmp); + + // convert clusters to binary + std::unique_ptr x_b(new uint8_t[clus.k * code_size]); + real_to_binary(d * clus.k, clus.centroids.data(), x_b.get()); + + quantizer->add(clus.k, x_b.get()); + quantizer->is_trained = true; + } + + is_trained = true; +} + +void IndexBinaryIVF::merge_from(IndexBinaryIVF &other, idx_t add_id) { + // minimal sanity checks + FAISS_THROW_IF_NOT(other.d == d); + FAISS_THROW_IF_NOT(other.nlist == nlist); + FAISS_THROW_IF_NOT(other.code_size == code_size); + FAISS_THROW_IF_NOT_MSG(direct_map.no() && other.direct_map.no(), + "direct map copy not implemented"); + FAISS_THROW_IF_NOT_MSG(typeid (*this) == typeid (other), + "can only merge indexes of the same type"); + + invlists->merge_from (other.invlists, add_id); + + ntotal += other.ntotal; + other.ntotal = 0; +} + +void IndexBinaryIVF::replace_invlists(InvertedLists *il, bool own) { + FAISS_THROW_IF_NOT(il->nlist == nlist && + il->code_size == code_size); + if (own_invlists) { + delete invlists; + } + invlists = il; + own_invlists = own; +} + + +namespace { + +using idx_t = Index::idx_t; + + +template +struct IVFBinaryScannerL2: BinaryInvertedListScanner { + + HammingComputer hc; + size_t code_size; + bool store_pairs; + + IVFBinaryScannerL2 (size_t code_size, bool store_pairs): + code_size (code_size), store_pairs(store_pairs) + {} + + void set_query (const uint8_t *query_vector) override { + hc.set (query_vector, code_size); + } + + idx_t list_no; + void set_list (idx_t list_no, uint8_t /* coarse_dis */) override { + this->list_no = list_no; + } + + uint32_t distance_to_code (const uint8_t *code) const override { + return hc.hamming (code); + } + + size_t scan_codes (size_t n, + const uint8_t *codes, + const idx_t *ids, + int32_t *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset) const override + { + using C = CMax; + + size_t nup = 0; + for (size_t j = 0; j < n; j++) { + if (!bitset || !bitset->test(ids[j])) { + uint32_t dis = hc.hamming (codes); + if (dis < simi[0]) { + idx_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + heap_swap_top (k, simi, idxi, dis, id); + nup++; + } + } + codes += code_size; + } + return nup; + } + + void scan_codes_range (size_t n, + const uint8_t *codes, + const idx_t *ids, + int radius, + RangeQueryResult &result) const + { + size_t nup = 0; + for (size_t j = 0; j < n; j++) { + uint32_t dis = hc.hamming (codes); + if (dis < radius) { + int64_t id = store_pairs ? lo_build (list_no, j) : ids[j]; + result.add (dis, id); + } + codes += code_size; + } + } +}; + +template +struct IVFBinaryScannerJaccard: BinaryInvertedListScanner { + DistanceComputer hc; + size_t code_size; + + IVFBinaryScannerJaccard (size_t code_size): code_size (code_size) + {} + + void set_query (const uint8_t *query_vector) override { + hc.set (query_vector, code_size); + } + + idx_t list_no; + void set_list (idx_t list_no, uint8_t /* coarse_dis */) override { + this->list_no = list_no; + } + + uint32_t distance_to_code (const uint8_t *code) const override { + return 0; + } + + size_t scan_codes (size_t n, + const uint8_t *codes, + const idx_t *ids, + int32_t *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset = nullptr) const override + { + using C = CMax; + float* psimi = (float*)simi; + size_t nup = 0; + for (size_t j = 0; j < n; j++) { + if(!bitset || !bitset->test(ids[j])){ + float dis = hc.compute (codes); + + if (dis < psimi[0]) { + idx_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + heap_swap_top (k, psimi, idxi, dis, id); + nup++; + } + } + codes += code_size; + } + return nup; + } + + void scan_codes_range (size_t n, + const uint8_t *codes, + const idx_t *ids, + int radius, + RangeQueryResult &result) const override { + // not yet + } +}; + +template +BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) { +#define HC(name) return new IVFBinaryScannerL2 (code_size, store_pairs) + switch (code_size) { + case 4: HC(HammingComputer4); + case 8: HC(HammingComputer8); + case 16: HC(HammingComputer16); + case 20: HC(HammingComputer20); + case 32: HC(HammingComputer32); + case 64: HC(HammingComputer64); + default: + if (code_size % 8 == 0) { + HC(HammingComputerM8); + } else if (code_size % 4 == 0) { + HC(HammingComputerM4); + } else { + HC(HammingComputerDefault); + } + } +#undef HC +} + +template +BinaryInvertedListScanner *select_IVFBinaryScannerJaccard (size_t code_size) { + switch (code_size) { +#define HANDLE_CS(cs) \ + case cs: \ + return new IVFBinaryScannerJaccard (cs); + HANDLE_CS(16) + HANDLE_CS(32) + HANDLE_CS(64) + HANDLE_CS(128) + HANDLE_CS(256) + HANDLE_CS(512) +#undef HANDLE_CS + default: + return new IVFBinaryScannerJaccard(code_size); + } +} + +void search_knn_hamming_heap(const IndexBinaryIVF& ivf, + size_t n, + const uint8_t *x, + idx_t k, + const idx_t *keys, + const int32_t * coarse_dis, + int32_t *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset = nullptr) +{ + long nprobe = params ? params->nprobe : ivf.nprobe; + long max_codes = params ? params->max_codes : ivf.max_codes; + MetricType metric_type = ivf.metric_type; + + // almost verbatim copy from IndexIVF::search_preassigned + + size_t nlistv = 0, ndis = 0, nheap = 0; + using HeapForIP = CMin; + using HeapForL2 = CMax; + +#pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap) + { + std::unique_ptr scanner + (ivf.get_InvertedListScanner (store_pairs)); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *xi = x + i * ivf.code_size; + scanner->set_query(xi); + + const idx_t * keysi = keys + i * nprobe; + int32_t * simi = distances + k * i; + idx_t * idxi = labels + k * i; + + if (metric_type == METRIC_INNER_PRODUCT) { + heap_heapify (k, simi, idxi); + } else { + heap_heapify (k, simi, idxi); + } + + size_t nscan = 0; + + for (size_t ik = 0; ik < nprobe; ik++) { + idx_t key = keysi[ik]; /* select the list */ + if (key < 0) { + // not enough centroids for multiprobe + continue; + } + FAISS_THROW_IF_NOT_FMT + (key < (idx_t) ivf.nlist, + "Invalid key=%ld at ik=%ld nlist=%ld\n", + key, ik, ivf.nlist); + + scanner->set_list (key, coarse_dis[i * nprobe + ik]); + + nlistv++; + + size_t list_size = ivf.invlists->list_size(key); + InvertedLists::ScopedCodes scodes (ivf.invlists, key); + std::unique_ptr sids; + const Index::idx_t * ids = nullptr; + + if (!store_pairs) { + sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key)); + ids = sids->get(); + } + + nheap += scanner->scan_codes (list_size, scodes.get(), + ids, simi, idxi, k, bitset); + + nscan += list_size; + if (max_codes && nscan >= max_codes) + break; + } + + ndis += nscan; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_reorder (k, simi, idxi); + } else { + heap_reorder (k, simi, idxi); + } + + } // parallel for + } // parallel + + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nheap_updates += nheap; + +} + +void search_knn_binary_dis_heap(const IndexBinaryIVF& ivf, + size_t n, + const uint8_t *x, + idx_t k, + const idx_t *keys, + const float * coarse_dis, + float *distances, + idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset = nullptr) +{ + long nprobe = params ? params->nprobe : ivf.nprobe; + long max_codes = params ? params->max_codes : ivf.max_codes; + MetricType metric_type = ivf.metric_type; + + // almost verbatim copy from IndexIVF::search_preassigned + + size_t nlistv = 0, ndis = 0, nheap = 0; + using HeapForJaccard = CMax; + +#pragma omp parallel if(n > 1) reduction(+: nlistv, ndis, nheap) + { + std::unique_ptr scanner + (ivf.get_InvertedListScanner(store_pairs)); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *xi = x + i * ivf.code_size; + scanner->set_query(xi); + + const idx_t * keysi = keys + i * nprobe; + float * simi = distances + k * i; + idx_t * idxi = labels + k * i; + + heap_heapify (k, simi, idxi); + + size_t nscan = 0; + + for (size_t ik = 0; ik < nprobe; ik++) { + idx_t key = keysi[ik]; /* select the list */ + if (key < 0) { + // not enough centroids for multiprobe + continue; + } + FAISS_THROW_IF_NOT_FMT + (key < (idx_t) ivf.nlist, + "Invalid key=%ld at ik=%ld nlist=%ld\n", + key, ik, ivf.nlist); + + scanner->set_list (key, (int32_t)coarse_dis[i * nprobe + ik]); + + nlistv++; + + size_t list_size = ivf.invlists->list_size(key); + InvertedLists::ScopedCodes scodes (ivf.invlists, key); + std::unique_ptr sids; + const Index::idx_t * ids = nullptr; + + if (!store_pairs) { + sids.reset (new InvertedLists::ScopedIds (ivf.invlists, key)); + ids = sids->get(); + } + + nheap += scanner->scan_codes (list_size, scodes.get(), + ids, (int32_t*)simi, idxi, k, bitset); + + nscan += list_size; + if (max_codes && nscan >= max_codes) + break; + } + + ndis += nscan; + heap_reorder (k, simi, idxi); + + } // parallel for + } // parallel + + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nheap_updates += nheap; +} + +template +void search_knn_hamming_count(const IndexBinaryIVF& ivf, + size_t nx, + const uint8_t *x, + const idx_t *keys, + int k, + int32_t *distances, + idx_t *labels, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset = nullptr) { + const int nBuckets = ivf.d + 1; + std::vector all_counters(nx * nBuckets, 0); + std::unique_ptr all_ids_per_dis(new idx_t[nx * nBuckets * k]); + + long nprobe = params ? params->nprobe : ivf.nprobe; + long max_codes = params ? params->max_codes : ivf.max_codes; + + std::vector> cs; + for (size_t i = 0; i < nx; ++i) { + cs.push_back(HCounterState( + all_counters.data() + i * nBuckets, + all_ids_per_dis.get() + i * nBuckets * k, + x + i * ivf.code_size, + ivf.d, + k + )); + } + + size_t nlistv = 0, ndis = 0; + +#pragma omp parallel for reduction(+: nlistv, ndis) + for (size_t i = 0; i < nx; i++) { + const idx_t * keysi = keys + i * nprobe; + HCounterState& csi = cs[i]; + + size_t nscan = 0; + + for (size_t ik = 0; ik < nprobe; ik++) { + idx_t key = keysi[ik]; /* select the list */ + if (key < 0) { + // not enough centroids for multiprobe + continue; + } + FAISS_THROW_IF_NOT_FMT ( + key < (idx_t) ivf.nlist, + "Invalid key=%ld at ik=%ld nlist=%ld\n", + key, ik, ivf.nlist); + + nlistv++; + size_t list_size = ivf.invlists->list_size(key); + InvertedLists::ScopedCodes scodes (ivf.invlists, key); + const uint8_t *list_vecs = scodes.get(); + const Index::idx_t *ids = store_pairs + ? nullptr + : ivf.invlists->get_ids(key); + + for (size_t j = 0; j < list_size; j++) { + if (!bitset || !bitset->test(ids[j])) { + const uint8_t *yj = list_vecs + ivf.code_size * j; + idx_t id = store_pairs ? (key << 32 | j) : ids[j]; + csi.update_counter(yj, id); + } + } + if (ids) + ivf.invlists->release_ids (key, ids); + + nscan += list_size; + if (max_codes && nscan >= max_codes) + break; + } + ndis += nscan; + + int nres = 0; + for (int b = 0; b < nBuckets && nres < k; b++) { + for (int l = 0; l < csi.counters[b] && nres < k; l++) { + labels[i * k + nres] = csi.ids_per_dis[b * k + l]; + distances[i * k + nres] = b; + nres++; + } + } + while (nres < k) { + labels[i * k + nres] = -1; + distances[i * k + nres] = std::numeric_limits::max(); + ++nres; + } + } + + indexIVF_stats.nq += nx; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; +} + + + +template +void search_knn_hamming_count_1 ( + const IndexBinaryIVF& ivf, + size_t nx, + const uint8_t *x, + const idx_t *keys, + int k, + int32_t *distances, + idx_t *labels, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset = nullptr) { + switch (ivf.code_size) { +#define HANDLE_CS(cs) \ + case cs: \ + search_knn_hamming_count( \ + ivf, nx, x, keys, k, distances, labels, params, bitset); \ + break; + HANDLE_CS(4); + HANDLE_CS(8); + HANDLE_CS(16); + HANDLE_CS(20); + HANDLE_CS(32); + HANDLE_CS(64); +#undef HANDLE_CS + default: + if (ivf.code_size % 8 == 0) { + search_knn_hamming_count + (ivf, nx, x, keys, k, distances, labels, params, bitset); + } else if (ivf.code_size % 4 == 0) { + search_knn_hamming_count + (ivf, nx, x, keys, k, distances, labels, params, bitset); + } else { + search_knn_hamming_count + (ivf, nx, x, keys, k, distances, labels, params, bitset); + } + break; + } +} + +} // namespace + +BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner + (bool store_pairs) const +{ + switch (metric_type) { + case METRIC_Jaccard: + case METRIC_Tanimoto: + if (store_pairs) { + return select_IVFBinaryScannerJaccard (code_size); + } else { + return select_IVFBinaryScannerJaccard (code_size); + } + case METRIC_Substructure: + case METRIC_Superstructure: + // unsupported + return nullptr; + default: + if (store_pairs) { + return select_IVFBinaryScannerL2(code_size); + } else { + return select_IVFBinaryScannerL2(code_size); + } + } +} + +void IndexBinaryIVF::search_preassigned(idx_t n, const uint8_t *x, idx_t k, + const idx_t *idx, + const int32_t * coarse_dis, + int32_t *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset + ) const { + if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + if (use_heap) { + float *D = new float[k * n]; + float *c_dis = new float [n * nprobe]; + memcpy(c_dis, coarse_dis, sizeof(float) * n * nprobe); + search_knn_binary_dis_heap(*this, n, x, k, idx, c_dis , + D, labels, store_pairs, + params, bitset); + if (metric_type == METRIC_Tanimoto) { + for (int i = 0; i < k * n; i++) { + D[i] = -log2(1-D[i]); + } + } + memcpy(distances, D, sizeof(float) * n * k); + delete [] D; + delete [] c_dis; + } else { + //not implemented + } + } else if (metric_type == METRIC_Substructure || metric_type == METRIC_Superstructure) { + // unsupported + } else { + if (use_heap) { + search_knn_hamming_heap (*this, n, x, k, idx, coarse_dis, + distances, labels, store_pairs, + params, bitset); + } else { + if (store_pairs) { + search_knn_hamming_count_1 + (*this, n, x, idx, k, distances, labels, params, bitset); + } else { + search_knn_hamming_count_1 + (*this, n, x, idx, k, distances, labels, params, bitset); + } + } + } +} + +void IndexBinaryIVF::range_search( + idx_t n, const uint8_t *x, int radius, + RangeSearchResult *res, + ConcurrentBitsetPtr bitset) const +{ + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new int32_t[n * nprobe]); + + double t0 = getmillisecs(); + quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists(idx.get(), n * nprobe); + + bool store_pairs = false; + size_t nlistv = 0, ndis = 0; + + std::vector all_pres (omp_get_max_threads()); + +#pragma omp parallel reduction(+: nlistv, ndis) + { + RangeSearchPartialResult pres(res); + std::unique_ptr scanner + (get_InvertedListScanner(store_pairs)); + FAISS_THROW_IF_NOT (scanner.get ()); + + all_pres[omp_get_thread_num()] = &pres; + + auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) + { + + idx_t key = idx[i * nprobe + ik]; /* select the list */ + if (key < 0) return; + FAISS_THROW_IF_NOT_FMT ( + key < (idx_t) nlist, + "Invalid key=%ld at ik=%ld nlist=%ld\n", + key, ik, nlist); + const size_t list_size = invlists->list_size(key); + + if (list_size == 0) return; + + InvertedLists::ScopedCodes scodes (invlists, key); + InvertedLists::ScopedIds ids (invlists, key); + + scanner->set_list (key, coarse_dis[i * nprobe + ik]); + nlistv++; + ndis += list_size; + scanner->scan_codes_range (list_size, scodes.get(), + ids.get(), radius, qres); + }; + +#pragma omp for + for (size_t i = 0; i < n; i++) { + scanner->set_query (x + i * code_size); + + RangeQueryResult & qres = pres.new_result (i); + + for (size_t ik = 0; ik < nprobe; ik++) { + scan_list_func (i, ik, qres); + } + + } + + pres.finalize(); + + } + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.search_time += getmillisecs() - t0; + +} + + + + +IndexBinaryIVF::~IndexBinaryIVF() { + if (own_invlists) { + delete invlists; + } + + if (own_fields) { + delete quantizer; + } +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexBinaryIVF.h b/core/src/index/thirdparty/faiss/IndexBinaryIVF.h new file mode 100644 index 0000000000..c3cd7e7443 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexBinaryIVF.h @@ -0,0 +1,234 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_BINARY_IVF_H +#define FAISS_INDEX_BINARY_IVF_H + + +#include + +#include +#include +#include +#include + + +namespace faiss { + +struct BinaryInvertedListScanner; + +/** Index based on a inverted file (IVF) + * + * In the inverted file, the quantizer (an IndexBinary instance) provides a + * quantization index for each vector to be added. The quantization + * index maps to a list (aka inverted list or posting list), where the + * id of the vector is stored. + * + * Otherwise the object is similar to the IndexIVF + */ +struct IndexBinaryIVF : IndexBinary { + /// Acess to the actual data + InvertedLists *invlists; + bool own_invlists; + + size_t nprobe; ///< number of probes at query time + size_t max_codes; ///< max nb of codes to visit to do a query + + /** Select between using a heap or counting to select the k smallest values + * when scanning inverted lists. + */ + bool use_heap = true; + + /// map for direct access to the elements. Enables reconstruct(). + DirectMap direct_map; + + IndexBinary *quantizer; ///< quantizer that maps vectors to inverted lists + size_t nlist; ///< number of possible key values + + bool own_fields; ///< whether object owns the quantizer + + ClusteringParameters cp; ///< to override default clustering params + Index *clustering_index; ///< to override index used during clustering + + /** The Inverted file takes a quantizer (an IndexBinary) on input, + * which implements the function mapping a vector to a list + * identifier. The pointer is borrowed: the quantizer should not + * be deleted while the IndexBinaryIVF is in use. + */ + IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist); + + IndexBinaryIVF(IndexBinary *quantizer, size_t d, size_t nlist, MetricType metric); + + IndexBinaryIVF(); + + ~IndexBinaryIVF() override; + + void reset() override; + + /// Trains the quantizer + void train(idx_t n, const uint8_t *x) override; + + void add(idx_t n, const uint8_t *x) override; + + void add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids) override; + + /// same as add_with_ids, with precomputed coarse quantizer + void add_core (idx_t n, const uint8_t * x, const idx_t *xids, + const idx_t *precomputed_idx); + + /** Search a set of vectors, that are pre-quantized by the IVF + * quantizer. Fill in the corresponding heaps with the query + * results. search() calls this. + * + * @param n nb of vectors to query + * @param x query vectors, size nx * d + * @param assign coarse quantization indices, size nx * nprobe + * @param centroid_dis + * distances to coarse centroids, size nx * nprobe + * @param distance + * output distances, size n * k + * @param labels output labels, size n * k + * @param store_pairs store inv list index + inv list offset + * instead in upper/lower 32 bit of result, + * instead of ids (used for reranking). + * @param params used to override the object's search parameters + */ + void search_preassigned(idx_t n, const uint8_t *x, idx_t k, + const idx_t *assign, + const int32_t *centroid_dis, + int32_t *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params=nullptr, + ConcurrentBitsetPtr bitset = nullptr + ) const; + + virtual BinaryInvertedListScanner *get_InvertedListScanner ( + bool store_pairs=false) const; + + /** assign the vectors, then call search_preassign */ + void search(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset = nullptr) const override; + +#if 0 + /** get raw vectors by ids */ + void get_vector_by_id(idx_t n, const idx_t *xid, uint8_t *x, ConcurrentBitsetPtr bitset = nullptr) override; + + void search_by_id (idx_t n, const idx_t *xid, idx_t k, int32_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) override; +#endif + + void range_search(idx_t n, const uint8_t *x, int radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, uint8_t *recons) const override; + + /** Reconstruct a subset of the indexed vectors. + * + * Overrides default implementation to bypass reconstruct() which requires + * direct_map to be maintained. + * + * @param i0 first vector to reconstruct + * @param ni nb of vectors to reconstruct + * @param recons output array of reconstructed vectors, size ni * d / 8 + */ + void reconstruct_n(idx_t i0, idx_t ni, uint8_t *recons) const override; + + /** Similar to search, but also reconstructs the stored vectors (or an + * approximation in the case of lossy coding) for the search results. + * + * Overrides default implementation to avoid having to maintain direct_map + * and instead fetch the code offsets through the `store_pairs` flag in + * search_preassigned(). + * + * @param recons reconstructed vectors size (n, k, d / 8) + */ + void search_and_reconstruct(idx_t n, const uint8_t *x, idx_t k, + int32_t *distances, idx_t *labels, + uint8_t *recons) const override; + + /** Reconstruct a vector given the location in terms of (inv list index + + * inv list offset) instead of the id. + * + * Useful for reconstructing when the direct_map is not maintained and + * the inv list offset is computed by search_preassigned() with + * `store_pairs` set. + */ + virtual void reconstruct_from_offset(idx_t list_no, idx_t offset, + uint8_t* recons) const; + + + /// Dataset manipulation functions + size_t remove_ids(const IDSelector& sel) override; + + /** moves the entries from another dataset to self. On output, + * other is empty. add_id is added to all moved ids (for + * sequential ids, this would be this->ntotal */ + virtual void merge_from(IndexBinaryIVF& other, idx_t add_id); + + size_t get_list_size(size_t list_no) const + { return invlists->list_size(list_no); } + + /** intialize a direct map + * + * @param new_maintain_direct_map if true, create a direct map, + * else clear it + */ + void make_direct_map(bool new_maintain_direct_map=true); + + void set_direct_map_type (DirectMap::Type type); + + void replace_invlists(InvertedLists *il, bool own=false); +}; + + +struct BinaryInvertedListScanner { + + using idx_t = Index::idx_t; + + /// from now on we handle this query. + virtual void set_query (const uint8_t *query_vector) = 0; + + /// following codes come from this inverted list + virtual void set_list (idx_t list_no, uint8_t coarse_dis) = 0; + + /// compute a single query-to-code distance + virtual uint32_t distance_to_code (const uint8_t *code) const = 0; + + /** compute the distances to codes. (distances, labels) should be + * organized as a min- or max-heap + * + * @param n number of codes to scan + * @param codes codes to scan (n * code_size) + * @param ids corresponding ids (ignored if store_pairs) + * @param distances heap distances (size k) + * @param labels heap labels (size k) + * @param k heap size + */ + virtual size_t scan_codes (size_t n, + const uint8_t *codes, + const idx_t *ids, + int32_t *distances, idx_t *labels, + size_t k, + ConcurrentBitsetPtr bitset = nullptr) const = 0; + + virtual void scan_codes_range (size_t n, + const uint8_t *codes, + const idx_t *ids, + int radius, + RangeQueryResult &result) const = 0; + + virtual ~BinaryInvertedListScanner () {} + +}; + + +} // namespace faiss + +#endif // FAISS_INDEX_BINARY_IVF_H diff --git a/core/src/index/thirdparty/faiss/IndexFlat.cpp b/core/src/index/thirdparty/faiss/IndexFlat.cpp new file mode 100644 index 0000000000..7780650da3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexFlat.cpp @@ -0,0 +1,540 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +IndexFlat::IndexFlat (idx_t d, MetricType metric): + Index(d, metric) +{ +} + + + +void IndexFlat::add (idx_t n, const float *x) { + xb.insert(xb.end(), x, x + n * d); + ntotal += n; +} + + +void IndexFlat::reset() { + xb.clear(); + ntotal = 0; +} + + +void IndexFlat::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + // we see the distances and labels as heaps + + if (metric_type == METRIC_INNER_PRODUCT) { + float_minheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + knn_inner_product (x, xb.data(), d, n, ntotal, &res, bitset); + } else if (metric_type == METRIC_L2) { + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + knn_L2sqr (x, xb.data(), d, n, ntotal, &res, bitset); + } else if (metric_type == METRIC_Jaccard) { + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + knn_jaccard(x, xb.data(), d, n, ntotal, &res, bitset); + } else { + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + knn_extra_metrics (x, xb.data(), d, n, ntotal, + metric_type, metric_arg, + &res, bitset); + } +} + +void IndexFlat::assign(idx_t n, const float * x, idx_t * labels, float* distances) +{ + // usually used in IVF k-means algorithm + float *dis_inner = (distances == nullptr) ? new float[n] : distances; + switch (metric_type) { + case METRIC_INNER_PRODUCT: + case METRIC_L2: { + // ignore the metric_type, both use L2 + elkan_L2_sse(x, xb.data(), d, n, ntotal, labels, dis_inner); + break; + } + default: { + // binary metrics + // There may be something wrong, but maintain the original logic now. + Index::assign(n, x, labels, dis_inner); + break; + } + } + if (distances == nullptr) { + delete[] dis_inner; + } +} + +void IndexFlat::range_search (idx_t n, const float *x, float radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + switch (metric_type) { + case METRIC_INNER_PRODUCT: + range_search_inner_product (x, xb.data(), d, n, ntotal, + radius, result); + break; + case METRIC_L2: + range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result); + break; + default: + FAISS_THROW_MSG("metric type not supported"); + } +} + + +void IndexFlat::compute_distance_subset ( + idx_t n, + const float *x, + idx_t k, + float *distances, + const idx_t *labels) const +{ + switch (metric_type) { + case METRIC_INNER_PRODUCT: + fvec_inner_products_by_idx ( + distances, + x, xb.data(), labels, d, n, k); + break; + case METRIC_L2: + fvec_L2sqr_by_idx ( + distances, + x, xb.data(), labels, d, n, k); + break; + default: + FAISS_THROW_MSG("metric type not supported"); + } + +} + +size_t IndexFlat::remove_ids (const IDSelector & sel) +{ + idx_t j = 0; + for (idx_t i = 0; i < ntotal; i++) { + if (sel.is_member (i)) { + // should be removed + } else { + if (i > j) { + memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d); + } + j++; + } + } + size_t nremove = ntotal - j; + if (nremove > 0) { + ntotal = j; + xb.resize (ntotal * d); + } + return nremove; +} + + +namespace { + + +struct FlatL2Dis : DistanceComputer { + size_t d; + Index::idx_t nb; + const float *q; + const float *b; + size_t ndis; + + float operator () (idx_t i) override { + ndis++; + return fvec_L2sqr(q, b + i * d, d); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return fvec_L2sqr(b + j * d, b + i * d, d); + } + + explicit FlatL2Dis(const IndexFlat& storage, const float *q = nullptr) + : d(storage.d), + nb(storage.ntotal), + q(q), + b(storage.xb.data()), + ndis(0) {} + + void set_query(const float *x) override { + q = x; + } +}; + +struct FlatIPDis : DistanceComputer { + size_t d; + Index::idx_t nb; + const float *q; + const float *b; + size_t ndis; + + float operator () (idx_t i) override { + ndis++; + return fvec_inner_product (q, b + i * d, d); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return fvec_inner_product (b + j * d, b + i * d, d); + } + + explicit FlatIPDis(const IndexFlat& storage, const float *q = nullptr) + : d(storage.d), + nb(storage.ntotal), + q(q), + b(storage.xb.data()), + ndis(0) {} + + void set_query(const float *x) override { + q = x; + } +}; + + + + +} // namespace + + +DistanceComputer * IndexFlat::get_distance_computer() const { + if (metric_type == METRIC_L2) { + return new FlatL2Dis(*this); + } else if (metric_type == METRIC_INNER_PRODUCT) { + return new FlatIPDis(*this); + } else { + return get_extra_distance_computer (d, metric_type, metric_arg, + ntotal, xb.data()); + } +} + + +void IndexFlat::reconstruct (idx_t key, float * recons) const +{ + memcpy (recons, &(xb[key * d]), sizeof(*recons) * d); +} + + +/* The standalone codec interface */ +size_t IndexFlat::sa_code_size () const +{ + return sizeof(float) * d; +} + +void IndexFlat::sa_encode (idx_t n, const float *x, uint8_t *bytes) const +{ + memcpy (bytes, x, sizeof(float) * d * n); +} + +void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const +{ + memcpy (x, bytes, sizeof(float) * d * n); +} + + + + +/*************************************************** + * IndexFlatL2BaseShift + ***************************************************/ + +IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift): + IndexFlatL2 (d), shift (nshift) +{ + memcpy (this->shift.data(), shift, sizeof(float) * nshift); +} + +void IndexFlatL2BaseShift::search ( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (shift.size() == ntotal); + + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data()); +} + + + +/*************************************************** + * IndexRefineFlat + ***************************************************/ + +IndexRefineFlat::IndexRefineFlat (Index *base_index): + Index (base_index->d, base_index->metric_type), + refine_index (base_index->d, base_index->metric_type), + base_index (base_index), own_fields (false), + k_factor (1) +{ + is_trained = base_index->is_trained; + FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0, + "base_index should be empty in the beginning"); +} + +IndexRefineFlat::IndexRefineFlat () { + base_index = nullptr; + own_fields = false; + k_factor = 1; +} + + +void IndexRefineFlat::train (idx_t n, const float *x) +{ + base_index->train (n, x); + is_trained = true; +} + +void IndexRefineFlat::add (idx_t n, const float *x) { + FAISS_THROW_IF_NOT (is_trained); + base_index->add (n, x); + refine_index.add (n, x); + ntotal = refine_index.ntotal; +} + +void IndexRefineFlat::reset () +{ + base_index->reset (); + refine_index.reset (); + ntotal = 0; +} + +namespace { +typedef faiss::Index::idx_t idx_t; + +template +static void reorder_2_heaps ( + idx_t n, + idx_t k, idx_t *labels, float *distances, + idx_t k_base, const idx_t *base_labels, const float *base_distances) +{ +#pragma omp parallel for + for (idx_t i = 0; i < n; i++) { + idx_t *idxo = labels + i * k; + float *diso = distances + i * k; + const idx_t *idxi = base_labels + i * k_base; + const float *disi = base_distances + i * k_base; + + heap_heapify (k, diso, idxo, disi, idxi, k); + if (k_base != k) { // add remaining elements + heap_addn (k, diso, idxo, disi + k, idxi + k, k_base - k); + } + heap_reorder (k, diso, idxo); + } +} + + +} + + +void IndexRefineFlat::search ( + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + idx_t k_base = idx_t (k * k_factor); + idx_t * base_labels = labels; + float * base_distances = distances; + ScopeDeleter del1; + ScopeDeleter del2; + + + if (k != k_base) { + base_labels = new idx_t [n * k_base]; + del1.set (base_labels); + base_distances = new float [n * k_base]; + del2.set (base_distances); + } + + base_index->search (n, x, k_base, base_distances, base_labels); + + for (int i = 0; i < n * k_base; i++) + assert (base_labels[i] >= -1 && + base_labels[i] < ntotal); + + // compute refined distances + refine_index.compute_distance_subset ( + n, x, k_base, base_distances, base_labels); + + // sort and store result + if (metric_type == METRIC_L2) { + typedef CMax C; + reorder_2_heaps ( + n, k, labels, distances, + k_base, base_labels, base_distances); + + } else if (metric_type == METRIC_INNER_PRODUCT) { + typedef CMin C; + reorder_2_heaps ( + n, k, labels, distances, + k_base, base_labels, base_distances); + } else { + FAISS_THROW_MSG("Metric type not supported"); + } + +} + + + +IndexRefineFlat::~IndexRefineFlat () +{ + if (own_fields) delete base_index; +} + +/*************************************************** + * IndexFlat1D + ***************************************************/ + + +IndexFlat1D::IndexFlat1D (bool continuous_update): + IndexFlatL2 (1), + continuous_update (continuous_update) +{ +} + +/// if not continuous_update, call this between the last add and +/// the first search +void IndexFlat1D::update_permutation () +{ + perm.resize (ntotal); + if (ntotal < 1000000) { + fvec_argsort (ntotal, xb.data(), (size_t*)perm.data()); + } else { + fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data()); + } +} + +void IndexFlat1D::add (idx_t n, const float *x) +{ + IndexFlatL2::add (n, x); + if (continuous_update) + update_permutation(); +} + +void IndexFlat1D::reset() +{ + IndexFlatL2::reset(); + perm.clear(); +} + +void IndexFlat1D::search ( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal, + "Call update_permutation before search"); + +#pragma omp parallel for + for (idx_t i = 0; i < n; i++) { + + float q = x[i]; // query + float *D = distances + i * k; + idx_t *I = labels + i * k; + + // binary search + idx_t i0 = 0, i1 = ntotal; + idx_t wp = 0; + + if (xb[perm[i0]] > q) { + i1 = 0; + goto finish_right; + } + + if (xb[perm[i1 - 1]] <= q) { + i0 = i1 - 1; + goto finish_left; + } + + while (i0 + 1 < i1) { + idx_t imed = (i0 + i1) / 2; + if (xb[perm[imed]] <= q) i0 = imed; + else i1 = imed; + } + + // query is between xb[perm[i0]] and xb[perm[i1]] + // expand to nearest neighs + + while (wp < k) { + float xleft = xb[perm[i0]]; + float xright = xb[perm[i1]]; + + if (q - xleft < xright - q) { + D[wp] = q - xleft; + I[wp] = perm[i0]; + i0--; wp++; + if (i0 < 0) { goto finish_right; } + } else { + D[wp] = xright - q; + I[wp] = perm[i1]; + i1++; wp++; + if (i1 >= ntotal) { goto finish_left; } + } + } + goto done; + + finish_right: + // grow to the right from i1 + while (wp < k) { + if (i1 < ntotal) { + D[wp] = xb[perm[i1]] - q; + I[wp] = perm[i1]; + i1++; + } else { + D[wp] = std::numeric_limits::infinity(); + I[wp] = -1; + } + wp++; + } + goto done; + + finish_left: + // grow to the left from i0 + while (wp < k) { + if (i0 >= 0) { + D[wp] = q - xb[perm[i0]]; + I[wp] = perm[i0]; + i0--; + } else { + D[wp] = std::numeric_limits::infinity(); + I[wp] = -1; + } + wp++; + } + done: ; + } + +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexFlat.h b/core/src/index/thirdparty/faiss/IndexFlat.h new file mode 100644 index 0000000000..13f8829f5f --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexFlat.h @@ -0,0 +1,189 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef INDEX_FLAT_H +#define INDEX_FLAT_H + +#include + +#include + + +namespace faiss { + +/** Index that stores the full vectors and performs exhaustive search */ +struct IndexFlat: Index { + + /// database vectors, size ntotal * d + std::vector xb; + + explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2); + + void add(idx_t n, const float* x) override; + + void reset() override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void assign ( + idx_t n, + const float * x, + idx_t * labels, + float* distances = nullptr) override; + + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, float* recons) const override; + + /** compute distance with a subset of vectors + * + * @param x query vectors, size n * d + * @param labels indices of the vectors that should be compared + * for each query vector, size n * k + * @param distances + * corresponding output distances, size n * k + */ + void compute_distance_subset ( + idx_t n, + const float *x, + idx_t k, + float *distances, + const idx_t *labels) const; + + /** remove some ids. NB that Because of the structure of the + * indexing structure, the semantics of this operation are + * different from the usual ones: the new ids are shifted */ + size_t remove_ids(const IDSelector& sel) override; + + IndexFlat () {} + + DistanceComputer * get_distance_computer() const override; + + /* The stanadlone codec interface (just memcopies in this case) */ + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + size_t cal_size() { return xb.size() * sizeof(float); } + +}; + + + +struct IndexFlatIP:IndexFlat { + explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {} + IndexFlatIP () {} +}; + + +struct IndexFlatL2:IndexFlat { + explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {} + IndexFlatL2 () {} +}; + + +// same as an IndexFlatL2 but a value is subtracted from each distance +struct IndexFlatL2BaseShift: IndexFlatL2 { + std::vector shift; + + IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift); + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; +}; + + +/** Index that queries in a base_index (a fast one) and refines the + * results with an exact search, hopefully improving the results. + */ +struct IndexRefineFlat: Index { + + /// storage for full vectors + IndexFlat refine_index; + + /// faster index to pre-select the vectors that should be filtered + Index *base_index; + bool own_fields; ///< should the base index be deallocated? + + /// factor between k requested in search and the k requested from + /// the base_index (should be >= 1) + float k_factor; + + explicit IndexRefineFlat (Index *base_index); + + IndexRefineFlat (); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void reset() override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + ~IndexRefineFlat() override; +}; + + +/// optimized version for 1D "vectors". +struct IndexFlat1D:IndexFlatL2 { + bool continuous_update; ///< is the permutation updated continuously? + + std::vector perm; ///< sorted database indices + + explicit IndexFlat1D (bool continuous_update=true); + + /// if not continuous_update, call this between the last add and + /// the first search + void update_permutation (); + + void add(idx_t n, const float* x) override; + + void reset() override; + + /// Warn: the distances returned are L1 not L2 + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; +}; + + +} + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexHNSW.cpp b/core/src/index/thirdparty/faiss/IndexHNSW.cpp new file mode 100644 index 0000000000..c06f9840e2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexHNSW.cpp @@ -0,0 +1,1142 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef __SSE__ +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +} + +namespace faiss { + +using idx_t = Index::idx_t; +using MinimaxHeap = HNSW::MinimaxHeap; +using storage_idx_t = HNSW::storage_idx_t; +using NodeDistFarther = HNSW::NodeDistFarther; + +HNSWStats hnsw_stats; + +/************************************************************** + * add / search blocks of descriptors + **************************************************************/ + +namespace { + + +/* Wrap the distance computer into one that negates the + distances. This makes supporting INNER_PRODUCE search easier */ + +struct NegativeDistanceComputer: DistanceComputer { + + /// owned by this + DistanceComputer *basedis; + + explicit NegativeDistanceComputer(DistanceComputer *basedis): + basedis(basedis) + {} + + void set_query(const float *x) override { + basedis->set_query(x); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) override { + return -(*basedis)(i); + } + + /// compute distance between two stored vectors + float symmetric_dis (idx_t i, idx_t j) override { + return -basedis->symmetric_dis(i, j); + } + + virtual ~NegativeDistanceComputer () + { + delete basedis; + } + +}; + +DistanceComputer *storage_distance_computer(const Index *storage) +{ + if (storage->metric_type == METRIC_INNER_PRODUCT) { + return new NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + + + +void hnsw_add_vertices(IndexHNSW &index_hnsw, + size_t n0, + size_t n, const float *x, + bool verbose, + bool preset_levels = false) { + size_t d = index_hnsw.d; + HNSW & hnsw = index_hnsw.hnsw; + size_t ntotal = n0 + n; + double t0 = getmillisecs(); + if (verbose) { + printf("hnsw_add_vertices: adding %ld elements on top of %ld " + "(preset_levels=%d)\n", + n, n0, int(preset_levels)); + } + + if (n == 0) { + return; + } + + int max_level = hnsw.prepare_level_tab(n, preset_levels); + + if (verbose) { + printf(" max_level = %d\n", max_level); + } + + std::vector locks(ntotal); + for(int i = 0; i < ntotal; i++) + omp_init_lock(&locks[i]); + + // add vectors from highest to lowest level + std::vector hist; + std::vector order(n); + + { // make buckets with vectors of the same level + + // build histogram + for (int i = 0; i < n; i++) { + storage_idx_t pt_id = i + n0; + int pt_level = hnsw.levels[pt_id] - 1; + while (pt_level >= hist.size()) + hist.push_back(0); + hist[pt_level] ++; + } + + // accumulate + std::vector offsets(hist.size() + 1, 0); + for (int i = 0; i < hist.size() - 1; i++) { + offsets[i + 1] = offsets[i] + hist[i]; + } + + // bucket sort + for (int i = 0; i < n; i++) { + storage_idx_t pt_id = i + n0; + int pt_level = hnsw.levels[pt_id] - 1; + order[offsets[pt_level]++] = pt_id; + } + } + + idx_t check_period = InterruptCallback::get_period_hint + (max_level * index_hnsw.d * hnsw.efConstruction); + + { // perform add + RandomGenerator rng2(789); + + int i1 = n; + + for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) { + int i0 = i1 - hist[pt_level]; + + if (verbose) { + printf("Adding %d elements at level %d\n", + i1 - i0, pt_level); + } + + // random permutation to get rid of dataset order bias + for (int j = i0; j < i1; j++) + std::swap(order[j], order[j + rng2.rand_int(i1 - j)]); + + bool interrupt = false; + +#pragma omp parallel if(i1 > i0 + 100) + { + VisitedTable vt (ntotal); + + DistanceComputer *dis = + storage_distance_computer (index_hnsw.storage); + ScopeDeleter1 del(dis); + int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1; + size_t counter = 0; + +#pragma omp for schedule(dynamic) + for (int i = i0; i < i1; i++) { + storage_idx_t pt_id = order[i]; + dis->set_query (x + (pt_id - n0) * d); + + // cannot break + if (interrupt) { + continue; + } + + hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt); + + if (prev_display >= 0 && i - i0 > prev_display + 10000) { + prev_display = i - i0; + printf(" %d / %d\r", i - i0, i1 - i0); + fflush(stdout); + } + + if (counter % check_period == 0) { + if (InterruptCallback::is_interrupted ()) { + interrupt = true; + } + } + counter++; + } + + } + if (interrupt) { + FAISS_THROW_MSG ("computation interrupted"); + } + i1 = i0; + } + FAISS_ASSERT(i1 == 0); + } + if (verbose) { + printf("Done in %.3f ms\n", getmillisecs() - t0); + } + + for(int i = 0; i < ntotal; i++) { + omp_destroy_lock(&locks[i]); + } +} + + +} // namespace + + + + +/************************************************************** + * IndexHNSW implementation + **************************************************************/ + +IndexHNSW::IndexHNSW(int d, int M, MetricType metric): + Index(d, metric), + hnsw(M), + own_fields(false), + storage(nullptr), + reconstruct_from_neighbors(nullptr) +{} + +IndexHNSW::IndexHNSW(Index *storage, int M): + Index(storage->d, storage->metric_type), + hnsw(M), + own_fields(false), + storage(storage), + reconstruct_from_neighbors(nullptr) +{} + +IndexHNSW::~IndexHNSW() { + if (own_fields) { + delete storage; + } +} + +void IndexHNSW::train(idx_t n, const float* x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly"); + // hnsw structure does not require training + storage->train (n, x); + is_trained = true; +} + +void IndexHNSW::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const + +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly"); + size_t nreorder = 0; + + idx_t check_period = InterruptCallback::get_period_hint ( + hnsw.max_level * d * hnsw.efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel reduction(+ : nreorder) + { + VisitedTable vt (ntotal); + + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + +#pragma omp for + for(idx_t i = i0; i < i1; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + maxheap_heapify (k, simi, idxi); + hnsw.search(*dis, k, idxi, simi, vt); + + maxheap_reorder (k, simi, idxi); + + if (reconstruct_from_neighbors && + reconstruct_from_neighbors->k_reorder != 0) { + int k_reorder = reconstruct_from_neighbors->k_reorder; + if (k_reorder == -1 || k_reorder > k) k_reorder = k; + + nreorder += reconstruct_from_neighbors->compute_distances( + k_reorder, idxi, x + i * d, simi); + + // sort top k_reorder + maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder); + maxheap_reorder (k_reorder, simi, idxi); + } + + } + + } + InterruptCallback::check (); + } + + if (metric_type == METRIC_INNER_PRODUCT) { + // we need to revert the negated distances + for (size_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } + + hnsw_stats.nreorder += nreorder; +} + + +void IndexHNSW::add(idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly"); + FAISS_THROW_IF_NOT(is_trained); + int n0 = ntotal; + storage->add(n, x); + ntotal = storage->ntotal; + + hnsw_add_vertices (*this, n0, n, x, verbose, + hnsw.levels.size() == ntotal); +} + +void IndexHNSW::reset() +{ + hnsw.reset(); + storage->reset(); + ntotal = 0; +} + +void IndexHNSW::reconstruct (idx_t key, float* recons) const +{ + storage->reconstruct(key, recons); +} + +void IndexHNSW::shrink_level_0_neighbors(int new_size) +{ +#pragma omp parallel + { + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + +#pragma omp for + for (idx_t i = 0; i < ntotal; i++) { + + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + + std::priority_queue initial_list; + + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) break; + initial_list.emplace(dis->symmetric_dis(i, v1), v1); + + // initial_list.emplace(qdis(v1), v1); + } + + std::vector shrunk_list; + HNSW::shrink_neighbor_list(*dis, initial_list, + shrunk_list, new_size); + + for (size_t j = begin; j < end; j++) { + if (j - begin < shrunk_list.size()) + hnsw.neighbors[j] = shrunk_list[j - begin].id; + else + hnsw.neighbors[j] = -1; + } + } + } + +} + +void IndexHNSW::search_level_0( + idx_t n, const float *x, idx_t k, + const storage_idx_t *nearest, const float *nearest_d, + float *distances, idx_t *labels, int nprobe, + int search_type) const +{ + + storage_idx_t ntotal = hnsw.levels.size(); +#pragma omp parallel + { + DistanceComputer *qdis = storage_distance_computer(storage); + ScopeDeleter1 del(qdis); + + VisitedTable vt (ntotal); + +#pragma omp for + for(idx_t i = 0; i < n; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + + qdis->set_query(x + i * d); + maxheap_heapify (k, simi, idxi); + + if (search_type == 1) { + + int nres = 0; + + for(int j = 0; j < nprobe; j++) { + storage_idx_t cj = nearest[i * nprobe + j]; + + if (cj < 0) break; + + if (vt.get(cj)) continue; + + int candidates_size = std::max(hnsw.efSearch, int(k)); + MinimaxHeap candidates(candidates_size); + + candidates.push(cj, nearest_d[i * nprobe + j]); + + nres = hnsw.search_from_candidates( + *qdis, k, idxi, simi, + candidates, vt, 0, nres + ); + } + } else if (search_type == 2) { + + int candidates_size = std::max(hnsw.efSearch, int(k)); + candidates_size = std::max(candidates_size, nprobe); + + MinimaxHeap candidates(candidates_size); + for(int j = 0; j < nprobe; j++) { + storage_idx_t cj = nearest[i * nprobe + j]; + + if (cj < 0) break; + candidates.push(cj, nearest_d[i * nprobe + j]); + } + hnsw.search_from_candidates( + *qdis, k, idxi, simi, + candidates, vt, 0 + ); + + } + vt.advance(); + + maxheap_reorder (k, simi, idxi); + + } + } + + +} + +void IndexHNSW::init_level_0_from_knngraph( + int k, const float *D, const idx_t *I) +{ + int dest_size = hnsw.nb_neighbors (0); + +#pragma omp parallel for + for (idx_t i = 0; i < ntotal; i++) { + DistanceComputer *qdis = storage_distance_computer(storage); + float vec[d]; + storage->reconstruct(i, vec); + qdis->set_query(vec); + + std::priority_queue initial_list; + + for (size_t j = 0; j < k; j++) { + int v1 = I[i * k + j]; + if (v1 == i) continue; + if (v1 < 0) break; + initial_list.emplace(D[i * k + j], v1); + } + + std::vector shrunk_list; + HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size); + + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + + for (size_t j = begin; j < end; j++) { + if (j - begin < shrunk_list.size()) + hnsw.neighbors[j] = shrunk_list[j - begin].id; + else + hnsw.neighbors[j] = -1; + } + } +} + + + +void IndexHNSW::init_level_0_from_entry_points( + int n, const storage_idx_t *points, + const storage_idx_t *nearests) +{ + + std::vector locks(ntotal); + for(int i = 0; i < ntotal; i++) + omp_init_lock(&locks[i]); + +#pragma omp parallel + { + VisitedTable vt (ntotal); + + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + float vec[storage->d]; + +#pragma omp for schedule(dynamic) + for (int i = 0; i < n; i++) { + storage_idx_t pt_id = points[i]; + storage_idx_t nearest = nearests[i]; + storage->reconstruct (pt_id, vec); + dis->set_query (vec); + + hnsw.add_links_starting_from(*dis, pt_id, + nearest, (*dis)(nearest), + 0, locks.data(), vt); + + if (verbose && i % 10000 == 0) { + printf(" %d / %d\r", i, n); + fflush(stdout); + } + } + } + if (verbose) { + printf("\n"); + } + + for(int i = 0; i < ntotal; i++) + omp_destroy_lock(&locks[i]); +} + +void IndexHNSW::reorder_links() +{ + int M = hnsw.nb_neighbors(0); + +#pragma omp parallel + { + std::vector distances (M); + std::vector order (M); + std::vector tmp (M); + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + +#pragma omp for + for(storage_idx_t i = 0; i < ntotal; i++) { + + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + + for (size_t j = begin; j < end; j++) { + storage_idx_t nj = hnsw.neighbors[j]; + if (nj < 0) { + end = j; + break; + } + distances[j - begin] = dis->symmetric_dis(i, nj); + tmp [j - begin] = nj; + } + + fvec_argsort (end - begin, distances.data(), order.data()); + for (size_t j = begin; j < end; j++) { + hnsw.neighbors[j] = tmp[order[j - begin]]; + } + } + + } +} + + +void IndexHNSW::link_singletons() +{ + printf("search for singletons\n"); + + std::vector seen(ntotal); + + for (size_t i = 0; i < ntotal; i++) { + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + storage_idx_t ni = hnsw.neighbors[j]; + if (ni >= 0) seen[ni] = true; + } + } + + int n_sing = 0, n_sing_l1 = 0; + std::vector singletons; + for (storage_idx_t i = 0; i < ntotal; i++) { + if (!seen[i]) { + singletons.push_back(i); + n_sing++; + if (hnsw.levels[i] > 1) + n_sing_l1++; + } + } + + printf(" Found %d / %ld singletons (%d appear in a level above)\n", + n_sing, ntotal, n_sing_l1); + + std::vectorrecons(singletons.size() * d); + for (int i = 0; i < singletons.size(); i++) { + + FAISS_ASSERT(!"not implemented"); + + } + + +} + + +/************************************************************** + * ReconstructFromNeighbors implementation + **************************************************************/ + + +ReconstructFromNeighbors::ReconstructFromNeighbors( + const IndexHNSW & index, size_t k, size_t nsq): + index(index), k(k), nsq(nsq) { + M = index.hnsw.nb_neighbors(0); + FAISS_ASSERT(k <= 256); + code_size = k == 1 ? 0 : nsq; + ntotal = 0; + d = index.d; + FAISS_ASSERT(d % nsq == 0); + dsub = d / nsq; + k_reorder = -1; +} + +void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp) const +{ + + + const HNSW & hnsw = index.hnsw; + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + + if (k == 1 || nsq == 1) { + const float * beta; + if (k == 1) { + beta = codebook.data(); + } else { + int idx = codes[i]; + beta = codebook.data() + idx * (M + 1); + } + + float w0 = beta[0]; // weight of image itself + index.storage->reconstruct(i, tmp); + + for (int l = 0; l < d; l++) + x[l] = w0 * tmp[l]; + + for (size_t j = begin; j < end; j++) { + + storage_idx_t ji = hnsw.neighbors[j]; + if (ji < 0) ji = i; + float w = beta[j - begin + 1]; + index.storage->reconstruct(ji, tmp); + for (int l = 0; l < d; l++) + x[l] += w * tmp[l]; + } + } else if (nsq == 2) { + int idx0 = codes[2 * i]; + int idx1 = codes[2 * i + 1]; + + const float *beta0 = codebook.data() + idx0 * (M + 1); + const float *beta1 = codebook.data() + (idx1 + k) * (M + 1); + + index.storage->reconstruct(i, tmp); + + float w0; + + w0 = beta0[0]; + for (int l = 0; l < dsub; l++) + x[l] = w0 * tmp[l]; + + w0 = beta1[0]; + for (int l = dsub; l < d; l++) + x[l] = w0 * tmp[l]; + + for (size_t j = begin; j < end; j++) { + storage_idx_t ji = hnsw.neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp); + float w; + w = beta0[j - begin + 1]; + for (int l = 0; l < dsub; l++) + x[l] += w * tmp[l]; + + w = beta1[j - begin + 1]; + for (int l = dsub; l < d; l++) + x[l] += w * tmp[l]; + } + } else { + const float *betas[nsq]; + { + const float *b = codebook.data(); + const uint8_t *c = &codes[i * code_size]; + for (int sq = 0; sq < nsq; sq++) { + betas[sq] = b + (*c++) * (M + 1); + b += (M + 1) * k; + } + } + + index.storage->reconstruct(i, tmp); + { + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] = w * tmp[l]; + } + d0 = d1; + } + } + + for (size_t j = begin; j < end; j++) { + storage_idx_t ji = hnsw.neighbors[j]; + if (ji < 0) ji = i; + + index.storage->reconstruct(ji, tmp); + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] += w * tmp[l]; + } + d0 = d1; + } + } + } +} + +void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0, + storage_idx_t ni, + float *x) const +{ +#pragma omp parallel + { + std::vector tmp(index.d); +#pragma omp for + for (storage_idx_t i = 0; i < ni; i++) { + reconstruct(n0 + i, x + i * index.d, tmp.data()); + } + } +} + +size_t ReconstructFromNeighbors::compute_distances( + size_t n, const idx_t *shortlist, + const float *query, float *distances) const +{ + std::vector tmp(2 * index.d); + size_t ncomp = 0; + for (int i = 0; i < n; i++) { + if (shortlist[i] < 0) break; + reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d); + distances[i] = fvec_L2sqr(query, tmp.data(), index.d); + ncomp++; + } + return ncomp; +} + +void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1) const +{ + const HNSW & hnsw = index.hnsw; + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + size_t d = index.d; + + index.storage->reconstruct(i, tmp1); + + for (size_t j = begin; j < end; j++) { + storage_idx_t ji = hnsw.neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d); + } + +} + + +/// called by add_codes +void ReconstructFromNeighbors::estimate_code( + const float *x, storage_idx_t i, uint8_t *code) const +{ + + // fill in tmp table with the neighbor values + float *tmp1 = new float[d * (M + 1) + (d * k)]; + float *tmp2 = tmp1 + d * (M + 1); + ScopeDeleter del(tmp1); + + // collect coordinates of base + get_neighbor_table (i, tmp1); + + for (size_t sq = 0; sq < nsq; sq++) { + int d0 = sq * dsub; + + { + FINTEGER ki = k, di = d, m1 = M + 1; + FINTEGER dsubi = dsub; + float zero = 0, one = 1; + + sgemm_ ("N", "N", &dsubi, &ki, &m1, &one, + tmp1 + d0, &di, + codebook.data() + sq * (m1 * k), &m1, + &zero, tmp2, &dsubi); + } + + float min = HUGE_VAL; + int argmin = -1; + for (size_t j = 0; j < k; j++) { + float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub); + if (dis < min) { + min = dis; + argmin = j; + } + } + code[sq] = argmin; + } + +} + +void ReconstructFromNeighbors::add_codes(size_t n, const float *x) +{ + if (k == 1) { // nothing to encode + ntotal += n; + return; + } + codes.resize(codes.size() + code_size * n); +#pragma omp parallel for + for (int i = 0; i < n; i++) { + estimate_code(x + i * index.d, ntotal + i, + codes.data() + (ntotal + i) * code_size); + } + ntotal += n; + FAISS_ASSERT (codes.size() == ntotal * code_size); +} + + +/************************************************************** + * IndexHNSWFlat implementation + **************************************************************/ + + +IndexHNSWFlat::IndexHNSWFlat() +{ + is_trained = true; +} + +IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric): + IndexHNSW(new IndexFlat(d, metric), M) +{ + own_fields = true; + is_trained = true; +} + + +/************************************************************** + * IndexHNSWPQ implementation + **************************************************************/ + + +IndexHNSWPQ::IndexHNSWPQ() {} + +IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M): + IndexHNSW(new IndexPQ(d, pq_m, 8), M) +{ + own_fields = true; + is_trained = false; +} + +void IndexHNSWPQ::train(idx_t n, const float* x) +{ + IndexHNSW::train (n, x); + (dynamic_cast (storage))->pq.compute_sdc_table(); +} + + +/************************************************************** + * IndexHNSWSQ implementation + **************************************************************/ + + +IndexHNSWSQ::IndexHNSWSQ(int d, QuantizerType qtype, int M, + MetricType metric): + IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M) +{ + is_trained = false; + own_fields = true; +} + +IndexHNSWSQ::IndexHNSWSQ() {} + + +/************************************************************** + * IndexHNSW2Level implementation + **************************************************************/ + + +IndexHNSW2Level::IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M): + IndexHNSW (new Index2Layer (quantizer, nlist, m_pq), M) +{ + own_fields = true; + is_trained = false; +} + +IndexHNSW2Level::IndexHNSW2Level() {} + + +namespace { + + +// same as search_from_candidates but uses v +// visno -> is in result list +// visno + 1 -> in result list + in candidates +int search_from_candidates_2(const HNSW & hnsw, + DistanceComputer & qdis, int k, + idx_t *I, float * D, + MinimaxHeap &candidates, + VisitedTable &vt, + int level, int nres_in = 0) +{ + int nres = nres_in; + int ndis = 0; + for (int i = 0; i < candidates.size(); i++) { + idx_t v1 = candidates.ids[i]; + FAISS_ASSERT(v1 >= 0); + vt.visited[v1] = vt.visno + 1; + } + + int nstep = 0; + + while (candidates.size() > 0) { + float d0 = 0; + int v0 = candidates.pop_min(&d0); + + size_t begin, end; + hnsw.neighbor_range(v0, level, &begin, &end); + + for (size_t j = begin; j < end; j++) { + int v1 = hnsw.neighbors[j]; + if (v1 < 0) break; + if (vt.visited[v1] == vt.visno + 1) { + // nothing to do + } else { + ndis++; + float d = qdis(v1); + candidates.push(v1, d); + + // never seen before --> add to heap + if (vt.visited[v1] < vt.visno) { + if (nres < k) { + faiss::maxheap_push (++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_pop (nres--, D, I); + faiss::maxheap_push (++nres, D, I, d, v1); + } + } + vt.visited[v1] = vt.visno + 1; + } + } + + nstep++; + if (nstep > hnsw.efSearch) { + break; + } + } + + if (level == 0) { +#pragma omp critical + { + hnsw_stats.n1 ++; + if (candidates.size() == 0) + hnsw_stats.n2 ++; + } + } + + + return nres; +} + + +} // namespace + +void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const +{ + if (dynamic_cast(storage)) { + IndexHNSW::search (n, x, k, distances, labels); + + } else { // "mixed" search + + const IndexIVFPQ *index_ivfpq = + dynamic_cast(storage); + + int nprobe = index_ivfpq->nprobe; + + std::unique_ptr coarse_assign(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(), + coarse_assign.get()); + + index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(), + coarse_dis.get(), distances, labels, + false); + +#pragma omp parallel + { + VisitedTable vt (ntotal); + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + + int candidates_size = hnsw.upper_beam; + MinimaxHeap candidates(candidates_size); + +#pragma omp for + for(idx_t i = 0; i < n; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + // mark all inverted list elements as visited + + for (int j = 0; j < nprobe; j++) { + idx_t key = coarse_assign[j + i * nprobe]; + if (key < 0) break; + size_t list_length = index_ivfpq->get_list_size (key); + const idx_t * ids = index_ivfpq->invlists->get_ids (key); + + for (int jj = 0; jj < list_length; jj++) { + vt.set (ids[jj]); + } + } + + candidates.clear(); + // copy the upper_beam elements to candidates list + + int search_policy = 2; + + if (search_policy == 1) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + // search_from_candidates adds them back + idxi[j] = -1; + simi[j] = HUGE_VAL; + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + hnsw.search_from_candidates( + *dis, k, idxi, simi, + candidates, vt, 0, k + ); + + vt.advance(); + + } else if (search_policy == 2) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + search_from_candidates_2 ( + hnsw, *dis, k, idxi, simi, + candidates, vt, 0, k); + vt.advance (); + vt.advance (); + + } + + maxheap_reorder (k, simi, idxi); + } + } + } + + +} + + +void IndexHNSW2Level::flip_to_ivf () +{ + Index2Layer *storage2l = + dynamic_cast(storage); + + FAISS_THROW_IF_NOT (storage2l); + + IndexIVFPQ * index_ivfpq = + new IndexIVFPQ (storage2l->q1.quantizer, + d, storage2l->q1.nlist, + storage2l->pq.M, 8); + index_ivfpq->pq = storage2l->pq; + index_ivfpq->is_trained = storage2l->is_trained; + index_ivfpq->precompute_table(); + index_ivfpq->own_fields = storage2l->q1.own_fields; + storage2l->transfer_to_IVFPQ(*index_ivfpq); + index_ivfpq->make_direct_map (true); + + storage = index_ivfpq; + delete storage2l; + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexHNSW.h b/core/src/index/thirdparty/faiss/IndexHNSW.h new file mode 100644 index 0000000000..a8cb10512f --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexHNSW.h @@ -0,0 +1,171 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include +#include +#include +#include +#include + + +namespace faiss { + +struct IndexHNSW; + +struct ReconstructFromNeighbors { + typedef Index::idx_t idx_t; + typedef HNSW::storage_idx_t storage_idx_t; + + const IndexHNSW & index; + size_t M; // number of neighbors + size_t k; // number of codebook entries + size_t nsq; // number of subvectors + size_t code_size; + int k_reorder; // nb to reorder. -1 = all + + std::vector codebook; // size nsq * k * (M + 1) + + std::vector codes; // size ntotal * code_size + size_t ntotal; + size_t d, dsub; // derived values + + explicit ReconstructFromNeighbors(const IndexHNSW& index, + size_t k=256, size_t nsq=1); + + /// codes must be added in the correct order and the IndexHNSW + /// must be populated and sorted + void add_codes(size_t n, const float *x); + + size_t compute_distances(size_t n, const idx_t *shortlist, + const float *query, float *distances) const; + + /// called by add_codes + void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const; + + /// called by compute_distances + void reconstruct(storage_idx_t i, float *x, float *tmp) const; + + void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const; + + /// get the M+1 -by-d table for neighbor coordinates for vector i + void get_neighbor_table(storage_idx_t i, float *out) const; + +}; + + +/** The HNSW index is a normal random-access index with a HNSW + * link structure built on top */ + +struct IndexHNSW : Index { + + typedef HNSW::storage_idx_t storage_idx_t; + + // the link strcuture + HNSW hnsw; + + // the sequential storage + bool own_fields; + Index *storage; + + ReconstructFromNeighbors *reconstruct_from_neighbors; + + explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2); + explicit IndexHNSW (Index *storage, int M = 32); + + ~IndexHNSW() override; + + void add(idx_t n, const float *x) override; + + /// Trains the storage if needed + void train(idx_t n, const float* x) override; + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, float* recons) const override; + + void reset () override; + + void shrink_level_0_neighbors(int size); + + /** Perform search only on level 0, given the starting points for + * each vertex. + * + * @param search_type 1:perform one search per nprobe, 2: enqueue + * all entry points + */ + void search_level_0(idx_t n, const float *x, idx_t k, + const storage_idx_t *nearest, const float *nearest_d, + float *distances, idx_t *labels, int nprobe = 1, + int search_type = 1) const; + + /// alternative graph building + void init_level_0_from_knngraph( + int k, const float *D, const idx_t *I); + + /// alternative graph building + void init_level_0_from_entry_points( + int npt, const storage_idx_t *points, + const storage_idx_t *nearests); + + // reorder links from nearest to farthest + void reorder_links(); + + void link_singletons(); +}; + + +/** Flat index topped with with a HNSW structure to access elements + * more efficiently. + */ + +struct IndexHNSWFlat : IndexHNSW { + IndexHNSWFlat(); + IndexHNSWFlat(int d, int M, MetricType metric = METRIC_L2); +}; + +/** PQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexHNSWPQ : IndexHNSW { + IndexHNSWPQ(); + IndexHNSWPQ(int d, int pq_m, int M); + void train(idx_t n, const float* x) override; +}; + +/** SQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexHNSWSQ : IndexHNSW { + IndexHNSWSQ(); + IndexHNSWSQ(int d, QuantizerType qtype, int M, MetricType metric = METRIC_L2); +}; + +/** 2-level code structure with fast random access + */ +struct IndexHNSW2Level : IndexHNSW { + IndexHNSW2Level(); + IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M); + + void flip_to_ivf(); + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVF.cpp b/core/src/index/thirdparty/faiss/IndexIVF.cpp new file mode 100644 index 0000000000..145a7e33a5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVF.cpp @@ -0,0 +1,1270 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace faiss { + +using ScopedIds = InvertedLists::ScopedIds; +using ScopedCodes = InvertedLists::ScopedCodes; + +/***************************************** + * Level1Quantizer implementation + ******************************************/ + + +Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist): + quantizer (quantizer), + nlist (nlist), + quantizer_trains_alone (0), + own_fields (false), + clustering_index (nullptr) +{ + // here we set a low # iterations because this is typically used + // for large clusterings (nb this is not used for the MultiIndex, + // for which quantizer_trains_alone = true) + cp.niter = 10; +} + +Level1Quantizer::Level1Quantizer (): + quantizer (nullptr), + nlist (0), + quantizer_trains_alone (0), own_fields (false), + clustering_index (nullptr) +{} + +Level1Quantizer::~Level1Quantizer () +{ + if (own_fields) { + if(quantizer == quantizer_backup) { + if(quantizer != nullptr) { + delete quantizer; + } + } else { + if(quantizer != nullptr) { + delete quantizer; + } + + if(quantizer_backup != nullptr) { + delete quantizer_backup; + } + } + quantizer = nullptr; + quantizer_backup = nullptr; + } +} + +void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type) +{ + size_t d = quantizer->d; + if (quantizer->is_trained && (quantizer->ntotal == nlist)) { + if (verbose) + printf ("IVF quantizer does not need training.\n"); + } else if (quantizer_trains_alone == 1) { + if (verbose) + printf ("IVF quantizer trains alone...\n"); + quantizer->train (n, x); + quantizer->verbose = verbose; + FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist, + "nlist not consistent with quantizer size"); + } else if (quantizer_trains_alone == 0) { + if (verbose) + printf ("Training level-1 quantizer on %ld vectors in %ldD\n", + n, d); + + Clustering clus (d, nlist, cp); + quantizer->reset(); + if (clustering_index) { + clus.train (n, x, *clustering_index); + quantizer->add (nlist, clus.centroids.data()); + } else { + clus.train (n, x, *quantizer); + } + quantizer->is_trained = true; + } else if (quantizer_trains_alone == 2) { + if (verbose) + printf ( + "Training L2 quantizer on %ld vectors in %ldD%s\n", + n, d, + clustering_index ? "(user provided index)" : ""); + FAISS_THROW_IF_NOT (metric_type == METRIC_L2); + Clustering clus (d, nlist, cp); + if (!clustering_index) { + IndexFlatL2 assigner (d); + clus.train(n, x, assigner); + } else { + clus.train(n, x, *clustering_index); + } + if (verbose) + printf ("Adding centroids to quantizer\n"); + quantizer->add (nlist, clus.centroids.data()); + } +} + +size_t Level1Quantizer::coarse_code_size () const +{ + size_t nl = nlist - 1; + size_t nbyte = 0; + while (nl > 0) { + nbyte ++; + nl >>= 8; + } + return nbyte; +} + +void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const +{ + // little endian + size_t nl = nlist - 1; + while (nl > 0) { + *code++ = list_no & 0xff; + list_no >>= 8; + nl >>= 8; + } +} + +Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const +{ + size_t nl = nlist - 1; + int64_t list_no = 0; + int nbit = 0; + while (nl > 0) { + list_no |= int64_t(*code++) << nbit; + nbit += 8; + nl >>= 8; + } + FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist); + return list_no; +} + + + +/***************************************** + * IndexIVF implementation + ******************************************/ + + +IndexIVF::IndexIVF (Index * quantizer, size_t d, + size_t nlist, size_t code_size, + MetricType metric): + Index (d, metric), + Level1Quantizer (quantizer, nlist), + invlists (new ArrayInvertedLists (nlist, code_size)), + own_invlists (true), + code_size (code_size), + nprobe (1), + max_codes (0), + parallel_mode (0) +{ + FAISS_THROW_IF_NOT (d == quantizer->d); + is_trained = quantizer->is_trained && (quantizer->ntotal == nlist); + // Spherical by default if the metric is inner_product + if (metric_type == METRIC_INNER_PRODUCT) { + cp.spherical = true; + } + +} + +IndexIVF::IndexIVF (): + invlists (nullptr), own_invlists (false), + code_size (0), + nprobe (1), max_codes (0), parallel_mode (0) +{} + +void IndexIVF::add (idx_t n, const float * x) +{ + add_with_ids (n, x, nullptr); +} + +void IndexIVF::add_without_codes (idx_t n, const float * x) +{ + add_with_ids_without_codes (n, x, nullptr); +} + +void IndexIVF::add_with_ids_without_codes (idx_t n, const float * x, const idx_t *xids) +{ + // will be overriden + FAISS_THROW_MSG ("add_with_ids_without_codes not implemented for this type of index"); +} + +void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids) +{ + // do some blocking to avoid excessive allocs + idx_t bs = 65536; + if (n > bs) { + for (idx_t i0 = 0; i0 < n; i0 += bs) { + idx_t i1 = std::min (n, i0 + bs); + if (verbose) { + printf(" IndexIVF::add_with_ids %ld:%ld\n", i0, i1); + } + add_with_ids (i1 - i0, x + i0 * d, + xids ? xids + i0 : nullptr); + } + return; + } + + FAISS_THROW_IF_NOT (is_trained); + direct_map.check_can_add (xids); + + std::unique_ptr idx(new idx_t[n]); + quantizer->assign (n, x, idx.get()); + size_t nadd = 0, nminus1 = 0; + + for (size_t i = 0; i < n; i++) { + if (idx[i] < 0) nminus1++; + } + + std::unique_ptr flat_codes(new uint8_t [n * code_size]); + encode_vectors (n, x, idx.get(), flat_codes.get()); + + DirectMapAdd dm_adder(direct_map, n, xids); + +#pragma omp parallel reduction(+: nadd) + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + idx_t list_no = idx [i]; + if (list_no >= 0 && list_no % nt == rank) { + idx_t id = xids ? xids[i] : ntotal + i; + size_t ofs = invlists->add_entry ( + list_no, id, + flat_codes.get() + i * code_size + ); + + dm_adder.add (i, list_no, ofs); + + nadd++; + } else if (rank == 0 && list_no == -1) { + dm_adder.add (i, -1, 0); + } + } + } + + + if (verbose) { + printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1); + } + + ntotal += n; +} + +void IndexIVF::to_readonly() { + if (is_readonly()) return; + auto readonly_lists = this->invlists->to_readonly(); + if (!readonly_lists) return; + this->replace_invlists(readonly_lists, true); +} + +void IndexIVF::to_readonly_without_codes() { + if (is_readonly()) return; + auto readonly_lists = this->invlists->to_readonly_without_codes(); + if (!readonly_lists) return; + this->replace_invlists(readonly_lists, true); +} + +bool IndexIVF::is_readonly() const { + return this->invlists->is_readonly(); +} + +void IndexIVF::backup_quantizer() { + this->quantizer_backup = quantizer; +} + +void IndexIVF::restore_quantizer() { + if(this->quantizer_backup != nullptr) { + quantizer = this->quantizer_backup; + } +} + +void IndexIVF::make_direct_map (bool b) +{ + if (b) { + direct_map.set_type (DirectMap::Array, invlists, ntotal); + } else { + direct_map.set_type (DirectMap::NoMap, invlists, ntotal); + } +} + +void IndexIVF::set_direct_map_type (DirectMap::Type type) +{ + direct_map.set_type (type, invlists, ntotal); +} + + +void IndexIVF::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + double t0 = getmillisecs(); + quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists (idx.get(), n * nprobe); + + search_preassigned (n, x, k, idx.get(), coarse_dis.get(), + distances, labels, false, nullptr, bitset); + indexIVF_stats.search_time += getmillisecs() - t0; + + // nprobe logging + if (LOG_DEBUG_) { + auto ids = idx.get(); + for (size_t i = 0; i < n; i++) { + std::stringstream ss; + ss << "Query #" << i << ", nprobe list: "; + for (size_t j = 0; j < nprobe; j++) { + if (j != 0) { + ss << ","; + } + ss << ids[i * nprobe + j]; + } + (*LOG_DEBUG_)(ss.str()); + } + } +} + +void IndexIVF::search_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, std::vector prefix_sum, + bool is_sq8, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) +{ + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + double t0 = getmillisecs(); + quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists (idx.get(), n * nprobe); + + search_preassigned_without_codes (n, x, arranged_codes, prefix_sum, is_sq8, k, idx.get(), coarse_dis.get(), + distances, labels, false, nullptr, bitset); + indexIVF_stats.search_time += getmillisecs() - t0; + + // nprobe loggingss + if (LOG_DEBUG_) { + auto ids = idx.get(); + for (size_t i = 0; i < n; i++) { + std::stringstream ss; + ss << "Query #" << i << ", nprobe list: "; + for (size_t j = 0; j < nprobe; j++) { + if (j != 0) { + ss << ","; + } + ss << ids[i * nprobe + j]; + } + (*LOG_DEBUG_)(ss.str()); + } + } +} + +#if 0 +void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { + make_direct_map(true); + + /* only get vector by 1 id */ + FAISS_ASSERT(n == 1); + if (!bitset || !bitset->test(xid[0])) { + reconstruct(xid[0], x + 0 * d); + } else { + memset(x, UINT8_MAX, d * sizeof(float)); + } +} + +void IndexIVF::search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) { + make_direct_map(true); + + auto x = new float[n * d]; + for (idx_t i = 0; i < n; ++i) { + reconstruct(xid[i], x + i * d); + } + + search(n, x, k, distances, labels, bitset); + delete []x; +} +#endif + +void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k, + const idx_t *keys, + const float *coarse_dis , + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset) const +{ + long nprobe = params ? params->nprobe : this->nprobe; + long max_codes = params ? params->max_codes : this->max_codes; + + size_t nlistv = 0, ndis = 0, nheap = 0; + + using HeapForIP = CMin; + using HeapForL2 = CMax; + + bool interrupt = false; + + int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT; + bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT); + + // don't start parallel section if single query + bool do_parallel = + pmode == 0 ? n > 1 : + pmode == 1 ? nprobe > 1 : + nprobe * n > 1; + +#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap) + { + InvertedListScanner *scanner = get_InvertedListScanner(store_pairs); + ScopeDeleter1 del(scanner); + + /***************************************************** + * Depending on parallel_mode, there are two possible ways + * to organize the search. Here we define local functions + * that are in common between the two + ******************************************************/ + + // intialize + reorder a result heap + + auto init_result = [&](float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_heapify (k, simi, idxi); + } else { + heap_heapify (k, simi, idxi); + } + }; + + auto reorder_result = [&] (float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_reorder (k, simi, idxi); + } else { + heap_reorder (k, simi, idxi); + } + }; + + // single list scan using the current scanner (with query + // set porperly) and storing results in simi and idxi + auto scan_one_list = [&] (idx_t key, float coarse_dis_i, + float *simi, idx_t *idxi, + ConcurrentBitsetPtr bitset) { + + if (key < 0) { + // not enough centroids for multiprobe + return (size_t)0; + } + FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist, + "Invalid key=%ld nlist=%ld\n", + key, nlist); + + size_t list_size = invlists->list_size(key); + + // don't waste time on empty lists + if (list_size == 0) { + return (size_t)0; + } + + scanner->set_list (key, coarse_dis_i); + + nlistv++; + + InvertedLists::ScopedCodes scodes (invlists, key); + + std::unique_ptr sids; + const Index::idx_t * ids = nullptr; + + if (!store_pairs) { + sids.reset (new InvertedLists::ScopedIds (invlists, key)); + ids = sids->get(); + } + + nheap += scanner->scan_codes (list_size, scodes.get(), + ids, simi, idxi, k, bitset); + + return list_size; + }; + + /**************************************************** + * Actual loops, depending on parallel_mode + ****************************************************/ + + if (pmode == 0) { + +#pragma omp for + for (size_t i = 0; i < n; i++) { + + if (interrupt) { + continue; + } + + // loop over queries + scanner->set_query (x + i * d); + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; + + init_result (simi, idxi); + + long nscan = 0; + + // loop over probes + for (size_t ik = 0; ik < nprobe; ik++) { + + nscan += scan_one_list ( + keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + simi, idxi, bitset + ); + + if (max_codes && nscan >= max_codes) { + break; + } + } + + ndis += nscan; + reorder_result (simi, idxi); + + if (InterruptCallback::is_interrupted ()) { + interrupt = true; + } + + } // parallel for + } else if (pmode == 1) { + std::vector local_idx (k); + std::vector local_dis (k); + + for (size_t i = 0; i < n; i++) { + scanner->set_query (x + i * d); + init_result (local_dis.data(), local_idx.data()); + +#pragma omp for schedule(dynamic) + for (size_t ik = 0; ik < nprobe; ik++) { + ndis += scan_one_list + (keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + local_dis.data(), local_idx.data(), bitset); + + // can't do the test on max_codes + } + // merge thread-local results + + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; +#pragma omp single + init_result (simi, idxi); + +#pragma omp barrier +#pragma omp critical + { + if (metric_type == METRIC_INNER_PRODUCT) { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } else { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } + } +#pragma omp barrier +#pragma omp single + reorder_result (simi, idxi); + } + } else { + FAISS_THROW_FMT ("parallel_mode %d not supported\n", + pmode); + } + } // parallel section + + if (interrupt) { + FAISS_THROW_MSG ("computation interrupted"); + } + + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nheap_updates += nheap; + +} + + +void IndexIVF::search_preassigned_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, + std::vector prefix_sum, + bool is_sq8, idx_t k, + const idx_t *keys, + const float *coarse_dis , + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset) +{ + long nprobe = params ? params->nprobe : this->nprobe; + long max_codes = params ? params->max_codes : this->max_codes; + + size_t nlistv = 0, ndis = 0, nheap = 0; + + using HeapForIP = CMin; + using HeapForL2 = CMax; + + bool interrupt = false; + + int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT; + bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT); + + // don't start parallel section if single query + bool do_parallel = + pmode == 0 ? n > 1 : + pmode == 1 ? nprobe > 1 : + nprobe * n > 1; + +#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap) + { + InvertedListScanner *scanner = get_InvertedListScanner(store_pairs); + ScopeDeleter1 del(scanner); + + /***************************************************** + * Depending on parallel_mode, there are two possible ways + * to organize the search. Here we define local functions + * that are in common between the two + ******************************************************/ + + // intialize + reorder a result heap + + auto init_result = [&](float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_heapify (k, simi, idxi); + } else { + heap_heapify (k, simi, idxi); + } + }; + + auto reorder_result = [&] (float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_reorder (k, simi, idxi); + } else { + heap_reorder (k, simi, idxi); + } + }; + + // single list scan using the current scanner (with query + // set porperly) and storing results in simi and idxi + auto scan_one_list = [&] (idx_t key, float coarse_dis_i, const uint8_t *arranged_codes, + float *simi, idx_t *idxi, ConcurrentBitsetPtr bitset) { + + if (key < 0) { + // not enough centroids for multiprobe + return (size_t)0; + } + FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist, + "Invalid key=%ld nlist=%ld\n", + key, nlist); + + size_t list_size = invlists->list_size(key); + size_t offset = prefix_sum[key]; + + // don't waste time on empty lists + if (list_size == 0) { + return (size_t)0; + } + + scanner->set_list (key, coarse_dis_i); + + nlistv++; + + InvertedLists::ScopedCodes scodes (invlists, key, arranged_codes); + + std::unique_ptr sids; + const Index::idx_t * ids = nullptr; + + if (!store_pairs) { + sids.reset (new InvertedLists::ScopedIds (invlists, key)); + ids = sids->get(); + } + + size_t size = is_sq8 ? sizeof(uint8_t) : sizeof(float); + nheap += scanner->scan_codes (list_size, (const uint8_t *) (scodes.get() + d * offset * size), + ids, simi, idxi, k, bitset); + + return list_size; + }; + + /**************************************************** + * Actual loops, depending on parallel_mode + ****************************************************/ + + if (pmode == 0) { + +#pragma omp for + for (size_t i = 0; i < n; i++) { + + if (interrupt) { + continue; + } + + // loop over queries + scanner->set_query (x + i * d); + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; + + init_result (simi, idxi); + + long nscan = 0; + + // loop over probes + for (size_t ik = 0; ik < nprobe; ik++) { + + nscan += scan_one_list ( + keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + arranged_codes, + simi, idxi, bitset + ); + + if (max_codes && nscan >= max_codes) { + break; + } + } + + ndis += nscan; + reorder_result (simi, idxi); + + if (InterruptCallback::is_interrupted ()) { + interrupt = true; + } + + } // parallel for + } else if (pmode == 1) { + std::vector local_idx (k); + std::vector local_dis (k); + + for (size_t i = 0; i < n; i++) { + scanner->set_query (x + i * d); + init_result (local_dis.data(), local_idx.data()); + +#pragma omp for schedule(dynamic) + for (size_t ik = 0; ik < nprobe; ik++) { + ndis += scan_one_list + (keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + arranged_codes, + local_dis.data(), local_idx.data(), bitset); + + // can't do the test on max_codes + } + // merge thread-local results + + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; +#pragma omp single + init_result (simi, idxi); + +#pragma omp barrier +#pragma omp critical + { + if (metric_type == METRIC_INNER_PRODUCT) { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } else { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } + } +#pragma omp barrier +#pragma omp single + reorder_result (simi, idxi); + } + } else { + FAISS_THROW_FMT ("parallel_mode %d not supported\n", + pmode); + } + } // parallel section + + if (interrupt) { + FAISS_THROW_MSG ("computation interrupted"); + } + + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nheap_updates += nheap; + +} + +void IndexIVF::range_search (idx_t nx, const float *x, float radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + std::unique_ptr keys (new idx_t[nx * nprobe]); + std::unique_ptr coarse_dis (new float[nx * nprobe]); + + double t0 = getmillisecs(); + quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists (keys.get(), nx * nprobe); + + range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (), + result, bitset); + + indexIVF_stats.search_time += getmillisecs() - t0; +} + +void IndexIVF::range_search_preassigned ( + idx_t nx, const float *x, float radius, + const idx_t *keys, const float *coarse_dis, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + + size_t nlistv = 0, ndis = 0; + bool store_pairs = false; + + std::vector all_pres (omp_get_max_threads()); + +#pragma omp parallel reduction(+: nlistv, ndis) + { + RangeSearchPartialResult pres(result); + std::unique_ptr scanner + (get_InvertedListScanner(store_pairs)); + FAISS_THROW_IF_NOT (scanner.get ()); + all_pres[omp_get_thread_num()] = &pres; + + // prepare the list scanning function + + auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) { + + idx_t key = keys[i * nprobe + ik]; /* select the list */ + if (key < 0) return; + FAISS_THROW_IF_NOT_FMT ( + key < (idx_t) nlist, + "Invalid key=%ld at ik=%ld nlist=%ld\n", + key, ik, nlist); + const size_t list_size = invlists->list_size(key); + + if (list_size == 0) return; + + InvertedLists::ScopedCodes scodes (invlists, key); + InvertedLists::ScopedIds ids (invlists, key); + + scanner->set_list (key, coarse_dis[i * nprobe + ik]); + nlistv++; + ndis += list_size; + scanner->scan_codes_range (list_size, scodes.get(), + ids.get(), radius, qres, bitset); + }; + + if (parallel_mode == 0) { + +#pragma omp for + for (size_t i = 0; i < nx; i++) { + scanner->set_query (x + i * d); + + RangeQueryResult & qres = pres.new_result (i); + + for (size_t ik = 0; ik < nprobe; ik++) { + scan_list_func (i, ik, qres); + } + + } + + } else if (parallel_mode == 1) { + + for (size_t i = 0; i < nx; i++) { + scanner->set_query (x + i * d); + + RangeQueryResult & qres = pres.new_result (i); + +#pragma omp for schedule(dynamic) + for (size_t ik = 0; ik < nprobe; ik++) { + scan_list_func (i, ik, qres); + } + } + } else if (parallel_mode == 2) { + std::vector all_qres (nx); + RangeQueryResult *qres = nullptr; + +#pragma omp for schedule(dynamic) + for (size_t iik = 0; iik < nx * nprobe; iik++) { + size_t i = iik / nprobe; + size_t ik = iik % nprobe; + if (qres == nullptr || qres->qno != i) { + FAISS_ASSERT (!qres || i > qres->qno); + qres = &pres.new_result (i); + scanner->set_query (x + i * d); + } + scan_list_func (i, ik, *qres); + } + } else { + FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode); + } + if (parallel_mode == 0) { + pres.finalize (); + } else { +#pragma omp barrier +#pragma omp single + RangeSearchPartialResult::merge (all_pres, false); +#pragma omp barrier + + } + } + indexIVF_stats.nq += nx; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; +} + + +InvertedListScanner *IndexIVF::get_InvertedListScanner ( + bool /*store_pairs*/) const +{ + return nullptr; +} + +void IndexIVF::reconstruct (idx_t key, float* recons) const +{ + idx_t lo = direct_map.get (key); + reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons); +} + + +void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const +{ + FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal)); + + for (idx_t list_no = 0; list_no < nlist; list_no++) { + size_t list_size = invlists->list_size (list_no); + ScopedIds idlist (invlists, list_no); + + for (idx_t offset = 0; offset < list_size; offset++) { + idx_t id = idlist[offset]; + if (!(id >= i0 && id < i0 + ni)) { + continue; + } + + float* reconstructed = recons + (id - i0) * d; + reconstruct_from_offset (list_no, offset, reconstructed); + } + } +} + + +/* standalone codec interface */ +size_t IndexIVF::sa_code_size () const +{ + size_t coarse_size = coarse_code_size(); + return code_size + coarse_size; +} + +void IndexIVF::sa_encode (idx_t n, const float *x, + uint8_t *bytes) const +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr idx (new int64_t [n]); + quantizer->assign (n, x, idx.get()); + encode_vectors (n, x, idx.get(), bytes, true); +} + + +void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const +{ + idx_t * idx = new idx_t [n * nprobe]; + ScopeDeleter del (idx); + float * coarse_dis = new float [n * nprobe]; + ScopeDeleter del2 (coarse_dis); + + quantizer->search (n, x, nprobe, coarse_dis, idx); + + invlists->prefetch_lists (idx, n * nprobe); + + // search_preassigned() with `store_pairs` enabled to obtain the list_no + // and offset into `codes` for reconstruction + search_preassigned (n, x, k, idx, coarse_dis, + distances, labels, true /* store_pairs */); + for (idx_t i = 0; i < n; ++i) { + for (idx_t j = 0; j < k; ++j) { + idx_t ij = i * k + j; + idx_t key = labels[ij]; + float* reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + int list_no = lo_listno (key); + int offset = lo_offset (key); + + // Update label to the actual id + labels[ij] = invlists->get_single_id (list_no, offset); + + reconstruct_from_offset (list_no, offset, reconstructed); + } + } + } +} + +void IndexIVF::reconstruct_from_offset( + int64_t /*list_no*/, + int64_t /*offset*/, + float* /*recons*/) const { + FAISS_THROW_MSG ("reconstruct_from_offset not implemented"); +} + +void IndexIVF::reset () +{ + direct_map.clear (); + invlists->reset (); + ntotal = 0; +} + + +size_t IndexIVF::remove_ids (const IDSelector & sel) +{ + size_t nremove = direct_map.remove_ids (sel, invlists); + ntotal -= nremove; + return nremove; +} + + +void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x) +{ + + if (direct_map.type == DirectMap::Hashtable) { + // just remove then add + IDSelectorArray sel(n, new_ids); + size_t nremove = remove_ids (sel); + FAISS_THROW_IF_NOT_MSG (nremove == n, + "did not find all entries to remove"); + add_with_ids (n, x, new_ids); + return; + } + + FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array); + // here it is more tricky because we don't want to introduce holes + // in continuous range of ids + + FAISS_THROW_IF_NOT (is_trained); + std::vector assign (n); + quantizer->assign (n, x, assign.data()); + + std::vector flat_codes (n * code_size); + encode_vectors (n, x, assign.data(), flat_codes.data()); + + direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data()); + +} + + + + +void IndexIVF::train (idx_t n, const float *x) +{ + if (verbose) + printf ("Training level-1 quantizer\n"); + + train_q1 (n, x, verbose, metric_type); + + if (verbose) + printf ("Training IVF residual\n"); + + train_residual (n, x); + is_trained = true; + +} + +void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) { + if (verbose) + printf("IndexIVF: no residual training\n"); + // does nothing by default +} + + +void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const +{ + // minimal sanity checks + FAISS_THROW_IF_NOT (other.d == d); + FAISS_THROW_IF_NOT (other.nlist == nlist); + FAISS_THROW_IF_NOT (other.code_size == code_size); + FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other), + "can only merge indexes of the same type"); + FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(), + "merge direct_map not implemented"); +} + + +void IndexIVF::merge_from (IndexIVF &other, idx_t add_id) +{ + check_compatible_for_merge (other); + + invlists->merge_from (other.invlists, add_id); + + ntotal += other.ntotal; + other.ntotal = 0; +} + + +void IndexIVF::replace_invlists (InvertedLists *il, bool own) +{ + if (own_invlists) { + delete invlists; + } + // FAISS_THROW_IF_NOT (ntotal == 0); + if (il) { + FAISS_THROW_IF_NOT (il->nlist == nlist && + il->code_size == code_size); + } + invlists = il; + own_invlists = own; +} + + +void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type, + idx_t a1, idx_t a2) const +{ + + FAISS_THROW_IF_NOT (nlist == other.nlist); + FAISS_THROW_IF_NOT (code_size == other.code_size); + FAISS_THROW_IF_NOT (other.direct_map.no()); + FAISS_THROW_IF_NOT_FMT ( + subset_type == 0 || subset_type == 1 || subset_type == 2, + "subset type %d not implemented", subset_type); + + size_t accu_n = 0; + size_t accu_a1 = 0; + size_t accu_a2 = 0; + + InvertedLists *oivf = other.invlists; + + for (idx_t list_no = 0; list_no < nlist; list_no++) { + size_t n = invlists->list_size (list_no); + ScopedIds ids_in (invlists, list_no); + + if (subset_type == 0) { + for (idx_t i = 0; i < n; i++) { + idx_t id = ids_in[i]; + if (a1 <= id && id < a2) { + oivf->add_entry (list_no, + invlists->get_single_id (list_no, i), + ScopedCodes (invlists, list_no, i).get()); + other.ntotal++; + } + } + } else if (subset_type == 1) { + for (idx_t i = 0; i < n; i++) { + idx_t id = ids_in[i]; + if (id % a1 == a2) { + oivf->add_entry (list_no, + invlists->get_single_id (list_no, i), + ScopedCodes (invlists, list_no, i).get()); + other.ntotal++; + } + } + } else if (subset_type == 2) { + // see what is allocated to a1 and to a2 + size_t next_accu_n = accu_n + n; + size_t next_accu_a1 = next_accu_n * a1 / ntotal; + size_t i1 = next_accu_a1 - accu_a1; + size_t next_accu_a2 = next_accu_n * a2 / ntotal; + size_t i2 = next_accu_a2 - accu_a2; + + for (idx_t i = i1; i < i2; i++) { + oivf->add_entry (list_no, + invlists->get_single_id (list_no, i), + ScopedCodes (invlists, list_no, i).get()); + } + + other.ntotal += i2 - i1; + accu_a1 = next_accu_a1; + accu_a2 = next_accu_a2; + } + accu_n += n; + } + FAISS_ASSERT(accu_n == ntotal); + +} + +void +IndexIVF::dump() { + for (auto i = 0; i < invlists->nlist; ++ i) { + auto numVecs = invlists->list_size(i); + auto ids = invlists->get_ids(i); + auto codes = invlists->get_codes(i); + int code_size = invlists->code_size; + + + std::cout << "Bucket ID: " << i << ", with code size: " << code_size << ", vectors number: " << numVecs << std::endl; + if(code_size == 8) { + // int8 types + for (auto j=0; j < numVecs; ++j) { + std::cout << *(ids+j) << ": " << std::endl; + for(int k = 0; k < this->d; ++ k) { + printf("%u ", (uint8_t)(codes[j * d + k])); + } + std::cout << std::endl; + } + } + std::cout << "Bucket End." << std::endl; + } +} + + +IndexIVF::~IndexIVF() +{ + if (own_invlists) { + delete invlists; + } +} + + +void IndexIVFStats::reset() +{ + memset ((void*)this, 0, sizeof (*this)); +} + + +IndexIVFStats indexIVF_stats; + +void InvertedListScanner::scan_codes_range (size_t , + const uint8_t *, + const idx_t *, + float , + RangeQueryResult &, + ConcurrentBitsetPtr) const +{ + FAISS_THROW_MSG ("scan_codes_range not implemented"); +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVF.h b/core/src/index/thirdparty/faiss/IndexIVF.h new file mode 100644 index 0000000000..a7d2af1f8a --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVF.h @@ -0,0 +1,421 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_IVF_H +#define FAISS_INDEX_IVF_H + + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace faiss { + + +/** Encapsulates a quantizer object for the IndexIVF + * + * The class isolates the fields that are independent of the storage + * of the lists (especially training) + */ +struct Level1Quantizer { + Index * quantizer = nullptr; ///< quantizer that maps vectors to inverted lists + Index * quantizer_backup = nullptr; ///< quantizer for backup + size_t nlist; ///< number of possible key values + + /** + * = 0: use the quantizer as index in a kmeans training + * = 1: just pass on the training set to the train() of the quantizer + * = 2: kmeans training on a flat index + add the centroids to the quantizer + */ + char quantizer_trains_alone; + bool own_fields; ///< whether object owns the quantizer + + ClusteringParameters cp; ///< to override default clustering params + Index *clustering_index; ///< to override index used during clustering + + /// Trains the quantizer and calls train_residual to train sub-quantizers + void train_q1 (size_t n, const float *x, bool verbose, + MetricType metric_type); + + + /// compute the number of bytes required to store list ids + size_t coarse_code_size () const; + void encode_listno (Index::idx_t list_no, uint8_t *code) const; + Index::idx_t decode_listno (const uint8_t *code) const; + + Level1Quantizer (Index * quantizer, size_t nlist); + + Level1Quantizer (); + + ~Level1Quantizer (); + +}; + + + +struct IVFSearchParameters { + size_t nprobe; ///< number of probes at query time + size_t max_codes; ///< max nb of codes to visit to do a query + virtual ~IVFSearchParameters () {} +}; + + + +struct InvertedListScanner; + +/** Index based on a inverted file (IVF) + * + * In the inverted file, the quantizer (an Index instance) provides a + * quantization index for each vector to be added. The quantization + * index maps to a list (aka inverted list or posting list), where the + * id of the vector is stored. + * + * The inverted list object is required only after trainng. If none is + * set externally, an ArrayInvertedLists is used automatically. + * + * At search time, the vector to be searched is also quantized, and + * only the list corresponding to the quantization index is + * searched. This speeds up the search by making it + * non-exhaustive. This can be relaxed using multi-probe search: a few + * (nprobe) quantization indices are selected and several inverted + * lists are visited. + * + * Sub-classes implement a post-filtering of the index that refines + * the distance estimation from the query to databse vectors. + */ +struct IndexIVF: Index, Level1Quantizer { + /// Acess to the actual data + InvertedLists *invlists; + bool own_invlists; + + size_t code_size; ///< code size per vector in bytes + + size_t nprobe; ///< number of probes at query time + size_t max_codes; ///< max nb of codes to visit to do a query + + /** Parallel mode determines how queries are parallelized with OpenMP + * + * 0 (default): parallelize over queries + * 1: parallelize over inverted lists + * 2: parallelize over both + * + * PARALLEL_MODE_NO_HEAP_INIT: binary or with the previous to + * prevent the heap to be initialized and finalized + */ + int parallel_mode; + const int PARALLEL_MODE_NO_HEAP_INIT = 1024; + + /** optional map that maps back ids to invlist entries. This + * enables reconstruct() */ + DirectMap direct_map; + + /** The Inverted file takes a quantizer (an Index) on input, + * which implements the function mapping a vector to a list + * identifier. The pointer is borrowed: the quantizer should not + * be deleted while the IndexIVF is in use. + */ + IndexIVF (Index * quantizer, size_t d, + size_t nlist, size_t code_size, + MetricType metric = METRIC_L2); + + void reset() override; + + /// Trains the quantizer and calls train_residual to train sub-quantizers + void train(idx_t n, const float* x) override; + + /// Calls add_with_ids with NULL ids + void add(idx_t n, const float* x) override; + + /// Calls add_with_ids_without_codes + void add_without_codes(idx_t n, const float* x) override; + + /// default implementation that calls encode_vectors + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + /// Implementation for adding without original vector data + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; + + /** Encodes a set of vectors as they would appear in the inverted lists + * + * @param list_nos inverted list ids as returned by the + * quantizer (size n). -1s are ignored. + * @param codes output codes, size n * code_size + * @param include_listno + * include the list ids in the code (in this case add + * ceil(log8(nlist)) to the code size) + */ + virtual void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listno = false) const = 0; + + /// Sub-classes that encode the residuals can train their encoders here + /// does nothing by default + virtual void train_residual (idx_t n, const float *x); + + /** search a set of vectors, that are pre-quantized by the IVF + * quantizer. Fill in the corresponding heaps with the query + * results. The default implementation uses InvertedListScanners + * to do the search. + * + * @param n nb of vectors to query + * @param x query vectors, size nx * d + * @param assign coarse quantization indices, size nx * nprobe + * @param centroid_dis + * distances to coarse centroids, size nx * nprobe + * @param distance + * output distances, size n * k + * @param labels output labels, size n * k + * @param store_pairs store inv list index + inv list offset + * instead in upper/lower 32 bit of result, + * instead of ids (used for reranking). + * @param params used to override the object's search parameters + */ + virtual void search_preassigned (idx_t n, const float *x, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params=nullptr, + ConcurrentBitsetPtr bitset = nullptr + ) const; + + /** Similar to search_preassigned, but does not store codes **/ + virtual void search_preassigned_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, + std::vector prefix_sum, + bool is_sq8, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params = nullptr, + ConcurrentBitsetPtr bitset = nullptr); + + /** assign the vectors, then call search_preassign */ + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /** Similar to search, but does not store codes **/ + void search_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, std::vector prefix_sum, + bool is_sq8, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr); + +#if 0 + /** get raw vectors by ids */ + void get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset = nullptr) override; + + void search_by_id (idx_t n, const idx_t *xid, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) override; +#endif + + void range_search (idx_t n, const float* x, float radius, + RangeSearchResult* result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void range_search_preassigned(idx_t nx, const float *x, float radius, + const idx_t *keys, const float *coarse_dis, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const; + + /// get a scanner for this index (store_pairs means ignore labels) + virtual InvertedListScanner *get_InvertedListScanner ( + bool store_pairs=false) const; + + /** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2 */ + void reconstruct (idx_t key, float* recons) const override; + + /** Update a subset of vectors. + * + * The index must have a direct_map + * + * @param nv nb of vectors to update + * @param idx vector indices to update, size nv + * @param v vectors of new values, size nv*d + */ + virtual void update_vectors (int nv, const idx_t *idx, const float *v); + + /** Reconstruct a subset of the indexed vectors. + * + * Overrides default implementation to bypass reconstruct() which requires + * direct_map to be maintained. + * + * @param i0 first vector to reconstruct + * @param ni nb of vectors to reconstruct + * @param recons output array of reconstructed vectors, size ni * d + */ + void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override; + + /** Similar to search, but also reconstructs the stored vectors (or an + * approximation in the case of lossy coding) for the search results. + * + * Overrides default implementation to avoid having to maintain direct_map + * and instead fetch the code offsets through the `store_pairs` flag in + * search_preassigned(). + * + * @param recons reconstructed vectors size (n, k, d) + */ + void search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const override; + + /** Reconstruct a vector given the location in terms of (inv list index + + * inv list offset) instead of the id. + * + * Useful for reconstructing when the direct_map is not maintained and + * the inv list offset is computed by search_preassigned() with + * `store_pairs` set. + */ + virtual void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const; + + + /// Dataset manipulation functions + + size_t remove_ids(const IDSelector& sel) override; + + /** check that the two indexes are compatible (ie, they are + * trained in the same way and have the same + * parameters). Otherwise throw. */ + void check_compatible_for_merge (const IndexIVF &other) const; + + /** moves the entries from another dataset to self. On output, + * other is empty. add_id is added to all moved ids (for + * sequential ids, this would be this->ntotal */ + virtual void merge_from (IndexIVF &other, idx_t add_id); + + /** copy a subset of the entries index to the other index + * + * if subset_type == 0: copies ids in [a1, a2) + * if subset_type == 1: copies ids if id % a1 == a2 + * if subset_type == 2: copies inverted lists such that a1 + * elements are left before and a2 elements are after + */ + virtual void copy_subset_to (IndexIVF & other, int subset_type, + idx_t a1, idx_t a2) const; + + virtual void to_readonly(); + virtual void to_readonly_without_codes(); + virtual bool is_readonly() const; + + virtual void backup_quantizer(); + + virtual void restore_quantizer(); + + ~IndexIVF() override; + + size_t get_list_size (size_t list_no) const + { return invlists->list_size(list_no); } + + /** intialize a direct map + * + * @param new_maintain_direct_map if true, create a direct map, + * else clear it + */ + void make_direct_map (bool new_maintain_direct_map=true); + + void set_direct_map_type (DirectMap::Type type); + + + /// replace the inverted lists, old one is deallocated if own_invlists + void replace_invlists (InvertedLists *il, bool own=false); + + /* The standalone codec interface (except sa_decode that is specific) */ + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void dump(); + + IndexIVF (); +}; + +struct RangeQueryResult; + +/** Object that handles a query. The inverted lists to scan are + * provided externally. The object has a lot of state, but + * distance_to_code and scan_codes can be called in multiple + * threads */ +struct InvertedListScanner { + + using idx_t = Index::idx_t; + + /// from now on we handle this query. + virtual void set_query (const float *query_vector) = 0; + + /// following codes come from this inverted list + virtual void set_list (idx_t list_no, float coarse_dis) = 0; + + /// compute a single query-to-code distance + virtual float distance_to_code (const uint8_t *code) const = 0; + + /** scan a set of codes, compute distances to current query and + * update heap of results if necessary. + * + * @param n number of codes to scan + * @param codes codes to scan (n * code_size) + * @param ids corresponding ids (ignored if store_pairs) + * @param distances heap distances (size k) + * @param labels heap labels (size k) + * @param k heap size + * @return number of heap updates performed + */ + virtual size_t scan_codes (size_t n, + const uint8_t *codes, + const idx_t *ids, + float *distances, idx_t *labels, + size_t k, + ConcurrentBitsetPtr bitset = nullptr) const = 0; + + /** scan a set of codes, compute distances to current query and + * update results if distances are below radius + * + * (default implementation fails) */ + virtual void scan_codes_range (size_t n, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult &result, + ConcurrentBitsetPtr bitset = nullptr) const; + + virtual ~InvertedListScanner () {} + +}; + + +struct IndexIVFStats { + size_t nq; // nb of queries run + size_t nlist; // nb of inverted lists scanned + size_t ndis; // nb of distancs computed + size_t nheap_updates; // nb of times the heap was updated + double quantization_time; // time spent quantizing vectors (in ms) + double search_time; // time spent searching lists (in ms) + + IndexIVFStats () {reset (); } + void reset (); +}; + +// global var that collects them all +extern IndexIVFStats indexIVF_stats; + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp new file mode 100644 index 0000000000..147263750f --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp @@ -0,0 +1,506 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + +#include +#include +#include +#include +#include + +namespace faiss { + + +/***************************************** + * IndexIVFFlat implementation + ******************************************/ + +IndexIVFFlat::IndexIVFFlat (Index * quantizer, + size_t d, size_t nlist, MetricType metric): + IndexIVF (quantizer, d, nlist, sizeof(float) * d, metric) +{ + code_size = sizeof(float) * d; +} + + +void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const idx_t *xids) +{ + add_core (n, x, xids, nullptr); +} + +// Add ids only, vectors not added to Index. +void IndexIVFFlat::add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) +{ + FAISS_THROW_IF_NOT (is_trained); + assert (invlists); + direct_map.check_can_add (xids); + const int64_t * idx; + ScopeDeleter del; + + int64_t * idx0 = new int64_t [n]; + del.set (idx0); + quantizer->assign (n, x, idx0); + idx = idx0; + + int64_t n_add = 0; + for (size_t i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + idx_t list_no = idx [i]; + size_t offset; + + if (list_no >= 0) { + const float *xi = x + i * d; + offset = invlists->add_entry_without_codes ( + list_no, id); + n_add++; + } else { + offset = 0; + } + direct_map.add_single_id (id, list_no, offset); + } + + ntotal += n; +} + +void IndexIVFFlat::add_core (idx_t n, const float * x, const int64_t *xids, + const int64_t *precomputed_idx) + +{ + FAISS_THROW_IF_NOT (is_trained); + assert (invlists); + direct_map.check_can_add (xids); + const int64_t * idx; + ScopeDeleter del; + + if (precomputed_idx) { + idx = precomputed_idx; + } else { + int64_t * idx0 = new int64_t [n]; + del.set (idx0); + quantizer->assign (n, x, idx0); + idx = idx0; + } + int64_t n_add = 0; + for (size_t i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + idx_t list_no = idx [i]; + size_t offset; + + if (list_no >= 0) { + const float *xi = x + i * d; + offset = invlists->add_entry ( + list_no, id, (const uint8_t*) xi); + n_add++; + } else { + offset = 0; + } + direct_map.add_single_id (id, list_no, offset); + } + + if (verbose) { + printf("IndexIVFFlat::add_core: added %ld / %ld vectors\n", + n_add, n); + } + ntotal += n; +} + +void IndexIVFFlat::encode_vectors(idx_t n, const float* x, + const idx_t * list_nos, + uint8_t * codes, + bool include_listnos) const +{ + if (!include_listnos) { + memcpy (codes, x, code_size * n); + } else { + size_t coarse_size = coarse_code_size (); + for (size_t i = 0; i < n; i++) { + int64_t list_no = list_nos [i]; + uint8_t *code = codes + i * (code_size + coarse_size); + const float *xi = x + i * d; + if (list_no >= 0) { + encode_listno (list_no, code); + memcpy (code + coarse_size, xi, code_size); + } else { + memset (code, 0, code_size + coarse_size); + } + + } + } +} + +void IndexIVFFlat::sa_decode (idx_t n, const uint8_t *bytes, + float *x) const +{ + size_t coarse_size = coarse_code_size (); + for (size_t i = 0; i < n; i++) { + const uint8_t *code = bytes + i * (code_size + coarse_size); + float *xi = x + i * d; + memcpy (xi, code + coarse_size, code_size); + } +} + + +namespace { + + +template +struct IVFFlatScanner: InvertedListScanner { + size_t d; + bool store_pairs; + + IVFFlatScanner(size_t d, bool store_pairs): + d(d), store_pairs(store_pairs) {} + + const float *xi; + void set_query (const float *query) override { + this->xi = query; + } + + idx_t list_no; + void set_list (idx_t list_no, float /* coarse_dis */) override { + this->list_no = list_no; + } + + float distance_to_code (const uint8_t *code) const override { + const float *yj = (float*)code; + float dis = metric == METRIC_INNER_PRODUCT ? + fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d); + return dis; + } + + size_t scan_codes (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset) const override + { + const float *list_vecs = (const float*)codes; + size_t nup = 0; + for (size_t j = 0; j < list_size; j++) { + if (!bitset || !bitset->test(ids[j])) { + const float * yj = list_vecs + d * j; + float dis = metric == METRIC_INNER_PRODUCT ? + fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d); + if (C::cmp (simi[0], dis)) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + heap_swap_top (k, simi, idxi, dis, id); + nup++; + } + } + } + return nup; + } + + void scan_codes_range (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult & res, + ConcurrentBitsetPtr bitset = nullptr) const override + { + const float *list_vecs = (const float*)codes; + for (size_t j = 0; j < list_size; j++) { + const float * yj = list_vecs + d * j; + float dis = metric == METRIC_INNER_PRODUCT ? + fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d); + if (C::cmp (radius, dis)) { + int64_t id = store_pairs ? lo_build (list_no, j) : ids[j]; + res.add (dis, id); + } + } + } + + +}; + + +} // anonymous namespace + + + +InvertedListScanner* IndexIVFFlat::get_InvertedListScanner + (bool store_pairs) const +{ + if (metric_type == METRIC_INNER_PRODUCT) { + return new IVFFlatScanner< + METRIC_INNER_PRODUCT, CMin > (d, store_pairs); + } else if (metric_type == METRIC_L2) { + return new IVFFlatScanner< + METRIC_L2, CMax >(d, store_pairs); + } else { + FAISS_THROW_MSG("metric type not supported"); + } + return nullptr; +} + + + + +void IndexIVFFlat::reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const +{ + memcpy (recons, invlists->get_single_code (list_no, offset), code_size); +} + +/***************************************** + * IndexIVFFlatDedup implementation + ******************************************/ + +IndexIVFFlatDedup::IndexIVFFlatDedup ( + Index * quantizer, size_t d, size_t nlist_, + MetricType metric_type): + IndexIVFFlat (quantizer, d, nlist_, metric_type) +{} + + +void IndexIVFFlatDedup::train(idx_t n, const float* x) +{ + std::unordered_map map; + float * x2 = new float [n * d]; + ScopeDeleter del (x2); + + int64_t n2 = 0; + for (int64_t i = 0; i < n; i++) { + uint64_t hash = hash_bytes((uint8_t *)(x + i * d), code_size); + if (map.count(hash) && + !memcmp (x2 + map[hash] * d, x + i * d, code_size)) { + // is duplicate, skip + } else { + map [hash] = n2; + memcpy (x2 + n2 * d, x + i * d, code_size); + n2 ++; + } + } + if (verbose) { + printf ("IndexIVFFlatDedup::train: train on %ld points after dedup " + "(was %ld points)\n", n2, n); + } + IndexIVFFlat::train (n2, x2); +} + + + +void IndexIVFFlatDedup::add_with_ids( + idx_t na, const float* x, const idx_t* xids) +{ + + FAISS_THROW_IF_NOT (is_trained); + assert (invlists); + FAISS_THROW_IF_NOT_MSG (direct_map.no(), + "IVFFlatDedup not implemented with direct_map"); + int64_t * idx = new int64_t [na]; + ScopeDeleter del (idx); + quantizer->assign (na, x, idx); + + int64_t n_add = 0, n_dup = 0; + // TODO make a omp loop with this + for (size_t i = 0; i < na; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + int64_t list_no = idx [i]; + + if (list_no < 0) { + continue; + } + const float *xi = x + i * d; + + // search if there is already an entry with that id + InvertedLists::ScopedCodes codes (invlists, list_no); + + int64_t n = invlists->list_size (list_no); + int64_t offset = -1; + for (int64_t o = 0; o < n; o++) { + if (!memcmp (codes.get() + o * code_size, + xi, code_size)) { + offset = o; + break; + } + } + + if (offset == -1) { // not found + invlists->add_entry (list_no, id, (const uint8_t*) xi); + } else { + // mark equivalence + idx_t id2 = invlists->get_single_id (list_no, offset); + std::pair pair (id2, id); + instances.insert (pair); + n_dup ++; + } + n_add++; + } + if (verbose) { + printf("IndexIVFFlat::add_with_ids: added %ld / %ld vectors" + " (out of which %ld are duplicates)\n", + n_add, na, n_dup); + } + ntotal += n_add; +} + +void IndexIVFFlatDedup::search_preassigned ( + idx_t n, const float *x, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT_MSG ( + !store_pairs, "store_pairs not supported in IVFDedup"); + + IndexIVFFlat::search_preassigned (n, x, k, assign, centroid_dis, + distances, labels, false, + params); + + std::vector labels2 (k); + std::vector dis2 (k); + + for (int64_t i = 0; i < n; i++) { + idx_t *labels1 = labels + i * k; + float *dis1 = distances + i * k; + int64_t j = 0; + for (; j < k; j++) { + if (instances.find (labels1[j]) != instances.end ()) { + // a duplicate: special handling + break; + } + } + if (j < k) { + // there are duplicates, special handling + int64_t j0 = j; + int64_t rp = j; + while (j < k) { + auto range = instances.equal_range (labels1[rp]); + float dis = dis1[rp]; + labels2[j] = labels1[rp]; + dis2[j] = dis; + j ++; + for (auto it = range.first; j < k && it != range.second; ++it) { + labels2[j] = it->second; + dis2[j] = dis; + j++; + } + rp++; + } + memcpy (labels1 + j0, labels2.data() + j0, + sizeof(labels1[0]) * (k - j0)); + memcpy (dis1 + j0, dis2.data() + j0, + sizeof(dis2[0]) * (k - j0)); + } + } + +} + + +size_t IndexIVFFlatDedup::remove_ids(const IDSelector& sel) +{ + std::unordered_map replace; + std::vector > toadd; + for (auto it = instances.begin(); it != instances.end(); ) { + if (sel.is_member(it->first)) { + // then we erase this entry + if (!sel.is_member(it->second)) { + // if the second is not erased + if (replace.count(it->first) == 0) { + replace[it->first] = it->second; + } else { // remember we should add an element + std::pair new_entry ( + replace[it->first], it->second); + toadd.push_back(new_entry); + } + } + it = instances.erase(it); + } else { + if (sel.is_member(it->second)) { + it = instances.erase(it); + } else { + ++it; + } + } + } + + instances.insert (toadd.begin(), toadd.end()); + + // mostly copied from IndexIVF.cpp + + FAISS_THROW_IF_NOT_MSG (direct_map.no(), + "direct map remove not implemented"); + + std::vector toremove(nlist); + +#pragma omp parallel for + for (int64_t i = 0; i < nlist; i++) { + int64_t l0 = invlists->list_size (i), l = l0, j = 0; + InvertedLists::ScopedIds idsi (invlists, i); + while (j < l) { + if (sel.is_member (idsi[j])) { + if (replace.count(idsi[j]) == 0) { + l--; + invlists->update_entry ( + i, j, + invlists->get_single_id (i, l), + InvertedLists::ScopedCodes (invlists, i, l).get()); + } else { + invlists->update_entry ( + i, j, + replace[idsi[j]], + InvertedLists::ScopedCodes (invlists, i, j).get()); + j++; + } + } else { + j++; + } + } + toremove[i] = l0 - l; + } + // this will not run well in parallel on ondisk because of possible shrinks + int64_t nremove = 0; + for (int64_t i = 0; i < nlist; i++) { + if (toremove[i] > 0) { + nremove += toremove[i]; + invlists->resize( + i, invlists->list_size(i) - toremove[i]); + } + } + ntotal -= nremove; + return nremove; +} + + +void IndexIVFFlatDedup::range_search( + idx_t , + const float* , + float , + RangeSearchResult* , + ConcurrentBitsetPtr) const +{ + FAISS_THROW_MSG ("not implemented"); +} + +void IndexIVFFlatDedup::update_vectors (int , const idx_t *, const float *) +{ + FAISS_THROW_MSG ("not implemented"); +} + + +void IndexIVFFlatDedup::reconstruct_from_offset ( + int64_t , int64_t , float* ) const +{ + FAISS_THROW_MSG ("not implemented"); +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVFFlat.h b/core/src/index/thirdparty/faiss/IndexIVFFlat.h new file mode 100644 index 0000000000..74b0b4c0ec --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFFlat.h @@ -0,0 +1,113 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_IVF_FLAT_H +#define FAISS_INDEX_IVF_FLAT_H + +#include +#include + +#include + + +namespace faiss { + +/** Inverted file with stored vectors. Here the inverted file + * pre-selects the vectors to be searched, but they are not otherwise + * encoded, the code array just contains the raw float entries. + */ +struct IndexIVFFlat: IndexIVF { + + IndexIVFFlat ( + Index * quantizer, size_t d, size_t nlist_, + MetricType = METRIC_L2); + + /// same as add_with_ids, with precomputed coarse quantizer + virtual void add_core (idx_t n, const float * x, const int64_t *xids, + const int64_t *precomputed_idx); + + /// implemented for all IndexIVF* classes + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + /// implemented for all IndexIVF* classes + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; + + void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos=false) const override; + + + InvertedListScanner *get_InvertedListScanner (bool store_pairs) + const override; + + + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + IndexIVFFlat () {} +}; + + +struct IndexIVFFlatDedup: IndexIVFFlat { + + /** Maps ids stored in the index to the ids of vectors that are + * the same. When a vector is unique, it does not appear in the + * instances map */ + std::unordered_multimap instances; + + IndexIVFFlatDedup ( + Index * quantizer, size_t d, size_t nlist_, + MetricType = METRIC_L2); + + /// also dedups the training set + void train(idx_t n, const float* x) override; + + /// implemented for all IndexIVF* classes + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + void search_preassigned (idx_t n, const float *x, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params=nullptr, + ConcurrentBitsetPtr bitset = nullptr + ) const override; + + size_t remove_ids(const IDSelector& sel) override; + + /// not implemented + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// not implemented + void update_vectors (int nv, const idx_t *idx, const float *v) override; + + /// not implemented + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + IndexIVFFlatDedup () {} + + +}; + + + +} // namespace faiss + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp b/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp new file mode 100644 index 0000000000..fb786cc375 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFPQ.cpp @@ -0,0 +1,1239 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include + +#include + +namespace faiss { + +/***************************************** + * IndexIVFPQ implementation + ******************************************/ + +IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist, + size_t M, size_t nbits_per_idx, MetricType metric): + IndexIVF (quantizer, d, nlist, 0, metric), + pq (d, M, nbits_per_idx) +{ + FAISS_THROW_IF_NOT (nbits_per_idx <= 8); + code_size = pq.code_size; + invlists->code_size = code_size; + is_trained = false; + by_residual = true; + use_precomputed_table = 0; + scan_table_threshold = 0; + + polysemous_training = nullptr; + do_polysemous_training = false; + polysemous_ht = 0; + +} + + +/**************************************************************** + * training */ + +void IndexIVFPQ::train_residual (idx_t n, const float *x) +{ + train_residual_o (n, x, nullptr); +} + + +void IndexIVFPQ::train_residual_o (idx_t n, const float *x, float *residuals_2) +{ + const float * x_in = x; + + x = fvecs_maybe_subsample ( + d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub, + x, verbose, pq.cp.seed); + + ScopeDeleter del_x (x_in == x ? nullptr : x); + + const float *trainset; + ScopeDeleter del_residuals; + if (by_residual) { + if(verbose) printf("computing residuals\n"); + idx_t * assign = new idx_t [n]; // assignement to coarse centroids + ScopeDeleter del (assign); + quantizer->assign (n, x, assign); + float *residuals = new float [n * d]; + del_residuals.set (residuals); + for (idx_t i = 0; i < n; i++) + quantizer->compute_residual (x + i * d, residuals+i*d, assign[i]); + + trainset = residuals; + } else { + trainset = x; + } + if (verbose) + printf ("training %zdx%zd product quantizer on %ld vectors in %dD\n", + pq.M, pq.ksub, n, d); + pq.verbose = verbose; + pq.train (n, trainset); + + if (do_polysemous_training) { + if (verbose) + printf("doing polysemous training for PQ\n"); + PolysemousTraining default_pt; + PolysemousTraining *pt = polysemous_training; + if (!pt) pt = &default_pt; + pt->optimize_pq_for_hamming (pq, n, trainset); + } + + // prepare second-level residuals for refine PQ + if (residuals_2) { + uint8_t *train_codes = new uint8_t [pq.code_size * n]; + ScopeDeleter del (train_codes); + pq.compute_codes (trainset, train_codes, n); + + for (idx_t i = 0; i < n; i++) { + const float *xx = trainset + i * d; + float * res = residuals_2 + i * d; + pq.decode (train_codes + i * pq.code_size, res); + for (int j = 0; j < d; j++) + res[j] = xx[j] - res[j]; + } + + } + + if (by_residual) { + precompute_table (); + } + +} + + + + + + +/**************************************************************** + * IVFPQ as codec */ + + +/* produce a binary signature based on the residual vector */ +void IndexIVFPQ::encode (idx_t key, const float * x, uint8_t * code) const +{ + if (by_residual) { + float residual_vec[d]; + quantizer->compute_residual (x, residual_vec, key); + pq.compute_code (residual_vec, code); + } + else pq.compute_code (x, code); +} + +void IndexIVFPQ::encode_multiple (size_t n, idx_t *keys, + const float * x, uint8_t * xcodes, + bool compute_keys) const +{ + if (compute_keys) + quantizer->assign (n, x, keys); + + encode_vectors (n, x, keys, xcodes); +} + +void IndexIVFPQ::decode_multiple (size_t n, const idx_t *keys, + const uint8_t * xcodes, float * x) const +{ + pq.decode (xcodes, x, n); + if (by_residual) { + std::vector centroid (d); + for (size_t i = 0; i < n; i++) { + quantizer->reconstruct (keys[i], centroid.data()); + float *xi = x + i * d; + for (size_t j = 0; j < d; j++) { + xi [j] += centroid [j]; + } + } + } +} + + + + +/**************************************************************** + * add */ + + +void IndexIVFPQ::add_with_ids (idx_t n, const float * x, const idx_t *xids) +{ + add_core_o (n, x, xids, nullptr); +} + + +static float * compute_residuals ( + const Index *quantizer, + Index::idx_t n, const float* x, + const Index::idx_t *list_nos) +{ + size_t d = quantizer->d; + float *residuals = new float [n * d]; + // TODO: parallelize? + for (size_t i = 0; i < n; i++) { + if (list_nos[i] < 0) + memset (residuals + i * d, 0, sizeof(*residuals) * d); + else + quantizer->compute_residual ( + x + i * d, residuals + i * d, list_nos[i]); + } + return residuals; +} + +void IndexIVFPQ::encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos) const +{ + if (by_residual) { + float *to_encode = compute_residuals (quantizer, n, x, list_nos); + ScopeDeleter del (to_encode); + pq.compute_codes (to_encode, codes, n); + } else { + pq.compute_codes (x, codes, n); + } + + if (include_listnos) { + size_t coarse_size = coarse_code_size(); + for (idx_t i = n - 1; i >= 0; i--) { + uint8_t * code = codes + i * (coarse_size + code_size); + memmove (code + coarse_size, + codes + i * code_size, code_size); + encode_listno (list_nos[i], code); + } + } +} + + + +void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes, + float *x) const +{ + size_t coarse_size = coarse_code_size (); + +#pragma omp parallel + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *code = codes + i * (code_size + coarse_size); + int64_t list_no = decode_listno (code); + float *xi = x + i * d; + pq.decode (code + coarse_size, xi); + if (by_residual) { + quantizer->reconstruct (list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + } +} + + +void IndexIVFPQ::add_core_o (idx_t n, const float * x, const idx_t *xids, + float *residuals_2, const idx_t *precomputed_idx) +{ + + idx_t bs = 32768; + if (n > bs) { + for (idx_t i0 = 0; i0 < n; i0 += bs) { + idx_t i1 = std::min(i0 + bs, n); + if (verbose) { + printf("IndexIVFPQ::add_core_o: adding %ld:%ld / %ld\n", + i0, i1, n); + } + add_core_o (i1 - i0, x + i0 * d, + xids ? xids + i0 : nullptr, + residuals_2 ? residuals_2 + i0 * d : nullptr, + precomputed_idx ? precomputed_idx + i0 : nullptr); + } + return; + } + + InterruptCallback::check(); + + direct_map.check_can_add (xids); + + FAISS_THROW_IF_NOT (is_trained); + double t0 = getmillisecs (); + const idx_t * idx; + ScopeDeleter del_idx; + + if (precomputed_idx) { + idx = precomputed_idx; + } else { + idx_t * idx0 = new idx_t [n]; + del_idx.set (idx0); + quantizer->assign (n, x, idx0); + idx = idx0; + } + + double t1 = getmillisecs (); + uint8_t * xcodes = new uint8_t [n * code_size]; + ScopeDeleter del_xcodes (xcodes); + + const float *to_encode = nullptr; + ScopeDeleter del_to_encode; + + if (by_residual) { + to_encode = compute_residuals (quantizer, n, x, idx); + del_to_encode.set (to_encode); + } else { + to_encode = x; + } + pq.compute_codes (to_encode, xcodes, n); + + double t2 = getmillisecs (); + // TODO: parallelize? + size_t n_ignore = 0; + for (size_t i = 0; i < n; i++) { + idx_t key = idx[i]; + idx_t id = xids ? xids[i] : ntotal + i; + if (key < 0) { + direct_map.add_single_id (id, -1, 0); + n_ignore ++; + if (residuals_2) + memset (residuals_2, 0, sizeof(*residuals_2) * d); + continue; + } + + uint8_t *code = xcodes + i * code_size; + size_t offset = invlists->add_entry (key, id, code); + + if (residuals_2) { + float *res2 = residuals_2 + i * d; + const float *xi = to_encode + i * d; + pq.decode (code, res2); + for (int j = 0; j < d; j++) + res2[j] = xi[j] - res2[j]; + } + + direct_map.add_single_id (id, key, offset); + } + + double t3 = getmillisecs (); + if(verbose) { + char comment[100] = {0}; + if (n_ignore > 0) + snprintf (comment, 100, "(%ld vectors ignored)", n_ignore); + printf(" add_core times: %.3f %.3f %.3f %s\n", + t1 - t0, t2 - t1, t3 - t2, comment); + } + ntotal += n; +} + + +void IndexIVFPQ::reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const +{ + const uint8_t* code = invlists->get_single_code (list_no, offset); + + if (by_residual) { + std::vector centroid(d); + quantizer->reconstruct (list_no, centroid.data()); + + pq.decode (code, recons); + for (int i = 0; i < d; ++i) { + recons[i] += centroid[i]; + } + } else { + pq.decode (code, recons); + } +} + + + +/// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids +size_t IndexIVFPQ::precomputed_table_max_bytes = ((size_t)1) << 31; + +/** Precomputed tables for residuals + * + * During IVFPQ search with by_residual, we compute + * + * d = || x - y_C - y_R ||^2 + * + * where x is the query vector, y_C the coarse centroid, y_R the + * refined PQ centroid. The expression can be decomposed as: + * + * d = || x - y_C ||^2 + || y_R ||^2 + 2 * (y_C|y_R) - 2 * (x|y_R) + * --------------- --------------------------- ------- + * term 1 term 2 term 3 + * + * When using multiprobe, we use the following decomposition: + * - term 1 is the distance to the coarse centroid, that is computed + * during the 1st stage search. + * - term 2 can be precomputed, as it does not involve x. However, + * because of the PQ, it needs nlist * M * ksub storage. This is why + * use_precomputed_table is off by default + * - term 3 is the classical non-residual distance table. + * + * Since y_R defined by a product quantizer, it is split across + * subvectors and stored separately for each subvector. If the coarse + * quantizer is a MultiIndexQuantizer then the table can be stored + * more compactly. + * + * At search time, the tables for term 2 and term 3 are added up. This + * is faster when the length of the lists is > ksub * M. + */ + +void IndexIVFPQ::precompute_table () +{ + if (use_precomputed_table == -1) + return; + + if (use_precomputed_table == 0) { // then choose the type of table + if (quantizer->metric_type == METRIC_INNER_PRODUCT) { + if (verbose) { + printf("IndexIVFPQ::precompute_table: precomputed " + "tables not needed for inner product quantizers\n"); + } + return; + } + const MultiIndexQuantizer *miq = + dynamic_cast (quantizer); + if (miq && pq.M % miq->pq.M == 0) + use_precomputed_table = 2; + else { + size_t table_size = pq.M * pq.ksub * nlist * sizeof(float); + if (table_size > precomputed_table_max_bytes) { + if (verbose) { + printf( + "IndexIVFPQ::precompute_table: not precomputing table, " + "it would be too big: %ld bytes (max %ld)\n", + table_size, precomputed_table_max_bytes); + use_precomputed_table = 0; + } + return; + } + use_precomputed_table = 1; + } + } // otherwise assume user has set appropriate flag on input + + if (verbose) { + printf ("precomputing IVFPQ tables type %d\n", + use_precomputed_table); + } + + // squared norms of the PQ centroids + std::vector r_norms (pq.M * pq.ksub, NAN); + for (int m = 0; m < pq.M; m++) + for (int j = 0; j < pq.ksub; j++) + r_norms [m * pq.ksub + j] = + fvec_norm_L2sqr (pq.get_centroids (m, j), pq.dsub); + + if (use_precomputed_table == 1) { + + precomputed_table.resize (nlist * pq.M * pq.ksub); + std::vector centroid (d); + + for (size_t i = 0; i < nlist; i++) { + quantizer->reconstruct (i, centroid.data()); + + float *tab = &precomputed_table[i * pq.M * pq.ksub]; + pq.compute_inner_prod_table (centroid.data(), tab); + fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab); + } + } else if (use_precomputed_table == 2) { + const MultiIndexQuantizer *miq = + dynamic_cast (quantizer); + FAISS_THROW_IF_NOT (miq); + const ProductQuantizer &cpq = miq->pq; + FAISS_THROW_IF_NOT (pq.M % cpq.M == 0); + + precomputed_table.resize(cpq.ksub * pq.M * pq.ksub); + + // reorder PQ centroid table + std::vector centroids (d * cpq.ksub, NAN); + + for (int m = 0; m < cpq.M; m++) { + for (size_t i = 0; i < cpq.ksub; i++) { + memcpy (centroids.data() + i * d + m * cpq.dsub, + cpq.get_centroids (m, i), + sizeof (*centroids.data()) * cpq.dsub); + } + } + + pq.compute_inner_prod_tables (cpq.ksub, centroids.data (), + precomputed_table.data ()); + + for (size_t i = 0; i < cpq.ksub; i++) { + float *tab = &precomputed_table[i * pq.M * pq.ksub]; + fvec_madd (pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab); + } + + } + +} + +namespace { + +using idx_t = Index::idx_t; + + +#define TIC t0 = get_cycles() +#define TOC get_cycles () - t0 + + + +/** QueryTables manages the various ways of searching an + * IndexIVFPQ. The code contains a lot of branches, depending on: + * - metric_type: are we computing L2 or Inner product similarity? + * - by_residual: do we encode raw vectors or residuals? + * - use_precomputed_table: are x_R|x_C tables precomputed? + * - polysemous_ht: are we filtering with polysemous codes? + */ +struct QueryTables { + + /***************************************************** + * General data from the IVFPQ + *****************************************************/ + + const IndexIVFPQ & ivfpq; + const IVFSearchParameters *params; + + // copied from IndexIVFPQ for easier access + int d; + const ProductQuantizer & pq; + MetricType metric_type; + bool by_residual; + int use_precomputed_table; + int polysemous_ht; + + // pre-allocated data buffers + float * sim_table, * sim_table_2; + float * residual_vec, *decoded_vec; + + // single data buffer + std::vector mem; + + // for table pointers + std::vector sim_table_ptrs; + + explicit QueryTables (const IndexIVFPQ & ivfpq, + const IVFSearchParameters *params): + ivfpq(ivfpq), + d(ivfpq.d), + pq (ivfpq.pq), + metric_type (ivfpq.metric_type), + by_residual (ivfpq.by_residual), + use_precomputed_table (ivfpq.use_precomputed_table) + { + mem.resize (pq.ksub * pq.M * 2 + d * 2); + sim_table = mem.data (); + sim_table_2 = sim_table + pq.ksub * pq.M; + residual_vec = sim_table_2 + pq.ksub * pq.M; + decoded_vec = residual_vec + d; + + // for polysemous + polysemous_ht = ivfpq.polysemous_ht; + if (auto ivfpq_params = + dynamic_cast(params)) { + polysemous_ht = ivfpq_params->polysemous_ht; + } + if (polysemous_ht != 0) { + q_code.resize (pq.code_size); + } + init_list_cycles = 0; + sim_table_ptrs.resize (pq.M); + } + + /***************************************************** + * What we do when query is known + *****************************************************/ + + // field specific to query + const float * qi; + + // query-specific intialization + void init_query (const float * qi) { + this->qi = qi; + if (metric_type == METRIC_INNER_PRODUCT) + init_query_IP (); + else + init_query_L2 (); + if (!by_residual && polysemous_ht != 0) + pq.compute_code (qi, q_code.data()); + } + + void init_query_IP () { + // precompute some tables specific to the query qi + pq.compute_inner_prod_table (qi, sim_table); + } + + void init_query_L2 () { + if (!by_residual) { + pq.compute_distance_table (qi, sim_table); + } else if (use_precomputed_table) { + pq.compute_inner_prod_table (qi, sim_table_2); + } + } + + /***************************************************** + * When inverted list is known: prepare computations + *****************************************************/ + + // fields specific to list + Index::idx_t key; + float coarse_dis; + std::vector q_code; + + uint64_t init_list_cycles; + + /// once we know the query and the centroid, we can prepare the + /// sim_table that will be used for accumulation + /// and dis0, the initial value + float precompute_list_tables () { + float dis0 = 0; + uint64_t t0; TIC; + if (by_residual) { + if (metric_type == METRIC_INNER_PRODUCT) + dis0 = precompute_list_tables_IP (); + else + dis0 = precompute_list_tables_L2 (); + } + init_list_cycles += TOC; + return dis0; + } + + float precompute_list_table_pointers () { + float dis0 = 0; + uint64_t t0; TIC; + if (by_residual) { + if (metric_type == METRIC_INNER_PRODUCT) + FAISS_THROW_MSG ("not implemented"); + else + dis0 = precompute_list_table_pointers_L2 (); + } + init_list_cycles += TOC; + return dis0; + } + + /***************************************************** + * compute tables for inner prod + *****************************************************/ + + float precompute_list_tables_IP () + { + // prepare the sim_table that will be used for accumulation + // and dis0, the initial value + ivfpq.quantizer->reconstruct (key, decoded_vec); + // decoded_vec = centroid + float dis0 = fvec_inner_product (qi, decoded_vec, d); + + if (polysemous_ht) { + for (int i = 0; i < d; i++) { + residual_vec [i] = qi[i] - decoded_vec[i]; + } + pq.compute_code (residual_vec, q_code.data()); + } + return dis0; + } + + + /***************************************************** + * compute tables for L2 distance + *****************************************************/ + + float precompute_list_tables_L2 () + { + float dis0 = 0; + + if (use_precomputed_table == 0 || use_precomputed_table == -1) { + ivfpq.quantizer->compute_residual (qi, residual_vec, key); + pq.compute_distance_table (residual_vec, sim_table); + + if (polysemous_ht != 0) { + pq.compute_code (residual_vec, q_code.data()); + } + + } else if (use_precomputed_table == 1) { + dis0 = coarse_dis; + + fvec_madd (pq.M * pq.ksub, + &ivfpq.precomputed_table [key * pq.ksub * pq.M], + -2.0, sim_table_2, + sim_table); + + + if (polysemous_ht != 0) { + ivfpq.quantizer->compute_residual (qi, residual_vec, key); + pq.compute_code (residual_vec, q_code.data()); + } + + } else if (use_precomputed_table == 2) { + dis0 = coarse_dis; + + const MultiIndexQuantizer *miq = + dynamic_cast (ivfpq.quantizer); + FAISS_THROW_IF_NOT (miq); + const ProductQuantizer &cpq = miq->pq; + int Mf = pq.M / cpq.M; + + const float *qtab = sim_table_2; // query-specific table + float *ltab = sim_table; // (output) list-specific table + + long k = key; + for (int cm = 0; cm < cpq.M; cm++) { + // compute PQ index + int ki = k & ((uint64_t(1) << cpq.nbits) - 1); + k >>= cpq.nbits; + + // get corresponding table + const float *pc = &ivfpq.precomputed_table + [(ki * pq.M + cm * Mf) * pq.ksub]; + + if (polysemous_ht == 0) { + + // sum up with query-specific table + fvec_madd (Mf * pq.ksub, + pc, + -2.0, qtab, + ltab); + ltab += Mf * pq.ksub; + qtab += Mf * pq.ksub; + } else { + for (int m = cm * Mf; m < (cm + 1) * Mf; m++) { + q_code[m] = fvec_madd_and_argmin + (pq.ksub, pc, -2, qtab, ltab); + pc += pq.ksub; + ltab += pq.ksub; + qtab += pq.ksub; + } + } + + } + } + + return dis0; + } + + float precompute_list_table_pointers_L2 () + { + float dis0 = 0; + + if (use_precomputed_table == 1) { + dis0 = coarse_dis; + + const float * s = &ivfpq.precomputed_table [key * pq.ksub * pq.M]; + for (int m = 0; m < pq.M; m++) { + sim_table_ptrs [m] = s; + s += pq.ksub; + } + } else if (use_precomputed_table == 2) { + dis0 = coarse_dis; + + const MultiIndexQuantizer *miq = + dynamic_cast (ivfpq.quantizer); + FAISS_THROW_IF_NOT (miq); + const ProductQuantizer &cpq = miq->pq; + int Mf = pq.M / cpq.M; + + long k = key; + int m0 = 0; + for (int cm = 0; cm < cpq.M; cm++) { + int ki = k & ((uint64_t(1) << cpq.nbits) - 1); + k >>= cpq.nbits; + + const float *pc = &ivfpq.precomputed_table + [(ki * pq.M + cm * Mf) * pq.ksub]; + + for (int m = m0; m < m0 + Mf; m++) { + sim_table_ptrs [m] = pc; + pc += pq.ksub; + } + m0 += Mf; + } + } else { + FAISS_THROW_MSG ("need precomputed tables"); + } + + if (polysemous_ht) { + FAISS_THROW_MSG ("not implemented"); + // Not clear that it makes sense to implemente this, + // because it costs M * ksub, which is what we wanted to + // avoid with the tables pointers. + } + + return dis0; + } + + +}; + + + +template +struct KnnSearchResults { + idx_t key; + const idx_t *ids; + + // heap params + size_t k; + float * heap_sim; + idx_t * heap_ids; + + size_t nup; + + inline void add (idx_t j, float dis, ConcurrentBitsetPtr bitset = nullptr) { + if (C::cmp (heap_sim[0], dis)) { + idx_t id = ids ? ids[j] : lo_build (key, j); + if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)id)) + return; + heap_swap_top (k, heap_sim, heap_ids, dis, id); + nup++; + } + } + +}; + +template +struct RangeSearchResults { + idx_t key; + const idx_t *ids; + + // wrapped result structure + float radius; + RangeQueryResult & rres; + + inline void add (idx_t j, float dis, faiss::ConcurrentBitsetPtr bitset = nullptr) { + if (C::cmp (radius, dis)) { + idx_t id = ids ? ids[j] : lo_build (key, j); + rres.add (dis, id); + } + } +}; + + + +/***************************************************** + * Scaning the codes. + * The scanning functions call their favorite precompute_* + * function to precompute the tables they need. + *****************************************************/ +template +struct IVFPQScannerT: QueryTables { + + const uint8_t * list_codes; + const IDType * list_ids; + size_t list_size; + + IVFPQScannerT (const IndexIVFPQ & ivfpq, const IVFSearchParameters *params): + QueryTables (ivfpq, params) + { + assert(METRIC_TYPE == metric_type); + } + + float dis0; + + void init_list (idx_t list_no, float coarse_dis, + int mode) { + this->key = list_no; + this->coarse_dis = coarse_dis; + + if (mode == 2) { + dis0 = precompute_list_tables (); + } else if (mode == 1) { + dis0 = precompute_list_table_pointers (); + } + } + + /***************************************************** + * Scaning the codes: simple PQ scan. + *****************************************************/ + + /// version of the scan where we use precomputed tables + template + void scan_list_with_table (size_t ncode, const uint8_t *codes, + SearchResultType & res, + ConcurrentBitsetPtr bitset = nullptr) const + { + for (size_t j = 0; j < ncode; j++) { + PQDecoder decoder(codes, pq.nbits); + codes += pq.code_size; + float dis = dis0; + const float *tab = sim_table; + + for (size_t m = 0; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + + res.add(j, dis, bitset); + } + } + + + /// tables are not precomputed, but pointers are provided to the + /// relevant X_c|x_r tables + template + void scan_list_with_pointer (size_t ncode, const uint8_t *codes, + SearchResultType & res, + faiss::ConcurrentBitsetPtr bitset = nullptr) const + { + for (size_t j = 0; j < ncode; j++) { + PQDecoder decoder(codes, pq.nbits); + codes += pq.code_size; + + float dis = dis0; + const float *tab = sim_table_2; + + for (size_t m = 0; m < pq.M; m++) { + int ci = decoder.decode(); + dis += sim_table_ptrs [m][ci] - 2 * tab [ci]; + tab += pq.ksub; + } + res.add (j, dis, bitset); + } + } + + + /// nothing is precomputed: access residuals on-the-fly + template + void scan_on_the_fly_dist (size_t ncode, const uint8_t *codes, + SearchResultType &res, + faiss::ConcurrentBitsetPtr bitset = nullptr) const + { + const float *dvec; + float dis0 = 0; + if (by_residual) { + if (METRIC_TYPE == METRIC_INNER_PRODUCT) { + ivfpq.quantizer->reconstruct (key, residual_vec); + dis0 = fvec_inner_product (residual_vec, qi, d); + } else { + ivfpq.quantizer->compute_residual (qi, residual_vec, key); + } + dvec = residual_vec; + } else { + dvec = qi; + dis0 = 0; + } + + for (size_t j = 0; j < ncode; j++) { + + pq.decode (codes, decoded_vec); + codes += pq.code_size; + + float dis; + if (METRIC_TYPE == METRIC_INNER_PRODUCT) { + dis = dis0 + fvec_inner_product (decoded_vec, qi, d); + } else { + dis = fvec_L2sqr (decoded_vec, dvec, d); + } + res.add (j, dis, bitset); + } + } + + /***************************************************** + * Scanning codes with polysemous filtering + *****************************************************/ + + template + void scan_list_polysemous_hc ( + size_t ncode, const uint8_t *codes, + SearchResultType & res, + faiss::ConcurrentBitsetPtr bitset = nullptr) const + { + int ht = ivfpq.polysemous_ht; + size_t n_hamming_pass = 0, nup = 0; + + int code_size = pq.code_size; + + HammingComputer hc (q_code.data(), code_size); + + for (size_t j = 0; j < ncode; j++) { + const uint8_t *b_code = codes; + int hd = hc.hamming (b_code); + if (hd < ht) { + n_hamming_pass ++; + PQDecoder decoder(codes, pq.nbits); + + float dis = dis0; + const float *tab = sim_table; + + for (size_t m = 0; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + + res.add (j, dis, bitset); + } + codes += code_size; + } +#pragma omp critical + { + indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; + } + } + + template + void scan_list_polysemous ( + size_t ncode, const uint8_t *codes, + SearchResultType &res, + faiss::ConcurrentBitsetPtr bitset = nullptr) const + { + switch (pq.code_size) { +#define HANDLE_CODE_SIZE(cs) \ + case cs: \ + scan_list_polysemous_hc \ + \ + (ncode, codes, res, bitset); \ + break + HANDLE_CODE_SIZE(4); + HANDLE_CODE_SIZE(8); + HANDLE_CODE_SIZE(16); + HANDLE_CODE_SIZE(20); + HANDLE_CODE_SIZE(32); + HANDLE_CODE_SIZE(64); +#undef HANDLE_CODE_SIZE + default: + if (pq.code_size % 8 == 0) + scan_list_polysemous_hc + + (ncode, codes, res, bitset); + else + scan_list_polysemous_hc + + (ncode, codes, res, bitset); + break; + } + } + +}; + + +/* We put as many parameters as possible in template. Hopefully the + * gain in runtime is worth the code bloat. C is the comparator < or + * >, it is directly related to METRIC_TYPE. precompute_mode is how + * much we precompute (2 = precompute distance tables, 1 = precompute + * pointers to distances, 0 = compute distances one by one). + * Currently only 2 is supported */ +template +struct IVFPQScanner: + IVFPQScannerT, + InvertedListScanner +{ + bool store_pairs; + int precompute_mode; + + IVFPQScanner(const IndexIVFPQ & ivfpq, bool store_pairs, + int precompute_mode): + IVFPQScannerT(ivfpq, nullptr), + store_pairs(store_pairs), precompute_mode(precompute_mode) + { + } + + void set_query (const float *query) override { + this->init_query (query); + } + + void set_list (idx_t list_no, float coarse_dis) override { + this->init_list (list_no, coarse_dis, precompute_mode); + } + + float distance_to_code (const uint8_t *code) const override { + assert(precompute_mode == 2); + float dis = this->dis0; + const float *tab = this->sim_table; + PQDecoder decoder(code, this->pq.nbits); + + for (size_t m = 0; m < this->pq.M; m++) { + dis += tab[decoder.decode()]; + tab += this->pq.ksub; + } + return dis; + } + + size_t scan_codes (size_t ncode, + const uint8_t *codes, + const idx_t *ids, + float *heap_sim, idx_t *heap_ids, + size_t k, + faiss::ConcurrentBitsetPtr bitset) const override + { + KnnSearchResults res = { + /* key */ this->key, + /* ids */ this->store_pairs ? nullptr : ids, + /* k */ k, + /* heap_sim */ heap_sim, + /* heap_ids */ heap_ids, + /* nup */ 0 + }; + + if (this->polysemous_ht > 0) { + assert(precompute_mode == 2); + this->scan_list_polysemous (ncode, codes, res, bitset); + } else if (precompute_mode == 2) { + this->scan_list_with_table (ncode, codes, res, bitset); + } else if (precompute_mode == 1) { + this->scan_list_with_pointer (ncode, codes, res, bitset); + } else if (precompute_mode == 0) { + this->scan_on_the_fly_dist (ncode, codes, res, bitset); + } else { + FAISS_THROW_MSG("bad precomp mode"); + } + return res.nup; + } + + void scan_codes_range (size_t ncode, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult & rres, + faiss::ConcurrentBitsetPtr bitset = nullptr) const override + { + RangeSearchResults res = { + /* key */ this->key, + /* ids */ this->store_pairs ? nullptr : ids, + /* radius */ radius, + /* rres */ rres + }; + + if (this->polysemous_ht > 0) { + assert(precompute_mode == 2); + this->scan_list_polysemous (ncode, codes, res, bitset); + } else if (precompute_mode == 2) { + this->scan_list_with_table (ncode, codes, res, bitset); + } else if (precompute_mode == 1) { + this->scan_list_with_pointer (ncode, codes, res, bitset); + } else if (precompute_mode == 0) { + this->scan_on_the_fly_dist (ncode, codes, res, bitset); + } else { + FAISS_THROW_MSG("bad precomp mode"); + } + + } +}; + +template +InvertedListScanner *get_InvertedListScanner1 (const IndexIVFPQ &index, + bool store_pairs) +{ + + if (index.metric_type == METRIC_INNER_PRODUCT) { + return new IVFPQScanner + , PQDecoder> + (index, store_pairs, 2); + } else if (index.metric_type == METRIC_L2) { + return new IVFPQScanner + , PQDecoder> + (index, store_pairs, 2); + } + return nullptr; +} + + +} // anonymous namespace + +InvertedListScanner * +IndexIVFPQ::get_InvertedListScanner (bool store_pairs) const +{ + + if (pq.nbits == 8) { + return get_InvertedListScanner1 (*this, store_pairs); + } else if (pq.nbits == 16) { + return get_InvertedListScanner1 (*this, store_pairs); + } else { + return get_InvertedListScanner1 (*this, store_pairs); + } + return nullptr; + +} + + + +IndexIVFPQStats indexIVFPQ_stats; + +void IndexIVFPQStats::reset () { + memset (this, 0, sizeof (*this)); +} + + + +IndexIVFPQ::IndexIVFPQ () +{ + // initialize some runtime values + use_precomputed_table = 0; + scan_table_threshold = 0; + do_polysemous_training = false; + polysemous_ht = 0; + polysemous_training = nullptr; +} + + +struct CodeCmp { + const uint8_t *tab; + size_t code_size; + bool operator () (int a, int b) const { + return cmp (a, b) > 0; + } + int cmp (int a, int b) const { + return memcmp (tab + a * code_size, tab + b * code_size, + code_size); + } +}; + + +size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const +{ + size_t ngroup = 0; + lims[0] = 0; + for (size_t list_no = 0; list_no < nlist; list_no++) { + size_t n = invlists->list_size (list_no); + std::vector ord (n); + for (int i = 0; i < n; i++) ord[i] = i; + InvertedLists::ScopedCodes codes (invlists, list_no); + CodeCmp cs = { codes.get(), code_size }; + std::sort (ord.begin(), ord.end(), cs); + + InvertedLists::ScopedIds list_ids (invlists, list_no); + int prev = -1; // all elements from prev to i-1 are equal + for (int i = 0; i < n; i++) { + if (prev >= 0 && cs.cmp (ord [prev], ord [i]) == 0) { + // same as previous => remember + if (prev + 1 == i) { // start new group + ngroup++; + lims[ngroup] = lims[ngroup - 1]; + dup_ids [lims [ngroup]++] = list_ids [ord [prev]]; + } + dup_ids [lims [ngroup]++] = list_ids [ord [i]]; + } else { // not same as previous. + prev = i; + } + } + } + return ngroup; +} + + + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQ.h b/core/src/index/thirdparty/faiss/IndexIVFPQ.h new file mode 100644 index 0000000000..4ca04e9ef9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFPQ.h @@ -0,0 +1,160 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_IVFPQ_H +#define FAISS_INDEX_IVFPQ_H + + +#include + +#include +#include + + +namespace faiss { + +struct IVFPQSearchParameters: IVFSearchParameters { + size_t scan_table_threshold; ///< use table computation or on-the-fly? + int polysemous_ht; ///< Hamming thresh for polysemous filtering + ~IVFPQSearchParameters () {} +}; + + +/** Inverted file with Product Quantizer encoding. Each residual + * vector is encoded as a product quantizer code. + */ +struct IndexIVFPQ: IndexIVF { + bool by_residual; ///< Encode residual or plain vector? + + ProductQuantizer pq; ///< produces the codes + + bool do_polysemous_training; ///< reorder PQ centroids after training? + PolysemousTraining *polysemous_training; ///< if NULL, use default + + // search-time parameters + size_t scan_table_threshold; ///< use table computation or on-the-fly? + int polysemous_ht; ///< Hamming thresh for polysemous filtering + + /** Precompute table that speed up query preprocessing at some + * memory cost (used only for by_residual with L2 metric) + * =-1: force disable + * =0: decide heuristically (default: use tables only if they are + * < precomputed_tables_max_bytes) + * =1: tables that work for all quantizers (size 256 * nlist * M) + * =2: specific version for MultiIndexQuantizer (much more compact) + */ + int use_precomputed_table; + static size_t precomputed_table_max_bytes; + + /// if use_precompute_table + /// size nlist * pq.M * pq.ksub + std::vector precomputed_table; + + IndexIVFPQ ( + Index * quantizer, size_t d, size_t nlist, + size_t M, size_t nbits_per_idx, MetricType metric = METRIC_L2); + + void add_with_ids(idx_t n, const float* x, const idx_t* xids = nullptr) + override; + + void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos = false) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + + /// same as add_core, also: + /// - output 2nd level residuals if residuals_2 != NULL + /// - use precomputed list numbers if precomputed_idx != NULL + void add_core_o (idx_t n, const float *x, + const idx_t *xids, float *residuals_2, + const idx_t *precomputed_idx = nullptr); + + /// trains the product quantizer + void train_residual(idx_t n, const float* x) override; + + /// same as train_residual, also output 2nd level residuals + void train_residual_o (idx_t n, const float *x, float *residuals_2); + + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + /** Find exact duplicates in the dataset. + * + * the duplicates are returned in pre-allocated arrays (see the + * max sizes). + * + * @param lims limits between groups of duplicates + * (max size ntotal / 2 + 1) + * @param ids ids[lims[i]] : ids[lims[i+1]-1] is a group of + * duplicates (max size ntotal) + * @return n number of groups found + */ + size_t find_duplicates (idx_t *ids, size_t *lims) const; + + // map a vector to a binary code knowning the index + void encode (idx_t key, const float * x, uint8_t * code) const; + + /** Encode multiple vectors + * + * @param n nb vectors to encode + * @param keys posting list ids for those vectors (size n) + * @param x vectors (size n * d) + * @param codes output codes (size n * code_size) + * @param compute_keys if false, assume keys are precomputed, + * otherwise compute them + */ + void encode_multiple (size_t n, idx_t *keys, + const float * x, uint8_t * codes, + bool compute_keys = false) const; + + /// inverse of encode_multiple + void decode_multiple (size_t n, const idx_t *keys, + const uint8_t * xcodes, float * x) const; + + InvertedListScanner *get_InvertedListScanner (bool store_pairs) + const override; + + /// build precomputed table + void precompute_table (); + + IndexIVFPQ (); + +}; + + +/// statistics are robust to internal threading, but not if +/// IndexIVFPQ::search_preassigned is called by multiple threads +struct IndexIVFPQStats { + size_t nrefine; ///< nb of refines (IVFPQR) + + size_t n_hamming_pass; + ///< nb of passed Hamming distance tests (for polysemous) + + // timings measured with the CPU RTC on all threads + size_t search_cycles; + size_t refine_cycles; ///< only for IVFPQR + + IndexIVFPQStats () {reset (); } + void reset (); +}; + +// global var that collects them all +extern IndexIVFPQStats indexIVFPQ_stats; + + + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp b/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp new file mode 100644 index 0000000000..20d849210c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFPQR.cpp @@ -0,0 +1,219 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include + +#include +#include + +namespace faiss { + +/***************************************** + * IndexIVFPQR implementation + ******************************************/ + +IndexIVFPQR::IndexIVFPQR ( + Index * quantizer, size_t d, size_t nlist, + size_t M, size_t nbits_per_idx, + size_t M_refine, size_t nbits_per_idx_refine): + IndexIVFPQ (quantizer, d, nlist, M, nbits_per_idx), + refine_pq (d, M_refine, nbits_per_idx_refine), + k_factor (4) +{ + by_residual = true; +} + +IndexIVFPQR::IndexIVFPQR (): + k_factor (1) +{ + by_residual = true; +} + + + +void IndexIVFPQR::reset() +{ + IndexIVFPQ::reset(); + refine_codes.clear(); +} + + + + +void IndexIVFPQR::train_residual (idx_t n, const float *x) +{ + + float * residual_2 = new float [n * d]; + ScopeDeleter del(residual_2); + + train_residual_o (n, x, residual_2); + + if (verbose) + printf ("training %zdx%zd 2nd level PQ quantizer on %ld %dD-vectors\n", + refine_pq.M, refine_pq.ksub, n, d); + + refine_pq.cp.max_points_per_centroid = 1000; + refine_pq.cp.verbose = verbose; + + refine_pq.train (n, residual_2); + +} + + +void IndexIVFPQR::add_with_ids (idx_t n, const float *x, const idx_t *xids) { + add_core (n, x, xids, nullptr); +} + +void IndexIVFPQR::add_core (idx_t n, const float *x, const idx_t *xids, + const idx_t *precomputed_idx) { + + float * residual_2 = new float [n * d]; + ScopeDeleter del(residual_2); + + idx_t n0 = ntotal; + + add_core_o (n, x, xids, residual_2, precomputed_idx); + + refine_codes.resize (ntotal * refine_pq.code_size); + + refine_pq.compute_codes ( + residual_2, &refine_codes[n0 * refine_pq.code_size], n); + + +} +#define TIC t0 = get_cycles() +#define TOC get_cycles () - t0 + + +void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k, + const idx_t *idx, + const float *L1_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset + ) const +{ + uint64_t t0; + TIC; + size_t k_coarse = long(k * k_factor); + idx_t *coarse_labels = new idx_t [k_coarse * n]; + ScopeDeleter del1 (coarse_labels); + { // query with quantizer levels 1 and 2. + float *coarse_distances = new float [k_coarse * n]; + ScopeDeleter del(coarse_distances); + + IndexIVFPQ::search_preassigned ( + n, x, k_coarse, + idx, L1_dis, coarse_distances, coarse_labels, + true, params); + } + + + indexIVFPQ_stats.search_cycles += TOC; + + TIC; + + // 3rd level refinement + size_t n_refine = 0; +#pragma omp parallel reduction(+ : n_refine) + { + // tmp buffers + float *residual_1 = new float [2 * d]; + ScopeDeleter del (residual_1); + float *residual_2 = residual_1 + d; +#pragma omp for + for (idx_t i = 0; i < n; i++) { + const float *xq = x + i * d; + const idx_t * shortlist = coarse_labels + k_coarse * i; + float * heap_sim = distances + k * i; + idx_t * heap_ids = labels + k * i; + maxheap_heapify (k, heap_sim, heap_ids); + + for (int j = 0; j < k_coarse; j++) { + idx_t sl = shortlist[j]; + + if (sl == -1) continue; + + int list_no = lo_listno(sl); + int ofs = lo_offset(sl); + + assert (list_no >= 0 && list_no < nlist); + assert (ofs >= 0 && ofs < invlists->list_size (list_no)); + + // 1st level residual + quantizer->compute_residual (xq, residual_1, list_no); + + // 2nd level residual + const uint8_t * l2code = + invlists->get_single_code (list_no, ofs); + + pq.decode (l2code, residual_2); + for (int l = 0; l < d; l++) + residual_2[l] = residual_1[l] - residual_2[l]; + + // 3rd level residual's approximation + idx_t id = invlists->get_single_id (list_no, ofs); + assert (0 <= id && id < ntotal); + refine_pq.decode (&refine_codes [id * refine_pq.code_size], + residual_1); + + float dis = fvec_L2sqr (residual_1, residual_2, d); + + if (dis < heap_sim[0]) { + idx_t id_or_pair = store_pairs ? sl : id; + maxheap_swap_top (k, heap_sim, heap_ids, dis, id_or_pair); + } + n_refine ++; + } + maxheap_reorder (k, heap_sim, heap_ids); + } + } + indexIVFPQ_stats.nrefine += n_refine; + indexIVFPQ_stats.refine_cycles += TOC; +} + +void IndexIVFPQR::reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const +{ + IndexIVFPQ::reconstruct_from_offset (list_no, offset, recons); + + idx_t id = invlists->get_single_id (list_no, offset); + assert (0 <= id && id < ntotal); + + std::vector r3(d); + refine_pq.decode (&refine_codes [id * refine_pq.code_size], r3.data()); + for (int i = 0; i < d; ++i) { + recons[i] += r3[i]; + } +} + +void IndexIVFPQR::merge_from (IndexIVF &other_in, idx_t add_id) +{ + IndexIVFPQR *other = dynamic_cast (&other_in); + FAISS_THROW_IF_NOT(other); + + IndexIVF::merge_from (other_in, add_id); + + refine_codes.insert (refine_codes.end(), + other->refine_codes.begin(), + other->refine_codes.end()); + other->refine_codes.clear(); +} + +size_t IndexIVFPQR::remove_ids(const IDSelector& /*sel*/) { + FAISS_THROW_MSG("not implemented"); + return 0; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVFPQR.h b/core/src/index/thirdparty/faiss/IndexIVFPQR.h new file mode 100644 index 0000000000..38177bda41 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFPQR.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include + + +namespace faiss { + + + +/** Index with an additional level of PQ refinement */ +struct IndexIVFPQR: IndexIVFPQ { + ProductQuantizer refine_pq; ///< 3rd level quantizer + std::vector refine_codes; ///< corresponding codes + + /// factor between k requested in search and the k requested from the IVFPQ + float k_factor; + + IndexIVFPQR ( + Index * quantizer, size_t d, size_t nlist, + size_t M, size_t nbits_per_idx, + size_t M_refine, size_t nbits_per_idx_refine); + + void reset() override; + + size_t remove_ids(const IDSelector& sel) override; + + /// trains the two product quantizers + void train_residual(idx_t n, const float* x) override; + + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + /// same as add_with_ids, but optionally use the precomputed list ids + void add_core (idx_t n, const float *x, const idx_t *xids, + const idx_t *precomputed_idx = nullptr); + + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + void merge_from (IndexIVF &other, idx_t add_id) override; + + + void search_preassigned (idx_t n, const float *x, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params=nullptr, + ConcurrentBitsetPtr bitset = nullptr + ) const override; + + IndexIVFPQR(); +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp new file mode 100644 index 0000000000..4e27500a34 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.cpp @@ -0,0 +1,333 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace faiss { + + +IndexIVFSpectralHash::IndexIVFSpectralHash ( + Index * quantizer, size_t d, size_t nlist, + int nbit, float period): + IndexIVF (quantizer, d, nlist, (nbit + 7) / 8, METRIC_L2), + nbit (nbit), period (period), threshold_type (Thresh_global) +{ + FAISS_THROW_IF_NOT (code_size % 4 == 0); + RandomRotationMatrix *rr = new RandomRotationMatrix (d, nbit); + rr->init (1234); + vt = rr; + own_fields = true; + is_trained = false; +} + +IndexIVFSpectralHash::IndexIVFSpectralHash(): + IndexIVF(), vt(nullptr), own_fields(false), + nbit(0), period(0), threshold_type(Thresh_global) +{} + +IndexIVFSpectralHash::~IndexIVFSpectralHash () +{ + if (own_fields) { + delete vt; + } +} + +namespace { + + +float median (size_t n, float *x) { + std::sort(x, x + n); + if (n % 2 == 1) { + return x [n / 2]; + } else { + return (x [n / 2 - 1] + x [n / 2]) / 2; + } +} + +} + + +void IndexIVFSpectralHash::train_residual (idx_t n, const float *x) +{ + if (!vt->is_trained) { + vt->train (n, x); + } + + if (threshold_type == Thresh_global) { + // nothing to do + return; + } else if (threshold_type == Thresh_centroid || + threshold_type == Thresh_centroid_half) { + // convert all centroids with vt + std::vector centroids (nlist * d); + quantizer->reconstruct_n (0, nlist, centroids.data()); + trained.resize(nlist * nbit); + vt->apply_noalloc (nlist, centroids.data(), trained.data()); + if (threshold_type == Thresh_centroid_half) { + for (size_t i = 0; i < nlist * nbit; i++) { + trained[i] -= 0.25 * period; + } + } + return; + } + // otherwise train medians + + // assign + std::unique_ptr idx (new idx_t [n]); + quantizer->assign (n, x, idx.get()); + + std::vector sizes(nlist + 1); + for (size_t i = 0; i < n; i++) { + FAISS_THROW_IF_NOT (idx[i] >= 0); + sizes[idx[i]]++; + } + + size_t ofs = 0; + for (int j = 0; j < nlist; j++) { + size_t o0 = ofs; + ofs += sizes[j]; + sizes[j] = o0; + } + + // transform + std::unique_ptr xt (vt->apply (n, x)); + + // transpose + reorder + std::unique_ptr xo (new float[n * nbit]); + + for (size_t i = 0; i < n; i++) { + size_t idest = sizes[idx[i]]++; + for (size_t j = 0; j < nbit; j++) { + xo[idest + n * j] = xt[i * nbit + j]; + } + } + + trained.resize (n * nbit); + // compute medians +#pragma omp for + for (int i = 0; i < nlist; i++) { + size_t i0 = i == 0 ? 0 : sizes[i - 1]; + size_t i1 = sizes[i]; + for (int j = 0; j < nbit; j++) { + float *xoi = xo.get() + i0 + n * j; + if (i0 == i1) { // nothing to train + trained[i * nbit + j] = 0.0; + } else if (i1 == i0 + 1) { + trained[i * nbit + j] = xoi[0]; + } else { + trained[i * nbit + j] = median(i1 - i0, xoi); + } + } + } +} + + +namespace { + +void binarize_with_freq(size_t nbit, float freq, + const float *x, const float *c, + uint8_t *codes) +{ + memset (codes, 0, (nbit + 7) / 8); + for (size_t i = 0; i < nbit; i++) { + float xf = (x[i] - c[i]); + int xi = int(floor(xf * freq)); + int bit = xi & 1; + codes[i >> 3] |= bit << (i & 7); + } +} + + +}; + + + +void IndexIVFSpectralHash::encode_vectors(idx_t n, const float* x_in, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos) const +{ + FAISS_THROW_IF_NOT (is_trained); + float freq = 2.0 / period; + + FAISS_THROW_IF_NOT_MSG (!include_listnos, "listnos encoding not supported"); + + // transform with vt + std::unique_ptr x (vt->apply (n, x_in)); + +#pragma omp parallel + { + std::vector zero (nbit); + + // each thread takes care of a subset of lists +#pragma omp for + for (size_t i = 0; i < n; i++) { + int64_t list_no = list_nos [i]; + + if (list_no >= 0) { + const float *c; + if (threshold_type == Thresh_global) { + c = zero.data(); + } else { + c = trained.data() + list_no * nbit; + } + binarize_with_freq (nbit, freq, + x.get() + i * nbit, c, + codes + i * code_size) ; + } + } + } +} + +namespace { + + +template +struct IVFScanner: InvertedListScanner { + + // copied from index structure + const IndexIVFSpectralHash *index; + size_t code_size; + size_t nbit; + bool store_pairs; + + float period, freq; + std::vector q; + std::vector zero; + std::vector qcode; + HammingComputer hc; + + using idx_t = Index::idx_t; + + IVFScanner (const IndexIVFSpectralHash * index, + bool store_pairs): + index (index), + code_size(index->code_size), + nbit(index->nbit), + store_pairs(store_pairs), + period(index->period), freq(2.0 / index->period), + q(nbit), zero(nbit), qcode(code_size), + hc(qcode.data(), code_size) + { + } + + + void set_query (const float *query) override { + FAISS_THROW_IF_NOT(query); + FAISS_THROW_IF_NOT(q.size() == nbit); + index->vt->apply_noalloc (1, query, q.data()); + + if (index->threshold_type == + IndexIVFSpectralHash::Thresh_global) { + binarize_with_freq + (nbit, freq, q.data(), zero.data(), qcode.data()); + hc.set (qcode.data(), code_size); + } + } + + idx_t list_no; + + void set_list (idx_t list_no, float /*coarse_dis*/) override { + this->list_no = list_no; + if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) { + const float *c = index->trained.data() + list_no * nbit; + binarize_with_freq (nbit, freq, q.data(), c, qcode.data()); + hc.set (qcode.data(), code_size); + } + } + + float distance_to_code (const uint8_t *code) const final { + return hc.hamming (code); + } + + size_t scan_codes (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset) const override + { + size_t nup = 0; + for (size_t j = 0; j < list_size; j++) { + if (!bitset || !bitset->test(ids[j])) { + float dis = hc.hamming (codes); + + if (dis < simi [0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + maxheap_swap_top (k, simi, idxi, dis, id); + nup++; + } + } + codes += code_size; + } + return nup; + } + + void scan_codes_range (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult & res, + ConcurrentBitsetPtr bitset = nullptr) const override + { + for (size_t j = 0; j < list_size; j++) { + float dis = hc.hamming (codes); + if (dis < radius) { + int64_t id = store_pairs ? lo_build (list_no, j) : ids[j]; + res.add (dis, id); + } + codes += code_size; + } + } + + +}; + +} // anonymous namespace + +InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner + (bool store_pairs) const +{ + switch (code_size) { +#define HANDLE_CODE_SIZE(cs) \ + case cs: \ + return new IVFScanner (this, store_pairs) + HANDLE_CODE_SIZE(4); + HANDLE_CODE_SIZE(8); + HANDLE_CODE_SIZE(16); + HANDLE_CODE_SIZE(20); + HANDLE_CODE_SIZE(32); + HANDLE_CODE_SIZE(64); +#undef HANDLE_CODE_SIZE + default: + if (code_size % 8 == 0) { + return new IVFScanner(this, store_pairs); + } else if (code_size % 4 == 0) { + return new IVFScanner(this, store_pairs); + } else { + FAISS_THROW_MSG("not supported"); + } + } + +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.h b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.h new file mode 100644 index 0000000000..ee01ac81cd --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexIVFSpectralHash.h @@ -0,0 +1,75 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_IVFSH_H +#define FAISS_INDEX_IVFSH_H + + +#include + +#include + + +namespace faiss { + +struct VectorTransform; + +/** Inverted list that stores binary codes of size nbit. Before the + * binary conversion, the dimension of the vectors is transformed from + * dim d into dim nbit by vt (a random rotation by default). + * + * Each coordinate is subtracted from a value determined by + * threshold_type, and split into intervals of size period. Half of + * the interval is a 0 bit, the other half a 1. + */ +struct IndexIVFSpectralHash: IndexIVF { + + VectorTransform *vt; // transformation from d to nbit dim + bool own_fields; + + int nbit; + float period; + + enum ThresholdType { + Thresh_global, + Thresh_centroid, + Thresh_centroid_half, + Thresh_median + }; + ThresholdType threshold_type; + + // size nlist * nbit or 0 if Thresh_global + std::vector trained; + + IndexIVFSpectralHash (Index * quantizer, size_t d, size_t nlist, + int nbit, float period); + + IndexIVFSpectralHash (); + + void train_residual(idx_t n, const float* x) override; + + void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos = false) const override; + + InvertedListScanner *get_InvertedListScanner (bool store_pairs) + const override; + + ~IndexIVFSpectralHash () override; + +}; + + + + +}; // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexLSH.cpp b/core/src/index/thirdparty/faiss/IndexLSH.cpp new file mode 100644 index 0000000000..1a780a2e7d --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexLSH.cpp @@ -0,0 +1,226 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include + +#include +#include +#include + + +namespace faiss { + +/*************************************************************** + * IndexLSH + ***************************************************************/ + + +IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds): + Index(d), nbits(nbits), rotate_data(rotate_data), + train_thresholds (train_thresholds), rrot(d, nbits) +{ + is_trained = !train_thresholds; + + bytes_per_vec = (nbits + 7) / 8; + + if (rotate_data) { + rrot.init(5); + } else { + FAISS_THROW_IF_NOT (d >= nbits); + } +} + +IndexLSH::IndexLSH (): + nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false) +{ +} + + +const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const +{ + + float *xt = nullptr; + if (rotate_data) { + // also applies bias if exists + xt = rrot.apply (n, x); + } else if (d != nbits) { + assert (nbits < d); + xt = new float [nbits * n]; + float *xp = xt; + for (idx_t i = 0; i < n; i++) { + const float *xl = x + i * d; + for (int j = 0; j < nbits; j++) + *xp++ = xl [j]; + } + } + + if (train_thresholds) { + + if (xt == NULL) { + xt = new float [nbits * n]; + memcpy (xt, x, sizeof(*x) * n * nbits); + } + + float *xp = xt; + for (idx_t i = 0; i < n; i++) + for (int j = 0; j < nbits; j++) + *xp++ -= thresholds [j]; + } + + return xt ? xt : x; +} + + + +void IndexLSH::train (idx_t n, const float *x) +{ + if (train_thresholds) { + thresholds.resize (nbits); + train_thresholds = false; + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + train_thresholds = true; + + float * transposed_x = new float [n * nbits]; + ScopeDeleter del2 (transposed_x); + + for (idx_t i = 0; i < n; i++) + for (idx_t j = 0; j < nbits; j++) + transposed_x [j * n + i] = xt [i * nbits + j]; + + for (idx_t i = 0; i < nbits; i++) { + float *xi = transposed_x + i * n; + // std::nth_element + std::sort (xi, xi + n); + if (n % 2 == 1) + thresholds [i] = xi [n / 2]; + else + thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2; + + } + } + is_trained = true; +} + + +void IndexLSH::add (idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT (is_trained); + codes.resize ((ntotal + n) * bytes_per_vec); + + sa_encode (n, x, &codes[ntotal * bytes_per_vec]); + + ntotal += n; +} + + +void IndexLSH::search ( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + + uint8_t * qcodes = new uint8_t [n * bytes_per_vec]; + ScopeDeleter del2 (qcodes); + + fvecs2bitvecs (xt, qcodes, nbits, n); + + int * idistances = new int [n * k]; + ScopeDeleter del3 (idistances); + + int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances}; + + hammings_knn_hc (&res, qcodes, codes.data(), + ntotal, bytes_per_vec, true); + + + // convert distances to floats + for (int i = 0; i < k * n; i++) + distances[i] = idistances[i]; + +} + + +void IndexLSH::transfer_thresholds (LinearTransform *vt) { + if (!train_thresholds) return; + FAISS_THROW_IF_NOT (nbits == vt->d_out); + if (!vt->have_bias) { + vt->b.resize (nbits, 0); + vt->have_bias = true; + } + for (int i = 0; i < nbits; i++) + vt->b[i] -= thresholds[i]; + train_thresholds = false; + thresholds.clear(); +} + +void IndexLSH::reset() { + codes.clear(); + ntotal = 0; +} + + +size_t IndexLSH::sa_code_size () const +{ + return bytes_per_vec; +} + +void IndexLSH::sa_encode (idx_t n, const float *x, + uint8_t *bytes) const +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + fvecs2bitvecs (xt, bytes, nbits, n); +} + +void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes, + float *x) const +{ + float *xt = x; + ScopeDeleter del; + if (rotate_data || nbits != d) { + xt = new float [n * nbits]; + del.set(xt); + } + bitvecs2fvecs (bytes, xt, nbits, n); + + if (train_thresholds) { + float *xp = xt; + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < nbits; j++) { + *xp++ += thresholds [j]; + } + } + } + + if (rotate_data) { + rrot.reverse_transform (n, xt, x); + } else if (nbits != d) { + for (idx_t i = 0; i < n; i++) { + memcpy (x + i * d, xt + i * nbits, + nbits * sizeof(xt[0])); + } + } +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexLSH.h b/core/src/index/thirdparty/faiss/IndexLSH.h new file mode 100644 index 0000000000..7bcc9c5f84 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexLSH.h @@ -0,0 +1,91 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef INDEX_LSH_H +#define INDEX_LSH_H + +#include + +#include +#include + +namespace faiss { + + +/** The sign of each vector component is put in a binary signature */ +struct IndexLSH:Index { + typedef unsigned char uint8_t; + + int nbits; ///< nb of bits per vector + int bytes_per_vec; ///< nb of 8-bits per encoded vector + bool rotate_data; ///< whether to apply a random rotation to input + bool train_thresholds; ///< whether we train thresholds or use 0 + + RandomRotationMatrix rrot; ///< optional random rotation + + std::vector thresholds; ///< thresholds to compare with + + /// encoded dataset + std::vector codes; + + IndexLSH ( + idx_t d, int nbits, + bool rotate_data = true, + bool train_thresholds = false); + + /** Preprocesses and resizes the input to the size required to + * binarize the data + * + * @param x input vectors, size n * d + * @return output vectors, size n * bits. May be the same pointer + * as x, otherwise it should be deleted by the caller + */ + const float *apply_preprocess (idx_t n, const float *x) const; + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reset() override; + + /// transfer the thresholds to a pre-processing stage (and unset + /// train_thresholds) + void transfer_thresholds (LinearTransform * vt); + + ~IndexLSH() override {} + + IndexLSH (); + + /* standalone codec interface. + * + * The vectors are decoded to +/- 1 (not 0, 1) */ + + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + +}; + + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexLattice.cpp b/core/src/index/thirdparty/faiss/IndexLattice.cpp new file mode 100644 index 0000000000..5c7be9fcbc --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexLattice.cpp @@ -0,0 +1,143 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + + +#include +#include // for the bitstring routines +#include +#include + +namespace faiss { + + +IndexLattice::IndexLattice (idx_t d, int nsq, int scale_nbit, int r2): + Index (d), + nsq (nsq), + dsq (d / nsq), + zn_sphere_codec (dsq, r2), + scale_nbit (scale_nbit) +{ + FAISS_THROW_IF_NOT (d % nsq == 0); + + lattice_nbit = 0; + while (!( ((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) { + lattice_nbit++; + } + + int total_nbit = (lattice_nbit + scale_nbit) * nsq; + + code_size = (total_nbit + 7) / 8; + + is_trained = false; +} + +void IndexLattice::train(idx_t n, const float* x) +{ + // compute ranges per sub-block + trained.resize (nsq * 2); + float * mins = trained.data(); + float * maxs = trained.data() + nsq; + for (int sq = 0; sq < nsq; sq++) { + mins[sq] = HUGE_VAL; + maxs[sq] = -1; + } + + for (idx_t i = 0; i < n; i++) { + for (int sq = 0; sq < nsq; sq++) { + float norm2 = fvec_norm_L2sqr (x + i * d + sq * dsq, dsq); + if (norm2 > maxs[sq]) maxs[sq] = norm2; + if (norm2 < mins[sq]) mins[sq] = norm2; + } + } + + for (int sq = 0; sq < nsq; sq++) { + mins[sq] = sqrtf (mins[sq]); + maxs[sq] = sqrtf (maxs[sq]); + } + + is_trained = true; +} + +/* The standalone codec interface */ +size_t IndexLattice::sa_code_size () const +{ + return code_size; +} + + + +void IndexLattice::sa_encode (idx_t n, const float *x, uint8_t *codes) const +{ + + const float * mins = trained.data(); + const float * maxs = mins + nsq; + int64_t sc = int64_t(1) << scale_nbit; + +#pragma omp parallel for + for (idx_t i = 0; i < n; i++) { + BitstringWriter wr(codes + i * code_size, code_size); + const float *xi = x + i * d; + for (int j = 0; j < nsq; j++) { + float nj = + (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) + * sc / (maxs[j] - mins[j]); + if (nj < 0) nj = 0; + if (nj >= sc) nj = sc - 1; + wr.write((int64_t)nj, scale_nbit); + wr.write(zn_sphere_codec.encode(xi), lattice_nbit); + xi += dsq; + } + } +} + +void IndexLattice::sa_decode (idx_t n, const uint8_t *codes, float *x) const +{ + const float * mins = trained.data(); + const float * maxs = mins + nsq; + float sc = int64_t(1) << scale_nbit; + float r = sqrtf(zn_sphere_codec.r2); + +#pragma omp parallel for + for (idx_t i = 0; i < n; i++) { + BitstringReader rd(codes + i * code_size, code_size); + float *xi = x + i * d; + for (int j = 0; j < nsq; j++) { + float norm = + (rd.read (scale_nbit) + 0.5) * + (maxs[j] - mins[j]) / sc + mins[j]; + norm /= r; + zn_sphere_codec.decode (rd.read (lattice_nbit), xi); + for (int l = 0; l < dsq; l++) { + xi[l] *= norm; + } + xi += dsq; + } + } +} + +void IndexLattice::add(idx_t , const float* ) +{ + FAISS_THROW_MSG("not implemented"); +} + + +void IndexLattice::search(idx_t , const float* , idx_t , + float* , idx_t* , ConcurrentBitsetPtr ) const +{ + FAISS_THROW_MSG("not implemented"); +} + + +void IndexLattice::reset() +{ + FAISS_THROW_MSG("not implemented"); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexLattice.h b/core/src/index/thirdparty/faiss/IndexLattice.h new file mode 100644 index 0000000000..e946fac40a --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexLattice.h @@ -0,0 +1,69 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_LATTICE_H +#define FAISS_INDEX_LATTICE_H + + +#include + +#include +#include + +namespace faiss { + + + + + +/** Index that encodes a vector with a series of Zn lattice quantizers + */ +struct IndexLattice: Index { + + /// number of sub-vectors + int nsq; + /// dimension of sub-vectors + size_t dsq; + + /// the lattice quantizer + ZnSphereCodecAlt zn_sphere_codec; + + /// nb bits used to encode the scale, per subvector + int scale_nbit, lattice_nbit; + /// total, in bytes + size_t code_size; + + /// mins and maxes of the vector norms, per subquantizer + std::vector trained; + + IndexLattice (idx_t d, int nsq, int scale_nbit, int r2); + + void train(idx_t n, const float* x) override; + + /* The standalone codec interface */ + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + /// not implemented + void add(idx_t n, const float* x) override; + void search(idx_t n, const float* x, idx_t k, + float* distances, idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + void reset() override; + +}; + +} // namespace faiss + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexPQ.cpp b/core/src/index/thirdparty/faiss/IndexPQ.cpp new file mode 100644 index 0000000000..6e50ba1a2c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexPQ.cpp @@ -0,0 +1,1190 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace faiss { + +/********************************************************* + * IndexPQ implementation + ********************************************************/ + + +IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric): + Index(d, metric), pq(d, M, nbits) +{ + is_trained = false; + do_polysemous_training = false; + polysemous_ht = nbits * M + 1; + search_type = ST_PQ; + encode_signs = false; +} + +IndexPQ::IndexPQ () +{ + metric_type = METRIC_L2; + is_trained = false; + do_polysemous_training = false; + polysemous_ht = pq.nbits * pq.M + 1; + search_type = ST_PQ; + encode_signs = false; +} + + +void IndexPQ::train (idx_t n, const float *x) +{ + if (!do_polysemous_training) { // standard training + pq.train(n, x); + } else { + idx_t ntrain_perm = polysemous_training.ntrain_permutation; + + if (ntrain_perm > n / 4) + ntrain_perm = n / 4; + if (verbose) { + printf ("PQ training on %ld points, remains %ld points: " + "training polysemous on %s\n", + n - ntrain_perm, ntrain_perm, + ntrain_perm == 0 ? "centroids" : "these"); + } + pq.train(n - ntrain_perm, x); + + polysemous_training.optimize_pq_for_hamming ( + pq, ntrain_perm, x + (n - ntrain_perm) * d); + } + is_trained = true; +} + + +void IndexPQ::add (idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT (is_trained); + codes.resize ((n + ntotal) * pq.code_size); + pq.compute_codes (x, &codes[ntotal * pq.code_size], n); + ntotal += n; +} + + +size_t IndexPQ::remove_ids (const IDSelector & sel) +{ + idx_t j = 0; + for (idx_t i = 0; i < ntotal; i++) { + if (sel.is_member (i)) { + // should be removed + } else { + if (i > j) { + memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size); + } + j++; + } + } + size_t nremove = ntotal - j; + if (nremove > 0) { + ntotal = j; + codes.resize (ntotal * pq.code_size); + } + return nremove; +} + + +void IndexPQ::reset() +{ + codes.clear(); + ntotal = 0; +} + +void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const +{ + FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal)); + for (idx_t i = 0; i < ni; i++) { + const uint8_t * code = &codes[(i0 + i) * pq.code_size]; + pq.decode (code, recons + i * d); + } +} + + +void IndexPQ::reconstruct (idx_t key, float * recons) const +{ + FAISS_THROW_IF_NOT (key >= 0 && key < ntotal); + pq.decode (&codes[key * pq.code_size], recons); +} + + +namespace { + + +struct PQDis: DistanceComputer { + size_t d; + Index::idx_t nb; + const uint8_t *codes; + size_t code_size; + const ProductQuantizer & pq; + const float *sdc; + std::vector precomputed_table; + size_t ndis; + + float operator () (idx_t i) override + { + const uint8_t *code = codes + i * code_size; + const float *dt = precomputed_table.data(); + float accu = 0; + for (int j = 0; j < pq.M; j++) { + accu += dt[*code++]; + dt += 256; + } + ndis++; + return accu; + } + + float symmetric_dis(idx_t i, idx_t j) override + { + const float * sdci = sdc; + float accu = 0; + const uint8_t *codei = codes + i * code_size; + const uint8_t *codej = codes + j * code_size; + + for (int l = 0; l < pq.M; l++) { + accu += sdci[(*codei++) + (*codej++) * 256]; + sdci += 256 * 256; + } + return accu; + } + + explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr) + : pq(storage.pq) { + precomputed_table.resize(pq.M * pq.ksub); + nb = storage.ntotal; + d = storage.d; + codes = storage.codes.data(); + code_size = pq.code_size; + FAISS_ASSERT(pq.ksub == 256); + FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M); + sdc = pq.sdc_table.data(); + ndis = 0; + } + + void set_query(const float *x) override { + pq.compute_distance_table(x, precomputed_table.data()); + } +}; + + +} // namespace + + +DistanceComputer * IndexPQ::get_distance_computer() const { + FAISS_THROW_IF_NOT(pq.nbits == 8); + return new PQDis(*this); +} + + +/***************************************** + * IndexPQ polysemous search routines + ******************************************/ + + + + + +void IndexPQ::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + if (search_type == ST_PQ) { // Simple PQ search + + if (metric_type == METRIC_L2) { + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances }; + pq.search (x, n, codes.data(), ntotal, &res, true); + } else { + float_minheap_array_t res = { + size_t(n), size_t(k), labels, distances }; + pq.search_ip (x, n, codes.data(), ntotal, &res, true); + } + indexPQ_stats.nq += n; + indexPQ_stats.ncode += n * ntotal; + + } else if (search_type == ST_polysemous || + search_type == ST_polysemous_generalize) { + + FAISS_THROW_IF_NOT (metric_type == METRIC_L2); + + search_core_polysemous (n, x, k, distances, labels); + + } else { // code-to-code distances + + uint8_t * q_codes = new uint8_t [n * pq.code_size]; + ScopeDeleter del (q_codes); + + + if (!encode_signs) { + pq.compute_codes (x, q_codes, n); + } else { + FAISS_THROW_IF_NOT (d == pq.nbits * pq.M); + memset (q_codes, 0, n * pq.code_size); + for (size_t i = 0; i < n; i++) { + const float *xi = x + i * d; + uint8_t *code = q_codes + i * pq.code_size; + for (int j = 0; j < d; j++) + if (xi[j] > 0) code [j>>3] |= 1 << (j & 7); + } + } + + if (search_type == ST_SDC) { + + float_maxheap_array_t res = { + size_t(n), size_t(k), labels, distances}; + + pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true); + + } else { + int * idistances = new int [n * k]; + ScopeDeleter del (idistances); + + int_maxheap_array_t res = { + size_t (n), size_t (k), labels, idistances}; + + if (search_type == ST_HE) { + + hammings_knn_hc (&res, q_codes, codes.data(), + ntotal, pq.code_size, true); + + } else if (search_type == ST_generalized_HE) { + + generalized_hammings_knn_hc (&res, q_codes, codes.data(), + ntotal, pq.code_size, true); + } + + // convert distances to floats + for (int i = 0; i < k * n; i++) + distances[i] = idistances[i]; + + } + + + indexPQ_stats.nq += n; + indexPQ_stats.ncode += n * ntotal; + } +} + + + + + +void IndexPQStats::reset() +{ + nq = ncode = n_hamming_pass = 0; +} + +IndexPQStats indexPQ_stats; + + +template +static size_t polysemous_inner_loop ( + const IndexPQ & index, + const float *dis_table_qi, const uint8_t *q_code, + size_t k, float *heap_dis, int64_t *heap_ids) +{ + + int M = index.pq.M; + int code_size = index.pq.code_size; + int ksub = index.pq.ksub; + size_t ntotal = index.ntotal; + int ht = index.polysemous_ht; + + const uint8_t *b_code = index.codes.data(); + + size_t n_pass_i = 0; + + HammingComputer hc (q_code, code_size); + + for (int64_t bi = 0; bi < ntotal; bi++) { + int hd = hc.hamming (b_code); + + if (hd < ht) { + n_pass_i ++; + + float dis = 0; + const float * dis_table = dis_table_qi; + for (int m = 0; m < M; m++) { + dis += dis_table [b_code[m]]; + dis_table += ksub; + } + + if (dis < heap_dis[0]) { + maxheap_swap_top (k, heap_dis, heap_ids, dis, bi); + } + } + b_code += code_size; + } + return n_pass_i; +} + + +void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels) const +{ + FAISS_THROW_IF_NOT (pq.nbits == 8); + + // PQ distance tables + float * dis_tables = new float [n * pq.ksub * pq.M]; + ScopeDeleter del (dis_tables); + pq.compute_distance_tables (n, x, dis_tables); + + // Hamming embedding queries + uint8_t * q_codes = new uint8_t [n * pq.code_size]; + ScopeDeleter del2 (q_codes); + + if (false) { + pq.compute_codes (x, q_codes, n); + } else { +#pragma omp parallel for + for (idx_t qi = 0; qi < n; qi++) { + pq.compute_code_from_distance_table + (dis_tables + qi * pq.M * pq.ksub, + q_codes + qi * pq.code_size); + } + } + + size_t n_pass = 0; + +#pragma omp parallel for reduction (+: n_pass) + for (idx_t qi = 0; qi < n; qi++) { + const uint8_t * q_code = q_codes + qi * pq.code_size; + + const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub; + + int64_t * heap_ids = labels + qi * k; + float *heap_dis = distances + qi * k; + maxheap_heapify (k, heap_dis, heap_ids); + + if (search_type == ST_polysemous) { + + switch (pq.code_size) { + case 4: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 8: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 16: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 32: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 20: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + default: + if (pq.code_size % 8 == 0) { + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + } else if (pq.code_size % 4 == 0) { + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + } else { + FAISS_THROW_FMT( + "code size %zd not supported for polysemous", + pq.code_size); + } + break; + } + } else { + switch (pq.code_size) { + case 8: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 16: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + case 32: + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + break; + default: + if (pq.code_size % 8 == 0) { + n_pass += polysemous_inner_loop + (*this, dis_table_qi, q_code, k, heap_dis, heap_ids); + } else { + FAISS_THROW_FMT( + "code size %zd not supported for polysemous", + pq.code_size); + } + break; + } + } + maxheap_reorder (k, heap_dis, heap_ids); + } + + indexPQ_stats.nq += n; + indexPQ_stats.ncode += n * ntotal; + indexPQ_stats.n_hamming_pass += n_pass; + + +} + + +/* The standalone codec interface (just remaps to the PQ functions) */ +size_t IndexPQ::sa_code_size () const +{ + return pq.code_size; +} + +void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const +{ + pq.compute_codes (x, bytes, n); +} + +void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const +{ + pq.decode (bytes, x, n); +} + + + + +/***************************************** + * Stats of IndexPQ codes + ******************************************/ + + + + +void IndexPQ::hamming_distance_table (idx_t n, const float *x, + int32_t *dis) const +{ + uint8_t * q_codes = new uint8_t [n * pq.code_size]; + ScopeDeleter del (q_codes); + + pq.compute_codes (x, q_codes, n); + + hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis); +} + + +void IndexPQ::hamming_distance_histogram (idx_t n, const float *x, + idx_t nb, const float *xb, + int64_t *hist) +{ + FAISS_THROW_IF_NOT (metric_type == METRIC_L2); + FAISS_THROW_IF_NOT (pq.code_size % 8 == 0); + FAISS_THROW_IF_NOT (pq.nbits == 8); + + // Hamming embedding queries + uint8_t * q_codes = new uint8_t [n * pq.code_size]; + ScopeDeleter del (q_codes); + pq.compute_codes (x, q_codes, n); + + uint8_t * b_codes ; + ScopeDeleter del_b_codes; + + if (xb) { + b_codes = new uint8_t [nb * pq.code_size]; + del_b_codes.set (b_codes); + pq.compute_codes (xb, b_codes, nb); + } else { + nb = ntotal; + b_codes = codes.data(); + } + int nbits = pq.M * pq.nbits; + memset (hist, 0, sizeof(*hist) * (nbits + 1)); + size_t bs = 256; + +#pragma omp parallel + { + std::vector histi (nbits + 1); + hamdis_t *distances = new hamdis_t [nb * bs]; + ScopeDeleter del (distances); +#pragma omp for + for (size_t q0 = 0; q0 < n; q0 += bs) { + // printf ("dis stats: %ld/%ld\n", q0, n); + size_t q1 = q0 + bs; + if (q1 > n) q1 = n; + + hammings (q_codes + q0 * pq.code_size, b_codes, + q1 - q0, nb, + pq.code_size, distances); + + for (size_t i = 0; i < nb * (q1 - q0); i++) + histi [distances [i]]++; + } +#pragma omp critical + { + for (int i = 0; i <= nbits; i++) + hist[i] += histi[i]; + } + } + +} + + + + + + + + + + + + + + + + + + + + +/***************************************** + * MultiIndexQuantizer + ******************************************/ + +namespace { + +template +struct PreSortedArray { + + const T * x; + int N; + + explicit PreSortedArray (int N): N(N) { + } + void init (const T*x) { + this->x = x; + } + // get smallest value + T get_0 () { + return x[0]; + } + + // get delta between n-smallest and n-1 -smallest + T get_diff (int n) { + return x[n] - x[n - 1]; + } + + // remap orders counted from smallest to indices in array + int get_ord (int n) { + return n; + } + +}; + +template +struct ArgSort { + const T * x; + bool operator() (size_t i, size_t j) { + return x[i] < x[j]; + } +}; + + +/** Array that maintains a permutation of its elements so that the + * array's elements are sorted + */ +template +struct SortedArray { + const T * x; + int N; + std::vector perm; + + explicit SortedArray (int N) { + this->N = N; + perm.resize (N); + } + + void init (const T*x) { + this->x = x; + for (int n = 0; n < N; n++) + perm[n] = n; + ArgSort cmp = {x }; + std::sort (perm.begin(), perm.end(), cmp); + } + + // get smallest value + T get_0 () { + return x[perm[0]]; + } + + // get delta between n-smallest and n-1 -smallest + T get_diff (int n) { + return x[perm[n]] - x[perm[n - 1]]; + } + + // remap orders counted from smallest to indices in array + int get_ord (int n) { + return perm[n]; + } +}; + + + +/** Array has n values. Sort the k first ones and copy the other ones + * into elements k..n-1 + */ +template +void partial_sort (int k, int n, + const typename C::T * vals, typename C::TI * perm) { + // insert first k elts in heap + for (int i = 1; i < k; i++) { + indirect_heap_push (i + 1, vals, perm, perm[i]); + } + + // insert next n - k elts in heap + for (int i = k; i < n; i++) { + typename C::TI id = perm[i]; + typename C::TI top = perm[0]; + + if (C::cmp(vals[top], vals[id])) { + indirect_heap_pop (k, vals, perm); + indirect_heap_push (k, vals, perm, id); + perm[i] = top; + } else { + // nothing, elt at i is good where it is. + } + } + + // order the k first elements in heap + for (int i = k - 1; i > 0; i--) { + typename C::TI top = perm[0]; + indirect_heap_pop (i + 1, vals, perm); + perm[i] = top; + } +} + +/** same as SortedArray, but only the k first elements are sorted */ +template +struct SemiSortedArray { + const T * x; + int N; + + // type of the heap: CMax = sort ascending + typedef CMax HC; + std::vector perm; + + int k; // k elements are sorted + + int initial_k, k_factor; + + explicit SemiSortedArray (int N) { + this->N = N; + perm.resize (N); + perm.resize (N); + initial_k = 3; + k_factor = 4; + } + + void init (const T*x) { + this->x = x; + for (int n = 0; n < N; n++) + perm[n] = n; + k = 0; + grow (initial_k); + } + + /// grow the sorted part of the array to size next_k + void grow (int next_k) { + if (next_k < N) { + partial_sort (next_k - k, N - k, x, &perm[k]); + k = next_k; + } else { // full sort of remainder of array + ArgSort cmp = {x }; + std::sort (perm.begin() + k, perm.end(), cmp); + k = N; + } + } + + // get smallest value + T get_0 () { + return x[perm[0]]; + } + + // get delta between n-smallest and n-1 -smallest + T get_diff (int n) { + if (n >= k) { + // want to keep powers of 2 - 1 + int next_k = (k + 1) * k_factor - 1; + grow (next_k); + } + return x[perm[n]] - x[perm[n - 1]]; + } + + // remap orders counted from smallest to indices in array + int get_ord (int n) { + assert (n < k); + return perm[n]; + } +}; + + + +/***************************************** + * Find the k smallest sums of M terms, where each term is taken in a + * table x of n values. + * + * A combination of terms is encoded as a scalar 0 <= t < n^M. The + * combination t0 ... t(M-1) that correspond to the sum + * + * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)] + * + * is encoded as + * + * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1) + * + * MinSumK is an object rather than a function, so that storage can be + * re-used over several computations with the same sizes. use_seen is + * good when there may be ties in the x array and it is a concern if + * occasionally several t's are returned. + * + * @param x size M * n, values to add up + * @parms k nb of results to retrieve + * @param M nb of terms + * @param n nb of distinct values + * @param sums output, size k, sorted + * @prarm terms output, size k, with encoding as above + * + ******************************************/ +template +struct MinSumK { + int K; ///< nb of sums to return + int M; ///< nb of elements to sum up + int nbit; ///< nb of bits to encode one entry + int N; ///< nb of possible elements for each of the M terms + + /** the heap. + * We use a heap to maintain a queue of sums, with the associated + * terms involved in the sum. + */ + typedef CMin HC; + size_t heap_capacity, heap_size; + T *bh_val; + int64_t *bh_ids; + + std::vector ssx; + + // all results get pushed several times. When there are ties, they + // are popped interleaved with others, so it is not easy to + // identify them. Therefore, this bit array just marks elements + // that were seen before. + std::vector seen; + + MinSumK (int K, int M, int nbit, int N): + K(K), M(M), nbit(nbit), N(N) { + heap_capacity = K * M; + assert (N <= (1 << nbit)); + + // we'll do k steps, each step pushes at most M vals + bh_val = new T[heap_capacity]; + bh_ids = new int64_t[heap_capacity]; + + if (use_seen) { + int64_t n_ids = weight(M); + seen.resize ((n_ids + 7) / 8); + } + + for (int m = 0; m < M; m++) + ssx.push_back (SSA(N)); + + } + + int64_t weight (int i) { + return 1 << (i * nbit); + } + + bool is_seen (int64_t i) { + return (seen[i >> 3] >> (i & 7)) & 1; + } + + void mark_seen (int64_t i) { + if (use_seen) + seen [i >> 3] |= 1 << (i & 7); + } + + void run (const T *x, int64_t ldx, + T * sums, int64_t * terms) { + heap_size = 0; + + for (int m = 0; m < M; m++) { + ssx[m].init(x); + x += ldx; + } + + { // intial result: take min for all elements + T sum = 0; + terms[0] = 0; + mark_seen (0); + for (int m = 0; m < M; m++) { + sum += ssx[m].get_0(); + } + sums[0] = sum; + for (int m = 0; m < M; m++) { + heap_push (++heap_size, bh_val, bh_ids, + sum + ssx[m].get_diff(1), + weight(m)); + } + } + + for (int k = 1; k < K; k++) { + // pop smallest value from heap + if (use_seen) {// skip already seen elements + while (is_seen (bh_ids[0])) { + assert (heap_size > 0); + heap_pop (heap_size--, bh_val, bh_ids); + } + } + assert (heap_size > 0); + + T sum = sums[k] = bh_val[0]; + int64_t ti = terms[k] = bh_ids[0]; + + if (use_seen) { + mark_seen (ti); + heap_pop (heap_size--, bh_val, bh_ids); + } else { + do { + heap_pop (heap_size--, bh_val, bh_ids); + } while (heap_size > 0 && bh_ids[0] == ti); + } + + // enqueue followers + int64_t ii = ti; + for (int m = 0; m < M; m++) { + int64_t n = ii & ((1L << nbit) - 1); + ii >>= nbit; + if (n + 1 >= N) continue; + + enqueue_follower (ti, m, n, sum); + } + } + + /* + for (int k = 0; k < K; k++) + for (int l = k + 1; l < K; l++) + assert (terms[k] != terms[l]); + */ + + // convert indices by applying permutation + for (int k = 0; k < K; k++) { + int64_t ii = terms[k]; + if (use_seen) { + // clear seen for reuse at next loop + seen[ii >> 3] = 0; + } + int64_t ti = 0; + for (int m = 0; m < M; m++) { + int64_t n = ii & ((1L << nbit) - 1); + ti += int64_t(ssx[m].get_ord(n)) << (nbit * m); + ii >>= nbit; + } + terms[k] = ti; + } + } + + + void enqueue_follower (int64_t ti, int m, int n, T sum) { + T next_sum = sum + ssx[m].get_diff(n + 1); + int64_t next_ti = ti + weight(m); + heap_push (++heap_size, bh_val, bh_ids, next_sum, next_ti); + } + + ~MinSumK () { + delete [] bh_ids; + delete [] bh_val; + } +}; + +} // anonymous namespace + + +MultiIndexQuantizer::MultiIndexQuantizer (int d, + size_t M, + size_t nbits): + Index(d, METRIC_L2), pq(d, M, nbits) +{ + is_trained = false; + pq.verbose = verbose; +} + + + +void MultiIndexQuantizer::train(idx_t n, const float *x) +{ + pq.verbose = verbose; + pq.train (n, x); + is_trained = true; + // count virtual elements in index + ntotal = 1; + for (int m = 0; m < pq.M; m++) + ntotal *= pq.ksub; +} + + +void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const { + if (n == 0) return; + + // the allocation just below can be severe... + idx_t bs = 32768; + if (n > bs) { + for (idx_t i0 = 0; i0 < n; i0 += bs) { + idx_t i1 = std::min(i0 + bs, n); + if (verbose) { + printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n", + i0, i1, n); + } + search (i1 - i0, x + i0 * d, k, + distances + i0 * k, + labels + i0 * k); + } + return; + } + + float * dis_tables = new float [n * pq.ksub * pq.M]; + ScopeDeleter del (dis_tables); + + pq.compute_distance_tables (n, x, dis_tables); + + if (k == 1) { + // simple version that just finds the min in each table + +#pragma omp parallel for + for (int i = 0; i < n; i++) { + const float * dis_table = dis_tables + i * pq.ksub * pq.M; + float dis = 0; + idx_t label = 0; + + for (int s = 0; s < pq.M; s++) { + float vmin = HUGE_VALF; + idx_t lmin = -1; + + for (idx_t j = 0; j < pq.ksub; j++) { + if (dis_table[j] < vmin) { + vmin = dis_table[j]; + lmin = j; + } + } + dis += vmin; + label |= lmin << (s * pq.nbits); + dis_table += pq.ksub; + } + + distances [i] = dis; + labels [i] = label; + } + + + } else { + +#pragma omp parallel if(n > 1) + { + MinSumK , false> + msk(k, pq.M, pq.nbits, pq.ksub); +#pragma omp for + for (int i = 0; i < n; i++) { + msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub, + distances + i * k, labels + i * k); + + } + } + } + +} + + +void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const +{ + + int64_t jj = key; + for (int m = 0; m < pq.M; m++) { + int64_t n = jj & ((1L << pq.nbits) - 1); + jj >>= pq.nbits; + memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub); + recons += pq.dsub; + } +} + +void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) { + FAISS_THROW_MSG( + "This index has virtual elements, " + "it does not support add"); +} + +void MultiIndexQuantizer::reset () +{ + FAISS_THROW_MSG ( "This index has virtual elements, " + "it does not support reset"); +} + + + + + + + + + + +/***************************************** + * MultiIndexQuantizer2 + ******************************************/ + + + +MultiIndexQuantizer2::MultiIndexQuantizer2 ( + int d, size_t M, size_t nbits, + Index **indexes): + MultiIndexQuantizer (d, M, nbits) +{ + assign_indexes.resize (M); + for (int i = 0; i < M; i++) { + FAISS_THROW_IF_NOT_MSG( + indexes[i]->d == pq.dsub, + "Provided sub-index has incorrect size"); + assign_indexes[i] = indexes[i]; + } + own_fields = false; +} + +MultiIndexQuantizer2::MultiIndexQuantizer2 ( + int d, size_t nbits, + Index *assign_index_0, + Index *assign_index_1): + MultiIndexQuantizer (d, 2, nbits) +{ + FAISS_THROW_IF_NOT_MSG( + assign_index_0->d == pq.dsub && + assign_index_1->d == pq.dsub, + "Provided sub-index has incorrect size"); + assign_indexes.resize (2); + assign_indexes [0] = assign_index_0; + assign_indexes [1] = assign_index_1; + own_fields = false; +} + +void MultiIndexQuantizer2::train(idx_t n, const float* x) +{ + MultiIndexQuantizer::train(n, x); + // add centroids to sub-indexes + for (int i = 0; i < pq.M; i++) { + assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0)); + } +} + + +void MultiIndexQuantizer2::search( + idx_t n, const float* x, idx_t K, + float* distances, idx_t* labels, + ConcurrentBitsetPtr bitset) const +{ + + if (n == 0) return; + + int k2 = std::min(K, int64_t(pq.ksub)); + + int64_t M = pq.M; + int64_t dsub = pq.dsub, ksub = pq.ksub; + + // size (M, n, k2) + std::vector sub_ids(n * M * k2); + std::vector sub_dis(n * M * k2); + std::vector xsub(n * dsub); + + for (int m = 0; m < M; m++) { + float *xdest = xsub.data(); + const float *xsrc = x + m * dsub; + for (int j = 0; j < n; j++) { + memcpy(xdest, xsrc, dsub * sizeof(xdest[0])); + xsrc += d; + xdest += dsub; + } + + assign_indexes[m]->search( + n, xsub.data(), k2, + &sub_dis[k2 * n * m], + &sub_ids[k2 * n * m]); + } + + if (K == 1) { + // simple version that just finds the min in each table + assert (k2 == 1); + + for (int i = 0; i < n; i++) { + float dis = 0; + idx_t label = 0; + + for (int m = 0; m < M; m++) { + float vmin = sub_dis[i + m * n]; + idx_t lmin = sub_ids[i + m * n]; + dis += vmin; + label |= lmin << (m * pq.nbits); + } + distances [i] = dis; + labels [i] = label; + } + + } else { + +#pragma omp parallel if(n > 1) + { + MinSumK , false> + msk(K, pq.M, pq.nbits, k2); +#pragma omp for + for (int i = 0; i < n; i++) { + idx_t *li = labels + i * K; + msk.run (&sub_dis[i * k2], k2 * n, + distances + i * K, li); + + // remap ids + + const idx_t *idmap0 = sub_ids.data() + i * k2; + int64_t ld_idmap = k2 * n; + int64_t mask1 = ksub - 1L; + + for (int k = 0; k < K; k++) { + const idx_t *idmap = idmap0; + int64_t vin = li[k]; + int64_t vout = 0; + int bs = 0; + for (int m = 0; m < M; m++) { + int64_t s = vin & mask1; + vin >>= pq.nbits; + vout |= idmap[s] << bs; + bs += pq.nbits; + idmap += ld_idmap; + } + li[k] = vout; + } + } + } + } +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexPQ.h b/core/src/index/thirdparty/faiss/IndexPQ.h new file mode 100644 index 0000000000..97ce84f11d --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexPQ.h @@ -0,0 +1,204 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_PQ_H +#define FAISS_INDEX_PQ_H + +#include + +#include + +#include +#include +#include + +namespace faiss { + + +/** Index based on a product quantizer. Stored vectors are + * approximated by PQ codes. */ +struct IndexPQ: Index { + + /// The product quantizer used to encode the vectors + ProductQuantizer pq; + + /// Codes. Size ntotal * pq.code_size + std::vector codes; + + /** Constructor. + * + * @param d dimensionality of the input vectors + * @param M number of subquantizers + * @param nbits number of bit per subvector index + */ + IndexPQ (int d, ///< dimensionality of the input vectors + size_t M, ///< number of subquantizers + size_t nbits, ///< number of bit per subvector index + MetricType metric = METRIC_L2); + + IndexPQ (); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reset() override; + + void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override; + + void reconstruct(idx_t key, float* recons) const override; + + size_t remove_ids(const IDSelector& sel) override; + + /* The standalone codec interface */ + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + + DistanceComputer * get_distance_computer() const override; + + /****************************************************** + * Polysemous codes implementation + ******************************************************/ + bool do_polysemous_training; ///< false = standard PQ + + /// parameters used for the polysemous training + PolysemousTraining polysemous_training; + + /// how to perform the search in search_core + enum Search_type_t { + ST_PQ, ///< asymmetric product quantizer (default) + ST_HE, ///< Hamming distance on codes + ST_generalized_HE, ///< nb of same codes + ST_SDC, ///< symmetric product quantizer (SDC) + ST_polysemous, ///< HE filter (using ht) + PQ combination + ST_polysemous_generalize, ///< Filter on generalized Hamming + }; + + Search_type_t search_type; + + // just encode the sign of the components, instead of using the PQ encoder + // used only for the queries + bool encode_signs; + + /// Hamming threshold used for polysemy + int polysemous_ht; + + // actual polysemous search + void search_core_polysemous (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels) const; + + /// prepare query for a polysemous search, but instead of + /// computing the result, just get the histogram of Hamming + /// distances. May be computed on a provided dataset if xb != NULL + /// @param dist_histogram (M * nbits + 1) + void hamming_distance_histogram (idx_t n, const float *x, + idx_t nb, const float *xb, + int64_t *dist_histogram); + + /** compute pairwise distances between queries and database + * + * @param n nb of query vectors + * @param x query vector, size n * d + * @param dis output distances, size n * ntotal + */ + void hamming_distance_table (idx_t n, const float *x, + int32_t *dis) const; + + size_t cal_size() { return codes.size() * sizeof(uint8_t) + pq.cal_size(); } + +}; + + +/// statistics are robust to internal threading, but not if +/// IndexPQ::search is called by multiple threads +struct IndexPQStats { + size_t nq; // nb of queries run + size_t ncode; // nb of codes visited + + size_t n_hamming_pass; // nb of passed Hamming distance tests (for polysemy) + + IndexPQStats () {reset (); } + void reset (); +}; + +extern IndexPQStats indexPQ_stats; + + + +/** Quantizer where centroids are virtual: they are the Cartesian + * product of sub-centroids. */ +struct MultiIndexQuantizer: Index { + ProductQuantizer pq; + + MultiIndexQuantizer (int d, ///< dimension of the input vectors + size_t M, ///< number of subquantizers + size_t nbits); ///< number of bit per subvector index + + void train(idx_t n, const float* x) override; + + void search( + idx_t n, const float* x, idx_t k, + float* distances, idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// add and reset will crash at runtime + void add(idx_t n, const float* x) override; + void reset() override; + + MultiIndexQuantizer () {} + + void reconstruct(idx_t key, float* recons) const override; +}; + + +/** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes + */ +struct MultiIndexQuantizer2: MultiIndexQuantizer { + + /// M Indexes on d / M dimensions + std::vector assign_indexes; + bool own_fields; + + MultiIndexQuantizer2 ( + int d, size_t M, size_t nbits, + Index **indexes); + + MultiIndexQuantizer2 ( + int d, size_t nbits, + Index *assign_index_0, + Index *assign_index_1); + + void train(idx_t n, const float* x) override; + + void search( + idx_t n, const float* x, idx_t k, + float* distances, idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + +}; + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexPreTransform.cpp b/core/src/index/thirdparty/faiss/IndexPreTransform.cpp new file mode 100644 index 0000000000..9172978df9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexPreTransform.cpp @@ -0,0 +1,289 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include + +namespace faiss { + +/********************************************* + * IndexPreTransform + *********************************************/ + +IndexPreTransform::IndexPreTransform (): + index(nullptr), own_fields (false) +{ +} + + +IndexPreTransform::IndexPreTransform ( + Index * index): + Index (index->d, index->metric_type), + index (index), own_fields (false) +{ + is_trained = index->is_trained; + ntotal = index->ntotal; +} + + +IndexPreTransform::IndexPreTransform ( + VectorTransform * ltrans, + Index * index): + Index (index->d, index->metric_type), + index (index), own_fields (false) +{ + is_trained = index->is_trained; + ntotal = index->ntotal; + prepend_transform (ltrans); +} + +void IndexPreTransform::prepend_transform (VectorTransform *ltrans) +{ + FAISS_THROW_IF_NOT (ltrans->d_out == d); + is_trained = is_trained && ltrans->is_trained; + chain.insert (chain.begin(), ltrans); + d = ltrans->d_in; +} + + +IndexPreTransform::~IndexPreTransform () +{ + if (own_fields) { + for (int i = 0; i < chain.size(); i++) + delete chain[i]; + delete index; + } +} + + + + +void IndexPreTransform::train (idx_t n, const float *x) +{ + int last_untrained = 0; + if (!index->is_trained) { + last_untrained = chain.size(); + } else { + for (int i = chain.size() - 1; i >= 0; i--) { + if (!chain[i]->is_trained) { + last_untrained = i; + break; + } + } + } + const float *prev_x = x; + ScopeDeleter del; + + if (verbose) { + printf("IndexPreTransform::train: training chain 0 to %d\n", + last_untrained); + } + + for (int i = 0; i <= last_untrained; i++) { + + if (i < chain.size()) { + VectorTransform *ltrans = chain [i]; + if (!ltrans->is_trained) { + if (verbose) { + printf(" Training chain component %d/%zd\n", + i, chain.size()); + if (OPQMatrix *opqm = dynamic_cast(ltrans)) { + opqm->verbose = true; + } + } + ltrans->train (n, prev_x); + } + } else { + if (verbose) { + printf(" Training sub-index\n"); + } + index->train (n, prev_x); + } + if (i == last_untrained) break; + if (verbose) { + printf(" Applying transform %d/%zd\n", + i, chain.size()); + } + + float * xt = chain[i]->apply (n, prev_x); + + if (prev_x != x) delete [] prev_x; + prev_x = xt; + del.set(xt); + } + + is_trained = true; +} + + +const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const +{ + const float *prev_x = x; + ScopeDeleter del; + + for (int i = 0; i < chain.size(); i++) { + float * xt = chain[i]->apply (n, prev_x); + ScopeDeleter del2 (xt); + del2.swap (del); + prev_x = xt; + } + del.release (); + return prev_x; +} + +void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const +{ + const float* next_x = xt; + ScopeDeleter del; + + for (int i = chain.size() - 1; i >= 0; i--) { + float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in]; + ScopeDeleter del2 ((prev_x == x) ? nullptr : prev_x); + chain [i]->reverse_transform (n, next_x, prev_x); + del2.swap (del); + next_x = prev_x; + } +} + +void IndexPreTransform::add (idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_chain (n, x); + ScopeDeleter del(xt == x ? nullptr : xt); + index->add (n, xt); + ntotal = index->ntotal; +} + +void IndexPreTransform::add_with_ids (idx_t n, const float * x, + const idx_t *xids) +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_chain (n, x); + ScopeDeleter del(xt == x ? nullptr : xt); + index->add_with_ids (n, xt, xids); + ntotal = index->ntotal; +} + + + + +void IndexPreTransform::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_chain (n, x); + ScopeDeleter del(xt == x ? nullptr : xt); + index->search (n, xt, k, distances, labels); +} + +void IndexPreTransform::range_search (idx_t n, const float* x, float radius, + RangeSearchResult* result, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + const float *xt = apply_chain (n, x); + ScopeDeleter del(xt == x ? nullptr : xt); + index->range_search (n, xt, radius, result); +} + + + +void IndexPreTransform::reset () { + index->reset(); + ntotal = 0; +} + +size_t IndexPreTransform::remove_ids (const IDSelector & sel) { + size_t nremove = index->remove_ids (sel); + ntotal = index->ntotal; + return nremove; +} + + +void IndexPreTransform::reconstruct (idx_t key, float * recons) const +{ + float *x = chain.empty() ? recons : new float [index->d]; + ScopeDeleter del (recons == x ? nullptr : x); + // Initial reconstruction + index->reconstruct (key, x); + + // Revert transformations from last to first + reverse_chain (1, x, recons); +} + + +void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const +{ + float *x = chain.empty() ? recons : new float [ni * index->d]; + ScopeDeleter del (recons == x ? nullptr : x); + // Initial reconstruction + index->reconstruct_n (i0, ni, x); + + // Revert transformations from last to first + reverse_chain (ni, x, recons); +} + + +void IndexPreTransform::search_and_reconstruct ( + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, float* recons) const +{ + FAISS_THROW_IF_NOT (is_trained); + + const float* xt = apply_chain (n, x); + ScopeDeleter del ((xt == x) ? nullptr : xt); + + float* recons_temp = chain.empty() ? recons : new float [n * k * index->d]; + ScopeDeleter del2 ((recons_temp == recons) ? nullptr : recons_temp); + index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp); + + // Revert transformations from last to first + reverse_chain (n * k, recons_temp, recons); +} + +size_t IndexPreTransform::sa_code_size () const +{ + return index->sa_code_size (); +} + +void IndexPreTransform::sa_encode (idx_t n, const float *x, + uint8_t *bytes) const +{ + if (chain.empty()) { + index->sa_encode (n, x, bytes); + } else { + const float *xt = apply_chain (n, x); + ScopeDeleter del(xt == x ? nullptr : xt); + index->sa_encode (n, xt, bytes); + } +} + +void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes, + float *x) const +{ + if (chain.empty()) { + index->sa_decode (n, bytes, x); + } else { + std::unique_ptr x1 (new float [index->d * n]); + index->sa_decode (n, bytes, x1.get()); + // Revert transformations from last to first + reverse_chain (n, x1.get(), x); + } +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexPreTransform.h b/core/src/index/thirdparty/faiss/IndexPreTransform.h new file mode 100644 index 0000000000..605ada9fa4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexPreTransform.h @@ -0,0 +1,93 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + + + +#include +#include + +namespace faiss { + +/** Index that applies a LinearTransform transform on vectors before + * handing them over to a sub-index */ +struct IndexPreTransform: Index { + + std::vector chain; ///! chain of tranforms + Index * index; ///! the sub-index + + bool own_fields; ///! whether pointers are deleted in destructor + + explicit IndexPreTransform (Index *index); + + IndexPreTransform (); + + /// ltrans is the last transform before the index + IndexPreTransform (VectorTransform * ltrans, Index * index); + + void prepend_transform (VectorTransform * ltrans); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + void reset() override; + + /** removes IDs from the index. Not supported by all indexes. + */ + size_t remove_ids(const IDSelector& sel) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + + /* range search, no attempt is done to change the radius */ + void range_search (idx_t n, const float* x, float radius, + RangeSearchResult* result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + + void reconstruct (idx_t key, float * recons) const override; + + void reconstruct_n (idx_t i0, idx_t ni, float *recons) + const override; + + void search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const override; + + /// apply the transforms in the chain. The returned float * may be + /// equal to x, otherwise it should be deallocated. + const float * apply_chain (idx_t n, const float *x) const; + + /// Reverse the transforms in the chain. May not be implemented for + /// all transforms in the chain or may return approximate results. + void reverse_chain (idx_t n, const float* xt, float* x) const; + + + /* standalone codec interface */ + size_t sa_code_size () const override; + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + ~IndexPreTransform() override; +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexRHNSW.cpp b/core/src/index/thirdparty/faiss/IndexRHNSW.cpp new file mode 100644 index 0000000000..bd112e596d --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexRHNSW.cpp @@ -0,0 +1,812 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef __SSE__ +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +} + +namespace faiss { + +using idx_t = Index::idx_t; +using MinimaxHeap = RHNSW::MinimaxHeap; +using storage_idx_t = RHNSW::storage_idx_t; +using NodeDistFarther = RHNSW::NodeDistFarther; + +RHNSWStats rhnsw_stats; + +/************************************************************** + * add / search blocks of descriptors + **************************************************************/ + +namespace { + + +/* Wrap the distance computer into one that negates the + distances. This makes supporting INNER_PRODUCE search easier */ + +struct NegativeDistanceComputer: DistanceComputer { + + /// owned by this + DistanceComputer *basedis; + + explicit NegativeDistanceComputer(DistanceComputer *basedis): + basedis(basedis) + {} + + void set_query(const float *x) override { + basedis->set_query(x); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) override { + return -(*basedis)(i); + } + + /// compute distance between two stored vectors + float symmetric_dis (idx_t i, idx_t j) override { + return -basedis->symmetric_dis(i, j); + } + + virtual ~NegativeDistanceComputer () + { + delete basedis; + } + +}; + +DistanceComputer *storage_distance_computer(const Index *storage) +{ + if (storage->metric_type == METRIC_INNER_PRODUCT) { + return new NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + +void hnsw_add_vertices(IndexRHNSW &index_hnsw, + size_t n0, + size_t n, const float *x, + bool verbose, + bool preset_levels = false) { + size_t d = index_hnsw.d; + RHNSW & hnsw = index_hnsw.hnsw; + size_t ntotal = n0 + n; + double t0 = getmillisecs(); + if (verbose) { + printf("hnsw_add_vertices: adding %ld elements on top of %ld " + "(preset_levels=%d)\n", + n, n0, int(preset_levels)); + } + + if (n == 0) { + return; + } + + int max_level = hnsw.prepare_level_tab(n, preset_levels); + + if (verbose) { + printf(" max_level = %d\n", max_level); + } + + + { // perform add + auto tas = getmillisecs(); + RandomGenerator rng2(789); + DistanceComputer *dis0 = + storage_distance_computer (index_hnsw.storage); + ScopeDeleter1 del0(dis0); + + dis0->set_query(x); + hnsw.addPoint(*dis0, hnsw.levels[n0], n0); + +#pragma omp parallel for + for (int i = 1; i < n; ++ i) { + DistanceComputer *dis = + storage_distance_computer (index_hnsw.storage); + ScopeDeleter1 del(dis); + dis->set_query(x + i * d); + hnsw.addPoint(*dis, hnsw.levels[n0 + i], i + n0); + } + } + if (verbose) { + printf("Done in %.3f ms\n", getmillisecs() - t0); + } + +} + +} // namespace + + + + +/************************************************************** + * IndexRHNSW implementation + **************************************************************/ + +IndexRHNSW::IndexRHNSW(int d, int M, MetricType metric): + Index(d, metric), + hnsw(M), + own_fields(false), + storage(nullptr), + reconstruct_from_neighbors(nullptr) +{} + +IndexRHNSW::IndexRHNSW(Index *storage, int M): + Index(storage->d, storage->metric_type), + hnsw(M), + own_fields(false), + storage(storage), + reconstruct_from_neighbors(nullptr) +{} + +IndexRHNSW::~IndexRHNSW() { + if (own_fields) { + delete storage; + } +} + +void IndexRHNSW::init_hnsw() { + hnsw.init(ntotal); +} + +void IndexRHNSW::train(idx_t n, const float* x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + // hnsw structure does not require training + storage->train (n, x); + is_trained = true; +} + +void IndexRHNSW::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const + +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + size_t nreorder = 0; + + idx_t check_period = InterruptCallback::get_period_hint ( + hnsw.max_level * d * hnsw.efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel reduction(+ : nreorder) + { + + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + +#pragma omp for + for(idx_t i = i0; i < i1; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + maxheap_heapify (k, simi, idxi); + + hnsw.searchKnn(*dis, k, idxi, simi, bitset); + + maxheap_reorder (k, simi, idxi); + + if (reconstruct_from_neighbors && + reconstruct_from_neighbors->k_reorder != 0) { + int k_reorder = reconstruct_from_neighbors->k_reorder; + if (k_reorder == -1 || k_reorder > k) k_reorder = k; + + nreorder += reconstruct_from_neighbors->compute_distances( + k_reorder, idxi, x + i * d, simi); + + // sort top k_reorder + maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder); + maxheap_reorder (k_reorder, simi, idxi); + } + + } + + } + InterruptCallback::check (); + } + + if (metric_type == METRIC_INNER_PRODUCT) { + // we need to revert the negated distances + for (size_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } + + rhnsw_stats.nreorder += nreorder; +} + + +void IndexRHNSW::add(idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT_MSG(storage, + "Please use IndexHSNWFlat (or variants) instead of IndexRHNSW directly"); + FAISS_THROW_IF_NOT(is_trained); + int n0 = ntotal; + storage->add(n, x); + ntotal = storage->ntotal; + + hnsw_add_vertices (*this, n0, n, x, verbose, + hnsw.levels.size() == ntotal); +} + +void IndexRHNSW::reset() +{ + hnsw.reset(); + storage->reset(); + ntotal = 0; +} + +void IndexRHNSW::reconstruct (idx_t key, float* recons) const +{ + storage->reconstruct(key, recons); +} + +size_t IndexRHNSW::cal_size() { + return hnsw.cal_size(); +} + +/************************************************************** + * ReconstructFromNeighbors implementation + **************************************************************/ + +ReconstructFromNeighbors2::ReconstructFromNeighbors2( + const IndexRHNSW & index, size_t k, size_t nsq): + index(index), k(k), nsq(nsq) { + M = index.hnsw.M << 1; + FAISS_ASSERT(k <= 256); + code_size = k == 1 ? 0 : nsq; + ntotal = 0; + d = index.d; + FAISS_ASSERT(d % nsq == 0); + dsub = d / nsq; + k_reorder = -1; +} + +void ReconstructFromNeighbors2::reconstruct(storage_idx_t i, float *x, float *tmp) const +{ + + + const RHNSW & hnsw = index.hnsw; + int *cur_links = hnsw.get_neighbor_link(i, 0); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + + if (k == 1 || nsq == 1) { + const float * beta; + if (k == 1) { + beta = codebook.data(); + } else { + int idx = codes[i]; + beta = codebook.data() + idx * (M + 1); + } + + float w0 = beta[0]; // weight of image itself + index.storage->reconstruct(i, tmp); + + for (int l = 0; l < d; l++) + x[l] = w0 * tmp[l]; + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + float w = beta[j + 1]; + index.storage->reconstruct(ji, tmp); + for (int l = 0; l < d; l++) + x[l] += w * tmp[l]; + } + } else if (nsq == 2) { + int idx0 = codes[2 * i]; + int idx1 = codes[2 * i + 1]; + + const float *beta0 = codebook.data() + idx0 * (M + 1); + const float *beta1 = codebook.data() + (idx1 + k) * (M + 1); + + index.storage->reconstruct(i, tmp); + + float w0; + + w0 = beta0[0]; + for (int l = 0; l < dsub; l++) + x[l] = w0 * tmp[l]; + + w0 = beta1[0]; + for (int l = dsub; l < d; l++) + x[l] = w0 * tmp[l]; + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp); + float w; + w = beta0[j + 1]; + for (int l = 0; l < dsub; l++) + x[l] += w * tmp[l]; + + w = beta1[j + 1]; + for (int l = dsub; l < d; l++) + x[l] += w * tmp[l]; + } + } else { + const float *betas[nsq]; + { + const float *b = codebook.data(); + const uint8_t *c = &codes[i * code_size]; + for (int sq = 0; sq < nsq; sq++) { + betas[sq] = b + (*c++) * (M + 1); + b += (M + 1) * k; + } + } + + index.storage->reconstruct(i, tmp); + { + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] = w * tmp[l]; + } + d0 = d1; + } + } + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + + index.storage->reconstruct(ji, tmp); + int d0 = 0; + for (int sq = 0; sq < nsq; sq++) { + float w = *(betas[sq]++); + int d1 = d0 + dsub; + for (int l = d0; l < d1; l++) { + x[l] += w * tmp[l]; + } + d0 = d1; + } + } + } +} + +void ReconstructFromNeighbors2::reconstruct_n(storage_idx_t n0, + storage_idx_t ni, + float *x) const +{ +#pragma omp parallel + { + std::vector tmp(index.d); +#pragma omp for + for (storage_idx_t i = 0; i < ni; i++) { + reconstruct(n0 + i, x + i * index.d, tmp.data()); + } + } +} + +size_t ReconstructFromNeighbors2::compute_distances( + size_t n, const idx_t *shortlist, + const float *query, float *distances) const +{ + std::vector tmp(2 * index.d); + size_t ncomp = 0; + for (int i = 0; i < n; i++) { + if (shortlist[i] < 0) break; + reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d); + distances[i] = fvec_L2sqr(query, tmp.data(), index.d); + ncomp++; + } + return ncomp; +} + +void ReconstructFromNeighbors2::get_neighbor_table(storage_idx_t i, float *tmp1) const +{ + const RHNSW & hnsw = index.hnsw; + int *cur_links = hnsw.get_neighbor_link(i, 0); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + size_t d = index.d; + + index.storage->reconstruct(i, tmp1); + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + storage_idx_t ji = cur_neighbors[j]; + if (ji < 0) ji = i; + index.storage->reconstruct(ji, tmp1 + (j + 1) * d); + } + +} + + +/// called by add_codes +void ReconstructFromNeighbors2::estimate_code( + const float *x, storage_idx_t i, uint8_t *code) const +{ + + // fill in tmp table with the neighbor values + float *tmp1 = new float[d * (M + 1) + (d * k)]; + float *tmp2 = tmp1 + d * (M + 1); + ScopeDeleter del(tmp1); + + // collect coordinates of base + get_neighbor_table (i, tmp1); + + for (size_t sq = 0; sq < nsq; sq++) { + int d0 = sq * dsub; + + { + FINTEGER ki = k, di = d, m1 = M + 1; + FINTEGER dsubi = dsub; + float zero = 0, one = 1; + + sgemm_ ("N", "N", &dsubi, &ki, &m1, &one, + tmp1 + d0, &di, + codebook.data() + sq * (m1 * k), &m1, + &zero, tmp2, &dsubi); + } + + float min = HUGE_VAL; + int argmin = -1; + for (size_t j = 0; j < k; j++) { + float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub); + if (dis < min) { + min = dis; + argmin = j; + } + } + code[sq] = argmin; + } + +} + +void ReconstructFromNeighbors2::add_codes(size_t n, const float *x) +{ + if (k == 1) { // nothing to encode + ntotal += n; + return; + } + codes.resize(codes.size() + code_size * n); +#pragma omp parallel for + for (int i = 0; i < n; i++) { + estimate_code(x + i * index.d, ntotal + i, + codes.data() + (ntotal + i) * code_size); + } + ntotal += n; + FAISS_ASSERT (codes.size() == ntotal * code_size); +} + + +/************************************************************** + * IndexRHNSWFlat implementation + **************************************************************/ + + +IndexRHNSWFlat::IndexRHNSWFlat() +{ + is_trained = true; +} + +IndexRHNSWFlat::IndexRHNSWFlat(int d, int M, MetricType metric): + IndexRHNSW(new IndexFlat(d, metric), M) +{ + own_fields = true; + is_trained = true; +} + +size_t IndexRHNSWFlat::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSWPQ implementation + **************************************************************/ + + +IndexRHNSWPQ::IndexRHNSWPQ() {} + +IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M): + IndexRHNSW(new IndexPQ(d, pq_m, 8), M) +{ + own_fields = true; + is_trained = false; +} + +void IndexRHNSWPQ::train(idx_t n, const float* x) +{ + IndexRHNSW::train (n, x); + (dynamic_cast (storage))->pq.compute_sdc_table(); +} + +size_t IndexRHNSWPQ::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSWSQ implementation + **************************************************************/ + + +IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M, + MetricType metric): + IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M) +{ + is_trained = false; + own_fields = true; +} + +IndexRHNSWSQ::IndexRHNSWSQ() {} + +size_t IndexRHNSWSQ::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +/************************************************************** + * IndexRHNSW2Level implementation + **************************************************************/ + + +IndexRHNSW2Level::IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M): + IndexRHNSW (new Index2Layer (quantizer, nlist, m_pq), M) +{ + own_fields = true; + is_trained = false; +} + +IndexRHNSW2Level::IndexRHNSW2Level() {} + +size_t IndexRHNSW2Level::cal_size() { + return IndexRHNSW::cal_size() + dynamic_cast(storage)->cal_size(); +} + +namespace { + + +// same as search_from_candidates but uses v +// visno -> is in result list +// visno + 1 -> in result list + in candidates +int search_from_candidates_2(const RHNSW & hnsw, + DistanceComputer & qdis, int k, + idx_t *I, float * D, + MinimaxHeap &candidates, + VisitedList &vt, + int level, int nres_in = 0) +{ + int nres = nres_in; + int ndis = 0; + for (int i = 0; i < candidates.size(); i++) { + idx_t v1 = candidates.ids[i]; + FAISS_ASSERT(v1 >= 0); + vt.mass[v1] = vt.curV + 1; + } + + int nstep = 0; + + while (candidates.size() > 0) { + float d0 = 0; + int v0 = candidates.pop_min(&d0); + + int *cur_links = hnsw.get_neighbor_link(v0, level); + int *cur_neighbors = cur_links + 1; + auto cur_neighbor_num = hnsw.get_neighbors_num(cur_links); + + for (auto j = 0; j < cur_neighbor_num; ++ j) { + int v1 = cur_neighbors[j]; + if (v1 < 0) break; + if (vt.mass[v1] == vt.curV + 1) { + // nothing to do + } else { + ndis++; + float d = qdis(v1); + candidates.push(v1, d); + + // never seen before --> add to heap + if (vt.mass[v1] < vt.curV) { + if (nres < k) { + faiss::maxheap_push (++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_pop (nres--, D, I); + faiss::maxheap_push (++nres, D, I, d, v1); + } + } + vt.mass[v1] = vt.curV + 1; + } + } + + nstep++; + if (nstep > hnsw.efSearch) { + break; + } + } + + if (level == 0) { +#pragma omp critical + { + rhnsw_stats.n1 ++; + if (candidates.size() == 0) + rhnsw_stats.n2 ++; + } + } + + + return nres; +} + + +} // namespace + +void IndexRHNSW2Level::search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, ConcurrentBitsetPtr bitset) const +{ + if (dynamic_cast(storage)) { + IndexRHNSW::search (n, x, k, distances, labels); + + } else { // "mixed" search + + const IndexIVFPQ *index_ivfpq = + dynamic_cast(storage); + + int nprobe = index_ivfpq->nprobe; + + std::unique_ptr coarse_assign(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(), + coarse_assign.get()); + + index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(), + coarse_dis.get(), distances, labels, + false); + +#pragma omp parallel + { + VisitedList vt (ntotal); + DistanceComputer *dis = storage_distance_computer(storage); + ScopeDeleter1 del(dis); + + int candidates_size = hnsw.upper_beam; + MinimaxHeap candidates(candidates_size); + +#pragma omp for + for(idx_t i = 0; i < n; i++) { + idx_t * idxi = labels + i * k; + float * simi = distances + i * k; + dis->set_query(x + i * d); + + // mark all inverted list elements as visited + + for (int j = 0; j < nprobe; j++) { + idx_t key = coarse_assign[j + i * nprobe]; + if (key < 0) break; + size_t list_length = index_ivfpq->get_list_size (key); + const idx_t * ids = index_ivfpq->invlists->get_ids (key); + + for (int jj = 0; jj < list_length; jj++) { + vt.set (ids[jj]); + } + } + + candidates.clear(); + // copy the upper_beam elements to candidates list + + int search_policy = 2; + + if (search_policy == 1) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + // search_from_candidates adds them back + idxi[j] = -1; + simi[j] = HUGE_VAL; + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + // removed from RHNSW, but still available in HNSW +// hnsw.search_from_candidates( +// *dis, k, idxi, simi, +// candidates, vt, 0, k +// ); + + vt.advance(); + + } else if (search_policy == 2) { + + for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) { + if (idxi[j] < 0) break; + candidates.push (idxi[j], simi[j]); + } + + // reorder from sorted to heap + maxheap_heapify (k, simi, idxi, simi, idxi, k); + + search_from_candidates_2 ( + hnsw, *dis, k, idxi, simi, + candidates, vt, 0, k); + vt.advance (); + vt.advance (); + + } + + maxheap_reorder (k, simi, idxi); + } + } + } + + +} + + +void IndexRHNSW2Level::flip_to_ivf () +{ + Index2Layer *storage2l = + dynamic_cast(storage); + + FAISS_THROW_IF_NOT (storage2l); + + IndexIVFPQ * index_ivfpq = + new IndexIVFPQ (storage2l->q1.quantizer, + d, storage2l->q1.nlist, + storage2l->pq.M, 8); + index_ivfpq->pq = storage2l->pq; + index_ivfpq->is_trained = storage2l->is_trained; + index_ivfpq->precompute_table(); + index_ivfpq->own_fields = storage2l->q1.own_fields; + storage2l->transfer_to_IVFPQ(*index_ivfpq); + index_ivfpq->make_direct_map (true); + + storage = index_ivfpq; + delete storage2l; + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexRHNSW.h b/core/src/index/thirdparty/faiss/IndexRHNSW.h new file mode 100644 index 0000000000..f2641a076c --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexRHNSW.h @@ -0,0 +1,152 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include + +#include +#include +#include +#include +#include +//#include + + +namespace faiss { + +struct IndexRHNSW; + +struct ReconstructFromNeighbors2 { + typedef Index::idx_t idx_t; + typedef RHNSW::storage_idx_t storage_idx_t; + + const IndexRHNSW & index; + size_t M; // number of neighbors + size_t k; // number of codebook entries + size_t nsq; // number of subvectors + size_t code_size; + int k_reorder; // nb to reorder. -1 = all + + std::vector codebook; // size nsq * k * (M + 1) + + std::vector codes; // size ntotal * code_size + size_t ntotal; + size_t d, dsub; // derived values + + explicit ReconstructFromNeighbors2(const IndexRHNSW& index, + size_t k=256, size_t nsq=1); + + /// codes must be added in the correct order and the IndexRHNSW + /// must be populated and sorted + void add_codes(size_t n, const float *x); + + size_t compute_distances(size_t n, const idx_t *shortlist, + const float *query, float *distances) const; + + /// called by add_codes + void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const; + + /// called by compute_distances + void reconstruct(storage_idx_t i, float *x, float *tmp) const; + + void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const; + + /// get the M+1 -by-d table for neighbor coordinates for vector i + void get_neighbor_table(storage_idx_t i, float *out) const; + +}; + +/** The HNSW index is a normal random-access index with a HNSW + * link structure built on top */ + +struct IndexRHNSW : Index { + + typedef RHNSW::storage_idx_t storage_idx_t; + + // the link strcuture + RHNSW hnsw; + + // the sequential storage + bool own_fields; + Index *storage; + + ReconstructFromNeighbors2 *reconstruct_from_neighbors; + + explicit IndexRHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2); + explicit IndexRHNSW (Index *storage, int M = 32); + + ~IndexRHNSW() override; + + void add(idx_t n, const float *x) override; + + /// Trains the storage if needed + void train(idx_t n, const float* x) override; + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(idx_t key, float* recons) const override; + + void reset () override; + + size_t cal_size(); + + void init_hnsw(); +}; + + +/** Flat index topped with with a HNSW structure to access elements + * more efficiently. + */ + +struct IndexRHNSWFlat : IndexRHNSW { + IndexRHNSWFlat(); + IndexRHNSWFlat(int d, int M, MetricType metric = METRIC_L2); + size_t cal_size(); +}; + +/** PQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexRHNSWPQ : IndexRHNSW { + IndexRHNSWPQ(); + IndexRHNSWPQ(int d, int pq_m, int M); + void train(idx_t n, const float* x) override; + size_t cal_size(); +}; + +/** SQ index topped with with a HNSW structure to access elements + * more efficiently. + */ +struct IndexRHNSWSQ : IndexRHNSW { + IndexRHNSWSQ(); + IndexRHNSWSQ(int d, QuantizerType qtype, int M, MetricType metric = METRIC_L2); + size_t cal_size(); +}; + +/** 2-level code structure with fast random access + */ +struct IndexRHNSW2Level : IndexRHNSW { + IndexRHNSW2Level(); + IndexRHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M); + + void flip_to_ivf(); + + /// entry point for search + void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + size_t cal_size(); +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexReplicas.cpp b/core/src/index/thirdparty/faiss/IndexReplicas.cpp new file mode 100644 index 0000000000..8749ab6cc5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexReplicas.cpp @@ -0,0 +1,124 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { + +template +IndexReplicasTemplate::IndexReplicasTemplate(bool threaded) + : ThreadedIndex(threaded) { +} + +template +IndexReplicasTemplate::IndexReplicasTemplate(idx_t d, bool threaded) + : ThreadedIndex(d, threaded) { +} + +template +IndexReplicasTemplate::IndexReplicasTemplate(int d, bool threaded) + : ThreadedIndex(d, threaded) { +} + +template +void +IndexReplicasTemplate::onAfterAddIndex(IndexT* index) { + // Make sure that the parameters are the same for all prior indices, unless + // we're the first index to be added + if (this->count() > 0 && this->at(0) != index) { + auto existing = this->at(0); + + FAISS_THROW_IF_NOT_FMT(index->ntotal == existing->ntotal, + "IndexReplicas: newly added index does " + "not have same number of vectors as prior index; " + "prior index has %ld vectors, new index has %ld", + existing->ntotal, index->ntotal); + + FAISS_THROW_IF_NOT_MSG(index->is_trained == existing->is_trained, + "IndexReplicas: newly added index does " + "not have same train status as prior index"); + } else { + // Set our parameters based on the first index we're adding + // (dimension is handled in ThreadedIndex) + this->ntotal = index->ntotal; + this->verbose = index->verbose; + this->is_trained = index->is_trained; + this->metric_type = index->metric_type; + } +} + +template +void +IndexReplicasTemplate::train(idx_t n, const component_t* x) { + this->runOnIndex([n, x](int, IndexT* index){ index->train(n, x); }); +} + +template +void +IndexReplicasTemplate::add(idx_t n, const component_t* x) { + this->runOnIndex([n, x](int, IndexT* index){ index->add(n, x); }); + this->ntotal += n; +} + +template +void +IndexReplicasTemplate::reconstruct(idx_t n, component_t* x) const { + FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index"); + + // Just pass to the first replica + this->at(0)->reconstruct(n, x); +} + +template +void +IndexReplicasTemplate::search(idx_t n, + const component_t* x, + idx_t k, + distance_t* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset) const { + FAISS_THROW_IF_NOT_MSG(this->count() > 0, "no replicas in index"); + + if (n == 0) { + return; + } + + auto dim = this->d; + size_t componentsPerVec = + sizeof(component_t) == 1 ? (dim + 7) / 8 : dim; + + // Partition the query by the number of indices we have + faiss::Index::idx_t queriesPerIndex = + (faiss::Index::idx_t) (n + this->count() - 1) / + (faiss::Index::idx_t) this->count(); + FAISS_ASSERT(n / queriesPerIndex <= this->count()); + + auto fn = + [queriesPerIndex, componentsPerVec, + n, x, k, distances, labels](int i, const IndexT* index) { + faiss::Index::idx_t base = (faiss::Index::idx_t) i * queriesPerIndex; + + if (base < n) { + auto numForIndex = std::min(queriesPerIndex, n - base); + + index->search(numForIndex, + x + base * componentsPerVec, + k, + distances + base * k, + labels + base * k); + } + }; + + this->runOnIndex(fn); +} + +// explicit instantiations +template struct IndexReplicasTemplate; +template struct IndexReplicasTemplate; + +} // namespace diff --git a/core/src/index/thirdparty/faiss/IndexReplicas.h b/core/src/index/thirdparty/faiss/IndexReplicas.h new file mode 100644 index 0000000000..a98c28cea5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexReplicas.h @@ -0,0 +1,77 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +/// Takes individual faiss::Index instances, and splits queries for +/// sending to each Index instance, and joins the results together +/// when done. +/// Each index is managed by a separate CPU thread. +template +class IndexReplicasTemplate : public ThreadedIndex { + public: + using idx_t = typename IndexT::idx_t; + using component_t = typename IndexT::component_t; + using distance_t = typename IndexT::distance_t; + + /// The dimension that all sub-indices must share will be the dimension of the + /// first sub-index added + /// @param threaded do we use one thread per sub-index or do queries + /// sequentially? + explicit IndexReplicasTemplate(bool threaded = true); + + /// @param d the dimension that all sub-indices must share + /// @param threaded do we use one thread per sub index or do queries + /// sequentially? + explicit IndexReplicasTemplate(idx_t d, bool threaded = true); + + /// int version due to the implicit bool conversion ambiguity of int as + /// dimension + explicit IndexReplicasTemplate(int d, bool threaded = true); + + /// Alias for addIndex() + void add_replica(IndexT* index) { this->addIndex(index); } + + /// Alias for removeIndex() + void remove_replica(IndexT* index) { this->removeIndex(index); } + + /// faiss::Index API + /// All indices receive the same call + void train(idx_t n, const component_t* x) override; + + /// faiss::Index API + /// All indices receive the same call + void add(idx_t n, const component_t* x) override; + + /// faiss::Index API + /// Query is partitioned into a slice for each sub-index + /// split by ceil(n / #indices) for our sub-indices + void search(idx_t n, + const component_t* x, + idx_t k, + distance_t* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// reconstructs from the first index + void reconstruct(idx_t, component_t *v) const override; + + protected: + /// Called just after an index is added + void onAfterAddIndex(IndexT* index) override; +}; + +using IndexReplicas = IndexReplicasTemplate; +using IndexBinaryReplicas = IndexReplicasTemplate; + +} // namespace diff --git a/core/src/index/thirdparty/faiss/IndexSQHybrid.cpp b/core/src/index/thirdparty/faiss/IndexSQHybrid.cpp new file mode 100644 index 0000000000..8376ca8b31 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexSQHybrid.cpp @@ -0,0 +1,183 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace faiss { + +/******************************************************************* + * IndexIVFSQHybrid implementation + ********************************************************************/ + +IndexIVFSQHybrid::IndexIVFSQHybrid ( + Index *quantizer, size_t d, size_t nlist, + QuantizerType qtype, + MetricType metric, bool encode_residual) + : IndexIVF(quantizer, d, nlist, 0, metric), + sq(d, qtype), + by_residual(encode_residual) +{ + code_size = sq.code_size; + // was not known at construction time + invlists->code_size = code_size; + is_trained = false; +} + +IndexIVFSQHybrid::IndexIVFSQHybrid (): + IndexIVF(), + by_residual(true) +{ +} + +void IndexIVFSQHybrid::train_residual (idx_t n, const float *x) +{ + sq.train_residual(n, x, quantizer, by_residual, verbose); +} + +void IndexIVFSQHybrid::encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos) const +{ + std::unique_ptr squant (sq.select_quantizer ()); + size_t coarse_size = include_listnos ? coarse_code_size () : 0; + memset(codes, 0, (code_size + coarse_size) * n); + +#pragma omp parallel if(n > 1) + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + int64_t list_no = list_nos [i]; + if (list_no >= 0) { + const float *xi = x + i * d; + uint8_t *code = codes + i * (code_size + coarse_size); + if (by_residual) { + quantizer->compute_residual ( + xi, residual.data(), list_no); + xi = residual.data (); + } + if (coarse_size) { + encode_listno (list_no, code); + } + squant->encode_vector (xi, code + coarse_size); + } + } + } +} + +void IndexIVFSQHybrid::sa_decode (idx_t n, const uint8_t *codes, + float *x) const +{ + std::unique_ptr squant (sq.select_quantizer ()); + size_t coarse_size = coarse_code_size (); + +#pragma omp parallel if(n > 1) + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *code = codes + i * (code_size + coarse_size); + int64_t list_no = decode_listno (code); + float *xi = x + i * d; + squant->decode_vector (code + coarse_size, xi); + if (by_residual) { + quantizer->reconstruct (list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + } +} + + + +void IndexIVFSQHybrid::add_with_ids + (idx_t n, const float * x, const idx_t *xids) +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr idx (new int64_t [n]); + quantizer->assign (n, x, idx.get()); + size_t nadd = 0; + std::unique_ptr squant(sq.select_quantizer ()); + +#pragma omp parallel reduction(+: nadd) + { + std::vector residual (d); + std::vector one_code (code_size); + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + int64_t list_no = idx [i]; + if (list_no >= 0 && list_no % nt == rank) { + int64_t id = xids ? xids[i] : ntotal + i; + + const float * xi = x + i * d; + if (by_residual) { + quantizer->compute_residual (xi, residual.data(), list_no); + xi = residual.data(); + } + + memset (one_code.data(), 0, code_size); + squant->encode_vector (xi, one_code.data()); + + invlists->add_entry (list_no, id, one_code.data()); + + nadd++; + + } + } + } + ntotal += n; +} + + + + + +InvertedListScanner* IndexIVFSQHybrid::get_InvertedListScanner + (bool store_pairs) const +{ + return sq.select_InvertedListScanner (metric_type, quantizer, store_pairs, + by_residual); +} + + +void IndexIVFSQHybrid::reconstruct_from_offset (int64_t list_no, + int64_t offset, + float* recons) const +{ + std::vector centroid(d); + quantizer->reconstruct (list_no, centroid.data()); + + const uint8_t* code = invlists->get_single_code (list_no, offset); + sq.decode (code, recons, 1); + for (int i = 0; i < d; ++i) { + recons[i] += centroid[i]; + } +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexSQHybrid.h b/core/src/index/thirdparty/faiss/IndexSQHybrid.h new file mode 100644 index 0000000000..c3bf599b08 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexSQHybrid.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_SQ_HYBRID_H +#define FAISS_INDEX_SQ_HYBRID_H + +#include +#include + +#include +#include +#include + + +namespace faiss { + + /** An IVF implementation where the components of the residuals are + * encoded with a scalar uniform quantizer. All distance computations + * are asymmetric, so the encoded vectors are decoded and approximate + * distances are computed. + */ + +struct IndexIVFSQHybrid: IndexIVF { + ScalarQuantizer sq; + bool by_residual; + + IndexIVFSQHybrid(Index *quantizer, size_t d, size_t nlist, + QuantizerType qtype, + MetricType metric = METRIC_L2, + bool encode_residual = true); + + IndexIVFSQHybrid(); + + void train_residual(idx_t n, const float* x) override; + + void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos=false) const override; + + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + InvertedListScanner *get_InvertedListScanner (bool store_pairs) + const override; + + + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + /* standalone codec interface */ + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + +}; + + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp new file mode 100644 index 0000000000..38cd28a6cd --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp @@ -0,0 +1,359 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace faiss { + + + +/******************************************************************* + * IndexScalarQuantizer implementation + ********************************************************************/ + +IndexScalarQuantizer::IndexScalarQuantizer + (int d, QuantizerType qtype, + MetricType metric): + Index(d, metric), + sq (d, qtype) +{ + is_trained = + qtype == QuantizerType::QT_fp16 || + qtype == QuantizerType::QT_8bit_direct; + code_size = sq.code_size; +} + + +IndexScalarQuantizer::IndexScalarQuantizer (): + IndexScalarQuantizer(0, QuantizerType::QT_8bit) +{} + +void IndexScalarQuantizer::train(idx_t n, const float* x) +{ + sq.train(n, x); + is_trained = true; +} + +void IndexScalarQuantizer::add(idx_t n, const float* x) +{ + FAISS_THROW_IF_NOT (is_trained); + codes.resize ((n + ntotal) * code_size); + sq.compute_codes (x, &codes[ntotal * code_size], n); + ntotal += n; +} + + +void IndexScalarQuantizer::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT (is_trained); + FAISS_THROW_IF_NOT (metric_type == METRIC_L2 || + metric_type == METRIC_INNER_PRODUCT); + +#pragma omp parallel + { + InvertedListScanner* scanner = sq.select_InvertedListScanner + (metric_type, nullptr, true); + ScopeDeleter1 del(scanner); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + float * D = distances + k * i; + idx_t * I = labels + k * i; + // re-order heap + if (metric_type == METRIC_L2) { + maxheap_heapify (k, D, I); + } else { + minheap_heapify (k, D, I); + } + scanner->set_query (x + i * d); + scanner->scan_codes (ntotal, codes.data(), + nullptr, D, I, k); + + // re-order heap + if (metric_type == METRIC_L2) { + maxheap_reorder (k, D, I); + } else { + minheap_reorder (k, D, I); + } + } + } + +} + + +DistanceComputer *IndexScalarQuantizer::get_distance_computer () const +{ + SQDistanceComputer *dc = sq.get_distance_computer (metric_type); + dc->code_size = sq.code_size; + dc->codes = codes.data(); + return dc; +} + + +void IndexScalarQuantizer::reset() +{ + codes.clear(); + ntotal = 0; +} + +void IndexScalarQuantizer::reconstruct_n( + idx_t i0, idx_t ni, float* recons) const +{ + std::unique_ptr squant(sq.select_quantizer ()); + for (size_t i = 0; i < ni; i++) { + squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d); + } +} + +void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const +{ + reconstruct_n(key, 1, recons); +} + +/* Codec interface */ +size_t IndexScalarQuantizer::sa_code_size () const +{ + return sq.code_size; +} + +void IndexScalarQuantizer::sa_encode (idx_t n, const float *x, + uint8_t *bytes) const +{ + FAISS_THROW_IF_NOT (is_trained); + sq.compute_codes (x, bytes, n); +} + +void IndexScalarQuantizer::sa_decode (idx_t n, const uint8_t *bytes, + float *x) const +{ + FAISS_THROW_IF_NOT (is_trained); + sq.decode(bytes, x, n); +} + + + +/******************************************************************* + * IndexIVFScalarQuantizer implementation + ********************************************************************/ + +IndexIVFScalarQuantizer::IndexIVFScalarQuantizer ( + Index *quantizer, size_t d, size_t nlist, + QuantizerType qtype, + MetricType metric, bool encode_residual) + : IndexIVF(quantizer, d, nlist, 0, metric), + sq(d, qtype), + by_residual(encode_residual) +{ + code_size = sq.code_size; + // was not known at construction time + invlists->code_size = code_size; + is_trained = false; +} + +IndexIVFScalarQuantizer::IndexIVFScalarQuantizer (): + IndexIVF(), + by_residual(true) +{ +} + +void IndexIVFScalarQuantizer::train_residual (idx_t n, const float *x) +{ + sq.train_residual(n, x, quantizer, by_residual, verbose); +} + +void IndexIVFScalarQuantizer::encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos) const +{ + std::unique_ptr squant (sq.select_quantizer ()); + size_t coarse_size = include_listnos ? coarse_code_size () : 0; + memset(codes, 0, (code_size + coarse_size) * n); + +#pragma omp parallel if(n > 1) + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + int64_t list_no = list_nos [i]; + if (list_no >= 0) { + const float *xi = x + i * d; + uint8_t *code = codes + i * (code_size + coarse_size); + if (by_residual) { + quantizer->compute_residual ( + xi, residual.data(), list_no); + xi = residual.data (); + } + if (coarse_size) { + encode_listno (list_no, code); + } + squant->encode_vector (xi, code + coarse_size); + } + } + } +} + +void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes, + float *x) const +{ + std::unique_ptr squant (sq.select_quantizer ()); + size_t coarse_size = coarse_code_size (); + +#pragma omp parallel if(n > 1) + { + std::vector residual (d); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + const uint8_t *code = codes + i * (code_size + coarse_size); + int64_t list_no = decode_listno (code); + float *xi = x + i * d; + squant->decode_vector (code + coarse_size, xi); + if (by_residual) { + quantizer->reconstruct (list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + } +} + + + +void IndexIVFScalarQuantizer::add_with_ids + (idx_t n, const float * x, const idx_t *xids) +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr idx (new int64_t [n]); + quantizer->assign (n, x, idx.get()); + size_t nadd = 0; + std::unique_ptr squant(sq.select_quantizer ()); + + DirectMapAdd dm_add (direct_map, n, xids); + +#pragma omp parallel reduction(+: nadd) + { + std::vector residual (d); + std::vector one_code (code_size); + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + int64_t list_no = idx [i]; + if (list_no >= 0 && list_no % nt == rank) { + int64_t id = xids ? xids[i] : ntotal + i; + + const float * xi = x + i * d; + if (by_residual) { + quantizer->compute_residual (xi, residual.data(), list_no); + xi = residual.data(); + } + + memset (one_code.data(), 0, code_size); + squant->encode_vector (xi, one_code.data()); + + size_t ofs = invlists->add_entry (list_no, id, one_code.data()); + + dm_add.add (i, list_no, ofs); + nadd++; + + } else if (rank == 0 && list_no == -1) { + dm_add.add (i, -1, 0); + } + } + } + + + ntotal += n; +} + + +void IndexIVFScalarQuantizer::add_with_ids_without_codes + (idx_t n, const float * x, const idx_t *xids) +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr idx (new int64_t [n]); + quantizer->assign (n, x, idx.get()); + size_t nadd = 0; + std::unique_ptr squant(sq.select_quantizer ()); + + DirectMapAdd dm_add (direct_map, n, xids); + +#pragma omp parallel reduction(+: nadd) + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + int64_t list_no = idx [i]; + if (list_no >= 0 && list_no % nt == rank) { + int64_t id = xids ? xids[i] : ntotal + i; + size_t ofs = invlists->add_entry_without_codes (list_no, id); + + dm_add.add (i, list_no, ofs); + nadd++; + + } else if (rank == 0 && list_no == -1) { + dm_add.add (i, -1, 0); + } + } + } + + + ntotal += n; +} + + +InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner + (bool store_pairs) const +{ + return sq.select_InvertedListScanner (metric_type, quantizer, store_pairs, + by_residual); +} + + +void IndexIVFScalarQuantizer::reconstruct_from_offset (int64_t list_no, + int64_t offset, + float* recons) const +{ + std::vector centroid(d); + quantizer->reconstruct (list_no, centroid.data()); + + const uint8_t* code = invlists->get_single_code (list_no, offset); + sq.decode (code, recons, 1); + for (int i = 0; i < d; ++i) { + recons[i] += centroid[i]; + } +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h new file mode 100644 index 0000000000..4313a5b37e --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h @@ -0,0 +1,131 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_SCALAR_QUANTIZER_H +#define FAISS_INDEX_SCALAR_QUANTIZER_H + +#include +#include + +#include +#include +#include + +namespace faiss { + +/** + * The uniform quantizer has a range [vmin, vmax]. The range can be + * the same for all dimensions (uniform) or specific per dimension + * (default). + */ + + + + +struct IndexScalarQuantizer: Index { + /// Used to encode the vectors + ScalarQuantizer sq; + + /// Codes. Size ntotal * pq.code_size + std::vector codes; + + size_t code_size; + + /** Constructor. + * + * @param d dimensionality of the input vectors + * @param M number of subquantizers + * @param nbits number of bit per subvector index + */ + IndexScalarQuantizer (int d, + QuantizerType qtype, + MetricType metric = METRIC_L2); + + IndexScalarQuantizer (); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reset() override; + + void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override; + + void reconstruct(idx_t key, float* recons) const override; + + DistanceComputer *get_distance_computer () const override; + + /* standalone codec interface */ + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + + size_t cal_size() { return codes.size() * sizeof(uint8_t) + sizeof(size_t) + sq.cal_size(); } + +}; + + + /** An IVF implementation where the components of the residuals are + * encoded with a scalar uniform quantizer. All distance computations + * are asymmetric, so the encoded vectors are decoded and approximate + * distances are computed. + */ + +struct IndexIVFScalarQuantizer: IndexIVF { + ScalarQuantizer sq; + bool by_residual; + + IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist, + QuantizerType qtype, + MetricType metric = METRIC_L2, + bool encode_residual = true); + + IndexIVFScalarQuantizer(); + + void train_residual(idx_t n, const float* x) override; + + void encode_vectors(idx_t n, const float* x, + const idx_t *list_nos, + uint8_t * codes, + bool include_listnos=false) const override; + + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; + + InvertedListScanner *get_InvertedListScanner (bool store_pairs) + const override; + + + void reconstruct_from_offset (int64_t list_no, int64_t offset, + float* recons) const override; + + /* standalone codec interface */ + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + +}; + + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/IndexShards.cpp b/core/src/index/thirdparty/faiss/IndexShards.cpp new file mode 100644 index 0000000000..0e0ac16264 --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexShards.cpp @@ -0,0 +1,318 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include +#include + +namespace faiss { + +// subroutines +namespace { + +typedef Index::idx_t idx_t; + + +// add translation to all valid labels +void translate_labels (long n, idx_t *labels, long translation) +{ + if (translation == 0) return; + for (long i = 0; i < n; i++) { + if(labels[i] < 0) continue; + labels[i] += translation; + } +} + + +/** merge result tables from several shards. + * @param all_distances size nshard * n * k + * @param all_labels idem + * @param translartions label translations to apply, size nshard + */ + +template +void +merge_tables(long n, long k, long nshard, + typename IndexClass::distance_t *distances, + idx_t *labels, + const std::vector& all_distances, + const std::vector& all_labels, + const std::vector& translations) { + if (k == 0) { + return; + } + using distance_t = typename IndexClass::distance_t; + + long stride = n * k; +#pragma omp parallel + { + std::vector buf (2 * nshard); + int * pointer = buf.data(); + int * shard_ids = pointer + nshard; + std::vector buf2 (nshard); + distance_t * heap_vals = buf2.data(); +#pragma omp for + for (long i = 0; i < n; i++) { + // the heap maps values to the shard where they are + // produced. + const distance_t *D_in = all_distances.data() + i * k; + const idx_t *I_in = all_labels.data() + i * k; + int heap_size = 0; + + for (long s = 0; s < nshard; s++) { + pointer[s] = 0; + if (I_in[stride * s] >= 0) { + heap_push (++heap_size, heap_vals, shard_ids, + D_in[stride * s], s); + } + } + + distance_t *D = distances + i * k; + idx_t *I = labels + i * k; + + for (int j = 0; j < k; j++) { + if (heap_size == 0) { + I[j] = -1; + D[j] = C::neutral(); + } else { + // pop best element + int s = shard_ids[0]; + int & p = pointer[s]; + D[j] = heap_vals[0]; + I[j] = I_in[stride * s + p] + translations[s]; + + heap_pop (heap_size--, heap_vals, shard_ids); + p++; + if (p < k && I_in[stride * s + p] >= 0) { + heap_push (++heap_size, heap_vals, shard_ids, + D_in[stride * s + p], s); + } + } + } + } + } +} + +} // anonymous namespace + +template +IndexShardsTemplate::IndexShardsTemplate(idx_t d, + bool threaded, + bool successive_ids) + : ThreadedIndex(d, threaded), + successive_ids(successive_ids) { +} + +template +IndexShardsTemplate::IndexShardsTemplate(int d, + bool threaded, + bool successive_ids) + : ThreadedIndex(d, threaded), + successive_ids(successive_ids) { +} + +template +IndexShardsTemplate::IndexShardsTemplate(bool threaded, + bool successive_ids) + : ThreadedIndex(threaded), + successive_ids(successive_ids) { +} + +template +void +IndexShardsTemplate::onAfterAddIndex(IndexT* index /* unused */) { + sync_with_shard_indexes(); +} + +template +void +IndexShardsTemplate::onAfterRemoveIndex(IndexT* index /* unused */) { + sync_with_shard_indexes(); +} + +template +void +IndexShardsTemplate::sync_with_shard_indexes() { + if (!this->count()) { + this->is_trained = false; + this->ntotal = 0; + + return; + } + + auto firstIndex = this->at(0); + this->metric_type = firstIndex->metric_type; + this->is_trained = firstIndex->is_trained; + this->ntotal = firstIndex->ntotal; + + for (int i = 1; i < this->count(); ++i) { + auto index = this->at(i); + FAISS_THROW_IF_NOT(this->metric_type == index->metric_type); + FAISS_THROW_IF_NOT(this->d == index->d); + + this->ntotal += index->ntotal; + } +} + +template +void +IndexShardsTemplate::train(idx_t n, + const component_t *x) { + auto fn = + [n, x](int no, IndexT *index) { + if (index->verbose) { + printf("begin train shard %d on %ld points\n", no, n); + } + + index->train(n, x); + + if (index->verbose) { + printf("end train shard %d\n", no); + } + }; + + this->runOnIndex(fn); + sync_with_shard_indexes(); +} + +template +void +IndexShardsTemplate::add(idx_t n, + const component_t *x) { + add_with_ids(n, x, nullptr); +} + +template +void +IndexShardsTemplate::add_with_ids(idx_t n, + const component_t * x, + const idx_t *xids) { + + FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids), + "It makes no sense to pass in ids and " + "request them to be shifted"); + + if (successive_ids) { + FAISS_THROW_IF_NOT_MSG(!xids, + "It makes no sense to pass in ids and " + "request them to be shifted"); + FAISS_THROW_IF_NOT_MSG(this->ntotal == 0, + "when adding to IndexShards with sucessive_ids, " + "only add() in a single pass is supported"); + } + + idx_t nshard = this->count(); + const idx_t *ids = xids; + + std::vector aids; + + if (!ids && !successive_ids) { + aids.resize(n); + + for (idx_t i = 0; i < n; i++) { + aids[i] = this->ntotal + i; + } + + ids = aids.data(); + } + + size_t components_per_vec = + sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d; + + auto fn = + [n, ids, x, nshard, components_per_vec](int no, IndexT *index) { + idx_t i0 = (idx_t) no * n / nshard; + idx_t i1 = ((idx_t) no + 1) * n / nshard; + auto x0 = x + i0 * components_per_vec; + + if (index->verbose) { + printf ("begin add shard %d on %ld points\n", no, n); + } + + if (ids) { + index->add_with_ids (i1 - i0, x0, ids + i0); + } else { + index->add (i1 - i0, x0); + } + + if (index->verbose) { + printf ("end add shard %d on %ld points\n", no, i1 - i0); + } + }; + + this->runOnIndex(fn); + + // This is safe to do here because the current thread controls execution in + // all threads, and nothing else is happening + this->ntotal += n; +} + +template +void +IndexShardsTemplate::search(idx_t n, + const component_t *x, + idx_t k, + distance_t *distances, + idx_t *labels, + ConcurrentBitsetPtr bitset) const { + long nshard = this->count(); + + std::vector all_distances(nshard * k * n); + std::vector all_labels(nshard * k * n); + + auto fn = + [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) { + if (index->verbose) { + printf ("begin query shard %d on %ld points\n", no, n); + } + + index->search (n, x, k, + all_distances.data() + no * k * n, + all_labels.data() + no * k * n); + + if (index->verbose) { + printf ("end query shard %d\n", no); + } + }; + + this->runOnIndex(fn); + + std::vector translations(nshard, 0); + + // Because we just called runOnIndex above, it is safe to access the sub-index + // ntotal here + if (successive_ids) { + translations[0] = 0; + + for (int s = 0; s + 1 < nshard; s++) { + translations[s + 1] = translations[s] + this->at(s)->ntotal; + } + } + + if (this->metric_type == METRIC_L2) { + merge_tables>( + n, k, nshard, distances, labels, + all_distances, all_labels, translations); + } else { + merge_tables>( + n, k, nshard, distances, labels, + all_distances, all_labels, translations); + } +} + +// explicit instanciations +template struct IndexShardsTemplate; +template struct IndexShardsTemplate; + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/IndexShards.h b/core/src/index/thirdparty/faiss/IndexShards.h new file mode 100644 index 0000000000..6fbca6778a --- /dev/null +++ b/core/src/index/thirdparty/faiss/IndexShards.h @@ -0,0 +1,101 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +/** + * Index that concatenates the results from several sub-indexes + */ +template +struct IndexShardsTemplate : public ThreadedIndex { + using idx_t = typename IndexT::idx_t; + using component_t = typename IndexT::component_t; + using distance_t = typename IndexT::distance_t; + + /** + * The dimension that all sub-indices must share will be the dimension of the + * first sub-index added + * + * @param threaded do we use one thread per sub_index or do + * queries sequentially? + * @param successive_ids should we shift the returned ids by + * the size of each sub-index or return them + * as they are? + */ + explicit IndexShardsTemplate(bool threaded = false, + bool successive_ids = true); + + /** + * @param threaded do we use one thread per sub_index or do + * queries sequentially? + * @param successive_ids should we shift the returned ids by + * the size of each sub-index or return them + * as they are? + */ + explicit IndexShardsTemplate(idx_t d, + bool threaded = false, + bool successive_ids = true); + + /// int version due to the implicit bool conversion ambiguity of int as + /// dimension + explicit IndexShardsTemplate(int d, + bool threaded = false, + bool successive_ids = true); + + /// Alias for addIndex() + void add_shard(IndexT* index) { this->addIndex(index); } + + /// Alias for removeIndex() + void remove_shard(IndexT* index) { this->removeIndex(index); } + + /// supported only for sub-indices that implement add_with_ids + void add(idx_t n, const component_t* x) override; + + /** + * Cases (successive_ids, xids): + * - true, non-NULL ERROR: it makes no sense to pass in ids and + * request them to be shifted + * - true, NULL OK, but should be called only once (calls add() + * on sub-indexes). + * - false, non-NULL OK: will call add_with_ids with passed in xids + * distributed evenly over shards + * - false, NULL OK: will call add_with_ids on each sub-index, + * starting at ntotal + */ + void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override; + + void search(idx_t n, const component_t* x, idx_t k, + distance_t* distances, idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void train(idx_t n, const component_t* x) override; + + // update metric_type and ntotal. Call if you changes something in + // the shard indexes. + void sync_with_shard_indexes(); + + bool successive_ids; + + protected: + /// Called just after an index is added + void onAfterAddIndex(IndexT* index) override; + + /// Called just after an index is removed + void onAfterRemoveIndex(IndexT* index) override; +}; + +using IndexShards = IndexShardsTemplate; +using IndexBinaryShards = IndexShardsTemplate; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/InvertedLists.cpp b/core/src/index/thirdparty/faiss/InvertedLists.cpp new file mode 100644 index 0000000000..d44e74b58a --- /dev/null +++ b/core/src/index/thirdparty/faiss/InvertedLists.cpp @@ -0,0 +1,919 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include + +#ifndef USE_CPU +#include "gpu/utils/DeviceUtils.h" +#include "cuda.h" +#include "cuda_runtime.h" + +namespace faiss { + +/* + * Use pin memory to build Readonly Inverted list will accelerate cuda memory copy, but it will downgrade cpu ivf search + * performance. read only inverted list structure will also make ivf search performance not stable. ISSUE 500 mention + * this problem. Best performance is the original inverted list with non pin memory. + */ + +PageLockMemory::PageLockMemory(size_t size) : nbytes(size) { + auto err = cudaHostAlloc(&(this->data), size, 0); + if (err) { + std::string msg = + "Fail to alloc page lock memory " + std::to_string(size) + ", err code " + std::to_string((int32_t)err); + FAISS_THROW_MSG(msg); + } +} + +PageLockMemory::~PageLockMemory() { + CUDA_VERIFY(cudaFreeHost((void*)(this->data))); +} + +PageLockMemory::PageLockMemory(const PageLockMemory& other) { + auto err = cudaHostAlloc(&(this->data), other.nbytes, 0); + if (err) { + std::string msg = "Fail to alloc page lock memory " + std::to_string(other.nbytes) + ", err code " + + std::to_string((int32_t)err); + FAISS_THROW_MSG(msg); + } + memcpy(this->data, other.data, other.nbytes); + this->nbytes = other.nbytes; +} + +PageLockMemory::PageLockMemory(PageLockMemory &&other) { + this->data = other.data; + this->nbytes = other.nbytes; + other.data = nullptr; + other.nbytes = 0; +} + +} +#endif + +namespace faiss { + + + +/***************************************** + * InvertedLists implementation + ******************************************/ + +InvertedLists::InvertedLists (size_t nlist, size_t code_size): + nlist (nlist), code_size (code_size) +{ +} + +InvertedLists::~InvertedLists () +{} + +InvertedLists::idx_t InvertedLists::get_single_id ( + size_t list_no, size_t offset) const +{ + assert (offset < list_size (list_no)); + return get_ids(list_no)[offset]; +} + + +void InvertedLists::release_codes (size_t, const uint8_t *) const +{} + +void InvertedLists::release_ids (size_t, const idx_t *) const +{} + +void InvertedLists::prefetch_lists (const idx_t *, int) const +{} + +const uint8_t * InvertedLists::get_single_code ( + size_t list_no, size_t offset) const +{ + assert (offset < list_size (list_no)); + return get_codes(list_no) + offset * code_size; +} + +size_t InvertedLists::add_entry (size_t list_no, idx_t theid, + const uint8_t *code) +{ + return add_entries (list_no, 1, &theid, code); +} + +size_t InvertedLists::add_entry_without_codes (size_t list_no, idx_t theid) +{ + return add_entries_without_codes (list_no, 1, &theid); +} + +size_t InvertedLists::add_entries_without_codes (size_t list_no, size_t n_entry, + const idx_t* ids) +{ + return 0; +} + +void InvertedLists::update_entry (size_t list_no, size_t offset, + idx_t id, const uint8_t *code) +{ + update_entries (list_no, offset, 1, &id, code); +} + +InvertedLists* InvertedLists::to_readonly() { + return nullptr; +} + +InvertedLists* InvertedLists::to_readonly_without_codes() { + return nullptr; +} + +bool InvertedLists::is_readonly() const { + return false; +} + +void InvertedLists::reset () { + for (size_t i = 0; i < nlist; i++) { + resize (i, 0); + } +} + +void InvertedLists::merge_from (InvertedLists *oivf, size_t add_id) { + +#pragma omp parallel for + for (idx_t i = 0; i < nlist; i++) { + size_t list_size = oivf->list_size (i); + ScopedIds ids (oivf, i); + if (add_id == 0) { + add_entries (i, list_size, ids.get (), + ScopedCodes (oivf, i).get()); + } else { + std::vector new_ids (list_size); + + for (size_t j = 0; j < list_size; j++) { + new_ids [j] = ids[j] + add_id; + } + add_entries (i, list_size, new_ids.data(), + ScopedCodes (oivf, i).get()); + } + oivf->resize (i, 0); + } +} + +double InvertedLists::imbalance_factor () const { + std::vector hist(nlist); + + for (size_t i = 0; i < nlist; i++) { + hist[i] = list_size(i); + } + + return faiss::imbalance_factor(nlist, hist.data()); +} + +void InvertedLists::print_stats () const { + std::vector sizes(40); + for (size_t i = 0; i < nlist; i++) { + for (size_t j = 0; j < sizes.size(); j++) { + if ((list_size(i) >> j) == 0) { + sizes[j]++; + break; + } + } + } + for (size_t i = 0; i < sizes.size(); i++) { + if (sizes[i]) { + printf("list size in < %d: %d instances\n", 1 << i, sizes[i]); + } + } +} + +size_t InvertedLists::compute_ntotal () const { + size_t tot = 0; + for (size_t i = 0; i < nlist; i++) { + tot += list_size(i); + } + return tot; +} + +/***************************************** + * ArrayInvertedLists implementation + ******************************************/ + +ArrayInvertedLists::ArrayInvertedLists (size_t nlist, size_t code_size): + InvertedLists (nlist, code_size) +{ + ids.resize (nlist); + codes.resize (nlist); +} + +size_t ArrayInvertedLists::add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids_in, const uint8_t *code) +{ + if (n_entry == 0) return 0; + assert (list_no < nlist); + size_t o = ids [list_no].size(); + ids [list_no].resize (o + n_entry); + memcpy (&ids[list_no][o], ids_in, sizeof (ids_in[0]) * n_entry); + codes [list_no].resize ((o + n_entry) * code_size); + memcpy (&codes[list_no][o * code_size], code, code_size * n_entry); + return o; +} + +size_t ArrayInvertedLists::add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids_in) +{ + if (n_entry == 0) return 0; + assert (list_no < nlist); + size_t o = ids [list_no].size(); + ids [list_no].resize (o + n_entry); + memcpy (&ids[list_no][o], ids_in, sizeof (ids_in[0]) * n_entry); + return o; +} + +size_t ArrayInvertedLists::list_size(size_t list_no) const +{ + assert (list_no < nlist); + return ids[list_no].size(); +} + +const uint8_t * ArrayInvertedLists::get_codes (size_t list_no) const +{ + assert (list_no < nlist); + return codes[list_no].data(); +} + + +const InvertedLists::idx_t * ArrayInvertedLists::get_ids (size_t list_no) const +{ + assert (list_no < nlist); + return ids[list_no].data(); +} + +void ArrayInvertedLists::resize (size_t list_no, size_t new_size) +{ + ids[list_no].resize (new_size); + codes[list_no].resize (new_size * code_size); +} + +void ArrayInvertedLists::update_entries ( + size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids_in, const uint8_t *codes_in) +{ + assert (list_no < nlist); + assert (n_entry + offset <= ids[list_no].size()); + memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry); + memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry); +} + +InvertedLists* ArrayInvertedLists::to_readonly() { + ReadOnlyArrayInvertedLists* readonly = new ReadOnlyArrayInvertedLists(*this); + return readonly; +} + +InvertedLists* ArrayInvertedLists::to_readonly_without_codes() { + ReadOnlyArrayInvertedLists* readonly = new ReadOnlyArrayInvertedLists(*this, true); + return readonly; +} + +ArrayInvertedLists::~ArrayInvertedLists () +{} + +/***************************************************************** + * ReadOnlyArrayInvertedLists implementations + *****************************************************************/ + +ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(size_t nlist, + size_t code_size, const std::vector& list_length) + : InvertedLists (nlist, code_size), + readonly_length(list_length) { + valid = readonly_length.size() == nlist; + if (!valid) { + FAISS_THROW_MSG ("Invalid list_length"); + return; + } + auto total_size = std::accumulate(readonly_length.begin(), readonly_length.end(), 0); + readonly_offset.reserve(nlist); + +#ifdef USE_CPU + readonly_codes.reserve(total_size * code_size); + readonly_ids.reserve(total_size); +#endif + + size_t offset = 0; + for (auto i=0; icode_size) * sizeof(uint8_t); + pin_readonly_codes = std::make_shared(codes_size); + pin_readonly_ids = std::make_shared(ids_size); + + offset = 0; + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + auto& list_codes = other.codes[i]; + + uint8_t* ids_ptr = (uint8_t*)(pin_readonly_ids->data) + offset * sizeof(idx_t); + memcpy(ids_ptr, list_ids.data(), list_ids.size() * sizeof(idx_t)); + + uint8_t* codes_ptr = (uint8_t*)(pin_readonly_codes->data) + offset * (this->code_size) * sizeof(uint8_t); + memcpy(codes_ptr, list_codes.data(), list_codes.size() * sizeof(uint8_t)); + + offset += list_ids.size(); + } +#endif + + valid = true; +} + +ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other, bool offset_only) + : InvertedLists (other.nlist, other.code_size) { + readonly_length.resize(nlist); + readonly_offset.resize(nlist); + size_t offset = 0; + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + readonly_length[i] = list_ids.size(); + readonly_offset[i] = offset; + offset += list_ids.size(); + } + +#ifdef USE_CPU + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + readonly_ids.insert(readonly_ids.end(), list_ids.begin(), list_ids.end()); + } +#else + size_t ids_size = offset * sizeof(idx_t); + size_t codes_size = offset * (this->code_size) * sizeof(uint8_t); + pin_readonly_codes = std::make_shared(codes_size); + pin_readonly_ids = std::make_shared(ids_size); + + offset = 0; + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + + uint8_t* ids_ptr = (uint8_t*)(pin_readonly_ids->data) + offset * sizeof(idx_t); + memcpy(ids_ptr, list_ids.data(), list_ids.size() * sizeof(idx_t)); + + offset += list_ids.size(); + } +#endif + + valid = true; +} + + +ReadOnlyArrayInvertedLists::~ReadOnlyArrayInvertedLists() { +} + +bool +ReadOnlyArrayInvertedLists::is_valid() { + return valid; +} + +size_t ReadOnlyArrayInvertedLists::add_entries ( + size_t , size_t , + const idx_t* , const uint8_t *) +{ + FAISS_THROW_MSG ("not implemented"); +} + +size_t ReadOnlyArrayInvertedLists::add_entries_without_codes ( + size_t , size_t , + const idx_t*) +{ + FAISS_THROW_MSG ("not implemented"); +} + +void ReadOnlyArrayInvertedLists::update_entries (size_t, size_t , size_t , + const idx_t *, const uint8_t *) +{ + FAISS_THROW_MSG ("not implemented"); +} + +void ReadOnlyArrayInvertedLists::resize (size_t , size_t ) +{ + FAISS_THROW_MSG ("not implemented"); +} + +size_t ReadOnlyArrayInvertedLists::list_size(size_t list_no) const +{ + FAISS_ASSERT(list_no < nlist && valid); + return readonly_length[list_no]; +} + +const uint8_t * ReadOnlyArrayInvertedLists::get_codes (size_t list_no) const +{ + FAISS_ASSERT(list_no < nlist && valid); +#ifdef USE_CPU + return readonly_codes.data() + readonly_offset[list_no] * code_size; +#else + uint8_t *pcodes = (uint8_t *)(pin_readonly_codes->data); + return pcodes + readonly_offset[list_no] * code_size; +#endif +} + +const InvertedLists::idx_t* ReadOnlyArrayInvertedLists::get_ids (size_t list_no) const +{ + FAISS_ASSERT(list_no < nlist && valid); +#ifdef USE_CPU + return readonly_ids.data() + readonly_offset[list_no]; +#else + idx_t *pids = (idx_t *)pin_readonly_ids->data; + return pids + readonly_offset[list_no]; +#endif +} + +const InvertedLists::idx_t* ReadOnlyArrayInvertedLists::get_all_ids() const { + FAISS_ASSERT(valid); +#ifdef USE_CPU + return readonly_ids.data(); +#else + return (idx_t *)(pin_readonly_ids->data); +#endif +} + +const uint8_t* ReadOnlyArrayInvertedLists::get_all_codes() const { + FAISS_ASSERT(valid); +#ifdef USE_CPU + return readonly_codes.data(); +#else + return (uint8_t *)(pin_readonly_codes->data); +#endif +} + +const std::vector& ReadOnlyArrayInvertedLists::get_list_length() const { + FAISS_ASSERT(valid); + return readonly_length; +} + +bool ReadOnlyArrayInvertedLists::is_readonly() const { + FAISS_ASSERT(valid); + return true; +} + +/***************************************************************** + * Meta-inverted list implementations + *****************************************************************/ + + +size_t ReadOnlyInvertedLists::add_entries ( + size_t , size_t , + const idx_t* , const uint8_t *) +{ + FAISS_THROW_MSG ("not implemented"); +} + +size_t ReadOnlyInvertedLists::add_entries_without_codes ( + size_t , size_t , + const idx_t*) +{ + FAISS_THROW_MSG ("not implemented"); +} + +void ReadOnlyInvertedLists::update_entries (size_t, size_t , size_t , + const idx_t *, const uint8_t *) +{ + FAISS_THROW_MSG ("not implemented"); +} + +void ReadOnlyInvertedLists::resize (size_t , size_t ) +{ + FAISS_THROW_MSG ("not implemented"); +} + + + +/***************************************** + * HStackInvertedLists implementation + ******************************************/ + +HStackInvertedLists::HStackInvertedLists ( + int nil, const InvertedLists **ils_in): + ReadOnlyInvertedLists (nil > 0 ? ils_in[0]->nlist : 0, + nil > 0 ? ils_in[0]->code_size : 0) +{ + FAISS_THROW_IF_NOT (nil > 0); + for (int i = 0; i < nil; i++) { + ils.push_back (ils_in[i]); + FAISS_THROW_IF_NOT (ils_in[i]->code_size == code_size && + ils_in[i]->nlist == nlist); + } +} + +size_t HStackInvertedLists::list_size(size_t list_no) const +{ + size_t sz = 0; + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + sz += il->list_size (list_no); + } + return sz; +} + +const uint8_t * HStackInvertedLists::get_codes (size_t list_no) const +{ + uint8_t *codes = new uint8_t [code_size * list_size(list_no)], *c = codes; + + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + size_t sz = il->list_size(list_no) * code_size; + if (sz > 0) { + memcpy (c, ScopedCodes (il, list_no).get(), sz); + c += sz; + } + } + return codes; +} + +const uint8_t * HStackInvertedLists::get_single_code ( + size_t list_no, size_t offset) const +{ + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + size_t sz = il->list_size (list_no); + if (offset < sz) { + // here we have to copy the code, otherwise it will crash at dealloc + uint8_t * code = new uint8_t [code_size]; + memcpy (code, ScopedCodes (il, list_no, offset).get(), code_size); + return code; + } + offset -= sz; + } + FAISS_THROW_FMT ("offset %ld unknown", offset); +} + + +void HStackInvertedLists::release_codes (size_t, const uint8_t *codes) const { + delete [] codes; +} + +const Index::idx_t * HStackInvertedLists::get_ids (size_t list_no) const +{ + idx_t *ids = new idx_t [list_size(list_no)], *c = ids; + + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + size_t sz = il->list_size(list_no); + if (sz > 0) { + memcpy (c, ScopedIds (il, list_no).get(), sz * sizeof(idx_t)); + c += sz; + } + } + return ids; +} + +Index::idx_t HStackInvertedLists::get_single_id ( + size_t list_no, size_t offset) const +{ + + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + size_t sz = il->list_size (list_no); + if (offset < sz) { + return il->get_single_id (list_no, offset); + } + offset -= sz; + } + FAISS_THROW_FMT ("offset %ld unknown", offset); +} + + +void HStackInvertedLists::release_ids (size_t, const idx_t *ids) const { + delete [] ids; +} + +void HStackInvertedLists::prefetch_lists (const idx_t *list_nos, int nlist) const +{ + for (int i = 0; i < ils.size(); i++) { + const InvertedLists *il = ils[i]; + il->prefetch_lists (list_nos, nlist); + } +} + +/***************************************** + * SliceInvertedLists implementation + ******************************************/ + + +namespace { + + using idx_t = InvertedLists::idx_t; + + idx_t translate_list_no (const SliceInvertedLists *sil, + idx_t list_no) { + FAISS_THROW_IF_NOT (list_no >= 0 && list_no < sil->nlist); + return list_no + sil->i0; + } + +}; + + + +SliceInvertedLists::SliceInvertedLists ( + const InvertedLists *il, idx_t i0, idx_t i1): + ReadOnlyInvertedLists (i1 - i0, il->code_size), + il (il), i0(i0), i1(i1) +{ + +} + +size_t SliceInvertedLists::list_size(size_t list_no) const +{ + return il->list_size (translate_list_no (this, list_no)); +} + +const uint8_t * SliceInvertedLists::get_codes (size_t list_no) const +{ + return il->get_codes (translate_list_no (this, list_no)); +} + +const uint8_t * SliceInvertedLists::get_single_code ( + size_t list_no, size_t offset) const +{ + return il->get_single_code (translate_list_no (this, list_no), offset); +} + + +void SliceInvertedLists::release_codes ( + size_t list_no, const uint8_t *codes) const { + return il->release_codes (translate_list_no (this, list_no), codes); +} + +const Index::idx_t * SliceInvertedLists::get_ids (size_t list_no) const +{ + return il->get_ids (translate_list_no (this, list_no)); +} + +Index::idx_t SliceInvertedLists::get_single_id ( + size_t list_no, size_t offset) const +{ + return il->get_single_id (translate_list_no (this, list_no), offset); +} + + +void SliceInvertedLists::release_ids (size_t list_no, const idx_t *ids) const { + return il->release_ids (translate_list_no (this, list_no), ids); +} + +void SliceInvertedLists::prefetch_lists (const idx_t *list_nos, int nlist) const +{ + std::vector translated_list_nos; + for (int j = 0; j < nlist; j++) { + idx_t list_no = list_nos[j]; + if (list_no < 0) continue; + translated_list_nos.push_back (translate_list_no (this, list_no)); + } + il->prefetch_lists (translated_list_nos.data(), + translated_list_nos.size()); +} + + +/***************************************** + * VStackInvertedLists implementation + ******************************************/ + +namespace { + + using idx_t = InvertedLists::idx_t; + + // find the invlist this number belongs to + int translate_list_no (const VStackInvertedLists *vil, + idx_t list_no) { + FAISS_THROW_IF_NOT (list_no >= 0 && list_no < vil->nlist); + int i0 = 0, i1 = vil->ils.size(); + const idx_t *cumsz = vil->cumsz.data(); + while (i0 + 1 < i1) { + int imed = (i0 + i1) / 2; + if (list_no >= cumsz[imed]) { + i0 = imed; + } else { + i1 = imed; + } + } + assert(list_no >= cumsz[i0] && list_no < cumsz[i0 + 1]); + return i0; + } + + idx_t sum_il_sizes (int nil, const InvertedLists **ils_in) { + idx_t tot = 0; + for (int i = 0; i < nil; i++) { + tot += ils_in[i]->nlist; + } + return tot; + } + +}; + + + +VStackInvertedLists::VStackInvertedLists ( + int nil, const InvertedLists **ils_in): + ReadOnlyInvertedLists (sum_il_sizes(nil, ils_in), + nil > 0 ? ils_in[0]->code_size : 0) +{ + FAISS_THROW_IF_NOT (nil > 0); + cumsz.resize (nil + 1); + for (int i = 0; i < nil; i++) { + ils.push_back (ils_in[i]); + FAISS_THROW_IF_NOT (ils_in[i]->code_size == code_size); + cumsz[i + 1] = cumsz[i] + ils_in[i]->nlist; + } +} + +size_t VStackInvertedLists::list_size(size_t list_no) const +{ + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->list_size (list_no); +} + +const uint8_t * VStackInvertedLists::get_codes (size_t list_no) const +{ + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->get_codes (list_no); +} + +const uint8_t * VStackInvertedLists::get_single_code ( + size_t list_no, size_t offset) const +{ + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->get_single_code (list_no, offset); +} + + +void VStackInvertedLists::release_codes ( + size_t list_no, const uint8_t *codes) const { + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->release_codes (list_no, codes); +} + +const Index::idx_t * VStackInvertedLists::get_ids (size_t list_no) const +{ + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->get_ids (list_no); +} + +Index::idx_t VStackInvertedLists::get_single_id ( + size_t list_no, size_t offset) const +{ + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->get_single_id (list_no, offset); +} + + +void VStackInvertedLists::release_ids (size_t list_no, const idx_t *ids) const { + int i = translate_list_no (this, list_no); + list_no -= cumsz[i]; + return ils[i]->release_ids (list_no, ids); +} + +void VStackInvertedLists::prefetch_lists ( + const idx_t *list_nos, int nlist) const +{ + std::vector ilno (nlist, -1); + std::vector n_per_il (ils.size(), 0); + for (int j = 0; j < nlist; j++) { + idx_t list_no = list_nos[j]; + if (list_no < 0) continue; + int i = ilno[j] = translate_list_no (this, list_no); + n_per_il[i]++; + } + std::vector cum_n_per_il (ils.size() + 1, 0); + for (int j = 0; j < ils.size(); j++) { + cum_n_per_il[j + 1] = cum_n_per_il[j] + n_per_il[j]; + } + std::vector sorted_list_nos (cum_n_per_il.back()); + for (int j = 0; j < nlist; j++) { + idx_t list_no = list_nos[j]; + if (list_no < 0) continue; + int i = ilno[j]; + list_no -= cumsz[i]; + sorted_list_nos[cum_n_per_il[i]++] = list_no; + } + + int i0 = 0; + for (int j = 0; j < ils.size(); j++) { + int i1 = i0 + n_per_il[j]; + if (i1 > i0) { + ils[j]->prefetch_lists (sorted_list_nos.data() + i0, + i1 - i0); + } + i0 = i1; + } +} + + + +/***************************************** + * MaskedInvertedLists implementation + ******************************************/ + + +MaskedInvertedLists::MaskedInvertedLists (const InvertedLists *il0, + const InvertedLists *il1): + ReadOnlyInvertedLists (il0->nlist, il0->code_size), + il0 (il0), il1 (il1) +{ + FAISS_THROW_IF_NOT (il1->nlist == nlist); + FAISS_THROW_IF_NOT (il1->code_size == code_size); +} + +size_t MaskedInvertedLists::list_size(size_t list_no) const +{ + size_t sz = il0->list_size(list_no); + return sz ? sz : il1->list_size(list_no); +} + +const uint8_t * MaskedInvertedLists::get_codes (size_t list_no) const +{ + size_t sz = il0->list_size(list_no); + return (sz ? il0 : il1)->get_codes(list_no); +} + +const idx_t * MaskedInvertedLists::get_ids (size_t list_no) const +{ + size_t sz = il0->list_size (list_no); + return (sz ? il0 : il1)->get_ids (list_no); +} + +void MaskedInvertedLists::release_codes ( + size_t list_no, const uint8_t *codes) const +{ + size_t sz = il0->list_size (list_no); + (sz ? il0 : il1)->release_codes (list_no, codes); +} + +void MaskedInvertedLists::release_ids (size_t list_no, const idx_t *ids) const +{ + size_t sz = il0->list_size (list_no); + (sz ? il0 : il1)->release_ids (list_no, ids); +} + +idx_t MaskedInvertedLists::get_single_id (size_t list_no, size_t offset) const +{ + size_t sz = il0->list_size (list_no); + return (sz ? il0 : il1)->get_single_id (list_no, offset); +} + +const uint8_t * MaskedInvertedLists::get_single_code ( + size_t list_no, size_t offset) const +{ + size_t sz = il0->list_size (list_no); + return (sz ? il0 : il1)->get_single_code (list_no, offset); +} + +void MaskedInvertedLists::prefetch_lists ( + const idx_t *list_nos, int nlist) const +{ + std::vector list0, list1; + for (int i = 0; i < nlist; i++) { + idx_t list_no = list_nos[i]; + if (list_no < 0) continue; + size_t sz = il0->list_size(list_no); + (sz ? list0 : list1).push_back (list_no); + } + il0->prefetch_lists (list0.data(), list0.size()); + il1->prefetch_lists (list1.data(), list1.size()); +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/InvertedLists.h b/core/src/index/thirdparty/faiss/InvertedLists.h new file mode 100644 index 0000000000..c57b7b6961 --- /dev/null +++ b/core/src/index/thirdparty/faiss/InvertedLists.h @@ -0,0 +1,437 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INVERTEDLISTS_IVF_H +#define FAISS_INVERTEDLISTS_IVF_H + +/** + * Definition of inverted lists + a few common classes that implement + * the interface. + */ + +#include +#include +#include + +#ifndef USE_CPU +namespace faiss { + +struct PageLockMemory { +public: + PageLockMemory() : data(nullptr), nbytes(0) {} + + PageLockMemory(size_t size); + + ~PageLockMemory(); + + PageLockMemory(const PageLockMemory& other); + + PageLockMemory(PageLockMemory &&other); + + inline size_t size() { + return nbytes; + } + + void *data; + size_t nbytes; +}; +using PageLockMemoryPtr = std::shared_ptr; +} +#endif + +namespace faiss { + +/** Table of inverted lists + * multithreading rules: + * - concurrent read accesses are allowed + * - concurrent update accesses are allowed + * - for resize and add_entries, only concurrent access to different lists + * are allowed + */ +struct InvertedLists { + typedef Index::idx_t idx_t; + + size_t nlist; ///< number of possible key values + size_t code_size; ///< code size per vector in bytes + + InvertedLists (size_t nlist, size_t code_size); + + /************************* + * Read only functions */ + + /// get the size of a list + virtual size_t list_size(size_t list_no) const = 0; + + /** get the codes for an inverted list + * must be released by release_codes + * + * @return codes size list_size * code_size + */ + virtual const uint8_t * get_codes (size_t list_no) const = 0; + + /** get the ids for an inverted list + * must be released by release_ids + * + * @return ids size list_size + */ + virtual const idx_t * get_ids (size_t list_no) const = 0; + + /// release codes returned by get_codes (default implementation is nop + virtual void release_codes (size_t list_no, const uint8_t *codes) const; + + /// release ids returned by get_ids + virtual void release_ids (size_t list_no, const idx_t *ids) const; + + /// @return a single id in an inverted list + virtual idx_t get_single_id (size_t list_no, size_t offset) const; + + /// @return a single code in an inverted list + /// (should be deallocated with release_codes) + virtual const uint8_t * get_single_code ( + size_t list_no, size_t offset) const; + + /// prepare the following lists (default does nothing) + /// a list can be -1 hence the signed long + virtual void prefetch_lists (const idx_t *list_nos, int nlist) const; + + /************************* + * writing functions */ + + /// add one entry to an inverted list + virtual size_t add_entry (size_t list_no, idx_t theid, + const uint8_t *code); + + virtual size_t add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) = 0; + + /// add one entry to an inverted list without codes + virtual size_t add_entry_without_codes (size_t list_no, idx_t theid); + + virtual size_t add_entries_without_codes ( size_t list_no, size_t n_entry, + const idx_t* ids); + + virtual void update_entry (size_t list_no, size_t offset, + idx_t id, const uint8_t *code); + + virtual void update_entries (size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids, const uint8_t *code) = 0; + + virtual void resize (size_t list_no, size_t new_size) = 0; + + virtual void reset (); + + virtual InvertedLists* to_readonly(); + + virtual InvertedLists* to_readonly_without_codes(); + + virtual bool is_readonly() const; + + /// move all entries from oivf (empty on output) + void merge_from (InvertedLists *oivf, size_t add_id); + + virtual ~InvertedLists (); + + /************************* + * statistics */ + + /// 1= perfectly balanced, >1: imbalanced + double imbalance_factor () const; + + /// display some stats about the inverted lists + void print_stats () const; + + /// sum up list sizes + size_t compute_ntotal () const; + + /************************************** + * Scoped inverted lists (for automatic deallocation) + * + * instead of writing: + * + * uint8_t * codes = invlists->get_codes (10); + * ... use codes + * invlists->release_codes(10, codes) + * + * write: + * + * ScopedCodes codes (invlists, 10); + * ... use codes.get() + * // release called automatically when codes goes out of scope + * + * the following function call also works: + * + * foo (123, ScopedCodes (invlists, 10).get(), 456); + * + */ + + struct ScopedIds { + const InvertedLists *il; + const idx_t *ids; + size_t list_no; + + ScopedIds (const InvertedLists *il, size_t list_no): + il (il), ids (il->get_ids (list_no)), list_no (list_no) + {} + + const idx_t *get() {return ids; } + + idx_t operator [] (size_t i) const { + return ids[i]; + } + + ~ScopedIds () { + il->release_ids (list_no, ids); + } + }; + + struct ScopedCodes { + const InvertedLists *il; + const uint8_t *codes; + size_t list_no; + + ScopedCodes (const InvertedLists *il, size_t list_no): + il (il), codes (il->get_codes (list_no)), list_no (list_no) + {} + + ScopedCodes (const InvertedLists *il, size_t list_no, size_t offset): + il (il), codes (il->get_single_code (list_no, offset)), + list_no (list_no) + {} + + // For codes outside + ScopedCodes (const InvertedLists *il, size_t list_no, const uint8_t *original_codes): + il (il), codes (original_codes), list_no (list_no) + {} + + const uint8_t *get() {return codes; } + + ~ScopedCodes () { + il->release_codes (list_no, codes); + } + }; + + +}; + + +/// simple (default) implementation as an array of inverted lists +struct ArrayInvertedLists: InvertedLists { + std::vector < std::vector > codes; // binary codes, size nlist + std::vector < std::vector > ids; ///< Inverted lists for indexes + + ArrayInvertedLists (size_t nlist, size_t code_size); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + size_t add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) override; + + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + + void update_entries (size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids, const uint8_t *code) override; + + void resize (size_t list_no, size_t new_size) override; + + InvertedLists* to_readonly() override; + + InvertedLists* to_readonly_without_codes() override; + + virtual ~ArrayInvertedLists (); +}; + +struct ReadOnlyArrayInvertedLists: InvertedLists { +#ifdef USE_CPU + std::vector readonly_codes; + std::vector readonly_ids; +#else + PageLockMemoryPtr pin_readonly_codes; + PageLockMemoryPtr pin_readonly_ids; +#endif + + std::vector readonly_length; + std::vector readonly_offset; + bool valid; + + ReadOnlyArrayInvertedLists(size_t nlist, size_t code_size, const std::vector& list_length); + explicit ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other); + explicit ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other, bool offset); + + // Use default copy construct, just copy pointer, DON'T COPY pin_readonly_codes AND pin_readonly_ids +// explicit ReadOnlyArrayInvertedLists(const ReadOnlyArrayInvertedLists &); +// explicit ReadOnlyArrayInvertedLists(ReadOnlyArrayInvertedLists &&); + virtual ~ReadOnlyArrayInvertedLists(); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + const uint8_t * get_all_codes() const; + const idx_t * get_all_ids() const; + const std::vector& get_list_length() const; + + size_t add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) override; + + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + + void update_entries (size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids, const uint8_t *code) override; + + void resize (size_t list_no, size_t new_size) override; + + bool is_readonly() const override; + + bool is_valid(); +}; + +/***************************************************************** + * Meta-inverted lists + * + * About terminology: the inverted lists are seen as a sparse matrix, + * that can be stacked horizontally, vertically and sliced. + *****************************************************************/ + +struct ReadOnlyInvertedLists: InvertedLists { + + ReadOnlyInvertedLists (size_t nlist, size_t code_size): + InvertedLists (nlist, code_size) {} + + size_t add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) override; + + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + + void update_entries (size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids, const uint8_t *code) override; + + void resize (size_t list_no, size_t new_size) override; + +}; + + +/// Horizontal stack of inverted lists +struct HStackInvertedLists: ReadOnlyInvertedLists { + + std::vectorils; + + /// build InvertedLists by concatenating nil of them + HStackInvertedLists (int nil, const InvertedLists **ils); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + void prefetch_lists (const idx_t *list_nos, int nlist) const override; + + void release_codes (size_t list_no, const uint8_t *codes) const override; + void release_ids (size_t list_no, const idx_t *ids) const override; + + idx_t get_single_id (size_t list_no, size_t offset) const override; + + const uint8_t * get_single_code ( + size_t list_no, size_t offset) const override; + +}; + +using ConcatenatedInvertedLists = HStackInvertedLists; + + +/// vertical slice of indexes in another InvertedLists +struct SliceInvertedLists: ReadOnlyInvertedLists { + const InvertedLists *il; + idx_t i0, i1; + + SliceInvertedLists(const InvertedLists *il, idx_t i0, idx_t i1); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + void release_codes (size_t list_no, const uint8_t *codes) const override; + void release_ids (size_t list_no, const idx_t *ids) const override; + + idx_t get_single_id (size_t list_no, size_t offset) const override; + + const uint8_t * get_single_code ( + size_t list_no, size_t offset) const override; + + void prefetch_lists (const idx_t *list_nos, int nlist) const override; +}; + + +struct VStackInvertedLists: ReadOnlyInvertedLists { + std::vectorils; + std::vector cumsz; + + /// build InvertedLists by concatenating nil of them + VStackInvertedLists (int nil, const InvertedLists **ils); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + void release_codes (size_t list_no, const uint8_t *codes) const override; + void release_ids (size_t list_no, const idx_t *ids) const override; + + idx_t get_single_id (size_t list_no, size_t offset) const override; + + const uint8_t * get_single_code ( + size_t list_no, size_t offset) const override; + + void prefetch_lists (const idx_t *list_nos, int nlist) const override; + +}; + + +/** use the first inverted lists if they are non-empty otherwise use the second + * + * This is useful if il1 has a few inverted lists that are too long, + * and that il0 has replacement lists for those, with empty lists for + * the others. */ +struct MaskedInvertedLists: ReadOnlyInvertedLists { + + const InvertedLists *il0; + const InvertedLists *il1; + + MaskedInvertedLists (const InvertedLists *il0, + const InvertedLists *il1); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + void release_codes (size_t list_no, const uint8_t *codes) const override; + void release_ids (size_t list_no, const idx_t *ids) const override; + + idx_t get_single_id (size_t list_no, size_t offset) const override; + + const uint8_t * get_single_code ( + size_t list_no, size_t offset) const override; + + void prefetch_lists (const idx_t *list_nos, int nlist) const override; + +}; + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/LICENSE b/core/src/index/thirdparty/faiss/LICENSE new file mode 100644 index 0000000000..b96dcb0480 --- /dev/null +++ b/core/src/index/thirdparty/faiss/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/core/src/index/thirdparty/faiss/Makefile b/core/src/index/thirdparty/faiss/Makefile new file mode 100644 index 0000000000..f81e67914c --- /dev/null +++ b/core/src/index/thirdparty/faiss/Makefile @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include makefile.inc + +HEADERS = $(wildcard *.h impl/*.h utils/*.h) +SRC = $(wildcard *.cpp impl/*.cpp utils/*.cpp) +AVX_SRC = $(wildcard *avx.cpp impl/*avx.cpp utils/*avx.cpp) +AVX512_SRC = $(wildcard *avx512.cpp impl/*avx512.cpp utils/*avx512.cpp) +OBJ = $(SRC:.cpp=.o) +INSTALLDIRS = $(DESTDIR)$(libdir) $(DESTDIR)$(includedir)/faiss + +GPU_HEADERS = $(wildcard gpu/*.h gpu/impl/*.h gpu/impl/*.cuh gpu/utils/*.h gpu/utils/*.cuh) +GPU_CPPSRC = $(wildcard gpu/*.cpp gpu/impl/*.cpp gpu/utils/*.cpp) +GPU_CUSRC = $(wildcard gpu/*.cu gpu/impl/*.cu gpu/utils/*.cu \ +gpu/utils/nvidia/*.cu gpu/utils/blockselect/*.cu gpu/utils/warpselect/*.cu) +GPU_SRC = $(GPU_CPPSRC) $(GPU_CUSRC) +GPU_CPPOBJ = $(GPU_CPPSRC:.cpp=.o) +GPU_CUOBJ = $(GPU_CUSRC:.cu=.o) +GPU_OBJ = $(GPU_CPPOBJ) $(GPU_CUOBJ) + +ifneq ($(strip $(NVCC)),) + OBJ += $(GPU_OBJ) + HEADERS += $(GPU_HEADERS) +endif + +CPPFLAGS += -I. +NVCCFLAGS += -I. + +############################ +# Building + +all: libfaiss.a libfaiss.$(SHAREDEXT) + +libfaiss.a: $(OBJ) + $(AR) r $@ $^ + +libfaiss.$(SHAREDEXT): $(OBJ) + $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) + +%.o: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -c $< -o $@ + +# support avx +%avx.o: %avx.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -mavx2 -c $< -o $@ + +# support avx512 +%avx512.o: %avx512.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -mavx512f -mavx512dq -mavx512bw -c $< -o $@ + +%.o: %.cu + $(NVCC) $(NVCCFLAGS) -c $< -o $@ + +clean: + rm -f libfaiss.a libfaiss.$(SHAREDEXT) + rm -f $(OBJ) + + +############################ +# Installing + +install: libfaiss.a libfaiss.$(SHAREDEXT) installdirs + cp libfaiss.a libfaiss.$(SHAREDEXT) $(DESTDIR)$(libdir) + tar cf - $(HEADERS) | tar xf - -C $(DESTDIR)$(includedir)/faiss/ + +installdirs: + $(MKDIR_P) $(INSTALLDIRS) + +uninstall: + rm -f $(DESTDIR)$(libdir)/libfaiss.a \ + $(DESTDIR)$(libdir)/libfaiss.$(SHAREDEXT) + rm -rf $(DESTDIR)$(includedir)/faiss + + +############################# +# Dependencies + +-include depend + +depend: $(SRC) $(GPU_SRC) + for i in $^; do \ + $(CXXCPP) $(CPPFLAGS) -DCUDA_VERSION=7050 -x c++ -MM $$i; \ + done > depend + + +############################# +# Python + +py: libfaiss.a + $(MAKE) -C python + + +############################# +# Tests + +test: libfaiss.a py + $(MAKE) -C tests run + PYTHONPATH=./python/build/`ls python/build | grep lib` \ + $(PYTHON) -m unittest discover tests/ -v + +test_gpu: libfaiss.a + $(MAKE) -C gpu/test run + PYTHONPATH=./python/build/`ls python/build | grep lib` \ + $(PYTHON) -m unittest discover gpu/test/ -v + +############################# +# Demos + +demos: libfaiss.a + $(MAKE) -C demos + + +############################# +# Misc + +misc/test_blas: misc/test_blas.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) + + +.PHONY: all clean demos install installdirs py test test_gpu uninstall diff --git a/core/src/index/thirdparty/faiss/MatrixStats.cpp b/core/src/index/thirdparty/faiss/MatrixStats.cpp new file mode 100644 index 0000000000..1862d1a52f --- /dev/null +++ b/core/src/index/thirdparty/faiss/MatrixStats.cpp @@ -0,0 +1,252 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + + +#include + + +#include /* va_list, va_start, va_arg, va_end */ + +#include +#include +#include + +namespace faiss { + +/********************************************************************* + * MatrixStats + *********************************************************************/ + +MatrixStats::PerDimStats::PerDimStats(): + n(0), n_nan(0), n_inf(0), n0(0), + min(HUGE_VALF), max(-HUGE_VALF), + sum(0), sum2(0), + mean(NAN), stddev(NAN) +{} + + +void MatrixStats::PerDimStats::add (float x) +{ + n++; + if (std::isnan(x)) { + n_nan++; + return; + } + if (!std::isfinite(x)) { + n_inf++; + return; + } + if (x == 0) n0++; + if (x < min) min = x; + if (x > max) max = x; + sum += x; + sum2 += (double)x * (double)x; +} + +void MatrixStats::PerDimStats::compute_mean_std () +{ + n_valid = n - n_nan - n_inf; + mean = sum / n_valid; + double var = sum2 / n_valid - mean * mean; + if (var < 0) var = 0; + stddev = sqrt(var); +} + + +void MatrixStats::do_comment (const char *fmt, ...) +{ + va_list ap; + + /* Determine required size */ + va_start(ap, fmt); + size_t size = vsnprintf(buf, nbuf, fmt, ap); + va_end(ap); + + nbuf -= size; + buf += size; +} + + + +MatrixStats::MatrixStats (size_t n, size_t d, const float *x): + n(n), d(d), + n_collision(0), n_valid(0), n0(0), + min_norm2(HUGE_VAL), max_norm2(0) +{ + std::vector comment_buf (10000); + buf = comment_buf.data (); + nbuf = comment_buf.size(); + + do_comment ("analyzing %ld vectors of size %ld\n", n, d); + + if (d > 1024) { + do_comment ( + "indexing this many dimensions is hard, " + "please consider dimensionality reducution (with PCAMatrix)\n"); + } + + size_t nbytes = sizeof (x[0]) * d; + per_dim_stats.resize (d); + + for (size_t i = 0; i < n; i++) { + const float *xi = x + d * i; + double sum2 = 0; + for (size_t j = 0; j < d; j++) { + per_dim_stats[j].add (xi[j]); + sum2 += xi[j] * (double)xi[j]; + } + + if (std::isfinite (sum2)) { + n_valid++; + if (sum2 == 0) { + n0 ++; + } else { + if (sum2 < min_norm2) min_norm2 = sum2; + if (sum2 > max_norm2) max_norm2 = sum2; + } + } + + { // check hash + uint64_t hash = hash_bytes((const uint8_t*)xi, nbytes); + auto elt = occurrences.find (hash); + if (elt == occurrences.end()) { + Occurrence occ = {i, 1}; + occurrences[hash] = occ; + } else { + if (!memcmp (xi, x + elt->second.first * d, nbytes)) { + elt->second.count ++; + } else { + n_collision ++; + // we should use a list of collisions but overkill + } + } + } + } + + // invalid vecor stats + if (n_valid == n) { + do_comment ("no NaN or Infs in data\n"); + } else { + do_comment ("%ld vectors contain NaN or Inf " + "(or have too large components), " + "expect bad results with indexing!\n", n - n_valid); + } + + // copies in dataset + if (occurrences.size() == n) { + do_comment ("all vectors are distinct\n"); + } else { + do_comment ("%ld vectors are distinct (%.2f%%)\n", + occurrences.size(), + occurrences.size() * 100.0 / n); + + if (n_collision > 0) { + do_comment ("%ld collisions in hash table, " + "counts may be invalid\n", n_collision); + } + + Occurrence max = {0, 0}; + for (auto it = occurrences.begin(); + it != occurrences.end(); ++it) { + if (it->second.count > max.count) { + max = it->second; + } + } + do_comment ("vector %ld has %ld copies\n", max.first, max.count); + } + + { // norm stats + min_norm2 = sqrt (min_norm2); + max_norm2 = sqrt (max_norm2); + do_comment ("range of L2 norms=[%g, %g] (%ld null vectors)\n", + min_norm2, max_norm2, n0); + + if (max_norm2 < min_norm2 * 1.0001) { + do_comment ("vectors are normalized, inner product and " + "L2 search are equivalent\n"); + } + + if (max_norm2 > min_norm2 * 100) { + do_comment ("vectors have very large differences in norms, " + "is this normal?\n"); + } + } + + { // per dimension stats + + double max_std = 0, min_std = HUGE_VAL; + + size_t n_dangerous_range = 0, n_0_range = 0, n0 = 0; + + for (size_t j = 0; j < d; j++) { + PerDimStats &st = per_dim_stats[j]; + st.compute_mean_std (); + n0 += st.n0; + + if (st.max == st.min) { + n_0_range ++; + } else if (st.max < 1.001 * st.min) { + n_dangerous_range ++; + } + + if (st.stddev > max_std) max_std = st.stddev; + if (st.stddev < min_std) min_std = st.stddev; + } + + + + if (n0 == 0) { + do_comment ("matrix contains no 0s\n"); + } else { + do_comment ("matrix contains %.2f %% 0 entries\n", + n0 * 100.0 / (n * d)); + } + + if (n_0_range == 0) { + do_comment ("no constant dimensions\n"); + } else { + do_comment ("%ld dimensions are constant: they can be removed\n", + n_0_range); + } + + if (n_dangerous_range == 0) { + do_comment ("no dimension has a too large mean\n"); + } else { + do_comment ("%ld dimensions are too large " + "wrt. their variance, may loose precision " + "in IndexFlatL2 (use CenteringTransform)\n", + n_dangerous_range); + } + + do_comment ("stddevs per dimension are in [%g %g]\n", min_std, max_std); + + size_t n_small_var = 0; + + for (size_t j = 0; j < d; j++) { + const PerDimStats &st = per_dim_stats[j]; + if (st.stddev < max_std * 1e-4) { + n_small_var++; + } + } + + if (n_small_var > 0) { + do_comment ("%ld dimensions have negligible stddev wrt. " + "the largest dimension, they could be ignored", + n_small_var); + } + + } + comments = comment_buf.data (); + buf = nullptr; + nbuf = 0; +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/MatrixStats.h b/core/src/index/thirdparty/faiss/MatrixStats.h new file mode 100644 index 0000000000..6418644c6e --- /dev/null +++ b/core/src/index/thirdparty/faiss/MatrixStats.h @@ -0,0 +1,62 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + + +namespace faiss { + + +/** Reports some statistics on a dataset and comments on them. + * + * It is a class rather than a function so that all stats can also be + * accessed from code */ + +struct MatrixStats { + MatrixStats (size_t n, size_t d, const float *x); + std::string comments; + + // raw statistics + size_t n, d; + size_t n_collision, n_valid, n0; + double min_norm2, max_norm2; + + struct PerDimStats { + size_t n, n_nan, n_inf, n0; + + float min, max; + double sum, sum2; + + size_t n_valid; + double mean, stddev; + + PerDimStats(); + void add (float x); + void compute_mean_std (); + }; + + std::vector per_dim_stats; + struct Occurrence { + size_t first; + size_t count; + }; + std::unordered_map occurrences; + + char *buf; + size_t nbuf; + void do_comment (const char *fmt, ...); + +}; + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/MetaIndexes.cpp b/core/src/index/thirdparty/faiss/MetaIndexes.cpp new file mode 100644 index 0000000000..0094e17ba4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/MetaIndexes.cpp @@ -0,0 +1,379 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include +#include +#include + + +namespace faiss { + +namespace { + + +} // namespace + +/***************************************************** + * IndexIDMap implementation + *******************************************************/ + +template +IndexIDMapTemplate::IndexIDMapTemplate (IndexT *index): + index (index), + own_fields (false) +{ + FAISS_THROW_IF_NOT_MSG (index->ntotal == 0, "index must be empty on input"); + this->is_trained = index->is_trained; + this->metric_type = index->metric_type; + this->verbose = index->verbose; + this->d = index->d; +} + +template +void IndexIDMapTemplate::add + (idx_t, const typename IndexT::component_t *) +{ + FAISS_THROW_MSG ("add does not make sense with IndexIDMap, " + "use add_with_ids"); +} + + +template +void IndexIDMapTemplate::train + (idx_t n, const typename IndexT::component_t *x) +{ + index->train (n, x); + this->is_trained = index->is_trained; +} + +template +void IndexIDMapTemplate::reset () +{ + index->reset (); + id_map.clear(); + this->ntotal = 0; +} + + +template +void IndexIDMapTemplate::add_with_ids + (idx_t n, const typename IndexT::component_t * x, + const typename IndexT::idx_t *xids) +{ + index->add (n, x); + for (idx_t i = 0; i < n; i++) + id_map.push_back (xids[i]); + this->ntotal = index->ntotal; +} + + +template +void IndexIDMapTemplate::search + (idx_t n, const typename IndexT::component_t *x, idx_t k, + typename IndexT::distance_t *distances, typename IndexT::idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + index->search (n, x, k, distances, labels, bitset); + idx_t *li = labels; +#pragma omp parallel for + for (idx_t i = 0; i < n * k; i++) { + li[i] = li[i] < 0 ? li[i] : id_map[li[i]]; + } +} + +#if 0 +template +void IndexIDMapTemplate::get_vector_by_id(idx_t n, const idx_t *xid, component_t *x, + ConcurrentBitsetPtr bitset) +{ + /* only get vector by 1 id */ + FAISS_ASSERT(n == 1); + if (!bitset || !bitset->test(xid[0])) { + index->reconstruct(xid[0], x + 0 * IndexT::d); + } else { + memset(x, UINT8_MAX, IndexT::d * sizeof(component_t)); + } +} + +template +void IndexIDMapTemplate::search_by_id (idx_t n, const idx_t *xid, idx_t k, + typename IndexT::distance_t *distances, idx_t *labels, ConcurrentBitsetPtr bitset) +{ + auto x = new typename IndexT::component_t[n * IndexT::d]; + for (idx_t i = 0; i < n; i++) { + index->reconstruct(xid[i], x + i * IndexT::d); + } + index->search(n, x, k, distances, labels, bitset); + delete []x; +} +#endif + +template +void IndexIDMapTemplate::range_search + (typename IndexT::idx_t n, const typename IndexT::component_t *x, + typename IndexT::distance_t radius, RangeSearchResult *result, + ConcurrentBitsetPtr bitset) const +{ + index->range_search(n, x, radius, result, bitset); +#pragma omp parallel for + for (idx_t i = 0; i < result->lims[result->nq]; i++) { + result->labels[i] = result->labels[i] < 0 ? + result->labels[i] : id_map[result->labels[i]]; + } +} + +namespace { + +struct IDTranslatedSelector: IDSelector { + const std::vector & id_map; + const IDSelector & sel; + IDTranslatedSelector (const std::vector & id_map, + const IDSelector & sel): + id_map (id_map), sel (sel) + {} + bool is_member(idx_t id) const override { + return sel.is_member(id_map[id]); + } +}; + +} + +template +size_t IndexIDMapTemplate::remove_ids (const IDSelector & sel) +{ + // remove in sub-index first + IDTranslatedSelector sel2 (id_map, sel); + size_t nremove = index->remove_ids (sel2); + + int64_t j = 0; + for (idx_t i = 0; i < this->ntotal; i++) { + if (sel.is_member (id_map[i])) { + // remove + } else { + id_map[j] = id_map[i]; + j++; + } + } + FAISS_ASSERT (j == index->ntotal); + this->ntotal = j; + id_map.resize(this->ntotal); + return nremove; +} + +template +IndexIDMapTemplate::~IndexIDMapTemplate () +{ + if (own_fields) delete index; +} + + + +/***************************************************** + * IndexIDMap2 implementation + *******************************************************/ + +template +IndexIDMap2Template::IndexIDMap2Template (IndexT *index): + IndexIDMapTemplate (index) +{} + +template +void IndexIDMap2Template::add_with_ids + (idx_t n, const typename IndexT::component_t* x, + const typename IndexT::idx_t* xids) +{ + size_t prev_ntotal = this->ntotal; + IndexIDMapTemplate::add_with_ids (n, x, xids); + for (size_t i = prev_ntotal; i < this->ntotal; i++) { + rev_map [this->id_map [i]] = i; + } +} + +template +void IndexIDMap2Template::construct_rev_map () +{ + rev_map.clear (); + for (size_t i = 0; i < this->ntotal; i++) { + rev_map [this->id_map [i]] = i; + } +} + + +template +size_t IndexIDMap2Template::remove_ids(const IDSelector& sel) +{ + // This is quite inefficient + size_t nremove = IndexIDMapTemplate::remove_ids (sel); + construct_rev_map (); + return nremove; +} + +template +void IndexIDMap2Template::reconstruct + (idx_t key, typename IndexT::component_t * recons) const +{ + try { + this->index->reconstruct (rev_map.at (key), recons); + } catch (const std::out_of_range& e) { + FAISS_THROW_FMT ("key %ld not found", key); + } +} + + +// explicit template instantiations + +template struct IndexIDMapTemplate; +template struct IndexIDMapTemplate; +template struct IndexIDMap2Template; +template struct IndexIDMap2Template; + + +/***************************************************** + * IndexSplitVectors implementation + *******************************************************/ + + +IndexSplitVectors::IndexSplitVectors (idx_t d, bool threaded): + Index (d), own_fields (false), + threaded (threaded), sum_d (0) +{ + +} + +void IndexSplitVectors::add_sub_index (Index *index) +{ + sub_indexes.push_back (index); + sync_with_sub_indexes (); +} + +void IndexSplitVectors::sync_with_sub_indexes () +{ + if (sub_indexes.empty()) return; + Index * index0 = sub_indexes[0]; + sum_d = index0->d; + metric_type = index0->metric_type; + is_trained = index0->is_trained; + ntotal = index0->ntotal; + for (int i = 1; i < sub_indexes.size(); i++) { + Index * index = sub_indexes[i]; + FAISS_THROW_IF_NOT (metric_type == index->metric_type); + FAISS_THROW_IF_NOT (ntotal == index->ntotal); + sum_d += index->d; + } + +} + +void IndexSplitVectors::add(idx_t /*n*/, const float* /*x*/) { + FAISS_THROW_MSG("not implemented"); +} + + + +void IndexSplitVectors::search ( + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) const +{ + FAISS_THROW_IF_NOT_MSG (k == 1, + "search implemented only for k=1"); + FAISS_THROW_IF_NOT_MSG (sum_d == d, + "not enough indexes compared to # dimensions"); + + int64_t nshard = sub_indexes.size(); + float *all_distances = new float [nshard * k * n]; + idx_t *all_labels = new idx_t [nshard * k * n]; + ScopeDeleter del (all_distances); + ScopeDeleter del2 (all_labels); + + auto query_func = [n, x, k, distances, labels, all_distances, all_labels, this] + (int no) { + const IndexSplitVectors *index = this; + float *distances1 = no == 0 ? distances : all_distances + no * k * n; + idx_t *labels1 = no == 0 ? labels : all_labels + no * k * n; + if (index->verbose) + printf ("begin query shard %d on %ld points\n", no, n); + const Index * sub_index = index->sub_indexes[no]; + int64_t sub_d = sub_index->d, d = index->d; + idx_t ofs = 0; + for (int i = 0; i < no; i++) ofs += index->sub_indexes[i]->d; + float *sub_x = new float [sub_d * n]; + ScopeDeleter del1 (sub_x); + for (idx_t i = 0; i < n; i++) + memcpy (sub_x + i * sub_d, x + ofs + i * d, sub_d * sizeof (sub_x)); + sub_index->search (n, sub_x, k, distances1, labels1); + if (index->verbose) + printf ("end query shard %d\n", no); + }; + + if (!threaded) { + for (int i = 0; i < nshard; i++) { + query_func(i); + } + } else { + std::vector > threads; + std::vector> v; + + for (int i = 0; i < nshard; i++) { + threads.emplace_back(new WorkerThread()); + WorkerThread *wt = threads.back().get(); + v.emplace_back(wt->add([i, query_func](){query_func(i); })); + } + + // Blocking wait for completion + for (auto& func : v) { + func.get(); + } + } + + int64_t factor = 1; + for (int i = 0; i < nshard; i++) { + if (i > 0) { // results of 0 are already in the table + const float *distances_i = all_distances + i * k * n; + const idx_t *labels_i = all_labels + i * k * n; + for (int64_t j = 0; j < n; j++) { + if (labels[j] >= 0 && labels_i[j] >= 0) { + labels[j] += labels_i[j] * factor; + distances[j] += distances_i[j]; + } else { + labels[j] = -1; + distances[j] = 0.0 / 0.0; + } + } + } + factor *= sub_indexes[i]->ntotal; + } + +} + +void IndexSplitVectors::train(idx_t /*n*/, const float* /*x*/) { + FAISS_THROW_MSG("not implemented"); +} + +void IndexSplitVectors::reset () +{ + FAISS_THROW_MSG ("not implemented"); +} + + +IndexSplitVectors::~IndexSplitVectors () +{ + if (own_fields) { + for (int s = 0; s < sub_indexes.size(); s++) + delete sub_indexes [s]; + } +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/MetaIndexes.h b/core/src/index/thirdparty/faiss/MetaIndexes.h new file mode 100644 index 0000000000..cfac6a5572 --- /dev/null +++ b/core/src/index/thirdparty/faiss/MetaIndexes.h @@ -0,0 +1,135 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef META_INDEXES_H +#define META_INDEXES_H + +#include +#include +#include +#include +#include + +namespace faiss { + +/** Index that translates search results to ids */ +template +struct IndexIDMapTemplate : IndexT { + using idx_t = typename IndexT::idx_t; + using component_t = typename IndexT::component_t; + using distance_t = typename IndexT::distance_t; + + IndexT * index; ///! the sub-index + bool own_fields; ///! whether pointers are deleted in destructo + std::vector id_map; + + explicit IndexIDMapTemplate (IndexT *index); + + /// @param xids if non-null, ids to store for the vectors (size n) + void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override; + + /// this will fail. Use add_with_ids + void add(idx_t n, const component_t* x) override; + + void search( + idx_t n, const component_t* x, idx_t k, + distance_t* distances, idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + +#if 0 + void get_vector_by_id(idx_t n, const idx_t *xid, component_t *x, ConcurrentBitsetPtr bitset = nullptr) override; + + void search_by_id (idx_t n, const idx_t *xid, idx_t k, distance_t *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr) override; +#endif + + void train(idx_t n, const component_t* x) override; + + void reset() override; + + /// remove ids adapted to IndexFlat + size_t remove_ids(const IDSelector& sel) override; + + void range_search (idx_t n, const component_t *x, distance_t radius, + RangeSearchResult *result, + ConcurrentBitsetPtr bitset = nullptr) const override; + + ~IndexIDMapTemplate () override; + IndexIDMapTemplate () {own_fields=false; index=nullptr; } +}; + +using IndexIDMap = IndexIDMapTemplate; +using IndexBinaryIDMap = IndexIDMapTemplate; + + +/** same as IndexIDMap but also provides an efficient reconstruction + * implementation via a 2-way index */ +template +struct IndexIDMap2Template : IndexIDMapTemplate { + using idx_t = typename IndexT::idx_t; + using component_t = typename IndexT::component_t; + using distance_t = typename IndexT::distance_t; + + std::unordered_map rev_map; + + explicit IndexIDMap2Template (IndexT *index); + + /// make the rev_map from scratch + void construct_rev_map (); + + void add_with_ids(idx_t n, const component_t* x, const idx_t* xids) override; + + size_t remove_ids(const IDSelector& sel) override; + + void reconstruct (idx_t key, component_t * recons) const override; + + ~IndexIDMap2Template() override {} + IndexIDMap2Template () {} +}; + +using IndexIDMap2 = IndexIDMap2Template; +using IndexBinaryIDMap2 = IndexIDMap2Template; + + +/** splits input vectors in segments and assigns each segment to a sub-index + * used to distribute a MultiIndexQuantizer + */ +struct IndexSplitVectors: Index { + bool own_fields; + bool threaded; + std::vector sub_indexes; + idx_t sum_d; /// sum of dimensions seen so far + + explicit IndexSplitVectors (idx_t d, bool threaded = false); + + void add_sub_index (Index *); + void sync_with_sub_indexes (); + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void train(idx_t n, const float* x) override; + + void reset() override; + + ~IndexSplitVectors() override; +}; + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/MetricType.h b/core/src/index/thirdparty/faiss/MetricType.h new file mode 100644 index 0000000000..5248f5b801 --- /dev/null +++ b/core/src/index/thirdparty/faiss/MetricType.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_METRIC_TYPE_H +#define FAISS_METRIC_TYPE_H + +namespace faiss { + +/// The metric space for vector comparison for Faiss indices and algorithms. +/// +/// Most algorithms support both inner product and L2, with the flat +/// (brute-force) indices supporting additional metric types for vector +/// comparison. +enum MetricType { + METRIC_INNER_PRODUCT = 0, ///< maximum inner product search + METRIC_L2 = 1, ///< squared L2 search + METRIC_L1, ///< L1 (aka cityblock) + METRIC_Linf, ///< infinity distance + METRIC_Lp, ///< L_p distance, p is given by a faiss::Index + /// metric_arg + METRIC_Jaccard, + METRIC_Tanimoto, + METRIC_Hamming, + METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1 + METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0 + + /// some additional metrics defined in scipy.spatial.distance + METRIC_Canberra = 20, + METRIC_BrayCurtis, + METRIC_JensenShannon, +}; + +} + +#endif diff --git a/core/src/index/thirdparty/faiss/OnDiskInvertedLists.cpp b/core/src/index/thirdparty/faiss/OnDiskInvertedLists.cpp new file mode 100644 index 0000000000..2b798123d8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/OnDiskInvertedLists.cpp @@ -0,0 +1,674 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + +#include +#include +#include + +#include +#include + + +namespace faiss { + + +/********************************************** + * LockLevels + **********************************************/ + + +struct LockLevels { + /* There n times lock1(n), one lock2 and one lock3 + * Invariants: + * a single thread can hold one lock1(n) for some n + * a single thread can hold lock2, if it holds lock1(n) for some n + * a single thread can hold lock3, if it holds lock1(n) for some n + * AND lock2 AND no other thread holds lock1(m) for m != n + */ + pthread_mutex_t mutex1; + pthread_cond_t level1_cv; + pthread_cond_t level2_cv; + pthread_cond_t level3_cv; + + std::unordered_set level1_holders; // which level1 locks are held + int n_level2; // nb threads that wait on level2 + bool level3_in_use; // a threads waits on level3 + bool level2_in_use; + + LockLevels() { + pthread_mutex_init(&mutex1, nullptr); + pthread_cond_init(&level1_cv, nullptr); + pthread_cond_init(&level2_cv, nullptr); + pthread_cond_init(&level3_cv, nullptr); + n_level2 = 0; + level2_in_use = false; + level3_in_use = false; + } + + ~LockLevels() { + pthread_cond_destroy(&level1_cv); + pthread_cond_destroy(&level2_cv); + pthread_cond_destroy(&level3_cv); + pthread_mutex_destroy(&mutex1); + } + + void lock_1(int no) { + pthread_mutex_lock(&mutex1); + while (level3_in_use || level1_holders.count(no) > 0) { + pthread_cond_wait(&level1_cv, &mutex1); + } + level1_holders.insert(no); + pthread_mutex_unlock(&mutex1); + } + + void unlock_1(int no) { + pthread_mutex_lock(&mutex1); + assert(level1_holders.count(no) == 1); + level1_holders.erase(no); + if (level3_in_use) { // a writer is waiting + pthread_cond_signal(&level3_cv); + } else { + pthread_cond_broadcast(&level1_cv); + } + pthread_mutex_unlock(&mutex1); + } + + void lock_2() { + pthread_mutex_lock(&mutex1); + n_level2 ++; + if (level3_in_use) { // tell waiting level3 that we are blocked + pthread_cond_signal(&level3_cv); + } + while (level2_in_use) { + pthread_cond_wait(&level2_cv, &mutex1); + } + level2_in_use = true; + pthread_mutex_unlock(&mutex1); + } + + void unlock_2() { + pthread_mutex_lock(&mutex1); + level2_in_use = false; + n_level2 --; + pthread_cond_signal(&level2_cv); + pthread_mutex_unlock(&mutex1); + } + + void lock_3() { + pthread_mutex_lock(&mutex1); + level3_in_use = true; + // wait until there are no level1 holders anymore except the + // ones that are waiting on level2 (we are holding lock2) + while (level1_holders.size() > n_level2) { + pthread_cond_wait(&level3_cv, &mutex1); + } + // don't release the lock! + } + + void unlock_3() { + level3_in_use = false; + // wake up all level1_holders + pthread_cond_broadcast(&level1_cv); + pthread_mutex_unlock(&mutex1); + } + + void print () { + pthread_mutex_lock(&mutex1); + printf("State: level3_in_use=%d n_level2=%d level1_holders: [", level3_in_use, n_level2); + for (int k : level1_holders) { + printf("%d ", k); + } + printf("]\n"); + pthread_mutex_unlock(&mutex1); + } + +}; + +/********************************************** + * OngoingPrefetch + **********************************************/ + +struct OnDiskInvertedLists::OngoingPrefetch { + + struct Thread { + pthread_t pth; + OngoingPrefetch *pf; + + bool one_list () { + idx_t list_no = pf->get_next_list(); + if(list_no == -1) return false; + const OnDiskInvertedLists *od = pf->od; + od->locks->lock_1 (list_no); + size_t n = od->list_size (list_no); + const Index::idx_t *idx = od->get_ids (list_no); + const uint8_t *codes = od->get_codes (list_no); + int cs = 0; + for (size_t i = 0; i < n;i++) { + cs += idx[i]; + } + const idx_t *codes8 = (const idx_t*)codes; + idx_t n8 = n * od->code_size / 8; + + for (size_t i = 0; i < n8;i++) { + cs += codes8[i]; + } + od->locks->unlock_1(list_no); + + global_cs += cs & 1; + return true; + } + + }; + + std::vector threads; + + pthread_mutex_t list_ids_mutex; + std::vector list_ids; + int cur_list; + + // mutex for the list of tasks + pthread_mutex_t mutex; + + // pretext to avoid code below to be optimized out + static int global_cs; + + const OnDiskInvertedLists *od; + + explicit OngoingPrefetch (const OnDiskInvertedLists *od): od (od) + { + pthread_mutex_init (&mutex, nullptr); + pthread_mutex_init (&list_ids_mutex, nullptr); + cur_list = 0; + } + + static void* prefetch_list (void * arg) { + Thread *th = static_cast(arg); + + while (th->one_list()) ; + + return nullptr; + } + + idx_t get_next_list () { + idx_t list_no = -1; + pthread_mutex_lock (&list_ids_mutex); + if (cur_list >= 0 && cur_list < list_ids.size()) { + list_no = list_ids[cur_list++]; + } + pthread_mutex_unlock (&list_ids_mutex); + return list_no; + } + + void prefetch_lists (const idx_t *list_nos, int n) { + pthread_mutex_lock (&mutex); + pthread_mutex_lock (&list_ids_mutex); + list_ids.clear (); + pthread_mutex_unlock (&list_ids_mutex); + for (auto &th: threads) { + pthread_join (th.pth, nullptr); + } + + threads.resize (0); + cur_list = 0; + int nt = std::min (n, od->prefetch_nthread); + + if (nt > 0) { + // prepare tasks + for (int i = 0; i < n; i++) { + idx_t list_no = list_nos[i]; + if (list_no >= 0 && od->list_size(list_no) > 0) { + list_ids.push_back (list_no); + } + } + // prepare threads + threads.resize (nt); + for (Thread &th: threads) { + th.pf = this; + pthread_create (&th.pth, nullptr, prefetch_list, &th); + } + } + pthread_mutex_unlock (&mutex); + } + + ~OngoingPrefetch () { + pthread_mutex_lock (&mutex); + for (auto &th: threads) { + pthread_join (th.pth, nullptr); + } + pthread_mutex_unlock (&mutex); + pthread_mutex_destroy (&mutex); + pthread_mutex_destroy (&list_ids_mutex); + } + +}; + +int OnDiskInvertedLists::OngoingPrefetch::global_cs = 0; + + +void OnDiskInvertedLists::prefetch_lists (const idx_t *list_nos, int n) const +{ + pf->prefetch_lists (list_nos, n); +} + + + +/********************************************** + * OnDiskInvertedLists: mmapping + **********************************************/ + + +void OnDiskInvertedLists::do_mmap () +{ + const char *rw_flags = read_only ? "r" : "r+"; + int prot = read_only ? PROT_READ : PROT_WRITE | PROT_READ; + FILE *f = fopen (filename.c_str(), rw_flags); + FAISS_THROW_IF_NOT_FMT (f, "could not open %s in mode %s: %s", + filename.c_str(), rw_flags, strerror(errno)); + + uint8_t * ptro = (uint8_t*)mmap (nullptr, totsize, + prot, MAP_SHARED, fileno (f), 0); + + FAISS_THROW_IF_NOT_FMT (ptro != MAP_FAILED, + "could not mmap %s: %s", + filename.c_str(), + strerror(errno)); + ptr = ptro; + fclose (f); + +} + +void OnDiskInvertedLists::update_totsize (size_t new_size) +{ + + // unmap file + if (ptr != nullptr) { + int err = munmap (ptr, totsize); + FAISS_THROW_IF_NOT_FMT (err == 0, "munmap error: %s", + strerror(errno)); + } + if (totsize == 0) { + // must create file before truncating it + FILE *f = fopen (filename.c_str(), "w"); + FAISS_THROW_IF_NOT_FMT (f, "could not open %s in mode W: %s", + filename.c_str(), strerror(errno)); + fclose (f); + } + + if (new_size > totsize) { + if (!slots.empty() && + slots.back().offset + slots.back().capacity == totsize) { + slots.back().capacity += new_size - totsize; + } else { + slots.push_back (Slot(totsize, new_size - totsize)); + } + } else { + assert(!"not implemented"); + } + + totsize = new_size; + + // create file + printf ("resizing %s to %ld bytes\n", filename.c_str(), totsize); + + int err = truncate (filename.c_str(), totsize); + + FAISS_THROW_IF_NOT_FMT (err == 0, "truncate %s to %ld: %s", + filename.c_str(), totsize, + strerror(errno)); + do_mmap (); +} + + + + + + +/********************************************** + * OnDiskInvertedLists + **********************************************/ + +#define INVALID_OFFSET (size_t)(-1) + +OnDiskInvertedLists::List::List (): + size (0), capacity (0), offset (INVALID_OFFSET) +{} + +OnDiskInvertedLists::Slot::Slot (size_t offset, size_t capacity): + offset (offset), capacity (capacity) +{} + +OnDiskInvertedLists::Slot::Slot (): + offset (0), capacity (0) +{} + + + +OnDiskInvertedLists::OnDiskInvertedLists ( + size_t nlist, size_t code_size, + const char *filename): + InvertedLists (nlist, code_size), + filename (filename), + totsize (0), + ptr (nullptr), + read_only (false), + locks (new LockLevels ()), + pf (new OngoingPrefetch (this)), + prefetch_nthread (32) +{ + lists.resize (nlist); + + // slots starts empty +} + +OnDiskInvertedLists::OnDiskInvertedLists (): + OnDiskInvertedLists (0, 0, "") +{ +} + +OnDiskInvertedLists::~OnDiskInvertedLists () +{ + delete pf; + + // unmap all lists + if (ptr != nullptr) { + int err = munmap (ptr, totsize); + if (err != 0) { + fprintf(stderr, "mumap error: %s", + strerror(errno)); + } + } + delete locks; +} + + + + +size_t OnDiskInvertedLists::list_size(size_t list_no) const +{ + return lists[list_no].size; +} + + +const uint8_t * OnDiskInvertedLists::get_codes (size_t list_no) const +{ + if (lists[list_no].offset == INVALID_OFFSET) { + return nullptr; + } + + return ptr + lists[list_no].offset; +} + +const Index::idx_t * OnDiskInvertedLists::get_ids (size_t list_no) const +{ + if (lists[list_no].offset == INVALID_OFFSET) { + return nullptr; + } + + return (const idx_t*)(ptr + lists[list_no].offset + + code_size * lists[list_no].capacity); +} + + +void OnDiskInvertedLists::update_entries ( + size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids_in, const uint8_t *codes_in) +{ + FAISS_THROW_IF_NOT (!read_only); + if (n_entry == 0) return; + const List & l = lists[list_no]; + assert (n_entry + offset <= l.size); + idx_t *ids = const_cast(get_ids (list_no)); + memcpy (ids + offset, ids_in, sizeof(ids_in[0]) * n_entry); + uint8_t *codes = const_cast(get_codes (list_no)); + memcpy (codes + offset * code_size, codes_in, code_size * n_entry); +} + +size_t OnDiskInvertedLists::add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) +{ + FAISS_THROW_IF_NOT (!read_only); + locks->lock_1 (list_no); + size_t o = list_size (list_no); + resize_locked (list_no, n_entry + o); + update_entries (list_no, o, n_entry, ids, code); + locks->unlock_1 (list_no); + return o; +} + +void OnDiskInvertedLists::resize (size_t list_no, size_t new_size) +{ + FAISS_THROW_IF_NOT (!read_only); + locks->lock_1 (list_no); + resize_locked (list_no, new_size); + locks->unlock_1 (list_no); +} + + + +void OnDiskInvertedLists::resize_locked (size_t list_no, size_t new_size) +{ + List & l = lists[list_no]; + + if (new_size <= l.capacity && + new_size > l.capacity / 2) { + l.size = new_size; + return; + } + + // otherwise we release the current slot, and find a new one + + locks->lock_2 (); + free_slot (l.offset, l.capacity); + + List new_l; + + if (new_size == 0) { + new_l = List(); + } else { + new_l.size = new_size; + new_l.capacity = 1; + while (new_l.capacity < new_size) { + new_l.capacity *= 2; + } + new_l.offset = allocate_slot ( + new_l.capacity * (sizeof(idx_t) + code_size)); + } + + // copy common data + if (l.offset != new_l.offset) { + size_t n = std::min (new_size, l.size); + if (n > 0) { + memcpy (ptr + new_l.offset, get_codes(list_no), n * code_size); + memcpy (ptr + new_l.offset + new_l.capacity * code_size, + get_ids (list_no), n * sizeof(idx_t)); + } + } + + lists[list_no] = new_l; + locks->unlock_2 (); +} + +size_t OnDiskInvertedLists::allocate_slot (size_t capacity) { + // should hold lock2 + + auto it = slots.begin(); + while (it != slots.end() && it->capacity < capacity) { + it++; + } + + if (it == slots.end()) { + // not enough capacity + size_t new_size = totsize == 0 ? 32 : totsize * 2; + while (new_size - totsize < capacity) + new_size *= 2; + locks->lock_3 (); + update_totsize(new_size); + locks->unlock_3 (); + it = slots.begin(); + while (it != slots.end() && it->capacity < capacity) { + it++; + } + assert (it != slots.end()); + } + + size_t o = it->offset; + if (it->capacity == capacity) { + slots.erase (it); + } else { + // take from beginning of slot + it->capacity -= capacity; + it->offset += capacity; + } + + return o; +} + + + +void OnDiskInvertedLists::free_slot (size_t offset, size_t capacity) { + + // should hold lock2 + if (capacity == 0) return; + + auto it = slots.begin(); + while (it != slots.end() && it->offset <= offset) { + it++; + } + + size_t inf = 1UL << 60; + + size_t end_prev = inf; + if (it != slots.begin()) { + auto prev = it; + prev--; + end_prev = prev->offset + prev->capacity; + } + + size_t begin_next = 1L << 60; + if (it != slots.end()) { + begin_next = it->offset; + } + + assert (end_prev == inf || offset >= end_prev); + assert (offset + capacity <= begin_next); + + if (offset == end_prev) { + auto prev = it; + prev--; + if (offset + capacity == begin_next) { + prev->capacity += capacity + it->capacity; + slots.erase (it); + } else { + prev->capacity += capacity; + } + } else { + if (offset + capacity == begin_next) { + it->offset -= capacity; + it->capacity += capacity; + } else { + slots.insert (it, Slot (offset, capacity)); + } + } + + // TODO shrink global storage if needed +} + + +/***************************************** + * Compact form + *****************************************/ + +size_t OnDiskInvertedLists::merge_from (const InvertedLists **ils, int n_il, + bool verbose) +{ + FAISS_THROW_IF_NOT_MSG (totsize == 0, "works only on an empty InvertedLists"); + + std::vector sizes (nlist); + for (int i = 0; i < n_il; i++) { + const InvertedLists *il = ils[i]; + FAISS_THROW_IF_NOT (il->nlist == nlist && il->code_size == code_size); + + for (size_t j = 0; j < nlist; j++) { + sizes [j] += il->list_size(j); + } + } + + size_t cums = 0; + size_t ntotal = 0; + for (size_t j = 0; j < nlist; j++) { + ntotal += sizes[j]; + lists[j].size = 0; + lists[j].capacity = sizes[j]; + lists[j].offset = cums; + cums += lists[j].capacity * (sizeof(idx_t) + code_size); + } + + update_totsize (cums); + + + size_t nmerged = 0; + double t0 = getmillisecs(), last_t = t0; + +#pragma omp parallel for + for (size_t j = 0; j < nlist; j++) { + List & l = lists[j]; + for (int i = 0; i < n_il; i++) { + const InvertedLists *il = ils[i]; + size_t n_entry = il->list_size(j); + l.size += n_entry; + update_entries (j, l.size - n_entry, n_entry, + ScopedIds(il, j).get(), + ScopedCodes(il, j).get()); + } + assert (l.size == l.capacity); + if (verbose) { +#pragma omp critical + { + nmerged++; + double t1 = getmillisecs(); + if (t1 - last_t > 500) { + printf("merged %ld lists in %.3f s\r", + nmerged, (t1 - t0) / 1000.0); + fflush(stdout); + last_t = t1; + } + } + } + } + if(verbose) { + printf("\n"); + } + + return ntotal; +} + + +void OnDiskInvertedLists::crop_invlists(size_t l0, size_t l1) +{ + FAISS_THROW_IF_NOT(0 <= l0 && l0 <= l1 && l1 <= nlist); + + std::vector new_lists (l1 - l0); + memcpy (new_lists.data(), &lists[l0], (l1 - l0) * sizeof(List)); + + lists.swap(new_lists); + + nlist = l1 - l0; +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/OnDiskInvertedLists.h b/core/src/index/thirdparty/faiss/OnDiskInvertedLists.h new file mode 100644 index 0000000000..3476b48ca9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/OnDiskInvertedLists.h @@ -0,0 +1,127 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_ON_DISK_INVERTED_LISTS_H +#define FAISS_ON_DISK_INVERTED_LISTS_H + +#include +#include + +#include + +namespace faiss { + + +struct LockLevels; + +/** On-disk storage of inverted lists. + * + * The data is stored in a mmapped chunk of memory (base ptointer ptr, + * size totsize). Each list is a range of memory that contains (object + * List) that contains: + * + * - uint8_t codes[capacity * code_size] + * - followed by idx_t ids[capacity] + * + * in each of the arrays, the size <= capacity first elements are + * used, the rest is not initialized. + * + * Addition and resize are supported by: + * - roundind up the capacity of the lists to a power of two + * - maintaining a list of empty slots, sorted by size. + * - resizing the mmapped block is adjusted as needed. + * + * An OnDiskInvertedLists is compact if the size == capacity for all + * lists and there are no available slots. + * + * Addition to the invlists is slow. For incremental add it is better + * to use a default ArrayInvertedLists object and convert it to an + * OnDisk with merge_from. + * + * When it is known that a set of lists will be accessed, it is useful + * to call prefetch_lists, that launches a set of threads to read the + * lists in parallel. + */ +struct OnDiskInvertedLists: InvertedLists { + + struct List { + size_t size; // size of inverted list (entries) + size_t capacity; // allocated size (entries) + size_t offset; // offset in buffer (bytes) + List (); + }; + + // size nlist + std::vector lists; + + struct Slot { + size_t offset; // bytes + size_t capacity; // bytes + Slot (size_t offset, size_t capacity); + Slot (); + }; + + // size whatever space remains + std::list slots; + + std::string filename; + size_t totsize; + uint8_t *ptr; // mmap base pointer + bool read_only; /// are inverted lists mapped read-only + + OnDiskInvertedLists (size_t nlist, size_t code_size, + const char *filename); + + size_t list_size(size_t list_no) const override; + const uint8_t * get_codes (size_t list_no) const override; + const idx_t * get_ids (size_t list_no) const override; + + size_t add_entries ( + size_t list_no, size_t n_entry, + const idx_t* ids, const uint8_t *code) override; + + void update_entries (size_t list_no, size_t offset, size_t n_entry, + const idx_t *ids, const uint8_t *code) override; + + void resize (size_t list_no, size_t new_size) override; + + // copy all inverted lists into *this, in compact form (without + // allocating slots) + size_t merge_from (const InvertedLists **ils, int n_il, bool verbose=false); + + /// restrict the inverted lists to l0:l1 without touching the mmapped region + void crop_invlists(size_t l0, size_t l1); + + void prefetch_lists (const idx_t *list_nos, int nlist) const override; + + virtual ~OnDiskInvertedLists (); + + // private + + LockLevels * locks; + + // encapsulates the threads that are busy prefeteching + struct OngoingPrefetch; + OngoingPrefetch *pf; + int prefetch_nthread; + + void do_mmap (); + void update_totsize (size_t new_totsize); + void resize_locked (size_t list_no, size_t new_size); + size_t allocate_slot (size_t capacity); + void free_slot (size_t offset, size_t capacity); + + // empty constructor for the I/O functions + OnDiskInvertedLists (); +}; + + +} // namespace faiss + +#endif diff --git a/core/src/index/thirdparty/faiss/README.md b/core/src/index/thirdparty/faiss/README.md new file mode 100644 index 0000000000..299ad809da --- /dev/null +++ b/core/src/index/thirdparty/faiss/README.md @@ -0,0 +1,91 @@ +# Faiss + +Faiss is a library for efficient similarity search and clustering of dense vectors. It contains algorithms that search in sets of vectors of any size, up to ones that possibly do not fit in RAM. It also contains supporting code for evaluation and parameter tuning. Faiss is written in C++ with complete wrappers for Python/numpy. Some of the most useful algorithms are implemented on the GPU. It is developed by [Facebook AI Research](https://research.fb.com/category/facebook-ai-research-fair/). + +## NEWS + +*NEW: version 1.6.1 (2019-11-29) bugfix.* + +*NEW: version 1.6.0 (2019-10-15) code structure reorg, support for codec interface.* + +*NEW: version 1.5.3 (2019-06-24) fix performance regression in IndexIVF.* + +*NEW: version 1.5.2 (2019-05-27) the license was relaxed to MIT from BSD+Patents. Read LICENSE for details.* + +*NEW: version 1.5.0 (2018-12-19) GPU binary flat index and binary HNSW index* + +*NEW: version 1.4.0 (2018-08-30) no more crashes in pure Python code* + +*NEW: version 1.3.0 (2018-07-12) support for binary indexes* + +*NEW: latest commit (2018-02-22) supports on-disk storage of inverted indexes, see demos/demo_ondisk_ivf.py* + +*NEW: latest commit (2018-01-09) includes an implementation of the HNSW indexing method, see benchs/bench_hnsw.py* + +*NEW: there is now a Facebook public discussion group for Faiss users at https://www.facebook.com/groups/faissusers/* + +*NEW: on 2017-07-30, the license on Faiss was relaxed to BSD from CC-BY-NC. Read LICENSE for details.* + +## Introduction + +Faiss contains several methods for similarity search. It assumes that the instances are represented as vectors and are identified by an integer, and that the vectors can be compared with L2 (Euclidean) distances or dot products. Vectors that are similar to a query vector are those that have the lowest L2 distance or the highest dot product with the query vector. It also supports cosine similarity, since this is a dot product on normalized vectors. + +Most of the methods, like those based on binary vectors and compact quantization codes, solely use a compressed representation of the vectors and do not require to keep the original vectors. This generally comes at the cost of a less precise search but these methods can scale to billions of vectors in main memory on a single server. + +The GPU implementation can accept input from either CPU or GPU memory. On a server with GPUs, the GPU indexes can be used a drop-in replacement for the CPU indexes (e.g., replace `IndexFlatL2` with `GpuIndexFlatL2`) and copies to/from GPU memory are handled automatically. Results will be faster however if both input and output remain resident on the GPU. Both single and multi-GPU usage is supported. + +## Building + +The library is mostly implemented in C++, with optional GPU support provided via CUDA, and an optional Python interface. The CPU version requires a BLAS library. It compiles with a Makefile and can be packaged in a docker image. See [INSTALL.md](INSTALL.md) for details. + +## How Faiss works + +Faiss is built around an index type that stores a set of vectors, and provides a function to search in them with L2 and/or dot product vector comparison. Some index types are simple baselines, such as exact search. Most of the available indexing structures correspond to various trade-offs with respect to + +- search time +- search quality +- memory used per index vector +- training time +- need for external data for unsupervised training + +The optional GPU implementation provides what is likely (as of March 2017) the fastest exact and approximate (compressed-domain) nearest neighbor search implementation for high-dimensional vectors, fastest Lloyd's k-means, and fastest small k-selection algorithm known. [The implementation is detailed here](https://arxiv.org/abs/1702.08734). + +## Full documentation of Faiss + +The following are entry points for documentation: + +- the full documentation, including a [tutorial](https://github.com/facebookresearch/faiss/wiki/Getting-started), a [FAQ](https://github.com/facebookresearch/faiss/wiki/FAQ) and a [troubleshooting section](https://github.com/facebookresearch/faiss/wiki/Troubleshooting) can be found on the [wiki page](http://github.com/facebookresearch/faiss/wiki) +- the [doxygen documentation](http://rawgithub.com/facebookresearch/faiss/master/docs/html/annotated.html) gives per-class information +- to reproduce results from our research papers, [Polysemous codes](https://arxiv.org/abs/1609.01882) and [Billion-scale similarity search with GPUs](https://arxiv.org/abs/1702.08734), refer to the [benchmarks README](benchs/README.md). For [ +Link and code: Fast indexing with graphs and compact regression codes](https://arxiv.org/abs/1804.09996), see the [link_and_code README](benchs/link_and_code) + +## Authors + +The main authors of Faiss are: +- [Hervé Jégou](https://github.com/jegou) initiated the Faiss project and wrote its first implementation +- [Matthijs Douze](https://github.com/mdouze) implemented most of the CPU Faiss +- [Jeff Johnson](https://github.com/wickedfoo) implemented all of the GPU Faiss +- [Lucas Hosseini](https://github.com/beauby) implemented the binary indexes + +## Reference + +Reference to cite when you use Faiss in a research paper: + +``` +@article{JDH17, + title={Billion-scale similarity search with GPUs}, + author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, + journal={arXiv preprint arXiv:1702.08734}, + year={2017} +} +``` + +## Join the Faiss community + +For public discussion of Faiss or for questions, there is a Facebook public discussion group at https://www.facebook.com/groups/faissusers/ + +We monitor the [issues page](http://github.com/facebookresearch/faiss/issues) of the repository. You can report bugs, ask questions, etc. + +## License + +Faiss is MIT-licensed. diff --git a/core/src/index/thirdparty/faiss/VectorTransform.cpp b/core/src/index/thirdparty/faiss/VectorTransform.cpp new file mode 100644 index 0000000000..1a6d920171 --- /dev/null +++ b/core/src/index/thirdparty/faiss/VectorTransform.cpp @@ -0,0 +1,1158 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace faiss; + + +extern "C" { + +// this is to keep the clang syntax checker happy +#ifndef FINTEGER +#define FINTEGER int +#endif + + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ ( + const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, + FINTEGER *ldb, float *beta, + float *c, FINTEGER *ldc); + +int dgemm_ ( + const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const double *alpha, const double *a, + FINTEGER *lda, const double *b, + FINTEGER *ldb, double *beta, + double *c, FINTEGER *ldc); + +int ssyrk_ ( + const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k, + float *alpha, float *a, FINTEGER *lda, + float *beta, float *c, FINTEGER *ldc); + +/* Lapack functions from http://www.netlib.org/clapack/old/single/ */ + +int ssyev_ ( + const char *jobz, const char *uplo, FINTEGER *n, float *a, + FINTEGER *lda, float *w, float *work, FINTEGER *lwork, + FINTEGER *info); + +int dsyev_ ( + const char *jobz, const char *uplo, FINTEGER *n, double *a, + FINTEGER *lda, double *w, double *work, FINTEGER *lwork, + FINTEGER *info); + +int sgesvd_( + const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n, + float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt, + FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info); + + +int dgesvd_( + const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n, + double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt, + FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info); + +} + +/********************************************* + * VectorTransform + *********************************************/ + + + +float * VectorTransform::apply (Index::idx_t n, const float * x) const +{ + float * xt = new float[n * d_out]; + apply_noalloc (n, x, xt); + return xt; +} + + +void VectorTransform::train (idx_t, const float *) { + // does nothing by default +} + + +void VectorTransform::reverse_transform ( + idx_t , const float *, + float *) const +{ + FAISS_THROW_MSG ("reverse transform not implemented"); +} + + + + +/********************************************* + * LinearTransform + *********************************************/ + +/// both d_in > d_out and d_out < d_in are supported +LinearTransform::LinearTransform (int d_in, int d_out, + bool have_bias): + VectorTransform (d_in, d_out), have_bias (have_bias), + is_orthonormal (false), verbose (false) +{ + is_trained = false; // will be trained when A and b are initialized +} + +void LinearTransform::apply_noalloc (Index::idx_t n, const float * x, + float * xt) const +{ + FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet"); + + float c_factor; + if (have_bias) { + FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized"); + float * xi = xt; + for (int i = 0; i < n; i++) + for(int j = 0; j < d_out; j++) + *xi++ = b[j]; + c_factor = 1.0; + } else { + c_factor = 0.0; + } + + FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in, + "Transformation matrix not initialized"); + + float one = 1; + FINTEGER nbiti = d_out, ni = n, di = d_in; + sgemm_ ("Transposed", "Not transposed", + &nbiti, &ni, &di, + &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti); + +} + + +void LinearTransform::transform_transpose (idx_t n, const float * y, + float *x) const +{ + if (have_bias) { // allocate buffer to store bias-corrected data + float *y_new = new float [n * d_out]; + const float *yr = y; + float *yw = y_new; + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < d_out; j++) { + *yw++ = *yr++ - b [j]; + } + } + y = y_new; + } + + { + FINTEGER dii = d_in, doi = d_out, ni = n; + float one = 1.0, zero = 0.0; + sgemm_ ("Not", "Not", &dii, &ni, &doi, + &one, A.data (), &dii, y, &doi, &zero, x, &dii); + } + + if (have_bias) delete [] y; +} + +void LinearTransform::set_is_orthonormal () +{ + if (d_out > d_in) { + // not clear what we should do in this case + is_orthonormal = false; + return; + } + if (d_out == 0) { // borderline case, unnormalized matrix + is_orthonormal = true; + return; + } + + double eps = 4e-5; + FAISS_ASSERT(A.size() >= d_out * d_in); + { + std::vector ATA(d_out * d_out); + FINTEGER dii = d_in, doi = d_out; + float one = 1.0, zero = 0.0; + + sgemm_ ("Transposed", "Not", &doi, &doi, &dii, + &one, A.data (), &dii, + A.data(), &dii, + &zero, ATA.data(), &doi); + + is_orthonormal = true; + for (long i = 0; i < d_out; i++) { + for (long j = 0; j < d_out; j++) { + float v = ATA[i + j * d_out]; + if (i == j) v-= 1; + if (fabs(v) > eps) { + is_orthonormal = false; + } + } + } + } + +} + + +void LinearTransform::reverse_transform (idx_t n, const float * xt, + float *x) const +{ + if (is_orthonormal) { + transform_transpose (n, xt, x); + } else { + FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices"); + } +} + + +void LinearTransform::print_if_verbose ( + const char*name, const std::vector &mat, + int n, int d) const +{ + if (!verbose) return; + printf("matrix %s: %d*%d [\n", name, n, d); + FAISS_THROW_IF_NOT (mat.size() >= n * d); + for (int i = 0; i < n; i++) { + for (int j = 0; j < d; j++) { + printf("%10.5g ", mat[i * d + j]); + } + printf("\n"); + } + printf("]\n"); +} + +/********************************************* + * RandomRotationMatrix + *********************************************/ + +void RandomRotationMatrix::init (int seed) +{ + + if(d_out <= d_in) { + A.resize (d_out * d_in); + float *q = A.data(); + float_randn(q, d_out * d_in, seed); + matrix_qr(d_in, d_out, q); + } else { + // use tight-frame transformation + A.resize (d_out * d_out); + float *q = A.data(); + float_randn(q, d_out * d_out, seed); + matrix_qr(d_out, d_out, q); + // remove columns + int i, j; + for (i = 0; i < d_out; i++) { + for(j = 0; j < d_in; j++) { + q[i * d_in + j] = q[i * d_out + j]; + } + } + A.resize(d_in * d_out); + } + is_orthonormal = true; + is_trained = true; +} + +void RandomRotationMatrix::train (Index::idx_t /*n*/, const float */*x*/) +{ + // initialize with some arbitrary seed + init (12345); +} + + +/********************************************* + * PCAMatrix + *********************************************/ + +PCAMatrix::PCAMatrix (int d_in, int d_out, + float eigen_power, bool random_rotation): + LinearTransform(d_in, d_out, true), + eigen_power(eigen_power), random_rotation(random_rotation) +{ + is_trained = false; + max_points_per_d = 1000; + balanced_bins = 0; +} + + +namespace { + +/// Compute the eigenvalue decomposition of symmetric matrix cov, +/// dimensions d_in-by-d_in. Output eigenvectors in cov. + +void eig(size_t d_in, double *cov, double *eigenvalues, int verbose) +{ + { // compute eigenvalues and vectors + FINTEGER info = 0, lwork = -1, di = d_in; + double workq; + + dsyev_ ("Vectors as well", "Upper", + &di, cov, &di, eigenvalues, &workq, &lwork, &info); + lwork = FINTEGER(workq); + double *work = new double[lwork]; + + dsyev_ ("Vectors as well", "Upper", + &di, cov, &di, eigenvalues, work, &lwork, &info); + + delete [] work; + + if (info != 0) { + fprintf (stderr, "WARN ssyev info returns %d, " + "a very bad PCA matrix is learnt\n", + int(info)); + // do not throw exception, as the matrix could still be useful + } + + + if(verbose && d_in <= 10) { + printf("info=%ld new eigvals=[", long(info)); + for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]); + printf("]\n"); + + double *ci = cov; + printf("eigenvecs=\n"); + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + printf("%10.4g ", *ci++); + printf("\n"); + } + } + + } + + // revert order of eigenvectors & values + + for(int i = 0; i < d_in / 2; i++) { + + std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]); + double *v1 = cov + i * d_in; + double *v2 = cov + (d_in - 1 - i) * d_in; + for(int j = 0; j < d_in; j++) + std::swap(v1[j], v2[j]); + } + +} + + +} + +void PCAMatrix::train (Index::idx_t n, const float *x) +{ + const float * x_in = x; + + x = fvecs_maybe_subsample (d_in, (size_t*)&n, + max_points_per_d * d_in, x, verbose); + + ScopeDeleter del_x (x != x_in ? x : nullptr); + + // compute mean + mean.clear(); mean.resize(d_in, 0.0); + if (have_bias) { // we may want to skip the bias + const float *xi = x; + for (int i = 0; i < n; i++) { + for(int j = 0; j < d_in; j++) + mean[j] += *xi++; + } + for(int j = 0; j < d_in; j++) + mean[j] /= n; + } + if(verbose) { + printf("mean=["); + for(int j = 0; j < d_in; j++) printf("%g ", mean[j]); + printf("]\n"); + } + + if(n >= d_in) { + // compute covariance matrix, store it in PCA matrix + PCAMat.resize(d_in * d_in); + float * cov = PCAMat.data(); + { // initialize with mean * mean^T term + float *ci = cov; + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + *ci++ = - n * mean[i] * mean[j]; + } + } + { + FINTEGER di = d_in, ni = n; + float one = 1.0; + ssyrk_ ("Up", "Non transposed", + &di, &ni, &one, (float*)x, &di, &one, cov, &di); + + } + if(verbose && d_in <= 10) { + float *ci = cov; + printf("cov=\n"); + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + + std::vector covd (d_in * d_in); + for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i]; + + std::vector eigenvaluesd (d_in); + + eig (d_in, covd.data (), eigenvaluesd.data (), verbose); + + for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i]; + eigenvalues.resize (d_in); + + for (size_t i = 0; i < d_in; i++) + eigenvalues [i] = eigenvaluesd [i]; + + + } else { + + std::vector xc (n * d_in); + + for (size_t i = 0; i < n; i++) + for(size_t j = 0; j < d_in; j++) + xc [i * d_in + j] = x [i * d_in + j] - mean[j]; + + // compute Gram matrix + std::vector gram (n * n); + { + FINTEGER di = d_in, ni = n; + float one = 1.0, zero = 0.0; + ssyrk_ ("Up", "Transposed", + &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni); + } + + if(verbose && d_in <= 10) { + float *ci = gram.data(); + printf("gram=\n"); + for(int i = 0; i < n; i++) { + for(int j = 0; j < n; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + + std::vector gramd (n * n); + for (size_t i = 0; i < n * n; i++) + gramd [i] = gram [i]; + + std::vector eigenvaluesd (n); + + // eig will fill in only the n first eigenvals + + eig (n, gramd.data (), eigenvaluesd.data (), verbose); + + PCAMat.resize(d_in * n); + + for (size_t i = 0; i < n * n; i++) + gram [i] = gramd [i]; + + eigenvalues.resize (d_in); + // fill in only the n first ones + for (size_t i = 0; i < n; i++) + eigenvalues [i] = eigenvaluesd [i]; + + { // compute PCAMat = x' * v + FINTEGER di = d_in, ni = n; + float one = 1.0; + + sgemm_ ("Non", "Non Trans", + &di, &ni, &ni, + &one, xc.data(), &di, gram.data(), &ni, + &one, PCAMat.data(), &di); + } + + if(verbose && d_in <= 10) { + float *ci = PCAMat.data(); + printf("PCAMat=\n"); + for(int i = 0; i < n; i++) { + for(int j = 0; j < d_in; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + fvec_renorm_L2 (d_in, n, PCAMat.data()); + + } + + prepare_Ab(); + is_trained = true; +} + +void PCAMatrix::copy_from (const PCAMatrix & other) +{ + FAISS_THROW_IF_NOT (other.is_trained); + mean = other.mean; + eigenvalues = other.eigenvalues; + PCAMat = other.PCAMat; + prepare_Ab (); + is_trained = true; +} + +void PCAMatrix::prepare_Ab () +{ + FAISS_THROW_IF_NOT_FMT ( + d_out * d_in <= PCAMat.size(), + "PCA matrix cannot output %d dimensions from %d ", + d_out, d_in); + + if (!random_rotation) { + A = PCAMat; + A.resize(d_out * d_in); // strip off useless dimensions + + // first scale the components + if (eigen_power != 0) { + float *ai = A.data(); + for (int i = 0; i < d_out; i++) { + float factor = pow(eigenvalues[i], eigen_power); + for(int j = 0; j < d_in; j++) + *ai++ *= factor; + } + } + + if (balanced_bins != 0) { + FAISS_THROW_IF_NOT (d_out % balanced_bins == 0); + int dsub = d_out / balanced_bins; + std::vector Ain; + std::swap(A, Ain); + A.resize(d_out * d_in); + + std::vector accu(balanced_bins); + std::vector counter(balanced_bins); + + // greedy assignment + for (int i = 0; i < d_out; i++) { + // find best bin + int best_j = -1; + float min_w = 1e30; + for (int j = 0; j < balanced_bins; j++) { + if (counter[j] < dsub && accu[j] < min_w) { + min_w = accu[j]; + best_j = j; + } + } + int row_dst = best_j * dsub + counter[best_j]; + accu[best_j] += eigenvalues[i]; + counter[best_j] ++; + memcpy (&A[row_dst * d_in], &Ain[i * d_in], + d_in * sizeof (A[0])); + } + + if (verbose) { + printf(" bin accu=["); + for (int i = 0; i < balanced_bins; i++) + printf("%g ", accu[i]); + printf("]\n"); + } + } + + + } else { + FAISS_THROW_IF_NOT_MSG (balanced_bins == 0, + "both balancing bins and applying a random rotation " + "does not make sense"); + RandomRotationMatrix rr(d_out, d_out); + + rr.init(5); + + // apply scaling on the rotation matrix (right multiplication) + if (eigen_power != 0) { + for (int i = 0; i < d_out; i++) { + float factor = pow(eigenvalues[i], eigen_power); + for(int j = 0; j < d_out; j++) + rr.A[j * d_out + i] *= factor; + } + } + + A.resize(d_in * d_out); + { + FINTEGER dii = d_in, doo = d_out; + float one = 1.0, zero = 0.0; + + sgemm_ ("Not", "Not", &dii, &doo, &doo, + &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero, + A.data(), &dii); + + } + + } + + b.clear(); b.resize(d_out); + + for (int i = 0; i < d_out; i++) { + float accu = 0; + for (int j = 0; j < d_in; j++) + accu -= mean[j] * A[j + i * d_in]; + b[i] = accu; + } + + is_orthonormal = eigen_power == 0; + +} + +/********************************************* + * ITQMatrix + *********************************************/ + +ITQMatrix::ITQMatrix (int d): + LinearTransform(d, d, false), + max_iter (50), + seed (123) +{ +} + + +/** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */ +void ITQMatrix::train (Index::idx_t n, const float* xf) +{ + size_t d = d_in; + std::vector rotation (d * d); + + if (init_rotation.size() == d * d) { + memcpy (rotation.data(), init_rotation.data(), + d * d * sizeof(rotation[0])); + } else { + RandomRotationMatrix rrot (d, d); + rrot.init (seed); + for (size_t i = 0; i < d * d; i++) { + rotation[i] = rrot.A[i]; + } + } + + std::vector x (n * d); + + for (size_t i = 0; i < n * d; i++) { + x[i] = xf[i]; + } + + std::vector rotated_x (n * d), cov_mat (d * d); + std::vector u (d * d), vt (d * d), singvals (d); + + for (int i = 0; i < max_iter; i++) { + print_if_verbose ("rotation", rotation, d, d); + { // rotated_data = np.dot(training_data, rotation) + FINTEGER di = d, ni = n; + double one = 1, zero = 0; + dgemm_ ("N", "N", &di, &ni, &di, + &one, rotation.data(), &di, x.data(), &di, + &zero, rotated_x.data(), &di); + } + print_if_verbose ("rotated_x", rotated_x, n, d); + // binarize + for (size_t j = 0; j < n * d; j++) { + rotated_x[j] = rotated_x[j] < 0 ? -1 : 1; + } + // covariance matrix + { // rotated_data = np.dot(training_data, rotation) + FINTEGER di = d, ni = n; + double one = 1, zero = 0; + dgemm_ ("N", "T", &di, &di, &ni, + &one, rotated_x.data(), &di, x.data(), &di, + &zero, cov_mat.data(), &di); + } + print_if_verbose ("cov_mat", cov_mat, d, d); + // SVD + { + + FINTEGER di = d; + FINTEGER lwork = -1, info; + double lwork1; + + // workspace query + dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di, + singvals.data(), u.data(), &di, + vt.data(), &di, + &lwork1, &lwork, &info); + + FAISS_THROW_IF_NOT (info == 0); + lwork = size_t (lwork1); + std::vector work (lwork); + dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di, + singvals.data(), u.data(), &di, + vt.data(), &di, + work.data(), &lwork, &info); + FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info); + + } + print_if_verbose ("u", u, d, d); + print_if_verbose ("vt", vt, d, d); + // update rotation + { + FINTEGER di = d; + double one = 1, zero = 0; + dgemm_ ("N", "T", &di, &di, &di, + &one, u.data(), &di, vt.data(), &di, + &zero, rotation.data(), &di); + } + print_if_verbose ("final rot", rotation, d, d); + + } + A.resize (d * d); + for (size_t i = 0; i < d; i++) { + for (size_t j = 0; j < d; j++) { + A[i + d * j] = rotation[j + d * i]; + } + } + is_trained = true; + +} + +ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca): + VectorTransform (d_in, d_out), + do_pca (do_pca), + itq (d_out), + pca_then_itq (d_in, d_out, false) +{ + if (!do_pca) { + FAISS_THROW_IF_NOT (d_in == d_out); + } + max_train_per_dim = 10; + is_trained = false; +} + + + + +void ITQTransform::train (idx_t n, const float *x) +{ + FAISS_THROW_IF_NOT (!is_trained); + + const float * x_in = x; + size_t max_train_points = std::max(d_in * max_train_per_dim, 32768); + x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x); + + ScopeDeleter del_x (x != x_in ? x : nullptr); + + std::unique_ptr x_norm(new float[n * d_in]); + { // normalize + int d = d_in; + + mean.resize (d, 0); + for (idx_t i = 0; i < n; i++) { + for (idx_t j = 0; j < d; j++) { + mean[j] += x[i * d + j]; + } + } + for (idx_t j = 0; j < d; j++) { + mean[j] /= n; + } + for (idx_t i = 0; i < n; i++) { + for (idx_t j = 0; j < d; j++) { + x_norm[i * d + j] = x[i * d + j] - mean[j]; + } + } + fvec_renorm_L2 (d_in, n, x_norm.get()); + } + + // train PCA + + PCAMatrix pca (d_in, d_out); + float *x_pca; + std::unique_ptr x_pca_del; + if (do_pca) { + pca.have_bias = false; // for consistency with reference implem + pca.train (n, x_norm.get()); + x_pca = pca.apply (n, x_norm.get()); + x_pca_del.reset(x_pca); + } else { + x_pca = x_norm.get(); + } + + // train ITQ + itq.train (n, x_pca); + + // merge PCA and ITQ + if (do_pca) { + FINTEGER di = d_out, dini = d_in; + float one = 1, zero = 0; + pca_then_itq.A.resize(d_in * d_out); + sgemm_ ("N", "N", &dini, &di, &di, + &one, pca.A.data(), &dini, + itq.A.data(), &di, + &zero, pca_then_itq.A.data(), &dini); + } else { + pca_then_itq.A = itq.A; + } + pca_then_itq.is_trained = true; + is_trained = true; +} + +void ITQTransform::apply_noalloc (Index::idx_t n, const float * x, + float * xt) const +{ + FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet"); + + std::unique_ptr x_norm(new float[n * d_in]); + { // normalize + int d = d_in; + for (idx_t i = 0; i < n; i++) { + for (idx_t j = 0; j < d; j++) { + x_norm[i * d + j] = x[i * d + j] - mean[j]; + } + } + // this is not really useful if we are going to binarize right + // afterwards but OK + fvec_renorm_L2 (d_in, n, x_norm.get()); + } + + pca_then_itq.apply_noalloc (n, x_norm.get(), xt); +} + +/********************************************* + * OPQMatrix + *********************************************/ + + +OPQMatrix::OPQMatrix (int d, int M, int d2): + LinearTransform (d, d2 == -1 ? d : d2, false), M(M), + niter (50), + niter_pq (4), niter_pq_0 (40), + verbose(false), + pq(nullptr) +{ + is_trained = false; + // OPQ is quite expensive to train, so set this right. + max_train_points = 256 * 256; + pq = nullptr; +} + + + +void OPQMatrix::train (Index::idx_t n, const float *x) +{ + + const float * x_in = x; + + x = fvecs_maybe_subsample (d_in, (size_t*)&n, + max_train_points, x, verbose); + + ScopeDeleter del_x (x != x_in ? x : nullptr); + + // To support d_out > d_in, we pad input vectors with 0s to d_out + size_t d = d_out <= d_in ? d_in : d_out; + size_t d2 = d_out; + +#if 0 + // what this test shows: the only way of getting bit-exact + // reproducible results with sgeqrf and sgesvd seems to be forcing + // single-threading. + { // test repro + std::vector r (d * d); + float * rotation = r.data(); + float_randn (rotation, d * d, 1234); + printf("CS0: %016lx\n", + ivec_checksum (128*128, (int*)rotation)); + matrix_qr (d, d, rotation); + printf("CS1: %016lx\n", + ivec_checksum (128*128, (int*)rotation)); + return; + } +#endif + + if (verbose) { + printf ("OPQMatrix::train: training an OPQ rotation matrix " + "for M=%d from %ld vectors in %dD -> %dD\n", + M, n, d_in, d_out); + } + + std::vector xtrain (n * d); + // center x + { + std::vector sum (d); + const float *xi = x; + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < d_in; j++) + sum [j] += *xi++; + } + for (int i = 0; i < d; i++) sum[i] /= n; + float *yi = xtrain.data(); + xi = x; + for (size_t i = 0; i < n; i++) { + for (int j = 0; j < d_in; j++) + *yi++ = *xi++ - sum[j]; + yi += d - d_in; + } + } + float *rotation; + + if (A.size () == 0) { + A.resize (d * d); + rotation = A.data(); + if (verbose) + printf(" OPQMatrix::train: making random %ld*%ld rotation\n", + d, d); + float_randn (rotation, d * d, 1234); + matrix_qr (d, d, rotation); + // we use only the d * d2 upper part of the matrix + A.resize (d * d2); + } else { + FAISS_THROW_IF_NOT (A.size() == d * d2); + rotation = A.data(); + } + + std::vector + xproj (d2 * n), pq_recons (d2 * n), xxr (d * n), + tmp(d * d * 4); + + + ProductQuantizer pq_default (d2, M, 8); + ProductQuantizer &pq_regular = pq ? *pq : pq_default; + std::vector codes (pq_regular.code_size * n); + + double t0 = getmillisecs(); + for (int iter = 0; iter < niter; iter++) { + + { // torch.mm(xtrain, rotation:t()) + FINTEGER di = d, d2i = d2, ni = n; + float zero = 0, one = 1; + sgemm_ ("Transposed", "Not transposed", + &d2i, &ni, &di, + &one, rotation, &di, + xtrain.data(), &di, + &zero, xproj.data(), &d2i); + } + + pq_regular.cp.max_points_per_centroid = 1000; + pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq; + pq_regular.verbose = verbose; + pq_regular.train (n, xproj.data()); + + if (verbose) { + printf(" encode / decode\n"); + } + if (pq_regular.assign_index) { + pq_regular.compute_codes_with_assign_index + (xproj.data(), codes.data(), n); + } else { + pq_regular.compute_codes (xproj.data(), codes.data(), n); + } + pq_regular.decode (codes.data(), pq_recons.data(), n); + + float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n; + + if (verbose) + printf (" Iteration %d (%d PQ iterations):" + "%.3f s, obj=%g\n", iter, pq_regular.cp.niter, + (getmillisecs () - t0) / 1000.0, pq_err); + + { + float *u = tmp.data(), *vt = &tmp [d * d]; + float *sing_val = &tmp [2 * d * d]; + FINTEGER di = d, d2i = d2, ni = n; + float one = 1, zero = 0; + + if (verbose) { + printf(" X * recons\n"); + } + // torch.mm(xtrain:t(), pq_recons) + sgemm_ ("Not", "Transposed", + &d2i, &di, &ni, + &one, pq_recons.data(), &d2i, + xtrain.data(), &di, + &zero, xxr.data(), &d2i); + + + FINTEGER lwork = -1, info = -1; + float worksz; + // workspace query + sgesvd_ ("All", "All", + &d2i, &di, xxr.data(), &d2i, + sing_val, + vt, &d2i, u, &di, + &worksz, &lwork, &info); + + lwork = int(worksz); + std::vector work (lwork); + // u and vt swapped + sgesvd_ ("All", "All", + &d2i, &di, xxr.data(), &d2i, + sing_val, + vt, &d2i, u, &di, + work.data(), &lwork, &info); + + sgemm_ ("Transposed", "Transposed", + &di, &d2i, &d2i, + &one, u, &di, vt, &d2i, + &zero, rotation, &di); + + } + pq_regular.train_type = ProductQuantizer::Train_hot_start; + } + + // revert A matrix + if (d > d_in) { + for (long i = 0; i < d_out; i++) + memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in); + A.resize (d_in * d_out); + } + + is_trained = true; + is_orthonormal = true; +} + + +/********************************************* + * NormalizationTransform + *********************************************/ + +NormalizationTransform::NormalizationTransform (int d, float norm): + VectorTransform (d, d), norm (norm) +{ +} + +NormalizationTransform::NormalizationTransform (): + VectorTransform (-1, -1), norm (-1) +{ +} + +void NormalizationTransform::apply_noalloc + (idx_t n, const float* x, float* xt) const +{ + if (norm == 2.0) { + memcpy (xt, x, sizeof (x[0]) * n * d_in); + fvec_renorm_L2 (d_in, n, xt); + } else { + FAISS_THROW_MSG ("not implemented"); + } +} + +void NormalizationTransform::reverse_transform (idx_t n, const float* xt, + float* x) const +{ + memcpy (x, xt, sizeof (xt[0]) * n * d_in); +} + +/********************************************* + * CenteringTransform + *********************************************/ + +CenteringTransform::CenteringTransform (int d): + VectorTransform (d, d) +{ + is_trained = false; +} + +void CenteringTransform::train(Index::idx_t n, const float *x) { + FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector"); + mean.resize (d_in, 0); + for (idx_t i = 0; i < n; i++) { + for (size_t j = 0; j < d_in; j++) { + mean[j] += *x++; + } + } + + for (size_t j = 0; j < d_in; j++) { + mean[j] /= n; + } + is_trained = true; +} + + +void CenteringTransform::apply_noalloc + (idx_t n, const float* x, float* xt) const +{ + FAISS_THROW_IF_NOT (is_trained); + + for (idx_t i = 0; i < n; i++) { + for (size_t j = 0; j < d_in; j++) { + *xt++ = *x++ - mean[j]; + } + } +} + +void CenteringTransform::reverse_transform (idx_t n, const float* xt, + float* x) const +{ + FAISS_THROW_IF_NOT (is_trained); + + for (idx_t i = 0; i < n; i++) { + for (size_t j = 0; j < d_in; j++) { + *x++ = *xt++ + mean[j]; + } + } + +} + + + + + +/********************************************* + * RemapDimensionsTransform + *********************************************/ + + +RemapDimensionsTransform::RemapDimensionsTransform ( + int d_in, int d_out, const int *map_in): + VectorTransform (d_in, d_out) +{ + map.resize (d_out); + for (int i = 0; i < d_out; i++) { + map[i] = map_in[i]; + FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in)); + } +} + +RemapDimensionsTransform::RemapDimensionsTransform ( + int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out) +{ + map.resize (d_out, -1); + + if (uniform) { + if (d_in < d_out) { + for (int i = 0; i < d_in; i++) { + map [i * d_out / d_in] = i; + } + } else { + for (int i = 0; i < d_out; i++) { + map [i] = i * d_in / d_out; + } + } + } else { + for (int i = 0; i < d_in && i < d_out; i++) + map [i] = i; + } +} + + +void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x, + float *xt) const +{ + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < d_out; j++) { + xt[j] = map[j] < 0 ? 0 : x[map[j]]; + } + x += d_in; + xt += d_out; + } +} + +void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt, + float *x) const +{ + memset (x, 0, sizeof (*x) * n * d_in); + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < d_out; j++) { + if (map[j] >= 0) x[map[j]] = xt[j]; + } + x += d_in; + xt += d_out; + } +} diff --git a/core/src/index/thirdparty/faiss/VectorTransform.h b/core/src/index/thirdparty/faiss/VectorTransform.h new file mode 100644 index 0000000000..4b55245b07 --- /dev/null +++ b/core/src/index/thirdparty/faiss/VectorTransform.h @@ -0,0 +1,322 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_VECTOR_TRANSFORM_H +#define FAISS_VECTOR_TRANSFORM_H + +/** Defines a few objects that apply transformations to a set of + * vectors Often these are pre-processing steps. + */ + +#include +#include + +#include + + +namespace faiss { + + +/** Any transformation applied on a set of vectors */ +struct VectorTransform { + + typedef Index::idx_t idx_t; + + int d_in; ///! input dimension + int d_out; ///! output dimension + + explicit VectorTransform (int d_in = 0, int d_out = 0): + d_in(d_in), d_out(d_out), is_trained(true) + {} + + + /// set if the VectorTransform does not require training, or if + /// training is done already + bool is_trained; + + + /** Perform training on a representative set of vectors. Does + * nothing by default. + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void train (idx_t n, const float *x); + + /** apply the random roation, return new allocated matrix + * @param x size n * d_in + * @return size n * d_out + */ + float *apply (idx_t n, const float * x) const; + + /// same as apply, but result is pre-allocated + virtual void apply_noalloc (idx_t n, const float * x, + float *xt) const = 0; + + /// reverse transformation. May not be implemented or may return + /// approximate result + virtual void reverse_transform (idx_t n, const float * xt, + float *x) const; + + virtual ~VectorTransform () {} + +}; + + + +/** Generic linear transformation, with bias term applied on output + * y = A * x + b + */ +struct LinearTransform: VectorTransform { + + bool have_bias; ///! whether to use the bias term + + /// check if matrix A is orthonormal (enables reverse_transform) + bool is_orthonormal; + + /// Transformation matrix, size d_out * d_in + std::vector A; + + /// bias vector, size d_out + std::vector b; + + /// both d_in > d_out and d_out < d_in are supported + explicit LinearTransform (int d_in = 0, int d_out = 0, + bool have_bias = false); + + /// same as apply, but result is pre-allocated + void apply_noalloc(idx_t n, const float* x, float* xt) const override; + + /// compute x = A^T * (x - b) + /// is reverse transform if A has orthonormal lines + void transform_transpose (idx_t n, const float * y, + float *x) const; + + /// works only if is_orthonormal + void reverse_transform (idx_t n, const float * xt, + float *x) const override; + + /// compute A^T * A to set the is_orthonormal flag + void set_is_orthonormal (); + + bool verbose; + void print_if_verbose (const char*name, const std::vector &mat, + int n, int d) const; + + ~LinearTransform() override {} +}; + + + +/// Randomly rotate a set of vectors +struct RandomRotationMatrix: LinearTransform { + + /// both d_in > d_out and d_out < d_in are supported + RandomRotationMatrix (int d_in, int d_out): + LinearTransform(d_in, d_out, false) {} + + /// must be called before the transform is used + void init(int seed); + + // intializes with an arbitrary seed + void train(idx_t n, const float* x) override; + + RandomRotationMatrix () {} +}; + + +/** Applies a principal component analysis on a set of vectors, + * with optionally whitening and random rotation. */ +struct PCAMatrix: LinearTransform { + + /** after transformation the components are multiplied by + * eigenvalues^eigen_power + * + * =0: no whitening + * =-0.5: full whitening + */ + float eigen_power; + + /// random rotation after PCA + bool random_rotation; + + /// ratio between # training vectors and dimension + size_t max_points_per_d; + + /// try to distribute output eigenvectors in this many bins + int balanced_bins; + + /// Mean, size d_in + std::vector mean; + + /// eigenvalues of covariance matrix (= squared singular values) + std::vector eigenvalues; + + /// PCA matrix, size d_in * d_in + std::vector PCAMat; + + // the final matrix is computed after random rotation and/or whitening + explicit PCAMatrix (int d_in = 0, int d_out = 0, + float eigen_power = 0, bool random_rotation = false); + + /// train on n vectors. If n < d_in then the eigenvector matrix + /// will be completed with 0s + void train(idx_t n, const float* x) override; + + /// copy pre-trained PCA matrix + void copy_from (const PCAMatrix & other); + + /// called after mean, PCAMat and eigenvalues are computed + void prepare_Ab(); + +}; + + +/** ITQ implementation from + * + * Iterative quantization: A procrustean approach to learning binary codes + * for large-scale image retrieval, + * + * Yunchao Gong, Svetlana Lazebnik, Albert Gordo, Florent Perronnin, + * PAMI'12. + */ + +struct ITQMatrix: LinearTransform { + + int max_iter; + int seed; + + // force initialization of the rotation (for debugging) + std::vector init_rotation; + + explicit ITQMatrix (int d = 0); + + void train (idx_t n, const float* x) override; +}; + + + +/** The full ITQ transform, including normalizations and PCA transformation + */ +struct ITQTransform: VectorTransform { + + std::vector mean; + bool do_pca; + ITQMatrix itq; + + /// max training points per dimension + int max_train_per_dim; + + // concatenation of PCA + ITQ transformation + LinearTransform pca_then_itq; + + explicit ITQTransform (int d_in = 0, int d_out = 0, bool do_pca = false); + + void train (idx_t n, const float *x) override; + + void apply_noalloc (idx_t n, const float* x, float* xt) const override; + +}; + + +struct ProductQuantizer; + +/** Applies a rotation to align the dimensions with a PQ to minimize + * the reconstruction error. Can be used before an IndexPQ or an + * IndexIVFPQ. The method is the non-parametric version described in: + * + * "Optimized Product Quantization for Approximate Nearest Neighbor Search" + * Tiezheng Ge, Kaiming He, Qifa Ke, Jian Sun, CVPR'13 + * + */ +struct OPQMatrix: LinearTransform { + + int M; ///< nb of subquantizers + int niter; ///< Number of outer training iterations + int niter_pq; ///< Number of training iterations for the PQ + int niter_pq_0; ///< same, for the first outer iteration + + /// if there are too many training points, resample + size_t max_train_points; + bool verbose; + + /// if non-NULL, use this product quantizer for training + /// should be constructed with (d_out, M, _) + ProductQuantizer * pq; + + /// if d2 != -1, output vectors of this dimension + explicit OPQMatrix (int d = 0, int M = 1, int d2 = -1); + + void train(idx_t n, const float* x) override; +}; + + +/** remap dimensions for intput vectors, possibly inserting 0s + * strictly speaking this is also a linear transform but we don't want + * to compute it with matrix multiplies */ +struct RemapDimensionsTransform: VectorTransform { + + /// map from output dimension to input, size d_out + /// -1 -> set output to 0 + std::vector map; + + RemapDimensionsTransform (int d_in, int d_out, const int *map); + + /// remap input to output, skipping or inserting dimensions as needed + /// if uniform: distribute dimensions uniformly + /// otherwise just take the d_out first ones. + RemapDimensionsTransform (int d_in, int d_out, bool uniform = true); + + void apply_noalloc(idx_t n, const float* x, float* xt) const override; + + /// reverse transform correct only when the mapping is a permutation + void reverse_transform(idx_t n, const float* xt, float* x) const override; + + RemapDimensionsTransform () {} +}; + + +/** per-vector normalization */ +struct NormalizationTransform: VectorTransform { + float norm; + + explicit NormalizationTransform (int d, float norm = 2.0); + NormalizationTransform (); + + void apply_noalloc(idx_t n, const float* x, float* xt) const override; + + /// Identity transform since norm is not revertible + void reverse_transform(idx_t n, const float* xt, float* x) const override; +}; + +/** Subtract the mean of each component from the vectors. */ +struct CenteringTransform: VectorTransform { + + /// Mean, size d_in = d_out + std::vector mean; + + explicit CenteringTransform (int d = 0); + + /// train on n vectors. + void train(idx_t n, const float* x) override; + + /// subtract the mean + void apply_noalloc(idx_t n, const float* x, float* xt) const override; + + /// add the mean + void reverse_transform (idx_t n, const float * xt, + float *x) const override; + +}; + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/acinclude/ax_blas.m4 b/core/src/index/thirdparty/faiss/acinclude/ax_blas.m4 new file mode 100644 index 0000000000..ada1b17fee --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/ax_blas.m4 @@ -0,0 +1,234 @@ +# =========================================================================== +# https://www.gnu.org/software/autoconf-archive/ax_blas.html +# =========================================================================== +# +# SYNOPSIS +# +# AX_BLAS([ACTION-IF-FOUND[, ACTION-IF-NOT-FOUND]]) +# +# DESCRIPTION +# +# This macro looks for a library that implements the BLAS linear-algebra +# interface (see http://www.netlib.org/blas/). On success, it sets the +# BLAS_LIBS output variable to hold the requisite library linkages. +# +# To link with BLAS, you should link with: +# +# $BLAS_LIBS $LIBS $FLIBS +# +# in that order. FLIBS is the output variable of the +# AC_F77_LIBRARY_LDFLAGS macro (called if necessary by AX_BLAS), and is +# sometimes necessary in order to link with F77 libraries. Users will also +# need to use AC_F77_DUMMY_MAIN (see the autoconf manual), for the same +# reason. +# +# Many libraries are searched for, from ATLAS to CXML to ESSL. The user +# may also use --with-blas= in order to use some specific BLAS +# library . In order to link successfully, however, be aware that you +# will probably need to use the same Fortran compiler (which can be set +# via the F77 env. var.) as was used to compile the BLAS library. +# +# ACTION-IF-FOUND is a list of shell commands to run if a BLAS library is +# found, and ACTION-IF-NOT-FOUND is a list of commands to run it if it is +# not found. If ACTION-IF-FOUND is not specified, the default action will +# define HAVE_BLAS. +# +# LICENSE +# +# Copyright (c) 2008 Steven G. Johnson +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program. If not, see . +# +# As a special exception, the respective Autoconf Macro's copyright owner +# gives unlimited permission to copy, distribute and modify the configure +# scripts that are the output of Autoconf when processing the Macro. You +# need not follow the terms of the GNU General Public License when using +# or distributing such scripts, even though portions of the text of the +# Macro appear in them. The GNU General Public License (GPL) does govern +# all other use of the material that constitutes the Autoconf Macro. +# +# This special exception to the GPL applies to versions of the Autoconf +# Macro released by the Autoconf Archive. When you make and distribute a +# modified version of the Autoconf Macro, you may extend this special +# exception to the GPL to apply to your modified version as well. + +#serial 15 + +AU_ALIAS([ACX_BLAS], [AX_BLAS]) +AC_DEFUN([AX_BLAS], [ +AC_PREREQ(2.50) +# AC_REQUIRE([AC_F77_LIBRARY_LDFLAGS]) +AC_REQUIRE([AC_CANONICAL_HOST]) +ax_blas_ok=no + +AC_ARG_WITH(blas, + [AS_HELP_STRING([--with-blas=], [use BLAS library ])]) +case $with_blas in + yes | "") ;; + no) ax_blas_ok=disable ;; + -* | */* | *.a | *.so | *.so.* | *.o) BLAS_LIBS="$with_blas" ;; + *) BLAS_LIBS="-l$with_blas" ;; +esac + +OPENMP_LDFLAGS="$OPENMP_CXXFLAGS" + +# Get fortran linker names of BLAS functions to check for. +# AC_F77_FUNC(sgemm) +# AC_F77_FUNC(dgemm) +sgemm=sgemm_ +dgemm=dgemm_ + +ax_blas_save_LIBS="$LIBS" +LIBS="$LIBS $FLIBS" + +# First, check BLAS_LIBS environment variable +if test $ax_blas_ok = no; then +if test "x$BLAS_LIBS" != x; then + save_LIBS="$LIBS"; LIBS="$BLAS_LIBS $LIBS" + AC_MSG_CHECKING([for $sgemm in $BLAS_LIBS]) + AC_TRY_LINK_FUNC($sgemm, [ax_blas_ok=yes], [BLAS_LIBS=""]) + AC_MSG_RESULT($ax_blas_ok) + LIBS="$save_LIBS" +fi +fi + +# BLAS linked to by default? (happens on some supercomputers) +if test $ax_blas_ok = no; then + save_LIBS="$LIBS"; LIBS="$LIBS" + AC_MSG_CHECKING([if $sgemm is being linked in already]) + AC_TRY_LINK_FUNC($sgemm, [ax_blas_ok=yes]) + AC_MSG_RESULT($ax_blas_ok) + LIBS="$save_LIBS" +fi + +# BLAS in Intel MKL library? +if test $ax_blas_ok = no; then + case $host_os in + darwin*) + AC_CHECK_LIB(mkl_intel_lp64, $sgemm, + [ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread"; OPENMP_LDFLAGS=""],, + [-lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread]) + ;; + *) + if test $host_cpu = x86_64; then + AC_CHECK_LIB(mkl_intel_lp64, $sgemm, + [ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel_lp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl"],, + [-lmkl_intel_lp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl]) + elif test $host_cpu = i686; then + AC_CHECK_LIB(mkl_intel, $sgemm, + [ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl"],, + [-lmkl_intel -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl]) + fi + ;; + esac +fi +# Old versions of MKL +if test $ax_blas_ok = no; then + AC_CHECK_LIB(mkl, $sgemm, [ax_blas_ok=yes;BLAS_LIBS="-lmkl -lguide -lpthread"],,[-lguide -lpthread]) +fi + +# BLAS in OpenBLAS library? (http://xianyi.github.com/OpenBLAS/) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(openblas, $sgemm, [ax_blas_ok=yes + BLAS_LIBS="-lopenblas"]) +fi + +# BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(atlas, ATL_xerbla, + [AC_CHECK_LIB(f77blas, $sgemm, + [AC_CHECK_LIB(cblas, cblas_dgemm, + [ax_blas_ok=yes + BLAS_LIBS="-lcblas -lf77blas -latlas"], + [], [-lf77blas -latlas])], + [], [-latlas])]) +fi + +# BLAS in PhiPACK libraries? (requires generic BLAS lib, too) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(blas, $sgemm, + [AC_CHECK_LIB(dgemm, $dgemm, + [AC_CHECK_LIB(sgemm, $sgemm, + [ax_blas_ok=yes; BLAS_LIBS="-lsgemm -ldgemm -lblas"], + [], [-lblas])], + [], [-lblas])]) +fi + +# BLAS in Apple vecLib library? +if test $ax_blas_ok = no; then + save_LIBS="$LIBS"; LIBS="-framework vecLib $LIBS" + AC_MSG_CHECKING([for $sgemm in -framework vecLib]) + AC_TRY_LINK_FUNC($sgemm, [ax_blas_ok=yes;BLAS_LIBS="-framework vecLib"]) + AC_MSG_RESULT($ax_blas_ok) + LIBS="$save_LIBS" +fi + +# BLAS in Alpha CXML library? +if test $ax_blas_ok = no; then + AC_CHECK_LIB(cxml, $sgemm, [ax_blas_ok=yes;BLAS_LIBS="-lcxml"]) +fi + +# BLAS in Alpha DXML library? (now called CXML, see above) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(dxml, $sgemm, [ax_blas_ok=yes;BLAS_LIBS="-ldxml"]) +fi + +# BLAS in Sun Performance library? +if test $ax_blas_ok = no; then + if test "x$GCC" != xyes; then # only works with Sun CC + AC_CHECK_LIB(sunmath, acosp, + [AC_CHECK_LIB(sunperf, $sgemm, + [BLAS_LIBS="-xlic_lib=sunperf -lsunmath" + ax_blas_ok=yes],[],[-lsunmath])]) + fi +fi + +# BLAS in SCSL library? (SGI/Cray Scientific Library) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(scs, $sgemm, [ax_blas_ok=yes; BLAS_LIBS="-lscs"]) +fi + +# BLAS in SGIMATH library? +if test $ax_blas_ok = no; then + AC_CHECK_LIB(complib.sgimath, $sgemm, + [ax_blas_ok=yes; BLAS_LIBS="-lcomplib.sgimath"]) +fi + +# BLAS in IBM ESSL library? (requires generic BLAS lib, too) +if test $ax_blas_ok = no; then + AC_CHECK_LIB(blas, $sgemm, + [AC_CHECK_LIB(essl, $sgemm, + [ax_blas_ok=yes; BLAS_LIBS="-lessl -lblas"], + [], [-lblas $FLIBS])]) +fi + +# Generic BLAS library? +if test $ax_blas_ok = no; then + AC_CHECK_LIB(blas, $sgemm, [ax_blas_ok=yes; BLAS_LIBS="-lblas"]) +fi + +AC_SUBST(BLAS_LIBS) +AC_SUBST(OPENMP_LDFLAGS) + +LIBS="$ax_blas_save_LIBS" + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$ax_blas_ok" = xyes; then + ifelse([$1],,AC_DEFINE(HAVE_BLAS,1,[Define if you have a BLAS library.]),[$1]) + : +else + ax_blas_ok=no + $2 +fi +])dnl AX_BLAS diff --git a/core/src/index/thirdparty/faiss/acinclude/ax_check_cpu.m4 b/core/src/index/thirdparty/faiss/acinclude/ax_check_cpu.m4 new file mode 100644 index 0000000000..fc61fc91e9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/ax_check_cpu.m4 @@ -0,0 +1,26 @@ +# serial 1 + +AC_DEFUN([AX_CPU_ARCH], [ + +AC_MSG_CHECKING([for cpu arch]) + + AC_CANONICAL_TARGET + + case $target in + amd64-* | x86_64-*) + ARCH_CPUFLAGS="-mpopcnt -msse4" + ARCH_CXXFLAGS="-m64" + ;; + aarch64*-*) +dnl This is an arch for Nvidia Xavier a proper detection would be nice. + ARCH_CPUFLAGS="-march=armv8.2-a" + ;; + *) ;; + esac + +AC_MSG_RESULT([$target CPUFLAGS+="$ARCH_CPUFLAGS" CXXFLAGS+="$ARCH_CXXFLAGS"]) + +AC_SUBST(ARCH_CPUFLAGS) +AC_SUBST(ARCH_CXXFLAGS) + +])dnl diff --git a/core/src/index/thirdparty/faiss/acinclude/ax_cxx_compile_stdcxx.m4 b/core/src/index/thirdparty/faiss/acinclude/ax_cxx_compile_stdcxx.m4 new file mode 100644 index 0000000000..0b6cb3a7d7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/ax_cxx_compile_stdcxx.m4 @@ -0,0 +1,972 @@ +# =========================================================================== +# https://www.gnu.org/software/autoconf-archive/ax_cxx_compile_stdcxx.html +# =========================================================================== +# +# SYNOPSIS +# +# AX_CXX_COMPILE_STDCXX(VERSION, [ext|noext], [mandatory|optional]) +# +# DESCRIPTION +# +# Check for baseline language coverage in the compiler for the specified +# version of the C++ standard. If necessary, add switches to CXX and +# CXXCPP to enable support. VERSION may be '11' (for the C++11 standard) +# or '14' (for the C++14 standard). +# +# The second argument, if specified, indicates whether you insist on an +# extended mode (e.g. -std=gnu++11) or a strict conformance mode (e.g. +# -std=c++11). If neither is specified, you get whatever works, with +# preference for an extended mode. +# +# The third argument, if specified 'mandatory' or if left unspecified, +# indicates that baseline support for the specified C++ standard is +# required and that the macro should error out if no mode with that +# support is found. If specified 'optional', then configuration proceeds +# regardless, after defining HAVE_CXX${VERSION} if and only if a +# supporting mode is found. +# +# LICENSE +# +# Copyright (c) 2008 Benjamin Kosnik +# Copyright (c) 2012 Zack Weinberg +# Copyright (c) 2013 Roy Stogner +# Copyright (c) 2014, 2015 Google Inc.; contributed by Alexey Sokolov +# Copyright (c) 2015 Paul Norman +# Copyright (c) 2015 Moritz Klammler +# Copyright (c) 2016, 2018 Krzesimir Nowak +# +# Copying and distribution of this file, with or without modification, are +# permitted in any medium without royalty provided the copyright notice +# and this notice are preserved. This file is offered as-is, without any +# warranty. + +#serial 9 + +dnl This macro is based on the code from the AX_CXX_COMPILE_STDCXX_11 macro +dnl (serial version number 13). + +AC_DEFUN([AX_CXX_COMPILE_STDCXX], [dnl + m4_if([$1], [11], [ax_cxx_compile_alternatives="11 0x"], + [$1], [14], [ax_cxx_compile_alternatives="14 1y"], + [$1], [17], [ax_cxx_compile_alternatives="17 1z"], + [m4_fatal([invalid first argument `$1' to AX_CXX_COMPILE_STDCXX])])dnl + m4_if([$2], [], [], + [$2], [ext], [], + [$2], [noext], [], + [m4_fatal([invalid second argument `$2' to AX_CXX_COMPILE_STDCXX])])dnl + m4_if([$3], [], [ax_cxx_compile_cxx$1_required=true], + [$3], [mandatory], [ax_cxx_compile_cxx$1_required=true], + [$3], [optional], [ax_cxx_compile_cxx$1_required=false], + [m4_fatal([invalid third argument `$3' to AX_CXX_COMPILE_STDCXX])]) + AC_LANG_PUSH([C++])dnl + ac_success=no + + m4_if([$2], [noext], [], [dnl + if test x$ac_success = xno; then + for alternative in ${ax_cxx_compile_alternatives}; do + switch="-std=gnu++${alternative}" + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx$1_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++$1 features with $switch, + $cachevar, + [ac_save_CXX="$CXX" + CXX="$CXX $switch" + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_testbody_$1])], + [eval $cachevar=yes], + [eval $cachevar=no]) + CXX="$ac_save_CXX"]) + if eval test x\$$cachevar = xyes; then + CXX="$CXX $switch" + if test -n "$CXXCPP" ; then + CXXCPP="$CXXCPP $switch" + fi + ac_success=yes + break + fi + done + fi]) + + m4_if([$2], [ext], [], [dnl + if test x$ac_success = xno; then + dnl HP's aCC needs +std=c++11 according to: + dnl http://h21007.www2.hp.com/portal/download/files/unprot/aCxx/PDF_Release_Notes/769149-001.pdf + dnl Cray's crayCC needs "-h std=c++11" + for alternative in ${ax_cxx_compile_alternatives}; do + for switch in -std=c++${alternative} +std=c++${alternative} "-h std=c++${alternative}"; do + cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx$1_$switch]) + AC_CACHE_CHECK(whether $CXX supports C++$1 features with $switch, + $cachevar, + [ac_save_CXX="$CXX" + CXX="$CXX $switch" + AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_testbody_$1])], + [eval $cachevar=yes], + [eval $cachevar=no]) + CXX="$ac_save_CXX"]) + if eval test x\$$cachevar = xyes; then + CXX="$CXX $switch" + if test -n "$CXXCPP" ; then + CXXCPP="$CXXCPP $switch" + fi + ac_success=yes + break + fi + done + if test x$ac_success = xyes; then + break + fi + done + fi]) + AC_LANG_POP([C++]) + if test x$ax_cxx_compile_cxx$1_required = xtrue; then + if test x$ac_success = xno; then + AC_MSG_ERROR([*** A compiler with support for C++$1 language features is required.]) + fi + fi + if test x$ac_success = xno; then + HAVE_CXX$1=0 + AC_MSG_NOTICE([No compiler with C++$1 support was found]) + else + HAVE_CXX$1=1 + AC_DEFINE(HAVE_CXX$1,1, + [define if the compiler supports basic C++$1 syntax]) + fi + AC_SUBST(HAVE_CXX$1) +]) + + +dnl Test body for checking C++11 support + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_11], + _AX_CXX_COMPILE_STDCXX_testbody_new_in_11 +) + + +dnl Test body for checking C++14 support + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_14], + _AX_CXX_COMPILE_STDCXX_testbody_new_in_11 + _AX_CXX_COMPILE_STDCXX_testbody_new_in_14 +) + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_17], + _AX_CXX_COMPILE_STDCXX_testbody_new_in_11 + _AX_CXX_COMPILE_STDCXX_testbody_new_in_14 + _AX_CXX_COMPILE_STDCXX_testbody_new_in_17 +) + +dnl Tests for new features in C++11 + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_new_in_11], [[ + +// If the compiler admits that it is not ready for C++11, why torture it? +// Hopefully, this will speed up the test. + +#ifndef __cplusplus + +#error "This is not a C++ compiler" + +#elif __cplusplus < 201103L + +#error "This is not a C++11 compiler" + +#else + +namespace cxx11 +{ + + namespace test_static_assert + { + + template + struct check + { + static_assert(sizeof(int) <= sizeof(T), "not big enough"); + }; + + } + + namespace test_final_override + { + + struct Base + { + virtual void f() {} + }; + + struct Derived : public Base + { + virtual void f() override {} + }; + + } + + namespace test_double_right_angle_brackets + { + + template < typename T > + struct check {}; + + typedef check single_type; + typedef check> double_type; + typedef check>> triple_type; + typedef check>>> quadruple_type; + + } + + namespace test_decltype + { + + int + f() + { + int a = 1; + decltype(a) b = 2; + return a + b; + } + + } + + namespace test_type_deduction + { + + template < typename T1, typename T2 > + struct is_same + { + static const bool value = false; + }; + + template < typename T > + struct is_same + { + static const bool value = true; + }; + + template < typename T1, typename T2 > + auto + add(T1 a1, T2 a2) -> decltype(a1 + a2) + { + return a1 + a2; + } + + int + test(const int c, volatile int v) + { + static_assert(is_same::value == true, ""); + static_assert(is_same::value == false, ""); + static_assert(is_same::value == false, ""); + auto ac = c; + auto av = v; + auto sumi = ac + av + 'x'; + auto sumf = ac + av + 1.0; + static_assert(is_same::value == true, ""); + static_assert(is_same::value == true, ""); + static_assert(is_same::value == true, ""); + static_assert(is_same::value == false, ""); + static_assert(is_same::value == true, ""); + return (sumf > 0.0) ? sumi : add(c, v); + } + + } + + namespace test_noexcept + { + + int f() { return 0; } + int g() noexcept { return 0; } + + static_assert(noexcept(f()) == false, ""); + static_assert(noexcept(g()) == true, ""); + + } + + namespace test_constexpr + { + + template < typename CharT > + unsigned long constexpr + strlen_c_r(const CharT *const s, const unsigned long acc) noexcept + { + return *s ? strlen_c_r(s + 1, acc + 1) : acc; + } + + template < typename CharT > + unsigned long constexpr + strlen_c(const CharT *const s) noexcept + { + return strlen_c_r(s, 0UL); + } + + static_assert(strlen_c("") == 0UL, ""); + static_assert(strlen_c("1") == 1UL, ""); + static_assert(strlen_c("example") == 7UL, ""); + static_assert(strlen_c("another\0example") == 7UL, ""); + + } + + namespace test_rvalue_references + { + + template < int N > + struct answer + { + static constexpr int value = N; + }; + + answer<1> f(int&) { return answer<1>(); } + answer<2> f(const int&) { return answer<2>(); } + answer<3> f(int&&) { return answer<3>(); } + + void + test() + { + int i = 0; + const int c = 0; + static_assert(decltype(f(i))::value == 1, ""); + static_assert(decltype(f(c))::value == 2, ""); + static_assert(decltype(f(0))::value == 3, ""); + } + + } + + namespace test_uniform_initialization + { + + struct test + { + static const int zero {}; + static const int one {1}; + }; + + static_assert(test::zero == 0, ""); + static_assert(test::one == 1, ""); + + } + + namespace test_lambdas + { + + void + test1() + { + auto lambda1 = [](){}; + auto lambda2 = lambda1; + lambda1(); + lambda2(); + } + + int + test2() + { + auto a = [](int i, int j){ return i + j; }(1, 2); + auto b = []() -> int { return '0'; }(); + auto c = [=](){ return a + b; }(); + auto d = [&](){ return c; }(); + auto e = [a, &b](int x) mutable { + const auto identity = [](int y){ return y; }; + for (auto i = 0; i < a; ++i) + a += b--; + return x + identity(a + b); + }(0); + return a + b + c + d + e; + } + + int + test3() + { + const auto nullary = [](){ return 0; }; + const auto unary = [](int x){ return x; }; + using nullary_t = decltype(nullary); + using unary_t = decltype(unary); + const auto higher1st = [](nullary_t f){ return f(); }; + const auto higher2nd = [unary](nullary_t f1){ + return [unary, f1](unary_t f2){ return f2(unary(f1())); }; + }; + return higher1st(nullary) + higher2nd(nullary)(unary); + } + + } + + namespace test_variadic_templates + { + + template + struct sum; + + template + struct sum + { + static constexpr auto value = N0 + sum::value; + }; + + template <> + struct sum<> + { + static constexpr auto value = 0; + }; + + static_assert(sum<>::value == 0, ""); + static_assert(sum<1>::value == 1, ""); + static_assert(sum<23>::value == 23, ""); + static_assert(sum<1, 2>::value == 3, ""); + static_assert(sum<5, 5, 11>::value == 21, ""); + static_assert(sum<2, 3, 5, 7, 11, 13>::value == 41, ""); + + } + + // http://stackoverflow.com/questions/13728184/template-aliases-and-sfinae + // Clang 3.1 fails with headers of libstd++ 4.8.3 when using std::function + // because of this. + namespace test_template_alias_sfinae + { + + struct foo {}; + + template + using member = typename T::member_type; + + template + void func(...) {} + + template + void func(member*) {} + + void test(); + + void test() { func(0); } + + } + +} // namespace cxx11 + +#endif // __cplusplus >= 201103L + +]]) + + +dnl Tests for new features in C++14 + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_new_in_14], [[ + +// If the compiler admits that it is not ready for C++14, why torture it? +// Hopefully, this will speed up the test. + +#ifndef __cplusplus + +#error "This is not a C++ compiler" + +#elif __cplusplus < 201402L + +#error "This is not a C++14 compiler" + +#else + +namespace cxx14 +{ + + namespace test_polymorphic_lambdas + { + + int + test() + { + const auto lambda = [](auto&&... args){ + const auto istiny = [](auto x){ + return (sizeof(x) == 1UL) ? 1 : 0; + }; + const int aretiny[] = { istiny(args)... }; + return aretiny[0]; + }; + return lambda(1, 1L, 1.0f, '1'); + } + + } + + namespace test_binary_literals + { + + constexpr auto ivii = 0b0000000000101010; + static_assert(ivii == 42, "wrong value"); + + } + + namespace test_generalized_constexpr + { + + template < typename CharT > + constexpr unsigned long + strlen_c(const CharT *const s) noexcept + { + auto length = 0UL; + for (auto p = s; *p; ++p) + ++length; + return length; + } + + static_assert(strlen_c("") == 0UL, ""); + static_assert(strlen_c("x") == 1UL, ""); + static_assert(strlen_c("test") == 4UL, ""); + static_assert(strlen_c("another\0test") == 7UL, ""); + + } + + namespace test_lambda_init_capture + { + + int + test() + { + auto x = 0; + const auto lambda1 = [a = x](int b){ return a + b; }; + const auto lambda2 = [a = lambda1(x)](){ return a; }; + return lambda2(); + } + + } + + namespace test_digit_separators + { + + constexpr auto ten_million = 100'000'000; + static_assert(ten_million == 100000000, ""); + + } + + namespace test_return_type_deduction + { + + auto f(int& x) { return x; } + decltype(auto) g(int& x) { return x; } + + template < typename T1, typename T2 > + struct is_same + { + static constexpr auto value = false; + }; + + template < typename T > + struct is_same + { + static constexpr auto value = true; + }; + + int + test() + { + auto x = 0; + static_assert(is_same::value, ""); + static_assert(is_same::value, ""); + return x; + } + + } + +} // namespace cxx14 + +#endif // __cplusplus >= 201402L + +]]) + + +dnl Tests for new features in C++17 + +m4_define([_AX_CXX_COMPILE_STDCXX_testbody_new_in_17], [[ + +// If the compiler admits that it is not ready for C++17, why torture it? +// Hopefully, this will speed up the test. + +#ifndef __cplusplus + +#error "This is not a C++ compiler" + +#elif __cplusplus <= 201402L + +#error "This is not a C++17 compiler" + +#else + +#if defined(__clang__) + #define REALLY_CLANG +#else + #if defined(__GNUC__) + #define REALLY_GCC + #endif +#endif + +#include +#include +#include + +namespace cxx17 +{ + +#if !defined(REALLY_CLANG) + namespace test_constexpr_lambdas + { + + // TODO: test it with clang++ from git + + constexpr int foo = [](){return 42;}(); + + } +#endif // !defined(REALLY_CLANG) + + namespace test::nested_namespace::definitions + { + + } + + namespace test_fold_expression + { + + template + int multiply(Args... args) + { + return (args * ... * 1); + } + + template + bool all(Args... args) + { + return (args && ...); + } + + } + + namespace test_extended_static_assert + { + + static_assert (true); + + } + + namespace test_auto_brace_init_list + { + + auto foo = {5}; + auto bar {5}; + + static_assert(std::is_same, decltype(foo)>::value); + static_assert(std::is_same::value); + } + + namespace test_typename_in_template_template_parameter + { + + template typename X> struct D; + + } + + namespace test_fallthrough_nodiscard_maybe_unused_attributes + { + + int f1() + { + return 42; + } + + [[nodiscard]] int f2() + { + [[maybe_unused]] auto unused = f1(); + + switch (f1()) + { + case 17: + f1(); + [[fallthrough]]; + case 42: + f1(); + } + return f1(); + } + + } + + namespace test_extended_aggregate_initialization + { + + struct base1 + { + int b1, b2 = 42; + }; + + struct base2 + { + base2() { + b3 = 42; + } + int b3; + }; + + struct derived : base1, base2 + { + int d; + }; + + derived d1 {{1, 2}, {}, 4}; // full initialization + derived d2 {{}, {}, 4}; // value-initialized bases + + } + + namespace test_general_range_based_for_loop + { + + struct iter + { + int i; + + int& operator* () + { + return i; + } + + const int& operator* () const + { + return i; + } + + iter& operator++() + { + ++i; + return *this; + } + }; + + struct sentinel + { + int i; + }; + + bool operator== (const iter& i, const sentinel& s) + { + return i.i == s.i; + } + + bool operator!= (const iter& i, const sentinel& s) + { + return !(i == s); + } + + struct range + { + iter begin() const + { + return {0}; + } + + sentinel end() const + { + return {5}; + } + }; + + void f() + { + range r {}; + + for (auto i : r) + { + [[maybe_unused]] auto v = i; + } + } + + } + + namespace test_lambda_capture_asterisk_this_by_value + { + + struct t + { + int i; + int foo() + { + return [*this]() + { + return i; + }(); + } + }; + + } + + namespace test_enum_class_construction + { + + enum class byte : unsigned char + {}; + + byte foo {42}; + + } + + namespace test_constexpr_if + { + + template + int f () + { + if constexpr(cond) + { + return 13; + } + else + { + return 42; + } + } + + } + + namespace test_selection_statement_with_initializer + { + + int f() + { + return 13; + } + + int f2() + { + if (auto i = f(); i > 0) + { + return 3; + } + + switch (auto i = f(); i + 4) + { + case 17: + return 2; + + default: + return 1; + } + } + + } + +#if !defined(REALLY_CLANG) + namespace test_template_argument_deduction_for_class_templates + { + + // TODO: test it with clang++ from git + + template + struct pair + { + pair (T1 p1, T2 p2) + : m1 {p1}, + m2 {p2} + {} + + T1 m1; + T2 m2; + }; + + void f() + { + [[maybe_unused]] auto p = pair{13, 42u}; + } + + } +#endif // !defined(REALLY_CLANG) + + namespace test_non_type_auto_template_parameters + { + + template + struct B + {}; + + B<5> b1; + B<'a'> b2; + + } + +#if !defined(REALLY_CLANG) + namespace test_structured_bindings + { + + // TODO: test it with clang++ from git + + int arr[2] = { 1, 2 }; + std::pair pr = { 1, 2 }; + + auto f1() -> int(&)[2] + { + return arr; + } + + auto f2() -> std::pair& + { + return pr; + } + + struct S + { + int x1 : 2; + volatile double y1; + }; + + S f3() + { + return {}; + } + + auto [ x1, y1 ] = f1(); + auto& [ xr1, yr1 ] = f1(); + auto [ x2, y2 ] = f2(); + auto& [ xr2, yr2 ] = f2(); + const auto [ x3, y3 ] = f3(); + + } +#endif // !defined(REALLY_CLANG) + +#if !defined(REALLY_CLANG) + namespace test_exception_spec_type_system + { + + // TODO: test it with clang++ from git + + struct Good {}; + struct Bad {}; + + void g1() noexcept; + void g2(); + + template + Bad + f(T*, T*); + + template + Good + f(T1*, T2*); + + static_assert (std::is_same_v); + + } +#endif // !defined(REALLY_CLANG) + + namespace test_inline_variables + { + + template void f(T) + {} + + template inline T g(T) + { + return T{}; + } + + template<> inline void f<>(int) + {} + + template<> int g<>(int) + { + return 5; + } + + } + +} // namespace cxx17 + +#endif // __cplusplus <= 201402L + +]]) diff --git a/core/src/index/thirdparty/faiss/acinclude/ax_lapack.m4 b/core/src/index/thirdparty/faiss/acinclude/ax_lapack.m4 new file mode 100644 index 0000000000..4993f29b9c --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/ax_lapack.m4 @@ -0,0 +1,132 @@ +# =========================================================================== +# https://www.gnu.org/software/autoconf-archive/ax_lapack.html +# =========================================================================== +# +# SYNOPSIS +# +# AX_LAPACK([ACTION-IF-FOUND[, ACTION-IF-NOT-FOUND]]) +# +# DESCRIPTION +# +# This macro looks for a library that implements the LAPACK linear-algebra +# interface (see http://www.netlib.org/lapack/). On success, it sets the +# LAPACK_LIBS output variable to hold the requisite library linkages. +# +# To link with LAPACK, you should link with: +# +# $LAPACK_LIBS $BLAS_LIBS $LIBS $FLIBS +# +# in that order. BLAS_LIBS is the output variable of the AX_BLAS macro, +# called automatically. FLIBS is the output variable of the +# AC_F77_LIBRARY_LDFLAGS macro (called if necessary by AX_BLAS), and is +# sometimes necessary in order to link with F77 libraries. Users will also +# need to use AC_F77_DUMMY_MAIN (see the autoconf manual), for the same +# reason. +# +# The user may also use --with-lapack= in order to use some specific +# LAPACK library . In order to link successfully, however, be aware +# that you will probably need to use the same Fortran compiler (which can +# be set via the F77 env. var.) as was used to compile the LAPACK and BLAS +# libraries. +# +# ACTION-IF-FOUND is a list of shell commands to run if a LAPACK library +# is found, and ACTION-IF-NOT-FOUND is a list of commands to run it if it +# is not found. If ACTION-IF-FOUND is not specified, the default action +# will define HAVE_LAPACK. +# +# LICENSE +# +# Copyright (c) 2009 Steven G. Johnson +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program. If not, see . +# +# As a special exception, the respective Autoconf Macro's copyright owner +# gives unlimited permission to copy, distribute and modify the configure +# scripts that are the output of Autoconf when processing the Macro. You +# need not follow the terms of the GNU General Public License when using +# or distributing such scripts, even though portions of the text of the +# Macro appear in them. The GNU General Public License (GPL) does govern +# all other use of the material that constitutes the Autoconf Macro. +# +# This special exception to the GPL applies to versions of the Autoconf +# Macro released by the Autoconf Archive. When you make and distribute a +# modified version of the Autoconf Macro, you may extend this special +# exception to the GPL to apply to your modified version as well. + +#serial 8 + +AU_ALIAS([ACX_LAPACK], [AX_LAPACK]) +AC_DEFUN([AX_LAPACK], [ +AC_REQUIRE([AX_BLAS]) +ax_lapack_ok=no + +AC_ARG_WITH(lapack, + [AS_HELP_STRING([--with-lapack=], [use LAPACK library ])]) +case $with_lapack in + yes | "") ;; + no) ax_lapack_ok=disable ;; + -* | */* | *.a | *.so | *.so.* | *.o) LAPACK_LIBS="$with_lapack" ;; + *) LAPACK_LIBS="-l$with_lapack" ;; +esac + +# Get fortran linker name of LAPACK function to check for. +# AC_F77_FUNC(cheev) +cheev=cheev_ + +# We cannot use LAPACK if BLAS is not found +if test "x$ax_blas_ok" != xyes; then + ax_lapack_ok=noblas + LAPACK_LIBS="" +fi + +# First, check LAPACK_LIBS environment variable +if test "x$LAPACK_LIBS" != x; then + save_LIBS="$LIBS"; LIBS="$LAPACK_LIBS $BLAS_LIBS $LIBS $FLIBS" + AC_MSG_CHECKING([for $cheev in $LAPACK_LIBS]) + AC_TRY_LINK_FUNC($cheev, [ax_lapack_ok=yes], [LAPACK_LIBS=""]) + AC_MSG_RESULT($ax_lapack_ok) + LIBS="$save_LIBS" + if test $ax_lapack_ok = no; then + LAPACK_LIBS="" + fi +fi + +# LAPACK linked to by default? (is sometimes included in BLAS lib) +if test $ax_lapack_ok = no; then + save_LIBS="$LIBS"; LIBS="$LIBS $BLAS_LIBS $FLIBS" + AC_CHECK_FUNC($cheev, [ax_lapack_ok=yes]) + LIBS="$save_LIBS" +fi + +# Generic LAPACK library? +for lapack in lapack lapack_rs6k; do + if test $ax_lapack_ok = no; then + save_LIBS="$LIBS"; LIBS="$BLAS_LIBS $LIBS" + AC_CHECK_LIB($lapack, $cheev, + [ax_lapack_ok=yes; LAPACK_LIBS="-l$lapack"], [], [$FLIBS]) + LIBS="$save_LIBS" + fi +done + +AC_SUBST(LAPACK_LIBS) + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$ax_lapack_ok" = xyes; then + ifelse([$1],,AC_DEFINE(HAVE_LAPACK,1,[Define if you have LAPACK library.]),[$1]) + : +else + ax_lapack_ok=no + $2 +fi +])dnl AX_LAPACK diff --git a/core/src/index/thirdparty/faiss/acinclude/fa_check_cuda.m4 b/core/src/index/thirdparty/faiss/acinclude/fa_check_cuda.m4 new file mode 100644 index 0000000000..f730bc23e2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/fa_check_cuda.m4 @@ -0,0 +1,67 @@ +AC_DEFUN([FA_CHECK_CUDA], [ + +AC_ARG_WITH(cuda, + [AS_HELP_STRING([--with-cuda=], [prefix of the CUDA installation])]) +AC_ARG_WITH(cuda-arch, + [AS_HELP_STRING([--with-cuda-arch=], [device specific -gencode flags])], + [], + [with_cuda_arch=default]) + +if test x$with_cuda != xno; then + if test x$with_cuda != x; then + cuda_prefix=$with_cuda + AC_CHECK_PROG(NVCC, [nvcc], [$cuda_prefix/bin/nvcc], [], [$cuda_prefix/bin]) + NVCC_CPPFLAGS="-I$cuda_prefix/include" + NVCC_LDFLAGS="-L$cuda_prefix/lib64" + else + AC_CHECK_PROGS(NVCC, [nvcc /usr/local/cuda/bin/nvcc], []) + if test "x$NVCC" == "x/usr/local/cuda/bin/nvcc"; then + cuda_prefix="/usr/local/cuda" + NVCC_CPPFLAGS="-I$cuda_prefix/include" + NVCC_LDFLAGS="-L$cuda_prefix/lib64" + else + cuda_prefix="" + NVCC_CPPFLAGS="" + NVCC_LDFLAGS="" + fi + fi + + if test "x$NVCC" == x; then + AC_MSG_ERROR([Couldn't find nvcc]) + fi + + if test "x$with_cuda_arch" == xdefault; then + with_cuda_arch="-gencode=arch=compute_35,code=compute_35 \\ +-gencode=arch=compute_52,code=compute_52 \\ +-gencode=arch=compute_60,code=compute_60 \\ +-gencode=arch=compute_61,code=compute_61 \\ +-gencode=arch=compute_70,code=compute_70 \\ +-gencode=arch=compute_75,code=compute_75" + fi + + fa_save_CPPFLAGS="$CPPFLAGS" + fa_save_LDFLAGS="$LDFLAGS" + fa_save_LIBS="$LIBS" + + CPPFLAGS="$NVCC_CPPFLAGS $CPPFLAGS" + LDFLAGS="$NVCC_LDFLAGS $LDFLAGS" + + AC_CHECK_HEADER([cuda.h], [], AC_MSG_FAILURE([Couldn't find cuda.h])) + AC_CHECK_LIB([cublas], [cublasAlloc], [], AC_MSG_FAILURE([Couldn't find libcublas])) + AC_CHECK_LIB([cudart], [cudaSetDevice], [], AC_MSG_FAILURE([Couldn't find libcudart])) + + NVCC_LIBS="$LIBS" + NVCC_CPPFLAGS="$CPPFLAGS" + NVCC_LDFLAGS="$LDFLAGS" + CPPFLAGS="$fa_save_CPPFLAGS" + LDFLAGS="$fa_save_LDFLAGS" + LIBS="$fa_save_LIBS" +fi + +AC_SUBST(NVCC) +AC_SUBST(NVCC_CPPFLAGS) +AC_SUBST(NVCC_LDFLAGS) +AC_SUBST(NVCC_LIBS) +AC_SUBST(CUDA_PREFIX, $cuda_prefix) +AC_SUBST(CUDA_ARCH, $with_cuda_arch) +]) diff --git a/core/src/index/thirdparty/faiss/acinclude/fa_numpy.m4 b/core/src/index/thirdparty/faiss/acinclude/fa_numpy.m4 new file mode 100644 index 0000000000..6e3dcde531 --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/fa_numpy.m4 @@ -0,0 +1,20 @@ +AC_DEFUN([FA_NUMPY], [ +AC_REQUIRE([FA_PYTHON]) + +AC_MSG_CHECKING([for numpy headers path]) + +fa_numpy_headers=`$PYTHON -c "import numpy; print(numpy.get_include())"` + +if test $? == 0; then + if test x$fa_numpy_headers != x; then + AC_MSG_RESULT($fa_numpy_headers) + AC_SUBST(NUMPY_INCLUDE, $fa_numpy_headers) + else + AC_MSG_RESULT([not found]) + AC_MSG_WARN([You won't be able to build the python interface.]) + fi +else + AC_MSG_RESULT([not found]) + AC_MSG_WARN([You won't be able to build the python interface.]) +fi +])dnl diff --git a/core/src/index/thirdparty/faiss/acinclude/fa_prog_nm.m4 b/core/src/index/thirdparty/faiss/acinclude/fa_prog_nm.m4 new file mode 100644 index 0000000000..f450ba7645 --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/fa_prog_nm.m4 @@ -0,0 +1,16 @@ +dnl +dnl Check for an nm(1) utility. +dnl +AC_DEFUN([FA_PROG_NM], +[ + case "${NM-unset}" in + unset) AC_CHECK_PROGS(NM, nm, nm) ;; + *) AC_CHECK_PROGS(NM, $NM nm, nm) ;; + esac + AC_MSG_CHECKING(nm flags) + case "${NMFLAGS-unset}" in + unset) NMFLAGS= ;; + esac + AC_MSG_RESULT($NMFLAGS) + AC_SUBST(NMFLAGS) +]) diff --git a/core/src/index/thirdparty/faiss/acinclude/fa_prog_swig.m4 b/core/src/index/thirdparty/faiss/acinclude/fa_prog_swig.m4 new file mode 100644 index 0000000000..1e6ab8e49d --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/fa_prog_swig.m4 @@ -0,0 +1,11 @@ +AC_DEFUN([FA_PROG_SWIG], [ + +AC_ARG_WITH(swig, +[AS_HELP_STRING([--with-swig=], [use SWIG binary ])]) +case $with_swig in + "") AC_CHECK_PROG(SWIG, swig, swig);; + *) SWIG="$with_swig" +esac + +AC_SUBST(SWIG) +]) diff --git a/core/src/index/thirdparty/faiss/acinclude/fa_python.m4 b/core/src/index/thirdparty/faiss/acinclude/fa_python.m4 new file mode 100644 index 0000000000..a58a9d15ec --- /dev/null +++ b/core/src/index/thirdparty/faiss/acinclude/fa_python.m4 @@ -0,0 +1,21 @@ +AC_DEFUN([FA_PYTHON], [ + +AC_ARG_WITH(python, + [AS_HELP_STRING([--with-python=], [use Python binary ])]) +case $with_python in + "") PYTHON_BIN=python ;; + *) PYTHON_BIN="$with_python" +esac + +AC_CHECK_PROG(PYTHON, $PYTHON_BIN, $PYTHON_BIN) +fa_python_bin=$PYTHON + +AC_MSG_CHECKING([for Python C flags]) +fa_python_cflags=`$PYTHON -c " +import sysconfig +paths = [['-I' + sysconfig.get_path(p) for p in ['include', 'platinclude']]] +print(' '.join(paths))"` +AC_MSG_RESULT($fa_python_cflags) +AC_SUBST(PYTHON_CFLAGS, "$PYTHON_CFLAGS $fa_python_cflags") + +])dnl FA_PYTHON diff --git a/core/src/index/thirdparty/faiss/benchs/README.md b/core/src/index/thirdparty/faiss/benchs/README.md new file mode 100644 index 0000000000..7e95a7673d --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/README.md @@ -0,0 +1,338 @@ + +# Benchmarking scripts + +This directory contains benchmarking scripts that can reproduce the +numbers reported in the two papers + +``` +@inproceedings{DJP16, + Author = {Douze, Matthijs and J{\'e}gou, Herv{\'e} and Perronnin, Florent}, + Booktitle = "ECCV", + Organization = {Springer}, + Title = {Polysemous codes}, + Year = {2016} +} +``` +and + +``` +@inproceedings{JDJ17, + Author = {Jeff Johnson and Matthijs Douze and Herv{\'e} J{\'e}gou}, + journal= {arXiv:1702.08734},, + Title = {Billion-scale similarity search with GPUs}, + Year = {2017}, +} +``` + +Note that the numbers (especially timings) change slightly due to changes in the implementation, different machines, etc. + +The scripts are self-contained. They depend only on Faiss and external training data that should be stored in sub-directories. + +## SIFT1M experiments + +The script [`bench_polysemous_sift1m.py`](bench_polysemous_sift1m.py) reproduces the numbers in +Figure 3 from the "Polysemous" paper. + +### Getting SIFT1M + +To run it, please download the ANN_SIFT1M dataset from + +http://corpus-texmex.irisa.fr/ + +and unzip it to the subdirectory sift1M. + +### Result + +The output looks like: + +``` +PQ training on 100000 points, remains 0 points: training polysemous on centroids +add vectors to index +PQ baseline 7.517 ms per query, R@1 0.4474 +Polysemous 64 9.875 ms per query, R@1 0.4474 +Polysemous 62 8.358 ms per query, R@1 0.4474 +Polysemous 58 5.531 ms per query, R@1 0.4474 +Polysemous 54 3.420 ms per query, R@1 0.4478 +Polysemous 50 2.182 ms per query, R@1 0.4475 +Polysemous 46 1.621 ms per query, R@1 0.4408 +Polysemous 42 1.448 ms per query, R@1 0.4174 +Polysemous 38 1.331 ms per query, R@1 0.3563 +Polysemous 34 1.334 ms per query, R@1 0.2661 +Polysemous 30 1.272 ms per query, R@1 0.1794 +``` + + +## Experiments on 1B elements dataset + +The script [`bench_polysemous_1bn.py`](bench_polysemous_1bn.py) reproduces a few experiments on +two datasets of size 1B from the Polysemous codes" paper. + + +### Getting BIGANN + +Download the four files of ANN_SIFT1B from +http://corpus-texmex.irisa.fr/ to subdirectory bigann/ + +### Getting Deep1B + +The ground-truth and queries are available here + +https://yadi.sk/d/11eDCm7Dsn9GA + +For the learning and database vectors, use the script + +https://github.com/arbabenko/GNOIMI/blob/master/downloadDeep1B.py + +to download the data to subdirectory deep1b/, then concatenate the +database files to base.fvecs and the training files to learn.fvecs + +### Running the experiments + +These experiments are quite long. To support resuming, the script +stores the result of training to a temporary directory, `/tmp/bench_polysemous`. + +The script `bench_polysemous_1bn.py` takes at least two arguments: + +- the dataset name: SIFT1000M (aka SIFT1B, aka BIGANN) or Deep1B. SIFT1M, SIFT2M,... are also supported to make subsets of for small experiments (note that SIFT1M as a subset of SIFT1B is not the same as the SIFT1M above) + +- the type of index to build, which should be a valid [index_factory key](https://github.com/facebookresearch/faiss/wiki/High-level-interface-and-auto-tuning#index-factory) (see below for examples) + +- the remaining arguments are parsed as search-time parameters. + +### Experiments of Table 2 + +The `IMI*+PolyD+ADC` results in Table 2 can be reproduced with (for 16 bytes): + +``` +python bench_polysemous_1bn.par SIFT1000M IMI2x12,PQ16 nprobe=16,max_codes={10000,30000},ht={44..54} +``` + +Training takes about 2 minutes and adding vectors to the dataset +takes 3.1 h. These operations are multithreaded. Note that in the command +above, we use bash's [brace expansion](https://www.gnu.org/software/bash/manual/html_node/Brace-Expansion.html) to set a grid of parameters. + +The search is *not* multithreaded, and the output looks like: + +``` + R@1 R@10 R@100 time %pass +nprobe=16,max_codes=10000,ht=44 0.1779 0.2994 0.3139 0.194 12.45 +nprobe=16,max_codes=10000,ht=45 0.1859 0.3183 0.3339 0.197 14.24 +nprobe=16,max_codes=10000,ht=46 0.1930 0.3366 0.3543 0.202 16.22 +nprobe=16,max_codes=10000,ht=47 0.1993 0.3550 0.3745 0.209 18.39 +nprobe=16,max_codes=10000,ht=48 0.2033 0.3694 0.3917 0.640 20.77 +nprobe=16,max_codes=10000,ht=49 0.2070 0.3839 0.4077 0.229 23.36 +nprobe=16,max_codes=10000,ht=50 0.2101 0.3949 0.4205 0.232 26.17 +nprobe=16,max_codes=10000,ht=51 0.2120 0.4042 0.4310 0.239 29.21 +nprobe=16,max_codes=10000,ht=52 0.2134 0.4113 0.4402 0.245 32.47 +nprobe=16,max_codes=10000,ht=53 0.2157 0.4184 0.4482 0.250 35.96 +nprobe=16,max_codes=10000,ht=54 0.2170 0.4240 0.4546 0.256 39.66 +nprobe=16,max_codes=30000,ht=44 0.1882 0.3327 0.3555 0.226 11.29 +nprobe=16,max_codes=30000,ht=45 0.1964 0.3525 0.3771 0.231 13.05 +nprobe=16,max_codes=30000,ht=46 0.2039 0.3713 0.3987 0.236 15.01 +nprobe=16,max_codes=30000,ht=47 0.2103 0.3907 0.4202 0.245 17.19 +nprobe=16,max_codes=30000,ht=48 0.2145 0.4055 0.4384 0.251 19.60 +nprobe=16,max_codes=30000,ht=49 0.2179 0.4198 0.4550 0.257 22.25 +nprobe=16,max_codes=30000,ht=50 0.2208 0.4305 0.4681 0.268 25.15 +nprobe=16,max_codes=30000,ht=51 0.2227 0.4402 0.4791 0.275 28.30 +nprobe=16,max_codes=30000,ht=52 0.2241 0.4473 0.4884 0.284 31.70 +nprobe=16,max_codes=30000,ht=53 0.2265 0.4544 0.4965 0.294 35.34 +nprobe=16,max_codes=30000,ht=54 0.2278 0.4601 0.5031 0.303 39.20 +``` + +The result reported in table 2 is the one for which the %pass (percentage of code comparisons that pass the Hamming check) is around 20%, which occurs for Hamming threshold `ht=48`. + +The 8-byte results can be reproduced with the factory key `IMI2x12,PQ8` + +### Experiments of the appendix + +The experiments in the appendix are only in the ArXiv version of the paper (table 3). + +``` +python bench_polysemous_1bn.py SIFT1000M OPQ8_64,IMI2x13,PQ8 nprobe={1,2,4,8,16,32,64,128},ht={20,24,26,28,30} + + R@1 R@10 R@100 time %pass +nprobe=1,ht=20 0.0351 0.0616 0.0751 0.158 19.01 +... +nprobe=32,ht=28 0.1256 0.3563 0.5026 0.561 52.61 +... +``` +Here again the runs are not exactly the same but the original result was obtained from nprobe=32,ht=28. + +For Deep1B, we used a simple version of [auto-tuning](https://github.com/facebookresearch/faiss/wiki/High-level-interface-and-auto-tuning/_edit#auto-tuning-the-runtime-parameters) to sweep through the set of operating points: + +``` +python bench_polysemous_1bn.py Deep1B OPQ20_80,IMI2x14,PQ20 autotune +... +Done in 4067.555 s, available OPs: +Parameters 1-R@1 time + 0.0000 0.000 +nprobe=1,ht=22,max_codes=256 0.0215 3.115 +nprobe=1,ht=30,max_codes=256 0.0381 3.120 +... +nprobe=512,ht=68,max_codes=524288 0.4478 36.903 +nprobe=1024,ht=80,max_codes=131072 0.4557 46.363 +nprobe=1024,ht=78,max_codes=262144 0.4616 61.939 +... +``` +The original results were obtained with `nprobe=1024,ht=66,max_codes=262144`. + + +## GPU experiments + +The benchmarks below run 1 or 4 Titan X GPUs and reproduce the results of the "GPU paper". They are also a good starting point on how to use GPU Faiss. + +### Search on SIFT1M + +See above on how to get SIFT1M into subdirectory sift1M/. The script [`bench_gpu_sift1m.py`](bench_gpu_sift1m.py) reproduces the "exact k-NN time" plot in the ArXiv paper, and the SIFT1M numbers. + +The output is: +``` +============ Exact search +add vectors to index +warmup +benchmark +k=1 0.715 s, R@1 0.9914 +k=2 0.729 s, R@1 0.9935 +k=4 0.731 s, R@1 0.9935 +k=8 0.732 s, R@1 0.9935 +k=16 0.742 s, R@1 0.9935 +k=32 0.737 s, R@1 0.9935 +k=64 0.753 s, R@1 0.9935 +k=128 0.761 s, R@1 0.9935 +k=256 0.799 s, R@1 0.9935 +k=512 0.975 s, R@1 0.9935 +k=1024 1.424 s, R@1 0.9935 +============ Approximate search +train +WARNING clustering 100000 points to 4096 centroids: please provide at least 159744 training points +add vectors to index +WARN: increase temp memory to avoid cudaMalloc, or decrease query/add size (alloc 256000000 B, highwater 256000000 B) +warmup +benchmark +nprobe= 1 0.043 s recalls= 0.3909 0.4312 0.4312 +nprobe= 2 0.040 s recalls= 0.5041 0.5636 0.5636 +nprobe= 4 0.048 s recalls= 0.6048 0.6897 0.6897 +nprobe= 8 0.064 s recalls= 0.6879 0.8028 0.8028 +nprobe= 16 0.088 s recalls= 0.7534 0.8940 0.8940 +nprobe= 32 0.134 s recalls= 0.7957 0.9549 0.9550 +nprobe= 64 0.224 s recalls= 0.8125 0.9833 0.9834 +nprobe= 128 0.395 s recalls= 0.8205 0.9953 0.9954 +nprobe= 256 0.717 s recalls= 0.8227 0.9993 0.9994 +nprobe= 512 1.348 s recalls= 0.8228 0.9999 1.0000 +``` +The run produces two warnings: + +- the clustering complains that it does not have enough training data, there is not much we can do about this. + +- the add() function complains that there is an inefficient memory allocation, but this is a concern only when it happens often, and we are not benchmarking the add time anyways. + +To index small datasets, it is more efficient to use a `GpuIVFFlat`, which just stores the full vectors in the inverted lists. We did not mention this in the the paper because it is not as scalable. To experiment with this setting, change the `index_factory` string from "IVF4096,PQ64" to "IVF16384,Flat". This gives: + +``` +nprobe= 1 0.025 s recalls= 0.4084 0.4105 0.4105 +nprobe= 2 0.033 s recalls= 0.5235 0.5264 0.5264 +nprobe= 4 0.033 s recalls= 0.6332 0.6367 0.6367 +nprobe= 8 0.040 s recalls= 0.7358 0.7403 0.7403 +nprobe= 16 0.049 s recalls= 0.8273 0.8324 0.8324 +nprobe= 32 0.068 s recalls= 0.8957 0.9024 0.9024 +nprobe= 64 0.104 s recalls= 0.9477 0.9549 0.9549 +nprobe= 128 0.174 s recalls= 0.9760 0.9837 0.9837 +nprobe= 256 0.299 s recalls= 0.9866 0.9944 0.9944 +nprobe= 512 0.527 s recalls= 0.9907 0.9987 0.9987 +``` + +### Clustering on MNIST8m + +To get the "infinite MNIST dataset", follow the instructions on [Léon Bottou's website](http://leon.bottou.org/projects/infimnist). The script assumes the file `mnist8m-patterns-idx3-ubyte` is in subdirectory `mnist8m` + +The script [`kmeans_mnist.py`](kmeans_mnist.py) produces the following output: + +``` +python kmeans_mnist.py 1 256 +... +Clustering 8100000 points in 784D to 256 clusters, redo 1 times, 20 iterations + Preprocessing in 7.94526 s + Iteration 19 (131.697 s, search 114.78 s): objective=1.44881e+13 imbalance=1.05963 nsplit=0 +final objective: 1.449e+13 +total runtime: 140.615 s +``` + +### search on SIFT1B + +The script [`bench_gpu_1bn.py`](bench_gpu_1bn.py) runs multi-gpu searches on the two 1-billion vector datasets we considered. It is more complex than the previous scripts, because it supports many search options and decomposes the dataset build process in Python to exploit the best possible CPU/GPU parallelism and GPU distribution. + +Even on multiple GPUs, building the 1B datasets can last several hours. It is often a good idea to validate that everything is working fine on smaller datasets like SIFT1M, SIFT2M, etc. + +The search results on SIFT1B in the "GPU paper" can be obtained with + + + +``` +python bench_gpu_1bn.py SIFT1000M OPQ8_32,IVF262144,PQ8 -nnn 10 -ngpu 1 -tempmem $[1536*1024*1024] +... +0/10000 (0.024 s) probe=1 : 0.161 s 1-R@1: 0.0752 1-R@10: 0.1924 +0/10000 (0.005 s) probe=2 : 0.150 s 1-R@1: 0.0964 1-R@10: 0.2693 +0/10000 (0.005 s) probe=4 : 0.153 s 1-R@1: 0.1102 1-R@10: 0.3328 +0/10000 (0.005 s) probe=8 : 0.170 s 1-R@1: 0.1220 1-R@10: 0.3827 +0/10000 (0.005 s) probe=16 : 0.196 s 1-R@1: 0.1290 1-R@10: 0.4151 +0/10000 (0.006 s) probe=32 : 0.244 s 1-R@1: 0.1314 1-R@10: 0.4345 +0/10000 (0.006 s) probe=64 : 0.353 s 1-R@1: 0.1332 1-R@10: 0.4461 +0/10000 (0.005 s) probe=128: 0.587 s 1-R@1: 0.1341 1-R@10: 0.4502 +0/10000 (0.006 s) probe=256: 1.160 s 1-R@1: 0.1342 1-R@10: 0.4511 +``` + +We use the `-tempmem` option to reduce the temporary memory allocation to 1.5G, otherwise the dataset does not fit in GPU memory + +### search on Deep1B + +The same script generates the GPU search results on Deep1B. + +``` +python bench_gpu_1bn.py Deep1B OPQ20_80,IVF262144,PQ20 -nnn 10 -R 2 -ngpu 4 -altadd -noptables -tempmem $[1024*1024*1024] +... + +0/10000 (0.115 s) probe=1 : 0.239 s 1-R@1: 0.2387 1-R@10: 0.3420 +0/10000 (0.006 s) probe=2 : 0.103 s 1-R@1: 0.3110 1-R@10: 0.4623 +0/10000 (0.005 s) probe=4 : 0.105 s 1-R@1: 0.3772 1-R@10: 0.5862 +0/10000 (0.005 s) probe=8 : 0.116 s 1-R@1: 0.4235 1-R@10: 0.6889 +0/10000 (0.005 s) probe=16 : 0.133 s 1-R@1: 0.4517 1-R@10: 0.7693 +0/10000 (0.005 s) probe=32 : 0.168 s 1-R@1: 0.4713 1-R@10: 0.8281 +0/10000 (0.005 s) probe=64 : 0.238 s 1-R@1: 0.4841 1-R@10: 0.8649 +0/10000 (0.007 s) probe=128: 0.384 s 1-R@1: 0.4900 1-R@10: 0.8816 +0/10000 (0.005 s) probe=256: 0.736 s 1-R@1: 0.4933 1-R@10: 0.8912 +``` + +Here we are a bit tight on memory so we disable precomputed tables (`-noptables`) and restrict the amount of temporary memory. The `-altadd` option avoids GPU memory overflows during add. + + +### knn-graph on Deep1B + +The same script generates the KNN-graph on Deep1B. Note that the inverted file from above will not be re-used because the training sets are different. For the knngraph, the script will first do a pass over the whole dataset to compute the ground-truth knn for a subset of 10k nodes, for evaluation. + +``` +python bench_gpu_1bn.py Deep1B OPQ20_80,IVF262144,PQ20 -nnn 10 -altadd -knngraph -R 2 -noptables -tempmem $[1<<30] -ngpu 4 +... +CPU index contains 1000000000 vectors, move to GPU +Copy CPU index to 2 sharded GPU indexes + dispatch to GPUs 0:2 +IndexShards shard 0 indices 0:500000000 + IndexIVFPQ size 500000000 -> GpuIndexIVFPQ indicesOptions=0 usePrecomputed=0 useFloat16=0 reserveVecs=0 +IndexShards shard 1 indices 500000000:1000000000 + IndexIVFPQ size 500000000 -> GpuIndexIVFPQ indicesOptions=0 usePrecomputed=0 useFloat16=0 reserveVecs=0 + dispatch to GPUs 2:4 +IndexShards shard 0 indices 0:500000000 + IndexIVFPQ size 500000000 -> GpuIndexIVFPQ indicesOptions=0 usePrecomputed=0 useFloat16=0 reserveVecs=0 +IndexShards shard 1 indices 500000000:1000000000 + IndexIVFPQ size 500000000 -> GpuIndexIVFPQ indicesOptions=0 usePrecomputed=0 useFloat16=0 reserveVecs=0 +move to GPU done in 151.535 s +search... +999997440/1000000000 (8389.961 s, 0.3379) probe=1 : 8389.990 s rank-10 intersection results: 0.3379 +999997440/1000000000 (9205.934 s, 0.4079) probe=2 : 9205.966 s rank-10 intersection results: 0.4079 +999997440/1000000000 (9741.095 s, 0.4722) probe=4 : 9741.128 s rank-10 intersection results: 0.4722 +999997440/1000000000 (10830.420 s, 0.5256) probe=8 : 10830.455 s rank-10 intersection results: 0.5256 +999997440/1000000000 (12531.716 s, 0.5603) probe=16 : 12531.758 s rank-10 intersection results: 0.5603 +999997440/1000000000 (15922.519 s, 0.5825) probe=32 : 15922.571 s rank-10 intersection results: 0.5825 +999997440/1000000000 (22774.153 s, 0.5950) probe=64 : 22774.220 s rank-10 intersection results: 0.5950 +999997440/1000000000 (36717.207 s, 0.6015) probe=128: 36717.309 s rank-10 intersection results: 0.6015 +999997440/1000000000 (70616.392 s, 0.6047) probe=256: 70616.581 s rank-10 intersection results: 0.6047 +``` diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/README.md b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/README.md new file mode 100644 index 0000000000..2f7c76b5ac --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/README.md @@ -0,0 +1,20 @@ +# Benchmark of IVF variants + +This is a benchmark of IVF index variants, looking at compression vs. speed vs. accuracy. +The results are in [this wiki chapter](https://github.com/facebookresearch/faiss/wiki/Indexing-1G-vectors) + + +The code is organized as: + +- `datasets.py`: code to access the datafiles, compute the ground-truth and report accuracies + +- `bench_all_ivf.py`: evaluate one type of inverted file + +- `run_on_cluster_generic.bash`: call `bench_all_ivf.py` for all tested types of indices. +Since the number of experiments is quite large the script is structued so that the benchmark can be run on a cluster. + +- `parse_bench_all_ivf.py`: make nice tradeoff plots from all the results. + +The code depends on Faiss and can use 1 to 8 GPUs to do the k-means clustering for large vocabularies. + +It was run in October 2018 for the results in the wiki. diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_all_ivf.py b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_all_ivf.py new file mode 100644 index 0000000000..ee53018828 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_all_ivf.py @@ -0,0 +1,308 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +import os +import sys +import time +import numpy as np +import faiss +import argparse +import datasets +from datasets import sanitize + +###################################################### +# Command-line parsing +###################################################### + + +parser = argparse.ArgumentParser() + +def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + +group = parser.add_argument_group('dataset options') + +aa('--db', default='deep1M', help='dataset') +aa('--compute_gt', default=False, action='store_true', + help='compute and store the groundtruth') + +group = parser.add_argument_group('index consturction') + +aa('--indexkey', default='HNSW32', help='index_factory type') +aa('--efConstruction', default=200, type=int, + help='HNSW construction factor') +aa('--M0', default=-1, type=int, help='size of base level') +aa('--maxtrain', default=256 * 256, type=int, + help='maximum number of training points (0 to set automatically)') +aa('--indexfile', default='', help='file to read or write index from') +aa('--add_bs', default=-1, type=int, + help='add elements index by batches of this size') +aa('--no_precomputed_tables', action='store_true', default=False, + help='disable precomputed tables (uses less memory)') +aa('--clustering_niter', default=-1, type=int, + help='number of clustering iterations (-1 = leave default)') +aa('--train_on_gpu', default=False, action='store_true', + help='do training on GPU') +aa('--get_centroids_from', default='', + help='get the centroids from this index (to speed up training)') + +group = parser.add_argument_group('searching') + +aa('--k', default=100, type=int, help='nb of nearest neighbors') +aa('--searchthreads', default=-1, type=int, + help='nb of threads to use at search time') +aa('--searchparams', nargs='+', default=['autotune'], + help="search parameters to use (can be autotune or a list of params)") +aa('--n_autotune', default=500, type=int, + help="max nb of autotune experiments") +aa('--autotune_max', default=[], nargs='*', + help='set max value for autotune variables format "var:val" (exclusive)') +aa('--autotune_range', default=[], nargs='*', + help='set complete autotune range, format "var:val1,val2,..."') +aa('--min_test_duration', default=0, type=float, + help='run test at least for so long to avoid jitter') + +args = parser.parse_args() + +print("args:", args) + +os.system('echo -n "nb processors "; ' + 'cat /proc/cpuinfo | grep ^processor | wc -l; ' + 'cat /proc/cpuinfo | grep ^"model name" | tail -1') + +###################################################### +# Load dataset +###################################################### + +xt, xb, xq, gt = datasets.load_data( + dataset=args.db, compute_gt=args.compute_gt) + + +print("dataset sizes: train %s base %s query %s GT %s" % ( + xt.shape, xb.shape, xq.shape, gt.shape)) + +nq, d = xq.shape +nb, d = xb.shape + + +###################################################### +# Make index +###################################################### + +if args.indexfile and os.path.exists(args.indexfile): + + print("reading", args.indexfile) + index = faiss.read_index(args.indexfile) + + if isinstance(index, faiss.IndexPreTransform): + index_ivf = faiss.downcast_index(index.index) + else: + index_ivf = index + assert isinstance(index_ivf, faiss.IndexIVF) + vec_transform = lambda x: x + assert isinstance(index_ivf, faiss.IndexIVF) + +else: + + print("build index, key=", args.indexkey) + + index = faiss.index_factory(d, args.indexkey) + + if isinstance(index, faiss.IndexPreTransform): + index_ivf = faiss.downcast_index(index.index) + vec_transform = index.chain.at(0).apply_py + else: + index_ivf = index + vec_transform = lambda x:x + assert isinstance(index_ivf, faiss.IndexIVF) + index_ivf.verbose = True + index_ivf.quantizer.verbose = True + index_ivf.cp.verbose = True + + maxtrain = args.maxtrain + if maxtrain == 0: + if 'IMI' in args.indexkey: + maxtrain = int(256 * 2 ** (np.log2(index_ivf.nlist) / 2)) + else: + maxtrain = 50 * index_ivf.nlist + print("setting maxtrain to %d" % maxtrain) + args.maxtrain = maxtrain + + xt2 = sanitize(xt[:args.maxtrain]) + assert np.all(np.isfinite(xt2)) + + print("train, size", xt2.shape) + + if args.get_centroids_from == '': + + if args.clustering_niter >= 0: + print(("setting nb of clustering iterations to %d" % + args.clustering_niter)) + index_ivf.cp.niter = args.clustering_niter + + if args.train_on_gpu: + print("add a training index on GPU") + train_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) + index_ivf.clustering_index = train_index + + else: + print("Getting centroids from", args.get_centroids_from) + src_index = faiss.read_index(args.get_centroids_from) + src_quant = faiss.downcast_index(src_index.quantizer) + centroids = faiss.vector_to_array(src_quant.xb) + centroids = centroids.reshape(-1, d) + print(" centroid table shape", centroids.shape) + + if isinstance(index, faiss.IndexPreTransform): + print(" training vector transform") + assert index.chain.size() == 1 + vt = index.chain.at(0) + vt.train(xt2) + print(" transform centroids") + centroids = vt.apply_py(centroids) + + print(" add centroids to quantizer") + index_ivf.quantizer.add(centroids) + del src_index + + t0 = time.time() + index.train(xt2) + print(" train in %.3f s" % (time.time() - t0)) + + print("adding") + t0 = time.time() + if args.add_bs == -1: + index.add(sanitize(xb)) + else: + for i0 in range(0, nb, args.add_bs): + i1 = min(nb, i0 + args.add_bs) + print(" adding %d:%d / %d" % (i0, i1, nb)) + index.add(sanitize(xb[i0:i1])) + + print(" add in %.3f s" % (time.time() - t0)) + if args.indexfile: + print("storing", args.indexfile) + faiss.write_index(index, args.indexfile) + +if args.no_precomputed_tables: + if isinstance(index_ivf, faiss.IndexIVFPQ): + print("disabling precomputed table") + index_ivf.use_precomputed_table = -1 + index_ivf.precomputed_table.clear() + +if args.indexfile: + print("index size on disk: ", os.stat(args.indexfile).st_size) + +print("current RSS:", faiss.get_mem_usage_kb() * 1024) + +precomputed_table_size = 0 +if hasattr(index_ivf, 'precomputed_table'): + precomputed_table_size = index_ivf.precomputed_table.size() * 4 + +print("precomputed tables size:", precomputed_table_size) + + +############################################################# +# Index is ready +############################################################# + +xq = sanitize(xq) + +if args.searchthreads != -1: + print("Setting nb of threads to", args.searchthreads) + faiss.omp_set_num_threads(args.searchthreads) + + +ps = faiss.ParameterSpace() +ps.initialize(index) + + +parametersets = args.searchparams + +header = '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % "parameters" + + +def eval_setting(index, xq, gt, min_time): + nq = xq.shape[0] + ivf_stats = faiss.cvar.indexIVF_stats + ivf_stats.reset() + nrun = 0 + t0 = time.time() + while True: + D, I = index.search(xq, 100) + nrun += 1 + t1 = time.time() + if t1 - t0 > min_time: + break + ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun) + for rank in 1, 10, 100: + n_ok = (I[:, :rank] == gt[:, :1]).sum() + print("%.4f" % (n_ok / float(nq)), end=' ') + print(" %8.3f " % ms_per_query, end=' ') + print("%12d " % (ivf_stats.ndis / nrun), end=' ') + print(nrun) + + +if parametersets == ['autotune']: + + ps.n_experiments = args.n_autotune + ps.min_test_duration = args.min_test_duration + + for kv in args.autotune_max: + k, vmax = kv.split(':') + vmax = float(vmax) + print("limiting %s to %g" % (k, vmax)) + pr = ps.add_range(k) + values = faiss.vector_to_array(pr.values) + values = np.array([v for v in values if v < vmax]) + faiss.copy_array_to_vector(values, pr.values) + + for kv in args.autotune_range: + k, vals = kv.split(':') + vals = np.fromstring(vals, sep=',') + print("setting %s to %s" % (k, vals)) + pr = ps.add_range(k) + faiss.copy_array_to_vector(vals, pr.values) + + # setup the Criterion object: optimize for 1-R@1 + crit = faiss.OneRecallAtRCriterion(nq, 1) + + # by default, the criterion will request only 1 NN + crit.nnn = 100 + crit.set_groundtruth(None, gt.astype('int64')) + + # then we let Faiss find the optimal parameters by itself + print("exploring operating points") + ps.display() + + t0 = time.time() + op = ps.explore(index, xq, crit) + print("Done in %.3f s, available OPs:" % (time.time() - t0)) + + op.display() + + print(header) + opv = op.optimal_pts + for i in range(opv.size()): + opt = opv.at(i) + + ps.set_index_parameters(index, opt.key) + + print("%-40s " % opt.key, end=' ') + sys.stdout.flush() + + eval_setting(index, xq, gt, args.min_test_duration) + +else: + print(header) + for param in parametersets: + print("%-40s " % param, end=' ') + sys.stdout.flush() + ps.set_index_parameters(index, param) + + eval_setting(index, xq, gt, args.min_test_duration) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_kmeans.py b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_kmeans.py new file mode 100644 index 0000000000..90cb4e83d9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/bench_kmeans.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import os +import numpy as np +import faiss +import argparse +import datasets +from datasets import sanitize + +###################################################### +# Command-line parsing +###################################################### + +parser = argparse.ArgumentParser() + + +def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + +group = parser.add_argument_group('dataset options') + +aa('--db', default='deep1M', help='dataset') +aa('--nt', default=65536, type=int) +aa('--nb', default=100000, type=int) +aa('--nt_sample', default=0, type=int) + +group = parser.add_argument_group('kmeans options') +aa('--k', default=256, type=int) +aa('--seed', default=12345, type=int) +aa('--pcadim', default=-1, type=int, help='PCA to this dimension') +aa('--niter', default=25, type=int) +aa('--eval_freq', default=100, type=int) + + +args = parser.parse_args() + +print("args:", args) + +os.system('echo -n "nb processors "; ' + 'cat /proc/cpuinfo | grep ^processor | wc -l; ' + 'cat /proc/cpuinfo | grep ^"model name" | tail -1') + +ngpu = faiss.get_num_gpus() +print("nb GPUs:", ngpu) + +###################################################### +# Load dataset +###################################################### + +xt, xb, xq, gt = datasets.load_data(dataset=args.db) + + +if args.nt_sample == 0: + xt_pca = xt[args.nt:args.nt + 10000] + xt = xt[:args.nt] +else: + xt_pca = xt[args.nt_sample:args.nt_sample + 10000] + rs = np.random.RandomState(args.seed) + idx = rs.choice(args.nt_sample, size=args.nt, replace=False) + xt = xt[idx] + +xb = xb[:args.nb] + +d = xb.shape[1] + +if args.pcadim != -1: + print("training PCA: %d -> %d" % (d, args.pcadim)) + pca = faiss.PCAMatrix(d, args.pcadim) + pca.train(sanitize(xt_pca)) + xt = pca.apply_py(sanitize(xt)) + xb = pca.apply_py(sanitize(xb)) + d = xb.shape[1] + + +###################################################### +# Run clustering +###################################################### + + +index = faiss.IndexFlatL2(d) + +if ngpu > 0: + print("moving index to GPU") + index = faiss.index_cpu_to_all_gpus(index) + + +clustering = faiss.Clustering(d, args.k) + +clustering.verbose = True +clustering.seed = args.seed +clustering.max_points_per_centroid = 10**6 +clustering.min_points_per_centroid = 1 + + +for iter0 in range(0, args.niter, args.eval_freq): + iter1 = min(args.niter, iter0 + args.eval_freq) + clustering.niter = iter1 - iter0 + + if iter0 > 0: + faiss.copy_array_to_vector(centroids.ravel(), clustering.centroids) + + clustering.train(sanitize(xt), index) + index.reset() + centroids = faiss.vector_to_array(clustering.centroids).reshape(args.k, d) + index.add(centroids) + + _, I = index.search(sanitize(xb), 1) + + error = ((xb - centroids[I.ravel()]) ** 2).sum() + + print("iter1=%d quantization error on test: %.4f" % (iter1, error)) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/datasets.py b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/datasets.py new file mode 100644 index 0000000000..9f90643217 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/datasets.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +""" +Common functions to load datasets and compute their ground-truth +""" + +from __future__ import print_function +import time +import numpy as np +import faiss +import sys + +# set this to the directory that contains the datafiles. +# deep1b data should be at simdir + 'deep1b' +# bigann data should be at simdir + 'bigann' +simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/' + +################################################################# +# Small I/O functions +################################################################# + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def bvecs_mmap(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +def ivecs_write(fname, m): + n, d = m.shape + m1 = np.empty((n, d + 1), dtype='int32') + m1[:, 0] = d + m1[:, 1:] = m + m1.tofile(fname) + + +def fvecs_write(fname, m): + m = m.astype('float32') + ivecs_write(fname, m.view('int32')) + + + +################################################################# +# Dataset +################################################################# + +def sanitize(x): + return np.ascontiguousarray(x, dtype='float32') + + +class ResultHeap: + """ Combine query results from a sliced dataset """ + + def __init__(self, nq, k): + " nq: number of query vectors, k: number of results per query " + self.I = np.zeros((nq, k), dtype='int64') + self.D = np.zeros((nq, k), dtype='float32') + self.nq, self.k = nq, k + heaps = faiss.float_maxheap_array_t() + heaps.k = k + heaps.nh = nq + heaps.val = faiss.swig_ptr(self.D) + heaps.ids = faiss.swig_ptr(self.I) + heaps.heapify() + self.heaps = heaps + + def add_batch_result(self, D, I, i0): + assert D.shape == (self.nq, self.k) + assert I.shape == (self.nq, self.k) + I += i0 + self.heaps.addn_with_ids( + self.k, faiss.swig_ptr(D), + faiss.swig_ptr(I), self.k) + + def finalize(self): + self.heaps.reorder() + + +def compute_GT_sliced(xb, xq, k): + print("compute GT") + t0 = time.time() + nb, d = xb.shape + nq, d = xq.shape + rh = ResultHeap(nq, k) + bs = 10 ** 5 + + xqs = sanitize(xq) + + db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) + + # compute ground-truth by blocks of bs, and add to heaps + for i0 in range(0, nb, bs): + i1 = min(nb, i0 + bs) + xsl = sanitize(xb[i0:i1]) + db_gt.add(xsl) + D, I = db_gt.search(xqs, k) + rh.add_batch_result(D, I, i0) + db_gt.reset() + print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ') + sys.stdout.flush() + print() + rh.finalize() + gt_I = rh.I + + print("GT time: %.3f s" % (time.time() - t0)) + return gt_I + + +def do_compute_gt(xb, xq, k): + print("computing GT") + nb, d = xb.shape + index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) + if nb < 100 * 1000: + print(" add") + index.add(np.ascontiguousarray(xb, dtype='float32')) + print(" search") + D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k) + else: + I = compute_GT_sliced(xb, xq, k) + + return I.astype('int32') + + +def load_data(dataset='deep1M', compute_gt=False): + + print("load data", dataset) + + if dataset == 'sift1M': + basedir = simdir + 'sift1M/' + + xt = fvecs_read(basedir + "sift_learn.fvecs") + xb = fvecs_read(basedir + "sift_base.fvecs") + xq = fvecs_read(basedir + "sift_query.fvecs") + gt = ivecs_read(basedir + "sift_groundtruth.ivecs") + + elif dataset.startswith('bigann'): + basedir = simdir + 'bigann/' + + dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1]) + xb = bvecs_mmap(basedir + 'bigann_base.bvecs') + xq = bvecs_mmap(basedir + 'bigann_query.bvecs') + xt = bvecs_mmap(basedir + 'bigann_learn.bvecs') + # trim xb to correct size + xb = xb[:dbsize * 1000 * 1000] + gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize) + + elif dataset.startswith("deep"): + basedir = simdir + 'deep1b/' + szsuf = dataset[4:] + if szsuf[-1] == 'M': + dbsize = 10 ** 6 * int(szsuf[:-1]) + elif szsuf == '1B': + dbsize = 10 ** 9 + elif szsuf[-1] == 'k': + dbsize = 1000 * int(szsuf[:-1]) + else: + assert False, "did not recognize suffix " + szsuf + + xt = fvecs_mmap(basedir + "learn.fvecs") + xb = fvecs_mmap(basedir + "base.fvecs") + xq = fvecs_read(basedir + "deep1B_queries.fvecs") + + xb = xb[:dbsize] + + gt_fname = basedir + "%s_groundtruth.ivecs" % dataset + if compute_gt: + gt = do_compute_gt(xb, xq, 100) + print("store", gt_fname) + ivecs_write(gt_fname, gt) + + gt = ivecs_read(gt_fname) + + else: + assert False + + print("dataset %s sizes: B %s Q %s T %s" % ( + dataset, xb.shape, xq.shape, xt.shape)) + + return xt, xb, xq, gt + +################################################################# +# Evaluation +################################################################# + + +def evaluate_DI(D, I, gt): + nq = gt.shape[0] + k = I.shape[1] + rank = 1 + while rank <= k: + recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) + print("R@%d: %.4f" % (rank, recall), end=' ') + rank *= 10 + + +def evaluate(xq, gt, index, k=100, endl=True): + t0 = time.time() + D, I = index.search(xq, k) + t1 = time.time() + nq = xq.shape[0] + print("\t %8.4f ms per query, " % ( + (t1 - t0) * 1000.0 / nq), end=' ') + rank = 1 + while rank <= k: + recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) + print("R@%d: %.4f" % (rank, recall), end=' ') + rank *= 10 + if endl: + print() + return D, I diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/parse_bench_all_ivf.py b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/parse_bench_all_ivf.py new file mode 100644 index 0000000000..1a4d260ea5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/parse_bench_all_ivf.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +import os +import numpy as np +from matplotlib import pyplot + +import re + +from argparse import Namespace + + +# the directory used in run_on_cluster.bash +basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/' +logdir = basedir + 'logs/' + + +# which plot to output +db = 'bigann1B' +code_size = 8 + + + +def unitsize(indexkey): + """ size of one vector in the index """ + mo = re.match('.*,PQ(\\d+)', indexkey) + if mo: + return int(mo.group(1)) + if indexkey.endswith('SQ8'): + bits_per_d = 8 + elif indexkey.endswith('SQ4'): + bits_per_d = 4 + elif indexkey.endswith('SQfp16'): + bits_per_d = 16 + else: + assert False + mo = re.match('PCAR(\\d+),.*', indexkey) + if mo: + return bits_per_d * int(mo.group(1)) / 8 + mo = re.match('OPQ\\d+_(\\d+),.*', indexkey) + if mo: + return bits_per_d * int(mo.group(1)) / 8 + mo = re.match('RR(\\d+),.*', indexkey) + if mo: + return bits_per_d * int(mo.group(1)) / 8 + assert False + + +def dbsize_from_name(dbname): + sufs = { + '1B': 10**9, + '100M': 10**8, + '10M': 10**7, + '1M': 10**6, + } + for s in sufs: + if dbname.endswith(s): + return sufs[s] + else: + assert False + + +def keep_latest_stdout(fnames): + fnames = [fname for fname in fnames if fname.endswith('.stdout')] + fnames.sort() + n = len(fnames) + fnames2 = [] + for i, fname in enumerate(fnames): + if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]: + continue + fnames2.append(fname) + return fnames2 + + +def parse_result_file(fname): + # print fname + st = 0 + res = [] + keys = [] + stats = {} + stats['run_version'] = fname[-8] + for l in open(fname): + if st == 0: + if l.startswith('CHRONOS_JOB_INSTANCE_ID'): + stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1] + if l.startswith('index size on disk:'): + stats['index_size'] = int(l.split()[-1]) + if l.startswith('current RSS:'): + stats['RSS'] = int(l.split()[-1]) + if l.startswith('precomputed tables size:'): + stats['tables_size'] = int(l.split()[-1]) + if l.startswith('Setting nb of threads to'): + stats['n_threads'] = int(l.split()[-1]) + if l.startswith(' add in'): + stats['add_time'] = float(l.split()[-2]) + if l.startswith('args:'): + args = eval(l[l.find(' '):]) + indexkey = args.indexkey + elif 'R@1 R@10 R@100' in l: + st = 1 + elif 'index size on disk:' in l: + index_size = int(l.split()[-1]) + elif st == 1: + st = 2 + elif st == 2: + fi = l.split() + keys.append(fi[0]) + res.append([float(x) for x in fi[1:]]) + return indexkey, np.array(res), keys, stats + +# run parsing +allres = {} +allstats = {} +nts = [] +missing = [] +versions = {} + +fnames = keep_latest_stdout(os.listdir(logdir)) +# print fnames +# filenames are in the form .x.stdout +# where x is a version number (from a to z) +# keep only latest version of each name + +for fname in fnames: + if not ('db' + db in fname and fname.endswith('.stdout')): + continue + indexkey, res, _, stats = parse_result_file(logdir + fname) + if res.size == 0: + missing.append(fname) + errorline = open( + logdir + fname.replace('.stdout', '.stderr')).readlines() + if len(errorline) > 0: + errorline = errorline[-1] + else: + errorline = 'NO STDERR' + print fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline + + else: + if indexkey in allres: + if allstats[indexkey]['run_version'] > stats['run_version']: + # don't use this run + continue + n_threads = stats.get('n_threads', 1) + nts.append(n_threads) + allres[indexkey] = res + allstats[indexkey] = stats + +assert len(set(nts)) == 1 +n_threads = nts[0] + + +def plot_tradeoffs(allres, code_size, recall_rank): + dbsize = dbsize_from_name(db) + recall_idx = int(np.log10(recall_rank)) + + bigtab = [] + names = [] + + for k,v in sorted(allres.items()): + if v.ndim != 2: continue + us = unitsize(k) + if us != code_size: continue + perf = v[:, recall_idx] + times = v[:, 3] + bigtab.append( + np.vstack(( + np.ones(times.size, dtype=int) * len(names), + perf, times + )) + ) + names.append(k) + + bigtab = np.hstack(bigtab) + + perm = np.argsort(bigtab[1, :]) + bigtab = bigtab[:, perm] + + times = np.minimum.accumulate(bigtab[2, ::-1])[::-1] + selection = np.where(bigtab[2, :] == times) + + selected_methods = [names[i] for i in + np.unique(bigtab[0, selection].astype(int))] + not_selected = list(set(names) - set(selected_methods)) + + print "methods without an optimal OP: ", not_selected + + nq = 10000 + pyplot.title('database ' + db + ' code_size=%d' % code_size) + + # grayed out lines + + for k in not_selected: + v = allres[k] + if v.ndim != 2: continue + us = unitsize(k) + if us != code_size: continue + + linestyle = (':' if 'PQ' in k else + '-.' if 'SQ4' in k else + '--' if 'SQ8' in k else '-') + + pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None, + linestyle=linestyle, + marker='o' if 'HNSW' in k else '+', + color='#cccccc', linewidth=0.2) + + # important methods + for k in selected_methods: + v = allres[k] + if v.ndim != 2: continue + us = unitsize(k) + if us != code_size: continue + + stats = allstats[k] + tot_size = stats['index_size'] + stats['tables_size'] + id_size = 8 # 64 bit + + addt = '' + if 'add_time' in stats: + add_time = stats['add_time'] + if add_time > 7200: + add_min = add_time / 60 + addt = ', %dh%02d' % (add_min / 60, add_min % 60) + else: + add_sec = int(add_time) + addt = ', %dm%02d' % (add_sec / 60, add_sec % 60) + + + label = k + ' (size+%.1f%%%s)' % ( + tot_size / float((code_size + id_size) * dbsize) * 100 - 100, + addt) + + linestyle = (':' if 'PQ' in k else + '-.' if 'SQ4' in k else + '--' if 'SQ8' in k else '-') + + pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label, + linestyle=linestyle, + marker='o' if 'HNSW' in k else '+') + + if len(not_selected) == 0: + om = '' + else: + om = '\nomitted:' + nc = len(om) + for m in not_selected: + if nc > 80: + om += '\n' + nc = 0 + om += ' ' + m + nc += len(m) + 1 + + pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) ) + pyplot.ylabel('search time per query (ms, %d threads)' % n_threads) + pyplot.legend() + pyplot.grid() + pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % ( + db, code_size, recall_rank)) + return selected_methods, not_selected + + +pyplot.gcf().set_size_inches(15, 10) + +plot_tradeoffs(allres, code_size=code_size, recall_rank=1) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/run_on_cluster_generic.bash b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/run_on_cluster_generic.bash new file mode 100644 index 0000000000..6d88f43d9a --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_all_ivf/run_on_cluster_generic.bash @@ -0,0 +1,249 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# @nolint + +# This script launches the experiments on a cluster +# It assumes two shell functions are defined: +# +# run_on_1machine: runs a command on one (full) machine on a cluster +# +# run_on_8gpu: runs a command on one machine with 8 GPUs +# +# the two functions are called as: +# +# run_on_1machine +# +# the stdout of the command should be stored in $logdir/.stdout + +function run_on_1machine () { + # To be implemented +} + +function run_on_8gpu () { + # To be implemented +} + + +# prepare output directories +# set to some directory where all indexes, can be written. +basedir=XXXXX + +logdir=$basedir/logs +indexdir=$basedir/indexes + +mkdir -p $lars $logdir $indexdir + + +############################### 1M experiments + +for db in sift1M deep1M bigann1M; do + + for coarse in IMI2x9 IMI2x10 IVF1024_HNSW32 IVF4096_HNSW32 IVF16384_HNSW32 + do + + for indexkey in \ + OPQ8_64,$coarse,PQ8 \ + PCAR16,$coarse,SQ4 \ + OPQ16_64,$coarse,PQ16 \ + PCAR32,$coarse,SQ4 \ + PCAR16,$coarse,SQ8 \ + OPQ32_128,$coarse,PQ32 \ + PCAR64,$coarse,SQ4 \ + PCAR32,$coarse,SQ8 \ + PCAR16,$coarse,SQfp16 \ + PCAR64,$coarse,SQ8 \ + PCAR32,$coarse,SQfp16 \ + PCAR128,$coarse,SQ4 + do + key=autotune.db$db.${indexkey//,/_} + run_on_1machine $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey $indexkey \ + --maxtrain 0 \ + --indexfile $indexdir/$key.faissindex + + done + done +done + + + +############################### 10M experiments + + +for db in deep10M bigann10M; do + + for coarse in \ + IMI2x10 IMI2x11 IMI2x12 IMI2x13 IVF4096_HNSW32 \ + IVF16384_HNSW32 IVF65536_HNSW32 IVF262144_HNSW32 + do + + for indexkey in \ + OPQ8_64,$coarse,PQ8 \ + PCAR16,$coarse,SQ4 \ + OPQ16_64,$coarse,PQ16 \ + PCAR32,$coarse,SQ4 \ + PCAR16,$coarse,SQ8 \ + OPQ32_128,$coarse,PQ32 \ + PCAR64,$coarse,SQ4 \ + PCAR32,$coarse,SQ8 \ + PCAR16,$coarse,SQfp16 \ + PCAR64,$coarse,SQ8 \ + PCAR32,$coarse,SQfp16 \ + PCAR128,$coarse,SQ4 \ + OPQ64_128,$coarse,PQ64 + do + key=autotune.db$db.${indexkey//,/_} + run_on_1machine $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey $indexkey \ + --maxtrain 0 \ + --indexfile $indexdir/$key.faissindex \ + --searchthreads 16 \ + --min_test_duration 3 \ + + done + done +done + + +############################### 100M experiments + +for db in deep100M bigann100M; do + + for coarse in IMI2x11 IMI2x12 IVF65536_HNSW32 IVF262144_HNSW32 + do + + for indexkey in \ + OPQ8_64,$coarse,PQ8 \ + OPQ16_64,$coarse,PQ16 \ + PCAR32,$coarse,SQ4 \ + OPQ32_128,$coarse,PQ32 \ + PCAR64,$coarse,SQ4 \ + PCAR32,$coarse,SQ8 \ + PCAR64,$coarse,SQ8 \ + PCAR32,$coarse,SQfp16 \ + PCAR128,$coarse,SQ4 \ + OPQ64_128,$coarse,PQ64 + do + key=autotune.db$db.${indexkey//,/_} + run_on_1machine $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey $indexkey \ + --maxtrain 0 \ + --indexfile $indexdir/$key.faissindex \ + --searchthreads 16 \ + --min_test_duration 3 \ + --add_bs 1000000 + + done + done +done + + +############################### 1B experiments + +for db in deep1B bigann1B; do + + for coarse in IMI2x12 IMI2x13 IVF262144_HNSW32 + do + + for indexkey in \ + OPQ8_64,$coarse,PQ8 \ + OPQ16_64,$coarse,PQ16 \ + PCAR32,$coarse,SQ4 \ + OPQ32_128,$coarse,PQ32 \ + PCAR64,$coarse,SQ4 \ + PCAR32,$coarse,SQ8 \ + PCAR64,$coarse,SQ8 \ + PCAR32,$coarse,SQfp16 \ + PCAR128,$coarse,SQ4 \ + PQ64_128,$coarse,PQ64 \ + RR128,$coarse,SQ4 + do + key=autotune.db$db.${indexkey//,/_} + run_on_1machine $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey $indexkey \ + --maxtrain 0 \ + --indexfile $indexdir/$key.faissindex \ + --searchthreads 16 \ + --min_test_duration 3 \ + --add_bs 1000000 + + done + done + +done + +############################################ +# precompute centroids on GPU for large vocabularies + + +for db in deep1M bigann1M; do + + for ncent in 1048576 4194304; do + + key=clustering.db$db.IVF$ncent + run_on_8gpu $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey IVF$ncent,SQ8 \ + --maxtrain 100000000 \ + --indexfile $indexdir/$key.faissindex \ + --searchthreads 16 \ + --min_test_duration 3 \ + --add_bs 1000000 \ + --train_on_gpu + + done +done + + +################################# +# Run actual experiment + +for db in deep1B bigann1B; do + + for ncent in 1048576 4194304; do + coarse=IVF${ncent}_HNSW32 + centroidsname=clustering.db${db/1B/1M}.IVF${ncent}.faissindex + + for indexkey in \ + OPQ8_64,$coarse,PQ8 \ + OPQ16_64,$coarse,PQ16 \ + PCAR32,$coarse,SQ4 \ + OPQ32_128,$coarse,PQ32 \ + PCAR64,$coarse,SQ4 \ + PCAR32,$coarse,SQ8 \ + PCAR64,$coarse,SQ8 \ + PCAR32,$coarse,SQfp16 \ + OPQ64_128,$coarse,PQ64 \ + RR128,$coarse,SQ4 \ + OPQ64_128,$coarse,PQ64 \ + RR128,$coarse,SQ4 + do + key=autotune.db$db.${indexkey//,/_} + + run_on_1machine $key.c $key \ + python -u bench_all_ivf.py \ + --db $db \ + --indexkey $indexkey \ + --maxtrain 256000 \ + --indexfile $indexdir/$key.faissindex \ + --get_centroids_from $indexdir/$centroidsname \ + --searchthreads 16 \ + --min_test_duration 3 \ + --add_bs 1000000 + + done + done + +done diff --git a/core/src/index/thirdparty/faiss/benchs/bench_for_interrupt.py b/core/src/index/thirdparty/faiss/benchs/bench_for_interrupt.py new file mode 100644 index 0000000000..b72d825ef9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_for_interrupt.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python3 + +from __future__ import print_function +import numpy as np +import faiss +import time +import os +import argparse + + +parser = argparse.ArgumentParser() + +def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + +group = parser.add_argument_group('dataset options') +aa('--dim', type=int, default=64) +aa('--nb', type=int, default=int(1e6)) +aa('--subset_len', type=int, default=int(1e5)) +aa('--key', default='IVF1000,Flat') +aa('--nprobe', type=int, default=640) +aa('--no_intcallback', default=False, action='store_true') +aa('--twostage', default=False, action='store_true') +aa('--nt', type=int, default=-1) + + +args = parser.parse_args() +print("args:", args) + + +d = args.dim # dimension +nb = args.nb # database size +nq = 1000 # nb of queries +nt = 100000 +subset_len = args.subset_len + + +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xq = np.random.random((nq, d)).astype('float32') +xt = np.random.random((nt, d)).astype('float32') +k = 100 + +if args.no_intcallback: + faiss.InterruptCallback.clear_instance() + +if args.nt != -1: + faiss.omp_set_num_threads(args.nt) + +nprobe = args.nprobe +key = args.key +#key = 'IVF1000,Flat' +# key = 'IVF1000,PQ64' +# key = 'IVF100_HNSW32,PQ64' + +# faiss.omp_set_num_threads(1) + +pf = 'dim%d_' % d +if d == 64: + pf = '' + +basename = '/tmp/base%s%s.index' % (pf, key) + +if os.path.exists(basename): + print('load', basename) + index_1 = faiss.read_index(basename) +else: + print('train + write', basename) + index_1 = faiss.index_factory(d, key) + index_1.train(xt) + faiss.write_index(index_1, basename) + +print('add') +index_1.add(xb) + +print('set nprobe=', nprobe) +faiss.ParameterSpace().set_index_parameter(index_1, 'nprobe', nprobe) + +class ResultHeap: + """ Combine query results from a sliced dataset """ + + def __init__(self, nq, k): + " nq: number of query vectors, k: number of results per query " + self.I = np.zeros((nq, k), dtype='int64') + self.D = np.zeros((nq, k), dtype='float32') + self.nq, self.k = nq, k + heaps = faiss.float_maxheap_array_t() + heaps.k = k + heaps.nh = nq + heaps.val = faiss.swig_ptr(self.D) + heaps.ids = faiss.swig_ptr(self.I) + heaps.heapify() + self.heaps = heaps + + def add_batch_result(self, D, I, i0): + assert D.shape == (self.nq, self.k) + assert I.shape == (self.nq, self.k) + I += i0 + self.heaps.addn_with_ids( + self.k, faiss.swig_ptr(D), + faiss.swig_ptr(I), self.k) + + def finalize(self): + self.heaps.reorder() + +stats = faiss.cvar.indexIVF_stats +stats.reset() + +print('index size', index_1.ntotal, + 'imbalance', index_1.invlists.imbalance_factor()) +start = time.time() +Dref, Iref = index_1.search(xq, k) +print('time of searching: %.3f s = %.3f + %.3f ms' % ( + time.time() - start, stats.quantization_time, stats.search_time)) + +indexes = {} +if args.twostage: + + for i in range(0, nb, subset_len): + index = faiss.read_index(basename) + faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe) + print("add %d:%d" %(i, i+subset_len)) + index.add(xb[i:i + subset_len]) + indexes[i] = index + +rh = ResultHeap(nq, k) +sum_time = tq = ts = 0 +for i in range(0, nb, subset_len): + if not args.twostage: + index = faiss.read_index(basename) + faiss.ParameterSpace().set_index_parameter(index, 'nprobe', nprobe) + print("add %d:%d" %(i, i+subset_len)) + index.add(xb[i:i + subset_len]) + else: + index = indexes[i] + + stats.reset() + start = time.time() + Di, Ii = index.search(xq, k) + sum_time = sum_time + time.time() - start + tq += stats.quantization_time + ts += stats.search_time + rh.add_batch_result(Di, Ii, i) + +print('time of searching separately: %.3f s = %.3f + %.3f ms' % + (sum_time, tq, ts)) + +rh.finalize() + +print('diffs: %d / %d' % ((Iref != rh.I).sum(), Iref.size)) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_gpu_1bn.py b/core/src/index/thirdparty/faiss/benchs/bench_gpu_1bn.py new file mode 100644 index 0000000000..c676f7c793 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_gpu_1bn.py @@ -0,0 +1,747 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +from __future__ import print_function +import numpy as np +import time +import os +import sys +import faiss +import re + +from multiprocessing.dummy import Pool as ThreadPool +from datasets import ivecs_read + +#################################################################### +# Parse command line +#################################################################### + + +def usage(): + print(""" + +Usage: bench_gpu_1bn.py dataset indextype [options] + +dataset: set of vectors to operate on. + Supported: SIFT1M, SIFT2M, ..., SIFT1000M or Deep1B + +indextype: any index type supported by index_factory that runs on GPU. + + General options + +-ngpu ngpu nb of GPUs to use (default = all) +-tempmem N use N bytes of temporary GPU memory +-nocache do not read or write intermediate files +-float16 use 16-bit floats on the GPU side + + Add options + +-abs N split adds in blocks of no more than N vectors +-max_add N copy sharded dataset to CPU each max_add additions + (to avoid memory overflows with geometric reallocations) +-altadd Alternative add function, where the index is not stored + on GPU during add. Slightly faster for big datasets on + slow GPUs + + Search options + +-R R: nb of replicas of the same dataset (the dataset + will be copied across ngpu/R, default R=1) +-noptables do not use precomputed tables in IVFPQ. +-qbs N split queries in blocks of no more than N vectors +-nnn N search N neighbors for each query +-nprobe 4,16,64 try this number of probes +-knngraph instead of the standard setup for the dataset, + compute a k-nn graph with nnn neighbors per element +-oI xx%d.npy output the search result indices to this numpy file, + %d will be replaced with the nprobe +-oD xx%d.npy output the search result distances to this file + +""", file=sys.stderr) + sys.exit(1) + + +# default values + +dbname = None +index_key = None + +ngpu = faiss.get_num_gpus() + +replicas = 1 # nb of replicas of sharded dataset +add_batch_size = 32768 +query_batch_size = 16384 +nprobes = [1 << l for l in range(9)] +knngraph = False +use_precomputed_tables = True +tempmem = -1 # if -1, use system default +max_add = -1 +use_float16 = False +use_cache = True +nnn = 10 +altadd = False +I_fname = None +D_fname = None + +args = sys.argv[1:] + +while args: + a = args.pop(0) + if a == '-h': usage() + elif a == '-ngpu': ngpu = int(args.pop(0)) + elif a == '-R': replicas = int(args.pop(0)) + elif a == '-noptables': use_precomputed_tables = False + elif a == '-abs': add_batch_size = int(args.pop(0)) + elif a == '-qbs': query_batch_size = int(args.pop(0)) + elif a == '-nnn': nnn = int(args.pop(0)) + elif a == '-tempmem': tempmem = int(args.pop(0)) + elif a == '-nocache': use_cache = False + elif a == '-knngraph': knngraph = True + elif a == '-altadd': altadd = True + elif a == '-float16': use_float16 = True + elif a == '-nprobe': nprobes = [int(x) for x in args.pop(0).split(',')] + elif a == '-max_add': max_add = int(args.pop(0)) + elif not dbname: dbname = a + elif not index_key: index_key = a + else: + print("argument %s unknown" % a, file=sys.stderr) + sys.exit(1) + +cacheroot = '/tmp/bench_gpu_1bn' + +if not os.path.isdir(cacheroot): + print("%s does not exist, creating it" % cacheroot) + os.mkdir(cacheroot) + +################################################################# +# Small Utility Functions +################################################################# + +# we mem-map the biggest files to avoid having them in memory all at +# once + +def mmap_fvecs(fname): + x = np.memmap(fname, dtype='int32', mode='r') + d = x[0] + return x.view('float32').reshape(-1, d + 1)[:, 1:] + +def mmap_bvecs(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +def rate_limited_imap(f, l): + """A threaded imap that does not produce elements faster than they + are consumed""" + pool = ThreadPool(1) + res = None + for i in l: + res_next = pool.apply_async(f, (i, )) + if res: + yield res.get() + res = res_next + yield res.get() + + +class IdentPreproc: + """a pre-processor is either a faiss.VectorTransform or an IndentPreproc""" + + def __init__(self, d): + self.d_in = self.d_out = d + + def apply_py(self, x): + return x + + +def sanitize(x): + """ convert array to a c-contiguous float array """ + return np.ascontiguousarray(x.astype('float32')) + + +def dataset_iterator(x, preproc, bs): + """ iterate over the lines of x in blocks of size bs""" + + nb = x.shape[0] + block_ranges = [(i0, min(nb, i0 + bs)) + for i0 in range(0, nb, bs)] + + def prepare_block(i01): + i0, i1 = i01 + xb = sanitize(x[i0:i1]) + return i0, preproc.apply_py(xb) + + return rate_limited_imap(prepare_block, block_ranges) + + +def eval_intersection_measure(gt_I, I): + """ measure intersection measure (used for knngraph)""" + inter = 0 + rank = I.shape[1] + assert gt_I.shape[1] >= rank + for q in range(nq_gt): + inter += faiss.ranklist_intersection_size( + rank, faiss.swig_ptr(gt_I[q, :]), + rank, faiss.swig_ptr(I[q, :].astype('int64'))) + return inter / float(rank * nq_gt) + + +################################################################# +# Prepare dataset +################################################################# + +print("Preparing dataset", dbname) + +if dbname.startswith('SIFT'): + # SIFT1M to SIFT1000M + dbsize = int(dbname[4:-1]) + xb = mmap_bvecs('bigann/bigann_base.bvecs') + xq = mmap_bvecs('bigann/bigann_query.bvecs') + xt = mmap_bvecs('bigann/bigann_learn.bvecs') + + # trim xb to correct size + xb = xb[:dbsize * 1000 * 1000] + + gt_I = ivecs_read('bigann/gnd/idx_%dM.ivecs' % dbsize) + +elif dbname == 'Deep1B': + xb = mmap_fvecs('deep1b/base.fvecs') + xq = mmap_fvecs('deep1b/deep1B_queries.fvecs') + xt = mmap_fvecs('deep1b/learn.fvecs') + # deep1B's train is is outrageously big + xt = xt[:10 * 1000 * 1000] + gt_I = ivecs_read('deep1b/deep1B_groundtruth.ivecs') + +else: + print('unknown dataset', dbname, file=sys.stderr) + sys.exit(1) + + +if knngraph: + # convert to knn-graph dataset + xq = xb + xt = xb + # we compute the ground-truth on this number of queries for validation + nq_gt = 10000 + gt_sl = 100 + + # ground truth will be computed below + gt_I = None + + +print("sizes: B %s Q %s T %s gt %s" % ( + xb.shape, xq.shape, xt.shape, + gt_I.shape if gt_I is not None else None)) + + + +################################################################# +# Parse index_key and set cache files +# +# The index_key is a valid factory key that would work, but we +# decompose the training to do it faster +################################################################# + + +pat = re.compile('(OPQ[0-9]+(_[0-9]+)?,|PCAR[0-9]+,)?' + + '(IVF[0-9]+),' + + '(PQ[0-9]+|Flat)') + +matchobject = pat.match(index_key) + +assert matchobject, 'could not parse ' + index_key + +mog = matchobject.groups() + +preproc_str = mog[0] +ivf_str = mog[2] +pqflat_str = mog[3] + +ncent = int(ivf_str[3:]) + +prefix = '' + +if knngraph: + gt_cachefile = '%s/BK_gt_%s.npy' % (cacheroot, dbname) + prefix = 'BK_' + # files must be kept distinct because the training set is not the + # same for the knngraph + +if preproc_str: + preproc_cachefile = '%s/%spreproc_%s_%s.vectrans' % ( + cacheroot, prefix, dbname, preproc_str[:-1]) +else: + preproc_cachefile = None + preproc_str = '' + +cent_cachefile = '%s/%scent_%s_%s%s.npy' % ( + cacheroot, prefix, dbname, preproc_str, ivf_str) + +index_cachefile = '%s/%s%s_%s%s,%s.index' % ( + cacheroot, prefix, dbname, preproc_str, ivf_str, pqflat_str) + + +if not use_cache: + preproc_cachefile = None + cent_cachefile = None + index_cachefile = None + +print("cachefiles:") +print(preproc_cachefile) +print(cent_cachefile) +print(index_cachefile) + + +################################################################# +# Wake up GPUs +################################################################# + +print("preparing resources for %d GPUs" % ngpu) + +gpu_resources = [] + +for i in range(ngpu): + res = faiss.StandardGpuResources() + if tempmem >= 0: + res.setTempMemory(tempmem) + gpu_resources.append(res) + + +def make_vres_vdev(i0=0, i1=-1): + " return vectors of device ids and resources useful for gpu_multiple" + vres = faiss.GpuResourcesVector() + vdev = faiss.IntVector() + if i1 == -1: + i1 = ngpu + for i in range(i0, i1): + vdev.push_back(i) + vres.push_back(gpu_resources[i]) + return vres, vdev + + +################################################################# +# Prepare ground truth (for the knngraph) +################################################################# + + +def compute_GT(): + print("compute GT") + t0 = time.time() + + gt_I = np.zeros((nq_gt, gt_sl), dtype='int64') + gt_D = np.zeros((nq_gt, gt_sl), dtype='float32') + heaps = faiss.float_maxheap_array_t() + heaps.k = gt_sl + heaps.nh = nq_gt + heaps.val = faiss.swig_ptr(gt_D) + heaps.ids = faiss.swig_ptr(gt_I) + heaps.heapify() + bs = 10 ** 5 + + n, d = xb.shape + xqs = sanitize(xq[:nq_gt]) + + db_gt = faiss.IndexFlatL2(d) + vres, vdev = make_vres_vdev() + db_gt_gpu = faiss.index_cpu_to_gpu_multiple( + vres, vdev, db_gt) + + # compute ground-truth by blocks of bs, and add to heaps + for i0, xsl in dataset_iterator(xb, IdentPreproc(d), bs): + db_gt_gpu.add(xsl) + D, I = db_gt_gpu.search(xqs, gt_sl) + I += i0 + heaps.addn_with_ids( + gt_sl, faiss.swig_ptr(D), faiss.swig_ptr(I), gt_sl) + db_gt_gpu.reset() + print("\r %d/%d, %.3f s" % (i0, n, time.time() - t0), end=' ') + print() + heaps.reorder() + + print("GT time: %.3f s" % (time.time() - t0)) + return gt_I + + +if knngraph: + + if gt_cachefile and os.path.exists(gt_cachefile): + print("load GT", gt_cachefile) + gt_I = np.load(gt_cachefile) + else: + gt_I = compute_GT() + if gt_cachefile: + print("store GT", gt_cachefile) + np.save(gt_cachefile, gt_I) + +################################################################# +# Prepare the vector transformation object (pure CPU) +################################################################# + + +def train_preprocessor(): + print("train preproc", preproc_str) + d = xt.shape[1] + t0 = time.time() + if preproc_str.startswith('OPQ'): + fi = preproc_str[3:-1].split('_') + m = int(fi[0]) + dout = int(fi[1]) if len(fi) == 2 else d + preproc = faiss.OPQMatrix(d, m, dout) + elif preproc_str.startswith('PCAR'): + dout = int(preproc_str[4:-1]) + preproc = faiss.PCAMatrix(d, dout, 0, True) + else: + assert False + preproc.train(sanitize(xt[:1000000])) + print("preproc train done in %.3f s" % (time.time() - t0)) + return preproc + + +def get_preprocessor(): + if preproc_str: + if not preproc_cachefile or not os.path.exists(preproc_cachefile): + preproc = train_preprocessor() + if preproc_cachefile: + print("store", preproc_cachefile) + faiss.write_VectorTransform(preproc, preproc_cachefile) + else: + print("load", preproc_cachefile) + preproc = faiss.read_VectorTransform(preproc_cachefile) + else: + d = xb.shape[1] + preproc = IdentPreproc(d) + return preproc + + +################################################################# +# Prepare the coarse quantizer +################################################################# + + +def train_coarse_quantizer(x, k, preproc): + d = preproc.d_out + clus = faiss.Clustering(d, k) + clus.verbose = True + # clus.niter = 2 + clus.max_points_per_centroid = 10000000 + + print("apply preproc on shape", x.shape, 'k=', k) + t0 = time.time() + x = preproc.apply_py(sanitize(x)) + print(" preproc %.3f s output shape %s" % ( + time.time() - t0, x.shape)) + + vres, vdev = make_vres_vdev() + index = faiss.index_cpu_to_gpu_multiple( + vres, vdev, faiss.IndexFlatL2(d)) + + clus.train(x, index) + centroids = faiss.vector_float_to_array(clus.centroids) + + return centroids.reshape(k, d) + + +def prepare_coarse_quantizer(preproc): + + if cent_cachefile and os.path.exists(cent_cachefile): + print("load centroids", cent_cachefile) + centroids = np.load(cent_cachefile) + else: + nt = max(1000000, 256 * ncent) + print("train coarse quantizer...") + t0 = time.time() + centroids = train_coarse_quantizer(xt[:nt], ncent, preproc) + print("Coarse train time: %.3f s" % (time.time() - t0)) + if cent_cachefile: + print("store centroids", cent_cachefile) + np.save(cent_cachefile, centroids) + + coarse_quantizer = faiss.IndexFlatL2(preproc.d_out) + coarse_quantizer.add(centroids) + + return coarse_quantizer + + +################################################################# +# Make index and add elements to it +################################################################# + + +def prepare_trained_index(preproc): + + coarse_quantizer = prepare_coarse_quantizer(preproc) + d = preproc.d_out + if pqflat_str == 'Flat': + print("making an IVFFlat index") + idx_model = faiss.IndexIVFFlat(coarse_quantizer, d, ncent, + faiss.METRIC_L2) + else: + m = int(pqflat_str[2:]) + assert m < 56 or use_float16, "PQ%d will work only with -float16" % m + print("making an IVFPQ index, m = ", m) + idx_model = faiss.IndexIVFPQ(coarse_quantizer, d, ncent, m, 8) + + coarse_quantizer.this.disown() + idx_model.own_fields = True + + # finish training on CPU + t0 = time.time() + print("Training vector codes") + x = preproc.apply_py(sanitize(xt[:1000000])) + idx_model.train(x) + print(" done %.3f s" % (time.time() - t0)) + + return idx_model + + +def compute_populated_index(preproc): + """Add elements to a sharded index. Return the index and if available + a sharded gpu_index that contains the same data. """ + + indexall = prepare_trained_index(preproc) + + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = use_float16 + co.useFloat16CoarseQuantizer = False + co.usePrecomputed = use_precomputed_tables + co.indicesOptions = faiss.INDICES_CPU + co.verbose = True + co.reserveVecs = max_add if max_add > 0 else xb.shape[0] + co.shard = True + assert co.shard_type in (0, 1, 2) + vres, vdev = make_vres_vdev() + gpu_index = faiss.index_cpu_to_gpu_multiple( + vres, vdev, indexall, co) + + print("add...") + t0 = time.time() + nb = xb.shape[0] + for i0, xs in dataset_iterator(xb, preproc, add_batch_size): + i1 = i0 + xs.shape[0] + gpu_index.add_with_ids(xs, np.arange(i0, i1)) + if max_add > 0 and gpu_index.ntotal > max_add: + print("Flush indexes to CPU") + for i in range(ngpu): + index_src_gpu = faiss.downcast_index(gpu_index.at(i)) + index_src = faiss.index_gpu_to_cpu(index_src_gpu) + print(" index %d size %d" % (i, index_src.ntotal)) + index_src.copy_subset_to(indexall, 0, 0, nb) + index_src_gpu.reset() + index_src_gpu.reserveMemory(max_add) + gpu_index.sync_with_shard_indexes() + + print('\r%d/%d (%.3f s) ' % ( + i0, nb, time.time() - t0), end=' ') + sys.stdout.flush() + print("Add time: %.3f s" % (time.time() - t0)) + + print("Aggregate indexes to CPU") + t0 = time.time() + + if hasattr(gpu_index, 'at'): + # it is a sharded index + for i in range(ngpu): + index_src = faiss.index_gpu_to_cpu(gpu_index.at(i)) + print(" index %d size %d" % (i, index_src.ntotal)) + index_src.copy_subset_to(indexall, 0, 0, nb) + else: + # simple index + index_src = faiss.index_gpu_to_cpu(gpu_index) + index_src.copy_subset_to(indexall, 0, 0, nb) + + print(" done in %.3f s" % (time.time() - t0)) + + if max_add > 0: + # it does not contain all the vectors + gpu_index = None + + return gpu_index, indexall + +def compute_populated_index_2(preproc): + + indexall = prepare_trained_index(preproc) + + # set up a 3-stage pipeline that does: + # - stage 1: load + preproc + # - stage 2: assign on GPU + # - stage 3: add to index + + stage1 = dataset_iterator(xb, preproc, add_batch_size) + + vres, vdev = make_vres_vdev() + coarse_quantizer_gpu = faiss.index_cpu_to_gpu_multiple( + vres, vdev, indexall.quantizer) + + def quantize(args): + (i0, xs) = args + _, assign = coarse_quantizer_gpu.search(xs, 1) + return i0, xs, assign.ravel() + + stage2 = rate_limited_imap(quantize, stage1) + + print("add...") + t0 = time.time() + nb = xb.shape[0] + + for i0, xs, assign in stage2: + i1 = i0 + xs.shape[0] + if indexall.__class__ == faiss.IndexIVFPQ: + indexall.add_core_o(i1 - i0, faiss.swig_ptr(xs), + None, None, faiss.swig_ptr(assign)) + elif indexall.__class__ == faiss.IndexIVFFlat: + indexall.add_core(i1 - i0, faiss.swig_ptr(xs), None, + faiss.swig_ptr(assign)) + else: + assert False + + print('\r%d/%d (%.3f s) ' % ( + i0, nb, time.time() - t0), end=' ') + sys.stdout.flush() + print("Add time: %.3f s" % (time.time() - t0)) + + return None, indexall + + + +def get_populated_index(preproc): + + if not index_cachefile or not os.path.exists(index_cachefile): + if not altadd: + gpu_index, indexall = compute_populated_index(preproc) + else: + gpu_index, indexall = compute_populated_index_2(preproc) + if index_cachefile: + print("store", index_cachefile) + faiss.write_index(indexall, index_cachefile) + else: + print("load", index_cachefile) + indexall = faiss.read_index(index_cachefile) + gpu_index = None + + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = use_float16 + co.useFloat16CoarseQuantizer = False + co.usePrecomputed = use_precomputed_tables + co.indicesOptions = 0 + co.verbose = True + co.shard = True # the replicas will be made "manually" + t0 = time.time() + print("CPU index contains %d vectors, move to GPU" % indexall.ntotal) + if replicas == 1: + + if not gpu_index: + print("copying loaded index to GPUs") + vres, vdev = make_vres_vdev() + index = faiss.index_cpu_to_gpu_multiple( + vres, vdev, indexall, co) + else: + index = gpu_index + + else: + del gpu_index # We override the GPU index + + print("Copy CPU index to %d sharded GPU indexes" % replicas) + + index = faiss.IndexReplicas() + + for i in range(replicas): + gpu0 = ngpu * i / replicas + gpu1 = ngpu * (i + 1) / replicas + vres, vdev = make_vres_vdev(gpu0, gpu1) + + print(" dispatch to GPUs %d:%d" % (gpu0, gpu1)) + + index1 = faiss.index_cpu_to_gpu_multiple( + vres, vdev, indexall, co) + index1.this.disown() + index.addIndex(index1) + index.own_fields = True + del indexall + print("move to GPU done in %.3f s" % (time.time() - t0)) + return index + + + +################################################################# +# Perform search +################################################################# + + +def eval_dataset(index, preproc): + + ps = faiss.GpuParameterSpace() + ps.initialize(index) + + nq_gt = gt_I.shape[0] + print("search...") + sl = query_batch_size + nq = xq.shape[0] + for nprobe in nprobes: + ps.set_index_parameter(index, 'nprobe', nprobe) + t0 = time.time() + + if sl == 0: + D, I = index.search(preproc.apply_py(sanitize(xq)), nnn) + else: + I = np.empty((nq, nnn), dtype='int32') + D = np.empty((nq, nnn), dtype='float32') + + inter_res = '' + + for i0, xs in dataset_iterator(xq, preproc, sl): + print('\r%d/%d (%.3f s%s) ' % ( + i0, nq, time.time() - t0, inter_res), end=' ') + sys.stdout.flush() + + i1 = i0 + xs.shape[0] + Di, Ii = index.search(xs, nnn) + + I[i0:i1] = Ii + D[i0:i1] = Di + + if knngraph and not inter_res and i1 >= nq_gt: + ires = eval_intersection_measure( + gt_I[:, :nnn], I[:nq_gt]) + inter_res = ', %.4f' % ires + + t1 = time.time() + if knngraph: + ires = eval_intersection_measure(gt_I[:, :nnn], I[:nq_gt]) + print(" probe=%-3d: %.3f s rank-%d intersection results: %.4f" % ( + nprobe, t1 - t0, nnn, ires)) + else: + print(" probe=%-3d: %.3f s" % (nprobe, t1 - t0), end=' ') + gtc = gt_I[:, :1] + nq = xq.shape[0] + for rank in 1, 10, 100: + if rank > nnn: continue + nok = (I[:, :rank] == gtc).sum() + print("1-R@%d: %.4f" % (rank, nok / float(nq)), end=' ') + print() + if I_fname: + I_fname_i = I_fname % I + print("storing", I_fname_i) + np.save(I, I_fname_i) + if D_fname: + D_fname_i = I_fname % I + print("storing", D_fname_i) + np.save(D, D_fname_i) + + +################################################################# +# Driver +################################################################# + + +preproc = get_preprocessor() + +index = get_populated_index(preproc) + +eval_dataset(index, preproc) + +# make sure index is deleted before the resources +del index diff --git a/core/src/index/thirdparty/faiss/benchs/bench_gpu_sift1m.py b/core/src/index/thirdparty/faiss/benchs/bench_gpu_sift1m.py new file mode 100644 index 0000000000..76c312b5c5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_gpu_sift1m.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import os +import time +import numpy as np +import pdb + +import faiss +from datasets import load_sift1M, evaluate + + +print("load data") + +xb, xq, xt, gt = load_sift1M() +nq, d = xq.shape + +# we need only a StandardGpuResources per GPU +res = faiss.StandardGpuResources() + + +################################################################# +# Exact search experiment +################################################################# + +print("============ Exact search") + +flat_config = faiss.GpuIndexFlatConfig() +flat_config.device = 0 + +index = faiss.GpuIndexFlatL2(res, d, flat_config) + +print("add vectors to index") + +index.add(xb) + +print("warmup") + +index.search(xq, 123) + +print("benchmark") + +for lk in range(11): + k = 1 << lk + t, r = evaluate(index, xq, gt, k) + + # the recall should be 1 at all times + print("k=%d %.3f ms, R@1 %.4f" % (k, t, r[1])) + + +################################################################# +# Approximate search experiment +################################################################# + +print("============ Approximate search") + +index = faiss.index_factory(d, "IVF4096,PQ64") + +# faster, uses more memory +# index = faiss.index_factory(d, "IVF16384,Flat") + +co = faiss.GpuClonerOptions() + +# here we are using a 64-byte PQ, so we must set the lookup tables to +# 16 bit float (this is due to the limited temporary memory). +co.useFloat16 = True + +index = faiss.index_cpu_to_gpu(res, 0, index, co) + +print("train") + +index.train(xt) + +print("add vectors to index") + +index.add(xb) + +print("warmup") + +index.search(xq, 123) + +print("benchmark") + +for lnprobe in range(10): + nprobe = 1 << lnprobe + index.setNumProbes(nprobe) + t, r = evaluate(index, xq, gt, 100) + + print("nprobe=%4d %.3f ms recalls= %.4f %.4f %.4f" % (nprobe, t, r[1], r[10], r[100])) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_hnsw.py b/core/src/index/thirdparty/faiss/benchs/bench_hnsw.py new file mode 100644 index 0000000000..dea13da8c2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_hnsw.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import time +import sys +import numpy as np +import faiss +from datasets import load_sift1M + + +k = int(sys.argv[1]) +todo = sys.argv[1:] + +print("load data") +xb, xq, xt, gt = load_sift1M() +nq, d = xq.shape + +if todo == []: + todo = 'hnsw hnsw_sq ivf ivf_hnsw_quantizer kmeans kmeans_hnsw'.split() + + +def evaluate(index): + # for timing with a single core + # faiss.omp_set_num_threads(1) + + t0 = time.time() + D, I = index.search(xq, k) + t1 = time.time() + + missing_rate = (I == -1).sum() / float(k * nq) + recall_at_1 = (I == gt[:, :1]).sum() / float(nq) + print("\t %7.3f ms per query, R@1 %.4f, missing rate %.4f" % ( + (t1 - t0) * 1000.0 / nq, recall_at_1, missing_rate)) + + +if 'hnsw' in todo: + + print("Testing HNSW Flat") + + index = faiss.IndexHNSWFlat(d, 32) + + # training is not needed + + # this is the default, higher is more accurate and slower to + # construct + index.hnsw.efConstruction = 40 + + print("add") + # to see progress + index.verbose = True + index.add(xb) + + print("search") + for efSearch in 16, 32, 64, 128, 256: + for bounded_queue in [True, False]: + print("efSearch", efSearch, "bounded queue", bounded_queue, end=' ') + index.hnsw.search_bounded_queue = bounded_queue + index.hnsw.efSearch = efSearch + evaluate(index) + +if 'hnsw_sq' in todo: + + print("Testing HNSW with a scalar quantizer") + # also set M so that the vectors and links both use 128 bytes per + # entry (total 256 bytes) + index = faiss.IndexHNSWSQ(d, faiss.ScalarQuantizer.QT_8bit, 16) + + print("training") + # training for the scalar quantizer + index.train(xt) + + # this is the default, higher is more accurate and slower to + # construct + index.hnsw.efConstruction = 40 + + print("add") + # to see progress + index.verbose = True + index.add(xb) + + print("search") + for efSearch in 16, 32, 64, 128, 256: + print("efSearch", efSearch, end=' ') + index.hnsw.efSearch = efSearch + evaluate(index) + +if 'ivf' in todo: + + print("Testing IVF Flat (baseline)") + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 16384) + index.cp.min_points_per_centroid = 5 # quiet warning + + # to see progress + index.verbose = True + + print("training") + index.train(xt) + + print("add") + index.add(xb) + + print("search") + for nprobe in 1, 4, 16, 64, 256: + print("nprobe", nprobe, end=' ') + index.nprobe = nprobe + evaluate(index) + +if 'ivf_hnsw_quantizer' in todo: + + print("Testing IVF Flat with HNSW quantizer") + quantizer = faiss.IndexHNSWFlat(d, 32) + index = faiss.IndexIVFFlat(quantizer, d, 16384) + index.cp.min_points_per_centroid = 5 # quiet warning + index.quantizer_trains_alone = 2 + + # to see progress + index.verbose = True + + print("training") + index.train(xt) + + print("add") + index.add(xb) + + print("search") + quantizer.hnsw.efSearch = 64 + for nprobe in 1, 4, 16, 64, 256: + print("nprobe", nprobe, end=' ') + index.nprobe = nprobe + evaluate(index) + +# Bonus: 2 kmeans tests + +if 'kmeans' in todo: + print("Performing kmeans on sift1M database vectors (baseline)") + clus = faiss.Clustering(d, 16384) + clus.verbose = True + clus.niter = 10 + index = faiss.IndexFlatL2(d) + clus.train(xb, index) + + +if 'kmeans_hnsw' in todo: + print("Performing kmeans on sift1M using HNSW assignment") + clus = faiss.Clustering(d, 16384) + clus.verbose = True + clus.niter = 10 + index = faiss.IndexHNSWFlat(d, 32) + # increase the default efSearch, otherwise the number of empty + # clusters is too high. + index.hnsw.efSearch = 128 + clus.train(xb, index) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_index_pq.py b/core/src/index/thirdparty/faiss/benchs/bench_index_pq.py new file mode 100644 index 0000000000..4fd5ccfeb0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_index_pq.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import print_function +import faiss +from datasets import load_sift1M, evaluate + +xb, xq, xt, gt = load_sift1M() +nq, d = xq.shape + +k = 32 + +for nbits in 4, 6, 8, 10, 12: + index = faiss.IndexPQ(d, 8, nbits) + index.train(xt) + index.add(xb) + + t, r = evaluate(index, xq, gt, k) + print("\t %7.3f ms per query, R@1 %.4f" % (t, r[1])) + del index diff --git a/core/src/index/thirdparty/faiss/benchs/bench_pairwise_distances.py b/core/src/index/thirdparty/faiss/benchs/bench_pairwise_distances.py new file mode 100644 index 0000000000..bde8cc908e --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_pairwise_distances.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python3 + +"""small test script to benchmark the SIMD implementation of the +distance computations for the additional metrics. Call eg. with L1 to +get L1 distance computations. +""" + +import faiss + +import sys +import time + +d = 64 +nq = 4096 +nb = 16384 + +print("sample") + +xq = faiss.randn((nq, d), 123) +xb = faiss.randn((nb, d), 123) + +mt_name = "L2" if len(sys.argv) < 2 else sys.argv[1] + +mt = getattr(faiss, "METRIC_" + mt_name) + +print("distances") +t0 = time.time() +dis = faiss.pairwise_distances(xq, xb, mt) +t1 = time.time() + +print("nq=%d nb=%d d=%d %s: %.3f s" % (nq, nb, d, mt_name, t1 - t0)) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_polysemous_1bn.py b/core/src/index/thirdparty/faiss/benchs/bench_polysemous_1bn.py new file mode 100644 index 0000000000..0cf3b723a1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_polysemous_1bn.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import os +import sys +import time +import numpy as np +import re +import faiss +from multiprocessing.dummy import Pool as ThreadPool +from datasets import ivecs_read + + +# we mem-map the biggest files to avoid having them in memory all at +# once + + +def mmap_fvecs(fname): + x = np.memmap(fname, dtype='int32', mode='r') + d = x[0] + return x.view('float32').reshape(-1, d + 1)[:, 1:] + + +def mmap_bvecs(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +################################################################# +# Bookkeeping +################################################################# + + +dbname = sys.argv[1] +index_key = sys.argv[2] +parametersets = sys.argv[3:] + + +tmpdir = '/tmp/bench_polysemous' + +if not os.path.isdir(tmpdir): + print("%s does not exist, creating it" % tmpdir) + os.mkdir(tmpdir) + + +################################################################# +# Prepare dataset +################################################################# + + +print("Preparing dataset", dbname) + +if dbname.startswith('SIFT'): + # SIFT1M to SIFT1000M + dbsize = int(dbname[4:-1]) + xb = mmap_bvecs('bigann/bigann_base.bvecs') + xq = mmap_bvecs('bigann/bigann_query.bvecs') + xt = mmap_bvecs('bigann/bigann_learn.bvecs') + + # trim xb to correct size + xb = xb[:dbsize * 1000 * 1000] + + gt = ivecs_read('bigann/gnd/idx_%dM.ivecs' % dbsize) + +elif dbname == 'Deep1B': + xb = mmap_fvecs('deep1b/base.fvecs') + xq = mmap_fvecs('deep1b/deep1B_queries.fvecs') + xt = mmap_fvecs('deep1b/learn.fvecs') + # deep1B's train is is outrageously big + xt = xt[:10 * 1000 * 1000] + gt = ivecs_read('deep1b/deep1B_groundtruth.ivecs') + +else: + print('unknown dataset', dbname, file=sys.stderr) + sys.exit(1) + + +print("sizes: B %s Q %s T %s gt %s" % ( + xb.shape, xq.shape, xt.shape, gt.shape)) + +nq, d = xq.shape +nb, d = xb.shape +assert gt.shape[0] == nq + + +################################################################# +# Training +################################################################# + + +def choose_train_size(index_key): + + # some training vectors for PQ and the PCA + n_train = 256 * 1000 + + if "IVF" in index_key: + matches = re.findall('IVF([0-9]+)', index_key) + ncentroids = int(matches[0]) + n_train = max(n_train, 100 * ncentroids) + elif "IMI" in index_key: + matches = re.findall('IMI2x([0-9]+)', index_key) + nbit = int(matches[0]) + n_train = max(n_train, 256 * (1 << nbit)) + return n_train + + +def get_trained_index(): + filename = "%s/%s_%s_trained.index" % ( + tmpdir, dbname, index_key) + + if not os.path.exists(filename): + index = faiss.index_factory(d, index_key) + + n_train = choose_train_size(index_key) + + xtsub = xt[:n_train] + print("Keeping %d train vectors" % xtsub.shape[0]) + # make sure the data is actually in RAM and in float + xtsub = xtsub.astype('float32').copy() + index.verbose = True + + t0 = time.time() + index.train(xtsub) + index.verbose = False + print("train done in %.3f s" % (time.time() - t0)) + print("storing", filename) + faiss.write_index(index, filename) + else: + print("loading", filename) + index = faiss.read_index(filename) + return index + + +################################################################# +# Adding vectors to dataset +################################################################# + +def rate_limited_imap(f, l): + 'a thread pre-processes the next element' + pool = ThreadPool(1) + res = None + for i in l: + res_next = pool.apply_async(f, (i, )) + if res: + yield res.get() + res = res_next + yield res.get() + + +def matrix_slice_iterator(x, bs): + " iterate over the lines of x in blocks of size bs" + nb = x.shape[0] + block_ranges = [(i0, min(nb, i0 + bs)) + for i0 in range(0, nb, bs)] + + return rate_limited_imap( + lambda i01: x[i01[0]:i01[1]].astype('float32').copy(), + block_ranges) + + +def get_populated_index(): + + filename = "%s/%s_%s_populated.index" % ( + tmpdir, dbname, index_key) + + if not os.path.exists(filename): + index = get_trained_index() + i0 = 0 + t0 = time.time() + for xs in matrix_slice_iterator(xb, 100000): + i1 = i0 + xs.shape[0] + print('\radd %d:%d, %.3f s' % (i0, i1, time.time() - t0), end=' ') + sys.stdout.flush() + index.add(xs) + i0 = i1 + print() + print("Add done in %.3f s" % (time.time() - t0)) + print("storing", filename) + faiss.write_index(index, filename) + else: + print("loading", filename) + index = faiss.read_index(filename) + return index + + +################################################################# +# Perform searches +################################################################# + +index = get_populated_index() + +ps = faiss.ParameterSpace() +ps.initialize(index) + +# make sure queries are in RAM +xq = xq.astype('float32').copy() + +# a static C++ object that collects statistics about searches +ivfpq_stats = faiss.cvar.indexIVFPQ_stats +ivf_stats = faiss.cvar.indexIVF_stats + + +if parametersets == ['autotune'] or parametersets == ['autotuneMT']: + + if parametersets == ['autotune']: + faiss.omp_set_num_threads(1) + + # setup the Criterion object: optimize for 1-R@1 + crit = faiss.OneRecallAtRCriterion(nq, 1) + # by default, the criterion will request only 1 NN + crit.nnn = 100 + crit.set_groundtruth(None, gt.astype('int64')) + + # then we let Faiss find the optimal parameters by itself + print("exploring operating points") + + t0 = time.time() + op = ps.explore(index, xq, crit) + print("Done in %.3f s, available OPs:" % (time.time() - t0)) + + # opv is a C++ vector, so it cannot be accessed like a Python array + opv = op.optimal_pts + print("%-40s 1-R@1 time" % "Parameters") + for i in range(opv.size()): + opt = opv.at(i) + print("%-40s %.4f %7.3f" % (opt.key, opt.perf, opt.t)) + +else: + + # we do queries in a single thread + faiss.omp_set_num_threads(1) + + print(' ' * len(parametersets[0]), '\t', 'R@1 R@10 R@100 time %pass') + + for param in parametersets: + print(param, '\t', end=' ') + sys.stdout.flush() + ps.set_index_parameters(index, param) + t0 = time.time() + ivfpq_stats.reset() + ivf_stats.reset() + D, I = index.search(xq, 100) + t1 = time.time() + for rank in 1, 10, 100: + n_ok = (I[:, :rank] == gt[:, :1]).sum() + print("%.4f" % (n_ok / float(nq)), end=' ') + print("%8.3f " % ((t1 - t0) * 1000.0 / nq), end=' ') + print("%5.2f" % (ivfpq_stats.n_hamming_pass * 100.0 / ivf_stats.ndis)) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_polysemous_sift1m.py b/core/src/index/thirdparty/faiss/benchs/bench_polysemous_sift1m.py new file mode 100644 index 0000000000..f54c66bc2b --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_polysemous_sift1m.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import time +import numpy as np + +import faiss +from datasets import load_sift1M, evaluate + + +print("load data") +xb, xq, xt, gt = load_sift1M() +nq, d = xq.shape + +# index with 16 subquantizers, 8 bit each +index = faiss.IndexPQ(d, 16, 8) +index.do_polysemous_training = True +index.verbose = True + +print("train") + +index.train(xt) + +print("add vectors to index") + +index.add(xb) + +nt = 1 +faiss.omp_set_num_threads(1) + + +print("PQ baseline", end=' ') +index.search_type = faiss.IndexPQ.ST_PQ +t, r = evaluate(index, xq, gt, 1) +print("\t %7.3f ms per query, R@1 %.4f" % (t, r[1])) + +for ht in 64, 62, 58, 54, 50, 46, 42, 38, 34, 30: + print("Polysemous", ht, end=' ') + index.search_type = faiss.IndexPQ.ST_polysemous + index.polysemous_ht = ht + t, r = evaluate(index, xq, gt, 1) + print("\t %7.3f ms per query, R@1 %.4f" % (t, r[1])) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_scalar_quantizer.py b/core/src/index/thirdparty/faiss/benchs/bench_scalar_quantizer.py new file mode 100644 index 0000000000..a990b485f1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_scalar_quantizer.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import time +import numpy as np +import faiss +from datasets import load_sift1M + + +print("load data") + +xb, xq, xt, gt = load_sift1M() +nq, d = xq.shape + +ncent = 256 + +variants = [(name, getattr(faiss.ScalarQuantizer, name)) + for name in dir(faiss.ScalarQuantizer) + if name.startswith('QT_')] + +quantizer = faiss.IndexFlatL2(d) +# quantizer.add(np.zeros((1, d), dtype='float32')) + +if False: + for name, qtype in [('flat', 0)] + variants: + + print("============== test", name) + t0 = time.time() + + if name == 'flat': + index = faiss.IndexIVFFlat(quantizer, d, ncent, + faiss.METRIC_L2) + else: + index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent, + qtype, faiss.METRIC_L2) + + index.nprobe = 16 + print("[%.3f s] train" % (time.time() - t0)) + index.train(xt) + print("[%.3f s] add" % (time.time() - t0)) + index.add(xb) + print("[%.3f s] search" % (time.time() - t0)) + D, I = index.search(xq, 100) + print("[%.3f s] eval" % (time.time() - t0)) + + for rank in 1, 10, 100: + n_ok = (I[:, :rank] == gt[:, :1]).sum() + print("%.4f" % (n_ok / float(nq)), end=' ') + print() + +if True: + for name, qtype in variants: + + print("============== test", name) + + for rsname, vals in [('RS_minmax', + [-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]), + ('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]), + ('RS_quantiles', [0.02, 0.05, 0.1, 0.15]), + ('RS_optim', [0.0])]: + for val in vals: + print("%-15s %5g " % (rsname, val), end=' ') + index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent, + qtype, faiss.METRIC_L2) + index.nprobe = 16 + index.sq.rangestat = getattr(faiss.ScalarQuantizer, + rsname) + + index.rangestat_arg = val + + index.train(xt) + index.add(xb) + t0 = time.time() + D, I = index.search(xq, 100) + t1 = time.time() + + for rank in 1, 10, 100: + n_ok = (I[:, :rank] == gt[:, :1]).sum() + print("%.4f" % (n_ok / float(nq)), end=' ') + print(" %.3f s" % (t1 - t0)) diff --git a/core/src/index/thirdparty/faiss/benchs/bench_vector_ops.py b/core/src/index/thirdparty/faiss/benchs/bench_vector_ops.py new file mode 100644 index 0000000000..331a9923e2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/bench_vector_ops.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +from __future__ import print_function +import numpy as np +import faiss +import time + +swig_ptr = faiss.swig_ptr + +if False: + a = np.arange(10, 14).astype('float32') + b = np.arange(20, 24).astype('float32') + + faiss.fvec_inner_product (swig_ptr(a), swig_ptr(b), 4) + + 1/0 + +xd = 100 +yd = 1000000 + +np.random.seed(1234) + +faiss.omp_set_num_threads(1) + +print('xd=%d yd=%d' % (xd, yd)) + +print('Running inner products test..') +for d in 3, 4, 12, 36, 64: + + x = faiss.rand(xd * d).reshape(xd, d) + y = faiss.rand(yd * d).reshape(yd, d) + + distances = np.empty((xd, yd), dtype='float32') + + t0 = time.time() + for i in range(xd): + faiss.fvec_inner_products_ny(swig_ptr(distances[i]), + swig_ptr(x[i]), + swig_ptr(y), + d, yd) + t1 = time.time() + + # sparse verification + ntry = 100 + num, denom = 0, 0 + for t in range(ntry): + xi = np.random.randint(xd) + yi = np.random.randint(yd) + num += abs(distances[xi, yi] - np.dot(x[xi], y[yi])) + denom += abs(distances[xi, yi]) + + print('d=%d t=%.3f s diff=%g' % (d, t1 - t0, num / denom)) + + +print('Running L2sqr test..') +for d in 3, 4, 12, 36, 64: + + x = faiss.rand(xd * d).reshape(xd, d) + y = faiss.rand(yd * d).reshape(yd, d) + + distances = np.empty((xd, yd), dtype='float32') + + t0 = time.time() + for i in range(xd): + faiss.fvec_L2sqr_ny(swig_ptr(distances[i]), + swig_ptr(x[i]), + swig_ptr(y), + d, yd) + t1 = time.time() + + # sparse verification + ntry = 100 + num, denom = 0, 0 + for t in range(ntry): + xi = np.random.randint(xd) + yi = np.random.randint(yd) + num += abs(distances[xi, yi] - np.sum((x[xi] - y[yi]) ** 2)) + denom += abs(distances[xi, yi]) + + print('d=%d t=%.3f s diff=%g' % (d, t1 - t0, num / denom)) diff --git a/core/src/index/thirdparty/faiss/benchs/datasets.py b/core/src/index/thirdparty/faiss/benchs/datasets.py new file mode 100644 index 0000000000..3971f278f9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/datasets.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import print_function +import sys +import time +import numpy as np + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +def load_sift1M(): + print("Loading sift1M...", end='', file=sys.stderr) + xt = fvecs_read("sift1M/sift_learn.fvecs") + xb = fvecs_read("sift1M/sift_base.fvecs") + xq = fvecs_read("sift1M/sift_query.fvecs") + gt = ivecs_read("sift1M/sift_groundtruth.ivecs") + print("done", file=sys.stderr) + + return xb, xq, xt, gt + + +def evaluate(index, xq, gt, k): + nq = xq.shape[0] + t0 = time.time() + D, I = index.search(xq, k) # noqa: E741 + t1 = time.time() + + recalls = {} + i = 1 + while i <= k: + recalls[i] = (I[:, :i] == gt[:, :1]).sum() / float(nq) + i *= 10 + + return (t1 - t0) * 1000.0 / nq, recalls diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/README.md b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/README.md new file mode 100644 index 0000000000..c2c792992b --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/README.md @@ -0,0 +1,198 @@ + +# Distributed on-disk index for 1T-scale datasets + +This is code corresponding to the description in [Indexing 1T vectors](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors). +All the code is in python 3 (and not compatible with Python 2). +The current code uses the Deep1B dataset for demonstration purposes, but can scale to 1000x larger. +To run it, download the Deep1B dataset as explained [here](../#getting-deep1b), and edit paths to the dataset in the scripts. + +The cluster commands are written for the Slurm batch scheduling system. +Hopefully, changing to another type of scheduler should be quite straightforward. + +## Distributed k-means + +To cluster 500M vectors to 10M centroids, it is useful to have a distriubuted k-means implementation. +The distribution simply consists in splitting the training vectors across machines (servers) and have them do the assignment. +The master/client then synthesizes the results and updates the centroids. + +The distributed k-means implementation here is based on 3 files: + +- [`rpc.py`](rpc.py) is a very simple remote procedure call implementation based on sockets and pickle. +It exposes the methods of an object on the server side so that they can be called from the client as if the object was local. + +- [`distributed_kmeans.py`](distributed_kmeans.py) contains the k-means implementation. +The main loop of k-means is re-implemented in python but follows closely the Faiss C++ implementation, and should not be significantly less efficient. +It relies on a `DatasetAssign` object that does the assignement to centrtoids, which is the bulk of the computation. +The object can be a Faiss CPU index, a GPU index or a set of remote GPU or CPU indexes. + +- [`run_on_cluster.bash`](run_on_cluster.bash) contains the shell code to run the distributed k-means on a cluster. + +The distributed k-means works with a Python install that contains faiss and scipy (for sparse matrices). +It clusters the training data of Deep1B, this can be changed easily to any file in fvecs, bvecs or npy format that contains the training set. +The training vectors may be too large to fit in RAM, but they are memory-mapped so that should not be a problem. +The file is also assumed to be accessible from all server machines with eg. a distributed file system. + +### Local tests + +Edit `distibuted_kmeans.py` to point `testdata` to your local copy of the dataset. + +Then, 4 levels of sanity check can be run: +```bash +# reference Faiss C++ run +python distributed_kmeans.py --test 0 +# using the Python implementation +python distributed_kmeans.py --test 1 +# use the dispatch object (on local datasets) +python distributed_kmeans.py --test 2 +# same, with GPUs +python distributed_kmeans.py --test 3 +``` +The output should look like [This gist](https://gist.github.com/mdouze/ffa01fe666a9325761266fe55ead72ad). + +### Distributed sanity check + +To run the distributed k-means, `distibuted_kmeans.py` has to be run both on the servers (`--server` option) and client sides (`--client` option). +Edit the top of `run_on_cluster.bash` to set the path of the data to cluster. + +Sanity checks can be run with +```bash +# non distributed baseline +bash run_on_cluster.bash test_kmeans_0 +# using all the machine's GPUs +bash run_on_cluster.bash test_kmeans_1 +# distrbuted run, with one local server per GPU +bash run_on_cluster.bash test_kmeans_2 +``` +The test `test_kmeans_2` simulates a distributed run on a single machine by starting one server process per GPU and connecting to the servers via the rpc protocol. +The output should look like [this gist](https://gist.github.com/mdouze/5b2dc69b74579ecff04e1686a277d32e). + + + +### Distributed run + +The way the script can be distributed depends on the cluster's scheduling system. +Here we use Slurm, but it should be relatively easy to adapt to any scheduler that can allocate a set of matchines and start the same exectuable on all of them. + +The command +``` +bash run_on_cluster.bash slurm_distributed_kmeans +``` +asks SLURM for 5 machines with 4 GPUs each with the `srun` command. +All 5 machines run the script with the `slurm_within_kmeans_server` option. +They determine the number of servers and their own server id via the `SLURM_NPROCS` and `SLURM_PROCID` environment variables. + +All machines start `distributed_kmeans.py` in server mode for the slice of the dataset they are responsible for. + +In addition, the machine #0 also starts the client. +The client knows who are the other servers via the variable `SLURM_JOB_NODELIST`. +It connects to all clients and performs the clustering. + +The output should look like [this gist](https://gist.github.com/mdouze/8d25e89fb4af5093057cae0f917da6cd). + +### Run used for deep1B + +For the real run, we run the clustering on 50M vectors to 1M centroids. +This is just a matter of using as many machines / GPUs as possible in setting the output centroids with the `--out filename` option. +Then run +``` +bash run_on_cluster.bash deep1b_clustering +``` + +The last lines of output read like: +``` + Iteration 19 (898.92 s, search 875.71 s): objective=1.33601e+07 imbalance=1.303 nsplit=0 + 0: writing centroids to /checkpoint/matthijs/ondisk_distributed/1M_centroids.npy +``` + +This means that the total training time was 899s, of which 876s were used for computation. +However, the computation includes the I/O overhead to the assignment servers. +In this implementation, the overhead of transmitting the data is non-negligible and so is the centroid computation stage. +This is due to the inefficient Python implementation and the RPC protocol that is not optimized for broadcast / gather (like MPI). +However, it is a simple implementation that should run on most clusters. + +## Making the trained index + +After the centroids are obtained, an empty trained index must be constructed. +This is done by: + +- applying a pre-processing stage (a random rotation) to balance the dimensions of the vectors. This can be done after clustering, the clusters are just rotated as well. + +- wrapping the centroids into a HNSW index to speed up the CPU-based assignment of vectors + +- training the 6-bit scalar quantizer used to encode the vectors + +This is performed by the script [`make_trained_index.py`](make_trained_index.py). + +## Building the index by slices + +We call the slices "vslisces" as they are vertical slices of the big matrix, see explanation in the wiki section [Split across datanbase partitions](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors#split-across-database-partitions). + +The script [make_index_vslice.py](make_index_vslice.py) makes an index for a subset of the vectors of the input data and stores it as an independent index. +There are 200 slices of 5M vectors each for Deep1B. +It can be run in a brute-force parallel fashion, there is no constraint on ordering. +To run the script in parallel on a slurm cluster, use: +``` +bash run_on_cluster.bash make_index_vslices +``` +For a real dataset, the data would be read from a DBMS. +In that case, reading the data and indexing it in parallel is worthwhile because reading is very slow. + +## Splitting accross inverted lists + +The 200 slices need to be merged together. +This is done with the script [merge_to_ondisk.py](merge_to_ondisk.py), that memory maps the 200 vertical slice indexes, extracts a subset of the inverted lists and writes them to a contiguous horizontal slice. +We slice the inverted lists into 50 horizontal slices. +This is run with +``` +bash run_on_cluster.bash make_index_hslices +``` + +## Querying the index + +At this point the index is ready. +The horizontal slices need to be loaded in the right order and combined into an index to be usable. +This is done in the [combined_index.py](combined_index.py) script. +It provides a `CombinedIndexDeep1B` object that contains an index object that can be searched. +To test, run: +``` +python combined_index.py +``` +The output should look like: +``` +(faiss_1.5.2) matthijs@devfair0144:~/faiss_versions/faiss_1Tcode/faiss/benchs/distributed_ondisk$ python combined_index.py +reading /checkpoint/matthijs/ondisk_distributed//hslices/slice49.faissindex +loading empty index /checkpoint/matthijs/ondisk_distributed/trained.faissindex +replace invlists +loaded index of size 1000000000 +nprobe=1 1-recall@1=0.2904 t=12.35s +nnprobe=10 1-recall@1=0.6499 t=17.67s +nprobe=100 1-recall@1=0.8673 t=29.23s +nprobe=1000 1-recall@1=0.9132 t=129.58s +``` +ie. searching is a lot slower than from RAM. + +## Distributed query + +To reduce the bandwidth required from the machine that does the queries, it is possible to split the search accross several search servers. +This way, only the effective results are returned to the main machine. + +The search client and server are implemented in [`search_server.py`](search_server.py). +It can be used as a script to start a search server for `CombinedIndexDeep1B` or as a module to load the clients. + +The search servers can be started with +``` +bash run_on_cluster.bash run_search_servers +``` +(adjust to the number of servers that can be used). + +Then an example of search client is [`distributed_query_demo.py`](distributed_query_demo.py). +It connects to the servers and assigns subsets of inverted lists to visit to each of them. + +A typical output is [this gist](https://gist.github.com/mdouze/1585b9854a9a2437d71f2b2c3c05c7c5). +The number in MiB indicates the amount of data that is read from disk to perform the search. +In this case, the scale of the dataset is too small for the distributed search to have much impact, but on datasets > 10x larger, the difference becomes more significant. + +## Conclusion + +This code contains the core components to make an index that scales up to 1T vectors. +There are a few simplifications wrt. the index that was effectively used in [Indexing 1T vectors](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors). diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/combined_index.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/combined_index.py new file mode 100644 index 0000000000..3df2a0180a --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/combined_index.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import os +import faiss +import numpy as np + + +class CombinedIndex: + """ + combines a set of inverted lists into a hstack + masks part of those lists + adds these inverted lists to an empty index that contains + the info on how to perform searches + """ + + def __init__(self, invlist_fnames, empty_index_fname, + masked_index_fname=None): + + self.indexes = indexes = [] + ilv = faiss.InvertedListsPtrVector() + + for fname in invlist_fnames: + if os.path.exists(fname): + print('reading', fname, end='\r', flush=True) + index = faiss.read_index(fname) + indexes.append(index) + il = faiss.extract_index_ivf(index).invlists + else: + raise AssertionError + ilv.push_back(il) + print() + + self.big_il = faiss.VStackInvertedLists(ilv.size(), ilv.data()) + if masked_index_fname: + self.big_il_base = self.big_il + print('loading', masked_index_fname) + self.masked_index = faiss.read_index( + masked_index_fname, + faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) + self.big_il = faiss.MaskedInvertedLists( + faiss.extract_index_ivf(self.masked_index).invlists, + self.big_il_base) + + print('loading empty index', empty_index_fname) + self.index = faiss.read_index(empty_index_fname) + ntotal = self.big_il.compute_ntotal() + + print('replace invlists') + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.replace_invlists(self.big_il, False) + index_ivf.ntotal = self.index.ntotal = ntotal + index_ivf.parallel_mode = 1 # seems reasonable to do this all the time + + quantizer = faiss.downcast_index(index_ivf.quantizer) + quantizer.hnsw.efSearch = 1024 + + ############################################################ + # Expose fields and functions of the index as methods so that they + # can be called by RPC + + def search(self, x, k): + return self.index.search(x, k) + + def range_search(self, x, radius): + return self.index.range_search(x, radius) + + def transform_and_assign(self, xq): + index = self.index + + if isinstance(index, faiss.IndexPreTransform): + assert index.chain.size() == 1 + vt = index.chain.at(0) + xq = vt.apply_py(xq) + + # perform quantization + index_ivf = faiss.extract_index_ivf(index) + quantizer = index_ivf.quantizer + coarse_dis, list_nos = quantizer.search(xq, index_ivf.nprobe) + return xq, list_nos, coarse_dis + + + def ivf_search_preassigned(self, xq, list_nos, coarse_dis, k): + index_ivf = faiss.extract_index_ivf(self.index) + n, d = xq.shape + assert d == index_ivf.d + n2, d2 = list_nos.shape + assert list_nos.shape == coarse_dis.shape + assert n2 == n + assert d2 == index_ivf.nprobe + D = np.empty((n, k), dtype='float32') + I = np.empty((n, k), dtype='int64') + index_ivf.search_preassigned( + n, faiss.swig_ptr(xq), k, + faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis), + faiss.swig_ptr(D), faiss.swig_ptr(I), False) + return D, I + + + def ivf_range_search_preassigned(self, xq, list_nos, coarse_dis, radius): + index_ivf = faiss.extract_index_ivf(self.index) + n, d = xq.shape + assert d == index_ivf.d + n2, d2 = list_nos.shape + assert list_nos.shape == coarse_dis.shape + assert n2 == n + assert d2 == index_ivf.nprobe + res = faiss.RangeSearchResult(n) + + index_ivf.range_search_preassigned( + n, faiss.swig_ptr(xq), radius, + faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis), + res) + + lims = faiss.rev_swig_ptr(res.lims, n + 1).copy() + nd = int(lims[-1]) + D = faiss.rev_swig_ptr(res.distances, nd).copy() + I = faiss.rev_swig_ptr(res.labels, nd).copy() + return lims, D, I + + def set_nprobe(self, nprobe): + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.nprobe = nprobe + + def set_parallel_mode(self, pm): + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.parallel_mode = pm + + def get_ntotal(self): + return self.index.ntotal + + def set_prefetch_nthread(self, nt): + for idx in self.indexes: + il = faiss.downcast_InvertedLists( + faiss.extract_index_ivf(idx).invlists) + il.prefetch_nthread + il.prefetch_nthread = nt + + def set_omp_num_threads(self, nt): + faiss.omp_set_num_threads(nt) + +class CombinedIndexDeep1B(CombinedIndex): + """ loads a CombinedIndex with the data from the big photodna index """ + + def __init__(self): + # set some paths + workdir = "/checkpoint/matthijs/ondisk_distributed/" + + # empty index with the proper quantizer + indexfname = workdir + 'trained.faissindex' + + # index that has some invlists that override the big one + masked_index_fname = None + invlist_fnames = [ + '%s/hslices/slice%d.faissindex' % (workdir, i) + for i in range(50) + ] + CombinedIndex.__init__(self, invlist_fnames, indexfname, masked_index_fname) + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +if __name__ == '__main__': + import time + ci = CombinedIndexDeep1B() + print('loaded index of size ', ci.index.ntotal) + + deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" + + xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs") + gt_fname = deep1bdir + "deep1B_groundtruth.ivecs" + gt = ivecs_read(gt_fname) + + for nprobe in 1, 10, 100, 1000: + ci.set_nprobe(nprobe) + t0 = time.time() + D, I = ci.search(xq, 100) + t1 = time.time() + print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % ( + nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq), + t1 - t0 + )) diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_kmeans.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_kmeans.py new file mode 100644 index 0000000000..ae7a292d3d --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_kmeans.py @@ -0,0 +1,411 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python3 + +""" +Simple distributed kmeans implementation Relies on an abstraction +for the training matrix, that can be sharded over several machines. +""" + +import faiss +import time +import numpy as np +import sys +import pdb +import argparse + +from scipy.sparse import csc_matrix + +from multiprocessing.dummy import Pool as ThreadPool + +import rpc + + + + +class DatasetAssign: + """Wrapper for a matrix that offers a function to assign the vectors + to centroids. All other implementations offer the same interface""" + + def __init__(self, x): + self.x = np.ascontiguousarray(x, dtype='float32') + + def count(self): + return self.x.shape[0] + + def dim(self): + return self.x.shape[1] + + def get_subset(self, indices): + return self.x[indices] + + def perform_search(self, centroids): + index = faiss.IndexFlatL2(self.x.shape[1]) + index.add(centroids) + return index.search(self.x, 1) + + def assign_to(self, centroids, weights=None): + D, I = self.perform_search(centroids) + + I = I.ravel() + D = D.ravel() + n = len(self.x) + if weights is None: + weights = np.ones(n, dtype='float32') + nc = len(centroids) + m = csc_matrix((weights, I, np.arange(n + 1)), + shape=(nc, n)) + sum_per_centroid = m * self.x + + return I, D, sum_per_centroid + + +class DatasetAssignGPU(DatasetAssign): + """ GPU version of the previous """ + + def __init__(self, x, gpu_id, verbose=False): + DatasetAssign.__init__(self, x) + index = faiss.IndexFlatL2(x.shape[1]) + if gpu_id >= 0: + self.index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), + gpu_id, index) + else: + # -1 -> assign to all GPUs + self.index = faiss.index_cpu_to_all_gpus(index) + + + def perform_search(self, centroids): + self.index.reset() + self.index.add(centroids) + return self.index.search(self.x, 1) + + +class DatasetAssignDispatch: + """dispatches to several other DatasetAssigns and combines the + results""" + + def __init__(self, xes, in_parallel): + self.xes = xes + self.d = xes[0].dim() + if not in_parallel: + self.imap = map + else: + self.pool = ThreadPool(len(self.xes)) + self.imap = self.pool.imap + self.sizes = list(map(lambda x: x.count(), self.xes)) + self.cs = np.cumsum([0] + self.sizes) + + def count(self): + return self.cs[-1] + + def dim(self): + return self.d + + def get_subset(self, indices): + res = np.zeros((len(indices), self.d), dtype='float32') + nos = np.searchsorted(self.cs[1:], indices, side='right') + + def handle(i): + mask = nos == i + sub_indices = indices[mask] - self.cs[i] + subset = self.xes[i].get_subset(sub_indices) + res[mask] = subset + + list(self.imap(handle, range(len(self.xes)))) + return res + + def assign_to(self, centroids, weights=None): + src = self.imap( + lambda x: x.assign_to(centroids, weights), + self.xes + ) + I = [] + D = [] + sum_per_centroid = None + for Ii, Di, sum_per_centroid_i in src: + I.append(Ii) + D.append(Di) + if sum_per_centroid is None: + sum_per_centroid = sum_per_centroid_i + else: + sum_per_centroid += sum_per_centroid_i + return np.hstack(I), np.hstack(D), sum_per_centroid + + +def imbalance_factor(k , assign): + return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) + + +def reassign_centroids(hassign, centroids, rs=None): + """ reassign centroids when some of them collapse """ + if rs is None: + rs = np.random + k, d = centroids.shape + nsplit = 0 + empty_cents = np.where(hassign == 0)[0] + + if empty_cents.size == 0: + return 0 + + fac = np.ones(d) + fac[::2] += 1 / 1024. + fac[1::2] -= 1 / 1024. + + # this is a single pass unless there are more than k/2 + # empty centroids + while empty_cents.size > 0: + # choose which centroids to split + probas = hassign.astype('float') - 1 + probas[probas < 0] = 0 + probas /= probas.sum() + nnz = (probas > 0).sum() + + nreplace = min(nnz, empty_cents.size) + cjs = rs.choice(k, size=nreplace, p=probas) + + for ci, cj in zip(empty_cents[:nreplace], cjs): + + c = centroids[cj] + centroids[ci] = c * fac + centroids[cj] = c / fac + + hassign[ci] = hassign[cj] // 2 + hassign[cj] -= hassign[ci] + nsplit += 1 + + empty_cents = empty_cents[nreplace:] + + return nsplit + + +def kmeans(k, data, niter=25, seed=1234, checkpoint=None): + """Pure python kmeans implementation. Follows the Faiss C++ version + quite closely, but takes a DatasetAssign instead of a training data + matrix. Also redo is not implemented. """ + n, d = data.count(), data.dim() + + print(("Clustering %d points in %dD to %d clusters, " + + "%d iterations seed %d") % (n, d, k, niter, seed)) + + rs = np.random.RandomState(seed) + print("preproc...") + t0 = time.time() + # initialization + perm = rs.choice(n, size=k, replace=False) + centroids = data.get_subset(perm) + + print(" done") + t_search_tot = 0 + obj = [] + for i in range(niter): + t0s = time.time() + + print('assigning', end='\r', flush=True) + assign, D, sums = data.assign_to(centroids) + + print('compute centroids', end='\r', flush=True) + + # pdb.set_trace() + + t_search_tot += time.time() - t0s; + + err = D.sum() + obj.append(err) + + hassign = np.bincount(assign, minlength=k) + + fac = hassign.reshape(-1, 1).astype('float32') + fac[fac == 0] = 1 # quiet warning + + centroids = sums / fac + + nsplit = reassign_centroids(hassign, centroids, rs) + + print((" Iteration %d (%.2f s, search %.2f s): " + "objective=%g imbalance=%.3f nsplit=%d") % ( + i, (time.time() - t0), t_search_tot, + err, imbalance_factor (k, assign), + nsplit) + ) + + if checkpoint is not None: + print('storing centroids in', checkpoint) + np.save(checkpoint, centroids) + + return centroids + + +class AssignServer(rpc.Server): + """ Assign version that can be exposed via RPC """ + + def __init__(self, s, assign, log_prefix=''): + rpc.Server.__init__(self, s, log_prefix=log_prefix) + self.assign = assign + + def __getattr__(self, f): + return getattr(self.assign, f) + + + +def bvecs_mmap(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def do_test(todo): + testdata = '/datasets01_101/simsearch/041218/bigann/bigann_learn.bvecs' + + x = bvecs_mmap(testdata) + + # bad distribution to stress-test split code + xx = x[:100000].copy() + xx[:50000] = x[0] + + todo = sys.argv[1:] + + if "0" in todo: + # reference C++ run + km = faiss.Kmeans(x.shape[1], 1000, niter=20, verbose=True) + km.train(xx.astype('float32')) + + if "1" in todo: + # using the Faiss c++ implementation + data = DatasetAssign(xx) + kmeans(1000, data, 20) + + if "2" in todo: + # use the dispatch object (on local datasets) + data = DatasetAssignDispatch([ + DatasetAssign(xx[20000 * i : 20000 * (i + 1)]) + for i in range(5) + ], False + ) + kmeans(1000, data, 20) + + if "3" in todo: + # same, with GPU + ngpu = faiss.get_num_gpus() + print('using %d GPUs' % ngpu) + data = DatasetAssignDispatch([ + DatasetAssignGPU(xx[100000 * i // ngpu: 100000 * (i + 1) // ngpu], i) + for i in range(ngpu) + ], True + ) + kmeans(1000, data, 20) + + +def main(): + parser = argparse.ArgumentParser() + + def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + group = parser.add_argument_group('general options') + aa('--test', default='', help='perform tests (comma-separated numbers)') + + aa('--k', default=0, type=int, help='nb centroids') + aa('--seed', default=1234, type=int, help='random seed') + aa('--niter', default=20, type=int, help='nb iterations') + aa('--gpu', default=-2, type=int, help='GPU to use (-2:none, -1: all)') + + group = parser.add_argument_group('I/O options') + aa('--indata', default='', + help='data file to load (supported formats fvecs, bvecs, npy') + aa('--i0', default=0, type=int, help='first vector to keep') + aa('--i1', default=-1, type=int, help='last vec to keep + 1') + aa('--out', default='', help='file to store centroids') + aa('--store_each_iteration', default=False, action='store_true', + help='store centroid checkpoints') + + group = parser.add_argument_group('server options') + aa('--server', action='store_true', default=False, help='run server') + aa('--port', default=12345, type=int, help='server port') + aa('--when_ready', default=None, help='store host:port to this file when ready') + aa('--ipv4', default=False, action='store_true', help='force ipv4') + + group = parser.add_argument_group('client options') + aa('--client', action='store_true', default=False, help='run client') + aa('--servers', default='', help='list of server:port separated by spaces') + + args = parser.parse_args() + + if args.test: + do_test(args.test.split(',')) + return + + # prepare data matrix (either local or remote) + if args.indata: + print('loading ', args.indata) + if args.indata.endswith('.bvecs'): + x = bvecs_mmap(args.indata) + elif args.indata.endswith('.fvecs'): + x = fvecs_mmap(args.indata) + elif args.indata.endswith('.npy'): + x = np.load(args.indata, mmap_mode='r') + else: + raise AssertionError + + if args.i1 == -1: + args.i1 = len(x) + x = x[args.i0:args.i1] + if args.gpu == -2: + data = DatasetAssign(x) + else: + print('moving to GPU') + data = DatasetAssignGPU(x, args.gpu) + + elif args.client: + print('connecting to servers') + + def connect_client(hostport): + host, port = hostport.split(':') + port = int(port) + print('connecting %s:%d' % (host, port)) + client = rpc.Client(host, port, v6=not args.ipv4) + print('client %s:%d ready' % (host, port)) + return client + + hostports = args.servers.strip().split(' ') + # pool = ThreadPool(len(hostports)) + + data = DatasetAssignDispatch( + list(map(connect_client, hostports)), + True + ) + else: + raise AssertionError + + + if args.server: + print('starting server') + log_prefix = f"{rpc.socket.gethostname()}:{args.port}" + rpc.run_server( + lambda s: AssignServer(s, data, log_prefix=log_prefix), + args.port, report_to_file=args.when_ready, + v6=not args.ipv4) + + else: + print('running kmeans') + centroids = kmeans(args.k, data, niter=args.niter, seed=args.seed, + checkpoint=args.out if args.store_each_iteration else None) + if args.out != '': + print('writing centroids to', args.out) + np.save(args.out, centroids) + + +if __name__ == '__main__': + main() diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_query_demo.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_query_demo.py new file mode 100644 index 0000000000..9453c0ec27 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/distributed_query_demo.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import faiss +import numpy as np +import time +import rpc +import sys + +import combined_index +import search_server + +hostnames = sys.argv[1:] + +print("Load local index") +ci = combined_index.CombinedIndexDeep1B() + +print("connect to clients") +clients = [] +for host in hostnames: + client = rpc.Client(host, 12012, v6=False) + clients.append(client) + +# check if all servers respond +print("sizes seen by servers:", [cl.get_ntotal() for cl in clients]) + + +# aggregate all clients into a one that uses them all for speed +# note that it also requires a local index ci +sindex = search_server.SplitPerListIndex(ci, clients) +sindex.verbose = True + +# set reasonable parameters +ci.set_parallel_mode(1) +ci.set_prefetch_nthread(0) +ci.set_omp_num_threads(64) + +# initialize params +sindex.set_parallel_mode(1) +sindex.set_prefetch_nthread(0) +sindex.set_omp_num_threads(64) + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" + +xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs") +gt_fname = deep1bdir + "deep1B_groundtruth.ivecs" +gt = ivecs_read(gt_fname) + + +for nprobe in 1, 10, 100, 1000: + sindex.set_nprobe(nprobe) + t0 = time.time() + D, I = sindex.search(xq, 100) + t1 = time.time() + print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % ( + nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq), + t1 - t0 + )) diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_index_vslice.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_index_vslice.py new file mode 100644 index 0000000000..ca58425b25 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_index_vslice.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import numpy as np +import faiss +import argparse +from multiprocessing.dummy import Pool as ThreadPool + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def produce_batches(args): + + x = fvecs_mmap(args.input) + + if args.i1 == -1: + args.i1 = len(x) + + print("Iterating on vectors %d:%d from %s by batches of size %d" % ( + args.i0, args.i1, args.input, args.bs)) + + for j0 in range(args.i0, args.i1, args.bs): + j1 = min(j0 + args.bs, args.i1) + yield np.arange(j0, j1), x[j0:j1] + + +def rate_limited_iter(l): + 'a thread pre-processes the next element' + pool = ThreadPool(1) + res = None + + def next_or_None(): + try: + return next(l) + except StopIteration: + return None + + while True: + res_next = pool.apply_async(next_or_None) + if res is not None: + res = res.get() + if res is None: + return + yield res + res = res_next + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" +workdir = "/checkpoint/matthijs/ondisk_distributed/" + +def main(): + parser = argparse.ArgumentParser( + description='make index for a subset of the data') + + def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + group = parser.add_argument_group('index type') + aa('--inputindex', + default=workdir + 'trained.faissindex', + help='empty input index to fill in') + aa('--nt', default=-1, type=int, help='nb of openmp threads to use') + + group = parser.add_argument_group('db options') + aa('--input', default=deep1bdir + "base.fvecs") + aa('--bs', default=2**18, type=int, + help='batch size for db access') + aa('--i0', default=0, type=int, help='lower bound to index') + aa('--i1', default=-1, type=int, help='upper bound of vectors to index') + + group = parser.add_argument_group('output') + aa('-o', default='/tmp/x', help='output index') + aa('--keepquantizer', default=False, action='store_true', + help='by default we remove the data from the quantizer to save space') + + args = parser.parse_args() + print('args=', args) + + print('start accessing data') + src = produce_batches(args) + + print('loading index', args.inputindex) + index = faiss.read_index(args.inputindex) + + if args.nt != -1: + faiss.omp_set_num_threads(args.nt) + + t0 = time.time() + ntot = 0 + for ids, x in rate_limited_iter(src): + print('add %d:%d (%.3f s)' % (ntot, ntot + ids.size, time.time() - t0)) + index.add_with_ids(np.ascontiguousarray(x, dtype='float32'), ids) + ntot += ids.size + + index_ivf = faiss.extract_index_ivf(index) + print('invlists stats: imbalance %.3f' % index_ivf.invlists.imbalance_factor()) + index_ivf.invlists.print_stats() + + if not args.keepquantizer: + print('resetting quantizer content') + index_ivf = faiss.extract_index_ivf(index) + index_ivf.quantizer.reset() + + print('store output', args.o) + faiss.write_index(index, args.o) + +if __name__ == '__main__': + main() diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_trained_index.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_trained_index.py new file mode 100644 index 0000000000..50e4668f1b --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/make_trained_index.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import faiss + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" +workdir = "/checkpoint/matthijs/ondisk_distributed/" + + +print('Load centroids') +centroids = np.load(workdir + '1M_centroids.npy') +ncent, d = centroids.shape + + +print('apply random rotation') +rrot = faiss.RandomRotationMatrix(d, d) +rrot.init(1234) +centroids = rrot.apply_py(centroids) + +print('make HNSW index as quantizer') +quantizer = faiss.IndexHNSWFlat(d, 32) +quantizer.hnsw.efSearch = 1024 +quantizer.hnsw.efConstruction = 200 +quantizer.add(centroids) + +print('build index') +index = faiss.IndexPreTransform( + rrot, + faiss.IndexIVFScalarQuantizer( + quantizer, d, ncent, faiss.ScalarQuantizer.QT_6bit + ) + ) + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +print('finish training index') +xt = fvecs_mmap(deep1bdir + 'learn.fvecs') +xt = np.ascontiguousarray(xt[:256 * 1000], dtype='float32') +index.train(xt) + +print('write output') +faiss.write_index(index, workdir + 'trained.faissindex') diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/merge_to_ondisk.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/merge_to_ondisk.py new file mode 100644 index 0000000000..5c8f3ace94 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/merge_to_ondisk.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import faiss +import argparse +from multiprocessing.dummy import Pool as ThreadPool + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('--inputs', nargs='*', required=True, + help='input indexes to merge') + parser.add_argument('--l0', type=int, default=0) + parser.add_argument('--l1', type=int, default=-1) + + parser.add_argument('--nt', default=-1, + help='nb threads') + + parser.add_argument('--output', required=True, + help='output index filename') + parser.add_argument('--outputIL', + help='output invfile filename') + + args = parser.parse_args() + + if args.nt != -1: + print('set nb of threads to', args.nt) + + + ils = faiss.InvertedListsPtrVector() + ils_dont_dealloc = [] + + pool = ThreadPool(20) + + def load_index(fname): + print("loading", fname) + try: + index = faiss.read_index(fname, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) + except RuntimeError as e: + print('could not load %s: %s' % (fname, e)) + return fname, None + + print(" %d entries" % index.ntotal) + return fname, index + + index0 = None + + for _, index in pool.imap(load_index, args.inputs): + if index is None: + continue + index_ivf = faiss.extract_index_ivf(index) + il = faiss.downcast_InvertedLists(index_ivf.invlists) + index_ivf.invlists = None + il.this.own() + ils_dont_dealloc.append(il) + if (args.l0, args.l1) != (0, -1): + print('restricting to lists %d:%d' % (args.l0, args.l1)) + # il = faiss.SliceInvertedLists(il, args.l0, args.l1) + + il.crop_invlists(args.l0, args.l1) + ils_dont_dealloc.append(il) + ils.push_back(il) + + if index0 is None: + index0 = index + + print("loaded %d invlists" % ils.size()) + + if not args.outputIL: + args.outputIL = args.output + '_invlists' + + il0 = ils.at(0) + + il = faiss.OnDiskInvertedLists( + il0.nlist, il0.code_size, + args.outputIL) + + print("perform merge") + + ntotal = il.merge_from(ils.data(), ils.size(), True) + + print("swap into index0") + + index0_ivf = faiss.extract_index_ivf(index0) + index0_ivf.nlist = il0.nlist + index0_ivf.ntotal = index0.ntotal = ntotal + index0_ivf.invlists = il + index0_ivf.own_invlists = False + + print("write", args.output) + + faiss.write_index(index0, args.output) diff --git a/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/rpc.py b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/rpc.py new file mode 100644 index 0000000000..7b248ea0a1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/distributed_ondisk/rpc.py @@ -0,0 +1,252 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +""" +Simplistic RPC implementation. +Exposes all functions of a Server object. + +Uses pickle for serialization and the socket interface. +""" + +import os,pdb,pickle,time,errno,sys,_thread,traceback,socket,threading,gc + + +# default +PORT=12032 + + +######################################################################### +# simple I/O functions + + + +def inline_send_handle(f, conn): + st = os.fstat(f.fileno()) + size = st.st_size + pickle.dump(size, conn) + conn.write(f.read(size)) + +def inline_send_string(s, conn): + size = len(s) + pickle.dump(size, conn) + conn.write(s) + + +class FileSock: + " wraps a socket so that it is usable by pickle/cPickle " + + def __init__(self,sock): + self.sock = sock + self.nr=0 + + def write(self, buf): + # print("sending %d bytes"%len(buf)) + #self.sock.sendall(buf) + # print("...done") + bs = 512 * 1024 + ns = 0 + while ns < len(buf): + sent = self.sock.send(buf[ns:ns + bs]) + ns += sent + + + def read(self,bs=512*1024): + #if self.nr==10000: pdb.set_trace() + self.nr+=1 + # print("read bs=%d"%bs) + b = [] + nb = 0 + while len(b) $workdir/vslices/slice$i.bash < $workdir/hslices/slice$i.bash <0 nodes have 32 links (theses ones are "cheap" to store + because there are fewer nodes in the upper levels. + +- `--indexfile $bdir/deep1M_PQ36_M6.index`: name of the index file + (without information for the L&C extension) + +- `--beta_nsq 4`: number of bytes to allocate for the codes (M in the + paper) + +- `--beta_centroids $bdir/deep1M_PQ36_M6_nsq4.npy`: filename to store + the trained beta centroids + +- `--neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq4_codes.npy`: filename + for the encoded weights (beta) of the combination + +- `--k_reorder 0,5`: number of restults to reorder. 0 = baseline + without reordering, 5 = value used throughout the paper + +- `--efSearch 1,1024`: number of nodes to visit (T in the paper) + +The script will proceed with the following steps: + +0. load dataset (and possibly compute the ground-truth if the +ground-truth file is not provided) + +1. train the OPQ encoder + +2. build the index and store it + +3. compute the residuals and train the beta vocabulary to do the reconstuction + +4. encode the vertices + +5. search and evaluate the search results. + +With option `--exhaustive` the results of the exhaustive column can be +obtained. + +The run above should output: +``` +... +setting k_reorder=5 +... +efSearch=1024 0.3132 ms per query, R@1: 0.4283 R@10: 0.6337 R@100: 0.6520 ndis 40941919 nreorder 50000 + +``` +which matches the paper's table 2. + +Note that in multi-threaded mode, the building of the HNSW strcuture +is not deterministic. Therefore, the results across runs may not be exactly the same. + +Reproducing Figure 5 in the paper +--------------------------------- + +Figure 5 just evaluates the combination of HNSW and PQ. For example, +the operating point L6&OPQ40 can be obtained with + +``` +python bench_link_and_code.py \ + --db deep1M \ + --M0 6 \ + --indexkey OPQ40_160,HNSW32_PQ40 \ + --indexfile $bdir/deep1M_PQ40_M6.index \ + --beta_nsq 1 --beta_k 1 \ + --beta_centroids $bdir/deep1M_PQ40_M6_nsq0.npy \ + --neigh_recons_codes $bdir/deep1M_PQ36_M6_nsq0_codes.npy \ + --k_reorder 0 --efSearch 16,64,256,1024 +``` + +The arguments are similar to the previous table. Note that nsq = 0 is +simulated by setting beta_nsq = 1 and beta_k = 1 (ie a code with a single +reproduction value). + +The output should look like: + +``` +setting k_reorder=0 +efSearch=16 0.0147 ms per query, R@1: 0.3409 R@10: 0.4388 R@100: 0.4394 ndis 2629735 nreorder 0 +efSearch=64 0.0122 ms per query, R@1: 0.4836 R@10: 0.6490 R@100: 0.6509 ndis 4623221 nreorder 0 +efSearch=256 0.0344 ms per query, R@1: 0.5730 R@10: 0.7915 R@100: 0.7951 ndis 11090176 nreorder 0 +efSearch=1024 0.2656 ms per query, R@1: 0.6212 R@10: 0.8722 R@100: 0.8765 ndis 33501951 nreorder 0 +``` + +The results with k_reorder=5 are not reported in the paper, they +represent the performance of a "free coding" version of the algorithm. diff --git a/core/src/index/thirdparty/faiss/benchs/link_and_code/bench_link_and_code.py b/core/src/index/thirdparty/faiss/benchs/link_and_code/bench_link_and_code.py new file mode 100644 index 0000000000..0b055169e4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/link_and_code/bench_link_and_code.py @@ -0,0 +1,304 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +import os +import sys +import time +import numpy as np +import re +import faiss +from multiprocessing.dummy import Pool as ThreadPool +import pdb +import argparse +import datasets +from datasets import sanitize +import neighbor_codec + +###################################################### +# Command-line parsing +###################################################### + + +parser = argparse.ArgumentParser() + +def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + +group = parser.add_argument_group('dataset options') + +aa('--db', default='deep1M', help='dataset') +aa( '--compute_gt', default=False, action='store_true', + help='compute and store the groundtruth') + +group = parser.add_argument_group('index consturction') + +aa('--indexkey', default='HNSW32', help='index_factory type') +aa('--efConstruction', default=200, type=int, + help='HNSW construction factor') +aa('--M0', default=-1, type=int, help='size of base level') +aa('--maxtrain', default=256 * 256, type=int, + help='maximum number of training points') +aa('--indexfile', default='', help='file to read or write index from') +aa('--add_bs', default=-1, type=int, + help='add elements index by batches of this size') +aa('--link_singletons', default=False, action='store_true', + help='do a pass to link in the singletons') + +group = parser.add_argument_group( + 'searching (reconstruct_from_neighbors options)') + +aa('--beta_centroids', default='', + help='file with codebook') +aa('--neigh_recons_codes', default='', + help='file with codes for reconstruction') +aa('--beta_ntrain', default=250000, type=int, help='') +aa('--beta_k', default=256, type=int, help='beta codebook size') +aa('--beta_nsq', default=1, type=int, help='number of beta sub-vectors') +aa('--beta_niter', default=10, type=int, help='') +aa('--k_reorder', default='-1', help='') + +group = parser.add_argument_group('searching') + +aa('--k', default=100, type=int, help='nb of nearest neighbors') +aa('--exhaustive', default=False, action='store_true', + help='report the exhaustive search topline') +aa('--searchthreads', default=-1, type=int, + help='nb of threads to use at search time') +aa('--efSearch', default='', type=str, + help='comma-separated values of efSearch to try') + +args = parser.parse_args() + +print "args:", args + + +###################################################### +# Load dataset +###################################################### + +xt, xb, xq, gt = datasets.load_data( + dataset=args.db, compute_gt=args.compute_gt) + +nq, d = xq.shape +nb, d = xb.shape + + +###################################################### +# Make index +###################################################### + +if os.path.exists(args.indexfile): + + print "reading", args.indexfile + index = faiss.read_index(args.indexfile) + + if isinstance(index, faiss.IndexPreTransform): + index_hnsw = faiss.downcast_index(index.index) + vec_transform = index.chain.at(0).apply_py + else: + index_hnsw = index + vec_transform = lambda x:x + + hnsw = index_hnsw.hnsw + hnsw_stats = faiss.cvar.hnsw_stats + +else: + + print "build index, key=", args.indexkey + + index = faiss.index_factory(d, args.indexkey) + + if isinstance(index, faiss.IndexPreTransform): + index_hnsw = faiss.downcast_index(index.index) + vec_transform = index.chain.at(0).apply_py + else: + index_hnsw = index + vec_transform = lambda x:x + + hnsw = index_hnsw.hnsw + hnsw.efConstruction = args.efConstruction + hnsw_stats = faiss.cvar.hnsw_stats + index.verbose = True + index_hnsw.verbose = True + index_hnsw.storage.verbose = True + + if args.M0 != -1: + print "set level 0 nb of neighbors to", args.M0 + hnsw.set_nb_neighbors(0, args.M0) + + xt2 = sanitize(xt[:args.maxtrain]) + assert np.all(np.isfinite(xt2)) + + print "train, size", xt.shape + t0 = time.time() + index.train(xt2) + print " train in %.3f s" % (time.time() - t0) + + print "adding" + t0 = time.time() + if args.add_bs == -1: + index.add(sanitize(xb)) + else: + for i0 in range(0, nb, args.add_bs): + i1 = min(nb, i0 + args.add_bs) + print " adding %d:%d / %d" % (i0, i1, nb) + index.add(sanitize(xb[i0:i1])) + + print " add in %.3f s" % (time.time() - t0) + print "storing", args.indexfile + faiss.write_index(index, args.indexfile) + + +###################################################### +# Train beta centroids and encode dataset +###################################################### + +if args.beta_centroids: + print "reordering links" + index_hnsw.reorder_links() + + if os.path.exists(args.beta_centroids): + print "load", args.beta_centroids + beta_centroids = np.load(args.beta_centroids) + nsq, k, M1 = beta_centroids.shape + assert M1 == hnsw.nb_neighbors(0) + 1 + + rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq) + else: + print "train beta centroids" + rfn = faiss.ReconstructFromNeighbors( + index_hnsw, args.beta_k, args.beta_nsq) + + xb_full = vec_transform(sanitize(xb[:args.beta_ntrain])) + + beta_centroids = neighbor_codec.train_beta_codebook( + rfn, xb_full, niter=args.beta_niter) + + print " storing", args.beta_centroids + np.save(args.beta_centroids, beta_centroids) + + + faiss.copy_array_to_vector(beta_centroids.ravel(), + rfn.codebook) + index_hnsw.reconstruct_from_neighbors = rfn + + if rfn.k == 1: + pass # no codes to take care of + elif os.path.exists(args.neigh_recons_codes): + print "loading neigh codes", args.neigh_recons_codes + codes = np.load(args.neigh_recons_codes) + assert codes.size == rfn.code_size * index.ntotal + faiss.copy_array_to_vector(codes.astype('uint8'), + rfn.codes) + rfn.ntotal = index.ntotal + else: + print "encoding neigh codes" + t0 = time.time() + + bs = 1000000 if args.add_bs == -1 else args.add_bs + + for i0 in range(0, nb, bs): + i1 = min(i0 + bs, nb) + print " encode %d:%d / %d [%.3f s]\r" % ( + i0, i1, nb, time.time() - t0), + sys.stdout.flush() + xbatch = vec_transform(sanitize(xb[i0:i1])) + rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch)) + print + + print "storing %s" % args.neigh_recons_codes + codes = faiss.vector_to_array(rfn.codes) + np.save(args.neigh_recons_codes, codes) + +###################################################### +# Exhaustive evaluation +###################################################### + +if args.exhaustive: + print "exhaustive evaluation" + xq_tr = vec_transform(sanitize(xq)) + index2 = faiss.IndexFlatL2(index_hnsw.d) + accu_recons_error = 0.0 + + if faiss.get_num_gpus() > 0: + print "do eval on GPU" + co = faiss.GpuMultipleClonerOptions() + co.shard = False + index2 = faiss.index_cpu_to_all_gpus(index2, co) + + # process in batches in case the dataset does not fit in RAM + rh = datasets.ResultHeap(xq_tr.shape[0], 100) + t0 = time.time() + bs = 500000 + for i0 in range(0, nb, bs): + i1 = min(nb, i0 + bs) + print ' handling batch %d:%d' % (i0, i1) + + xb_recons = np.empty( + (i1 - i0, index_hnsw.d), dtype='float32') + rfn.reconstruct_n(i0, i1 - i0, faiss.swig_ptr(xb_recons)) + + accu_recons_error += ( + (vec_transform(sanitize(xb[i0:i1])) - + xb_recons)**2).sum() + + index2.reset() + index2.add(xb_recons) + D, I = index2.search(xq_tr, 100) + rh.add_batch_result(D, I, i0) + + rh.finalize() + del index2 + t1 = time.time() + print "done in %.3f s" % (t1 - t0) + print "total reconstruction error: ", accu_recons_error + print "eval retrieval:" + datasets.evaluate_DI(rh.D, rh.I, gt) + + +def get_neighbors(hnsw, i, level): + " list the neighbors for node i at level " + assert i < hnsw.levels.size() + assert level < hnsw.levels.at(i) + be = np.empty(2, 'uint64') + hnsw.neighbor_range(i, level, faiss.swig_ptr(be), faiss.swig_ptr(be[1:])) + return [hnsw.neighbors.at(j) for j in range(be[0], be[1])] + + +############################################################# +# Index is ready +############################################################# + +xq = sanitize(xq) + +if args.searchthreads != -1: + print "Setting nb of threads to", args.searchthreads + faiss.omp_set_num_threads(args.searchthreads) + + +if gt is None: + print "no valid groundtruth -- exit" + sys.exit() + + +k_reorders = [int(x) for x in args.k_reorder.split(',')] +efSearchs = [int(x) for x in args.efSearch.split(',')] + + +for k_reorder in k_reorders: + + if index_hnsw.reconstruct_from_neighbors: + print "setting k_reorder=%d" % k_reorder + index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder + + for efSearch in efSearchs: + print "efSearch=%-4d" % efSearch, + hnsw.efSearch = efSearch + hnsw_stats.reset() + datasets.evaluate(xq, gt, index, k=args.k, endl=False) + + print "ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder) diff --git a/core/src/index/thirdparty/faiss/benchs/link_and_code/datasets.py b/core/src/index/thirdparty/faiss/benchs/link_and_code/datasets.py new file mode 100644 index 0000000000..ce1379f408 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/link_and_code/datasets.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +""" +Common functions to load datasets and compute their ground-truth +""" + +import time +import numpy as np +import faiss +import pdb +import sys + +# set this to the directory that contains the datafiles. +# deep1b data should be at simdir + 'deep1b' +# bigann data should be at simdir + 'bigann' +simdir = '/mnt/vol/gfsai-east/ai-group/datasets/simsearch/' + +################################################################# +# Small I/O functions +################################################################# + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def bvecs_mmap(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +def ivecs_write(fname, m): + n, d = m.shape + m1 = np.empty((n, d + 1), dtype='int32') + m1[:, 0] = d + m1[:, 1:] = m + m1.tofile(fname) + + +def fvecs_write(fname, m): + m = m.astype('float32') + ivecs_write(fname, m.view('int32')) + + +################################################################# +# Dataset +################################################################# + +def sanitize(x): + return np.ascontiguousarray(x, dtype='float32') + + +class ResultHeap: + """ Combine query results from a sliced dataset """ + + def __init__(self, nq, k): + " nq: number of query vectors, k: number of results per query " + self.I = np.zeros((nq, k), dtype='int64') + self.D = np.zeros((nq, k), dtype='float32') + self.nq, self.k = nq, k + heaps = faiss.float_maxheap_array_t() + heaps.k = k + heaps.nh = nq + heaps.val = faiss.swig_ptr(self.D) + heaps.ids = faiss.swig_ptr(self.I) + heaps.heapify() + self.heaps = heaps + + def add_batch_result(self, D, I, i0): + assert D.shape == (self.nq, self.k) + assert I.shape == (self.nq, self.k) + I += i0 + self.heaps.addn_with_ids( + self.k, faiss.swig_ptr(D), + faiss.swig_ptr(I), self.k) + + def finalize(self): + self.heaps.reorder() + + + +def compute_GT_sliced(xb, xq, k): + print "compute GT" + t0 = time.time() + nb, d = xb.shape + nq, d = xq.shape + rh = ResultHeap(nq, k) + bs = 10 ** 5 + + xqs = sanitize(xq) + + db_gt = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) + + # compute ground-truth by blocks of bs, and add to heaps + for i0 in range(0, nb, bs): + i1 = min(nb, i0 + bs) + xsl = sanitize(xb[i0:i1]) + db_gt.add(xsl) + D, I = db_gt.search(xqs, k) + rh.add_batch_result(D, I, i0) + db_gt.reset() + print "\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), + sys.stdout.flush() + print + rh.finalize() + gt_I = rh.I + + print "GT time: %.3f s" % (time.time() - t0) + return gt_I + + +def do_compute_gt(xb, xq, k): + print "computing GT" + nb, d = xb.shape + index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d)) + if nb < 100 * 1000: + print " add" + index.add(np.ascontiguousarray(xb, dtype='float32')) + print " search" + D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k) + else: + I = compute_GT_sliced(xb, xq, k) + + return I.astype('int32') + + +def load_data(dataset='deep1M', compute_gt=False): + + print "load data", dataset + + if dataset == 'sift1M': + basedir = simdir + 'sift1M/' + + xt = fvecs_read(basedir + "sift_learn.fvecs") + xb = fvecs_read(basedir + "sift_base.fvecs") + xq = fvecs_read(basedir + "sift_query.fvecs") + gt = ivecs_read(basedir + "sift_groundtruth.ivecs") + + elif dataset.startswith('bigann'): + basedir = simdir + 'bigann/' + + dbsize = 1000 if dataset == "bigann1B" else int(dataset[6:-1]) + xb = bvecs_mmap(basedir + 'bigann_base.bvecs') + xq = bvecs_mmap(basedir + 'bigann_query.bvecs') + xt = bvecs_mmap(basedir + 'bigann_learn.bvecs') + # trim xb to correct size + xb = xb[:dbsize * 1000 * 1000] + gt = ivecs_read(basedir + 'gnd/idx_%dM.ivecs' % dbsize) + + elif dataset.startswith("deep"): + basedir = simdir + 'deep1b/' + szsuf = dataset[4:] + if szsuf[-1] == 'M': + dbsize = 10 ** 6 * int(szsuf[:-1]) + elif szsuf == '1B': + dbsize = 10 ** 9 + elif szsuf[-1] == 'k': + dbsize = 1000 * int(szsuf[:-1]) + else: + assert False, "did not recognize suffix " + szsuf + + xt = fvecs_mmap(basedir + "learn.fvecs") + xb = fvecs_mmap(basedir + "base.fvecs") + xq = fvecs_read(basedir + "deep1B_queries.fvecs") + + xb = xb[:dbsize] + + gt_fname = basedir + "%s_groundtruth.ivecs" % dataset + if compute_gt: + gt = do_compute_gt(xb, xq, 100) + print "store", gt_fname + ivecs_write(gt_fname, gt) + + gt = ivecs_read(gt_fname) + + else: + assert False + + print "dataset %s sizes: B %s Q %s T %s" % ( + dataset, xb.shape, xq.shape, xt.shape) + + return xt, xb, xq, gt + +################################################################# +# Evaluation +################################################################# + + +def evaluate_DI(D, I, gt): + nq = gt.shape[0] + k = I.shape[1] + rank = 1 + while rank <= k: + recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) + print "R@%d: %.4f" % (rank, recall), + rank *= 10 + + +def evaluate(xq, gt, index, k=100, endl=True): + t0 = time.time() + D, I = index.search(xq, k) + t1 = time.time() + nq = xq.shape[0] + print "\t %8.4f ms per query, " % ( + (t1 - t0) * 1000.0 / nq), + rank = 1 + while rank <= k: + recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) + print "R@%d: %.4f" % (rank, recall), + rank *= 10 + if endl: + print + return D, I diff --git a/core/src/index/thirdparty/faiss/benchs/link_and_code/neighbor_codec.py b/core/src/index/thirdparty/faiss/benchs/link_and_code/neighbor_codec.py new file mode 100644 index 0000000000..3869a2c109 --- /dev/null +++ b/core/src/index/thirdparty/faiss/benchs/link_and_code/neighbor_codec.py @@ -0,0 +1,239 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +""" +This is the training code for the link and code. Especially the +neighbors_kmeans function implements the EM-algorithm to find the +appropriate weightings and cluster them. +""" + +import time +import numpy as np +import faiss + +#---------------------------------------------------------- +# Utils +#---------------------------------------------------------- + +def sanitize(x): + return np.ascontiguousarray(x, dtype='float32') + + +def train_kmeans(x, k, ngpu, max_points_per_centroid=256): + "Runs kmeans on one or several GPUs" + d = x.shape[1] + clus = faiss.Clustering(d, k) + clus.verbose = True + clus.niter = 20 + clus.max_points_per_centroid = max_points_per_centroid + + if ngpu == 0: + index = faiss.IndexFlatL2(d) + else: + res = [faiss.StandardGpuResources() for i in range(ngpu)] + + flat_config = [] + for i in range(ngpu): + cfg = faiss.GpuIndexFlatConfig() + cfg.useFloat16 = False + cfg.device = i + flat_config.append(cfg) + + if ngpu == 1: + index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0]) + else: + indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i]) + for i in range(ngpu)] + index = faiss.IndexReplicas() + for sub_index in indexes: + index.addIndex(sub_index) + + # perform the training + clus.train(x, index) + centroids = faiss.vector_float_to_array(clus.centroids) + + obj = faiss.vector_float_to_array(clus.obj) + print "final objective: %.4g" % obj[-1] + + return centroids.reshape(k, d) + + +#---------------------------------------------------------- +# Learning the codebook from neighbors +#---------------------------------------------------------- + + +# works with both a full Inn table and dynamically generated neighbors + +def get_Inn_shape(Inn): + if type(Inn) != tuple: + return Inn.shape + return Inn[:2] + +def get_neighbor_table(x_coded, Inn, i): + if type(Inn) != tuple: + return x_coded[Inn[i,:],:] + rfn = x_coded + M, d = rfn.M, rfn.index.d + out = np.zeros((M + 1, d), dtype='float32') + rfn.get_neighbor_table(i, faiss.swig_ptr(out)) + _, _, sq = Inn + return out[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] + + +# Function that produces the best regression values from the vector +# and its neighbors +def regress_from_neighbors (x, x_coded, Inn): + (N, knn) = get_Inn_shape(Inn) + betas = np.zeros((N,knn)) + t0 = time.time() + for i in xrange (N): + xi = x[i,:] + NNi = get_neighbor_table(x_coded, Inn, i) + betas[i,:] = np.linalg.lstsq(NNi.transpose(), xi, rcond=0.01)[0] + if i % (N / 10) == 0: + print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) + return betas + + + +# find the best beta minimizing ||x-x_coded[Inn,:]*beta||^2 +def regress_opt_beta (x, x_coded, Inn): + (N, knn) = get_Inn_shape(Inn) + d = x.shape[1] + + # construct the linear system to be solved + X = np.zeros ((d*N)) + Y = np.zeros ((d*N, knn)) + for i in xrange (N): + X[i*d:(i+1)*d] = x[i,:] + neighbor_table = get_neighbor_table(x_coded, Inn, i) + Y[i*d:(i+1)*d, :] = neighbor_table.transpose() + beta_opt = np.linalg.lstsq(Y, X, rcond=0.01)[0] + return beta_opt + + +# Find the best encoding by minimizing the reconstruction error using +# a set of pre-computed beta values +def assign_beta (beta_centroids, x, x_coded, Inn, verbose=True): + if type(Inn) == tuple: + return assign_beta_2(beta_centroids, x, x_coded, Inn) + (N, knn) = Inn.shape + x_ibeta = np.zeros ((N), dtype='int32') + t0= time.time() + for i in xrange (N): + NNi = x_coded[Inn[i,:]] + # Consider all possible betas for the encoding and compute the + # encoding error + x_reg_all = np.dot (beta_centroids, NNi) + err = ((x_reg_all - x[i,:]) ** 2).sum(axis=1) + x_ibeta[i] = err.argmin() + if verbose: + if i % (N / 10) == 0: + print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) + return x_ibeta + + +# Reconstruct a set of vectors using the beta_centroids, the +# assignment, the encoded neighbors identified by the list Inn (which +# includes the vector itself) +def recons_from_neighbors (beta_centroids, x_ibeta, x_coded, Inn): + (N, knn) = Inn.shape + x_rec = np.zeros(x_coded.shape) + t0= time.time() + for i in xrange (N): + NNi = x_coded[Inn[i,:]] + x_rec[i, :] = np.dot (beta_centroids[x_ibeta[i]], NNi) + if i % (N / 10) == 0: + print ("[%d:%d] %6.3fs" % (i, i + N / 10, time.time() - t0)) + return x_rec + + +# Compute a EM-like algorithm trying at optimizing the beta such as they +# minimize the reconstruction error from the neighbors +def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5): + # First compute centroids using a regular k-means algorithm + betas = regress_from_neighbors (x, x_coded, Inn) + beta_centroids = train_kmeans( + sanitize(betas), K, ngpus, max_points_per_centroid=1000000) + _, knn = get_Inn_shape(Inn) + d = x.shape[1] + + rs = np.random.RandomState() + for iter in range(niter): + print 'iter', iter + idx = assign_beta (beta_centroids, x, x_coded, Inn, verbose=False) + + hist = np.bincount(idx) + for cl0 in np.where(hist == 0)[0]: + print " cluster %d empty, split" % cl0, + cl1 = idx[np.random.randint(idx.size)] + pos = np.nonzero (idx == cl1)[0] + pos = rs.choice(pos, pos.size / 2) + print " cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos)) + idx[pos] = cl0 + hist = np.bincount(idx) + + tot_err = 0 + for k in range (K): + pos = np.nonzero (idx == k)[0] + npos = pos.shape[0] + + X = np.zeros (d*npos) + Y = np.zeros ((d*npos, knn)) + + for i in range(npos): + X[i*d:(i+1)*d] = x[pos[i],:] + neighbor_table = get_neighbor_table(x_coded, Inn, pos[i]) + Y[i*d:(i+1)*d, :] = neighbor_table.transpose() + sol, residuals, _, _ = np.linalg.lstsq(Y, X, rcond=0.01) + if residuals.size > 0: + tot_err += residuals.sum() + beta_centroids[k, :] = sol + print ' err=%g' % tot_err + return beta_centroids + + +# assign the betas in C++ +def assign_beta_2(beta_centroids, x, rfn, Inn): + _, _, sq = Inn + if rfn.k == 1: + return np.zeros(x.shape[0], dtype=int) + # add dummy dimensions to beta_centroids and x + all_beta_centroids = np.zeros( + (rfn.nsq, rfn.k, rfn.M + 1), dtype='float32') + all_beta_centroids[sq] = beta_centroids + all_x = np.zeros((len(x), rfn.d), dtype='float32') + all_x[:, sq * rfn.dsub : (sq + 1) * rfn.dsub] = x + rfn.codes.clear() + rfn.ntotal = 0 + faiss.copy_array_to_vector( + all_beta_centroids.ravel(), rfn.codebook) + rfn.add_codes(len(x), faiss.swig_ptr(all_x)) + codes = faiss.vector_to_array(rfn.codes) + codes = codes.reshape(-1, rfn.nsq) + return codes[:, sq] + + +####################################################### +# For usage from bench_storages.py + +def train_beta_codebook(rfn, xb_full, niter=10): + beta_centroids = [] + for sq in range(rfn.nsq): + d0, d1 = sq * rfn.dsub, (sq + 1) * rfn.dsub + print "training subquantizer %d/%d on dimensions %d:%d" % ( + sq, rfn.nsq, d0, d1) + beta_centroids_i = neighbors_kmeans( + xb_full[:, d0:d1], rfn, (xb_full.shape[0], rfn.M + 1, sq), + rfn.k, + ngpus=0, niter=niter) + beta_centroids.append(beta_centroids_i) + rfn.ntotal = 0 + rfn.codes.clear() + rfn.codebook.clear() + return np.stack(beta_centroids) diff --git a/core/src/index/thirdparty/faiss/build-aux/config.guess b/core/src/index/thirdparty/faiss/build-aux/config.guess new file mode 100755 index 0000000000..2193702b12 --- /dev/null +++ b/core/src/index/thirdparty/faiss/build-aux/config.guess @@ -0,0 +1,1473 @@ +#! /bin/sh +# Attempt to guess a canonical system name. +# Copyright 1992-2017 Free Software Foundation, Inc. + +timestamp='2017-05-27' + +# This file is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . +# +# As a special exception to the GNU General Public License, if you +# distribute this file as part of a program that contains a +# configuration script generated by Autoconf, you may include it under +# the same distribution terms that you use for the rest of that +# program. This Exception is an additional permission under section 7 +# of the GNU General Public License, version 3 ("GPLv3"). +# +# Originally written by Per Bothner; maintained since 2000 by Ben Elliston. +# +# You can get the latest version of this script from: +# http://git.savannah.gnu.org/gitweb/?p=config.git;a=blob_plain;f=config.guess +# +# Please send patches to . + + +me=`echo "$0" | sed -e 's,.*/,,'` + +usage="\ +Usage: $0 [OPTION] + +Output the configuration name of the system \`$me' is run on. + +Operation modes: + -h, --help print this help, then exit + -t, --time-stamp print date of last modification, then exit + -v, --version print version number, then exit + +Report bugs and patches to ." + +version="\ +GNU config.guess ($timestamp) + +Originally written by Per Bothner. +Copyright 1992-2017 Free Software Foundation, Inc. + +This is free software; see the source for copying conditions. There is NO +warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE." + +help=" +Try \`$me --help' for more information." + +# Parse command line +while test $# -gt 0 ; do + case $1 in + --time-stamp | --time* | -t ) + echo "$timestamp" ; exit ;; + --version | -v ) + echo "$version" ; exit ;; + --help | --h* | -h ) + echo "$usage"; exit ;; + -- ) # Stop option processing + shift; break ;; + - ) # Use stdin as input. + break ;; + -* ) + echo "$me: invalid option $1$help" >&2 + exit 1 ;; + * ) + break ;; + esac +done + +if test $# != 0; then + echo "$me: too many arguments$help" >&2 + exit 1 +fi + +trap 'exit 1' 1 2 15 + +# CC_FOR_BUILD -- compiler used by this script. Note that the use of a +# compiler to aid in system detection is discouraged as it requires +# temporary files to be created and, as you can see below, it is a +# headache to deal with in a portable fashion. + +# Historically, `CC_FOR_BUILD' used to be named `HOST_CC'. We still +# use `HOST_CC' if defined, but it is deprecated. + +# Portable tmp directory creation inspired by the Autoconf team. + +set_cc_for_build=' +trap "exitcode=\$?; (rm -f \$tmpfiles 2>/dev/null; rmdir \$tmp 2>/dev/null) && exit \$exitcode" 0 ; +trap "rm -f \$tmpfiles 2>/dev/null; rmdir \$tmp 2>/dev/null; exit 1" 1 2 13 15 ; +: ${TMPDIR=/tmp} ; + { tmp=`(umask 077 && mktemp -d "$TMPDIR/cgXXXXXX") 2>/dev/null` && test -n "$tmp" && test -d "$tmp" ; } || + { test -n "$RANDOM" && tmp=$TMPDIR/cg$$-$RANDOM && (umask 077 && mkdir $tmp) ; } || + { tmp=$TMPDIR/cg-$$ && (umask 077 && mkdir $tmp) && echo "Warning: creating insecure temp directory" >&2 ; } || + { echo "$me: cannot create a temporary directory in $TMPDIR" >&2 ; exit 1 ; } ; +dummy=$tmp/dummy ; +tmpfiles="$dummy.c $dummy.o $dummy.rel $dummy" ; +case $CC_FOR_BUILD,$HOST_CC,$CC in + ,,) echo "int x;" > $dummy.c ; + for c in cc gcc c89 c99 ; do + if ($c -c -o $dummy.o $dummy.c) >/dev/null 2>&1 ; then + CC_FOR_BUILD="$c"; break ; + fi ; + done ; + if test x"$CC_FOR_BUILD" = x ; then + CC_FOR_BUILD=no_compiler_found ; + fi + ;; + ,,*) CC_FOR_BUILD=$CC ;; + ,*,*) CC_FOR_BUILD=$HOST_CC ;; +esac ; set_cc_for_build= ;' + +# This is needed to find uname on a Pyramid OSx when run in the BSD universe. +# (ghazi@noc.rutgers.edu 1994-08-24) +if (test -f /.attbin/uname) >/dev/null 2>&1 ; then + PATH=$PATH:/.attbin ; export PATH +fi + +UNAME_MACHINE=`(uname -m) 2>/dev/null` || UNAME_MACHINE=unknown +UNAME_RELEASE=`(uname -r) 2>/dev/null` || UNAME_RELEASE=unknown +UNAME_SYSTEM=`(uname -s) 2>/dev/null` || UNAME_SYSTEM=unknown +UNAME_VERSION=`(uname -v) 2>/dev/null` || UNAME_VERSION=unknown + +case "${UNAME_SYSTEM}" in +Linux|GNU|GNU/*) + # If the system lacks a compiler, then just pick glibc. + # We could probably try harder. + LIBC=gnu + + eval $set_cc_for_build + cat <<-EOF > $dummy.c + #include + #if defined(__UCLIBC__) + LIBC=uclibc + #elif defined(__dietlibc__) + LIBC=dietlibc + #else + LIBC=gnu + #endif + EOF + eval `$CC_FOR_BUILD -E $dummy.c 2>/dev/null | grep '^LIBC' | sed 's, ,,g'` + ;; +esac + +# Note: order is significant - the case branches are not exclusive. + +case "${UNAME_MACHINE}:${UNAME_SYSTEM}:${UNAME_RELEASE}:${UNAME_VERSION}" in + *:NetBSD:*:*) + # NetBSD (nbsd) targets should (where applicable) match one or + # more of the tuples: *-*-netbsdelf*, *-*-netbsdaout*, + # *-*-netbsdecoff* and *-*-netbsd*. For targets that recently + # switched to ELF, *-*-netbsd* would select the old + # object file format. This provides both forward + # compatibility and a consistent mechanism for selecting the + # object file format. + # + # Note: NetBSD doesn't particularly care about the vendor + # portion of the name. We always set it to "unknown". + sysctl="sysctl -n hw.machine_arch" + UNAME_MACHINE_ARCH=`(uname -p 2>/dev/null || \ + /sbin/$sysctl 2>/dev/null || \ + /usr/sbin/$sysctl 2>/dev/null || \ + echo unknown)` + case "${UNAME_MACHINE_ARCH}" in + armeb) machine=armeb-unknown ;; + arm*) machine=arm-unknown ;; + sh3el) machine=shl-unknown ;; + sh3eb) machine=sh-unknown ;; + sh5el) machine=sh5le-unknown ;; + earmv*) + arch=`echo ${UNAME_MACHINE_ARCH} | sed -e 's,^e\(armv[0-9]\).*$,\1,'` + endian=`echo ${UNAME_MACHINE_ARCH} | sed -ne 's,^.*\(eb\)$,\1,p'` + machine=${arch}${endian}-unknown + ;; + *) machine=${UNAME_MACHINE_ARCH}-unknown ;; + esac + # The Operating System including object format, if it has switched + # to ELF recently (or will in the future) and ABI. + case "${UNAME_MACHINE_ARCH}" in + earm*) + os=netbsdelf + ;; + arm*|i386|m68k|ns32k|sh3*|sparc|vax) + eval $set_cc_for_build + if echo __ELF__ | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ELF__ + then + # Once all utilities can be ECOFF (netbsdecoff) or a.out (netbsdaout). + # Return netbsd for either. FIX? + os=netbsd + else + os=netbsdelf + fi + ;; + *) + os=netbsd + ;; + esac + # Determine ABI tags. + case "${UNAME_MACHINE_ARCH}" in + earm*) + expr='s/^earmv[0-9]/-eabi/;s/eb$//' + abi=`echo ${UNAME_MACHINE_ARCH} | sed -e "$expr"` + ;; + esac + # The OS release + # Debian GNU/NetBSD machines have a different userland, and + # thus, need a distinct triplet. However, they do not need + # kernel version information, so it can be replaced with a + # suitable tag, in the style of linux-gnu. + case "${UNAME_VERSION}" in + Debian*) + release='-gnu' + ;; + *) + release=`echo ${UNAME_RELEASE} | sed -e 's/[-_].*//' | cut -d. -f1,2` + ;; + esac + # Since CPU_TYPE-MANUFACTURER-KERNEL-OPERATING_SYSTEM: + # contains redundant information, the shorter form: + # CPU_TYPE-MANUFACTURER-OPERATING_SYSTEM is used. + echo "${machine}-${os}${release}${abi}" + exit ;; + *:Bitrig:*:*) + UNAME_MACHINE_ARCH=`arch | sed 's/Bitrig.//'` + echo ${UNAME_MACHINE_ARCH}-unknown-bitrig${UNAME_RELEASE} + exit ;; + *:OpenBSD:*:*) + UNAME_MACHINE_ARCH=`arch | sed 's/OpenBSD.//'` + echo ${UNAME_MACHINE_ARCH}-unknown-openbsd${UNAME_RELEASE} + exit ;; + *:LibertyBSD:*:*) + UNAME_MACHINE_ARCH=`arch | sed 's/^.*BSD\.//'` + echo ${UNAME_MACHINE_ARCH}-unknown-libertybsd${UNAME_RELEASE} + exit ;; + *:ekkoBSD:*:*) + echo ${UNAME_MACHINE}-unknown-ekkobsd${UNAME_RELEASE} + exit ;; + *:SolidBSD:*:*) + echo ${UNAME_MACHINE}-unknown-solidbsd${UNAME_RELEASE} + exit ;; + macppc:MirBSD:*:*) + echo powerpc-unknown-mirbsd${UNAME_RELEASE} + exit ;; + *:MirBSD:*:*) + echo ${UNAME_MACHINE}-unknown-mirbsd${UNAME_RELEASE} + exit ;; + *:Sortix:*:*) + echo ${UNAME_MACHINE}-unknown-sortix + exit ;; + alpha:OSF1:*:*) + case $UNAME_RELEASE in + *4.0) + UNAME_RELEASE=`/usr/sbin/sizer -v | awk '{print $3}'` + ;; + *5.*) + UNAME_RELEASE=`/usr/sbin/sizer -v | awk '{print $4}'` + ;; + esac + # According to Compaq, /usr/sbin/psrinfo has been available on + # OSF/1 and Tru64 systems produced since 1995. I hope that + # covers most systems running today. This code pipes the CPU + # types through head -n 1, so we only detect the type of CPU 0. + ALPHA_CPU_TYPE=`/usr/sbin/psrinfo -v | sed -n -e 's/^ The alpha \(.*\) processor.*$/\1/p' | head -n 1` + case "$ALPHA_CPU_TYPE" in + "EV4 (21064)") + UNAME_MACHINE=alpha ;; + "EV4.5 (21064)") + UNAME_MACHINE=alpha ;; + "LCA4 (21066/21068)") + UNAME_MACHINE=alpha ;; + "EV5 (21164)") + UNAME_MACHINE=alphaev5 ;; + "EV5.6 (21164A)") + UNAME_MACHINE=alphaev56 ;; + "EV5.6 (21164PC)") + UNAME_MACHINE=alphapca56 ;; + "EV5.7 (21164PC)") + UNAME_MACHINE=alphapca57 ;; + "EV6 (21264)") + UNAME_MACHINE=alphaev6 ;; + "EV6.7 (21264A)") + UNAME_MACHINE=alphaev67 ;; + "EV6.8CB (21264C)") + UNAME_MACHINE=alphaev68 ;; + "EV6.8AL (21264B)") + UNAME_MACHINE=alphaev68 ;; + "EV6.8CX (21264D)") + UNAME_MACHINE=alphaev68 ;; + "EV6.9A (21264/EV69A)") + UNAME_MACHINE=alphaev69 ;; + "EV7 (21364)") + UNAME_MACHINE=alphaev7 ;; + "EV7.9 (21364A)") + UNAME_MACHINE=alphaev79 ;; + esac + # A Pn.n version is a patched version. + # A Vn.n version is a released version. + # A Tn.n version is a released field test version. + # A Xn.n version is an unreleased experimental baselevel. + # 1.2 uses "1.2" for uname -r. + echo ${UNAME_MACHINE}-dec-osf`echo ${UNAME_RELEASE} | sed -e 's/^[PVTX]//' | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz` + # Reset EXIT trap before exiting to avoid spurious non-zero exit code. + exitcode=$? + trap '' 0 + exit $exitcode ;; + Alpha\ *:Windows_NT*:*) + # How do we know it's Interix rather than the generic POSIX subsystem? + # Should we change UNAME_MACHINE based on the output of uname instead + # of the specific Alpha model? + echo alpha-pc-interix + exit ;; + 21064:Windows_NT:50:3) + echo alpha-dec-winnt3.5 + exit ;; + Amiga*:UNIX_System_V:4.0:*) + echo m68k-unknown-sysv4 + exit ;; + *:[Aa]miga[Oo][Ss]:*:*) + echo ${UNAME_MACHINE}-unknown-amigaos + exit ;; + *:[Mm]orph[Oo][Ss]:*:*) + echo ${UNAME_MACHINE}-unknown-morphos + exit ;; + *:OS/390:*:*) + echo i370-ibm-openedition + exit ;; + *:z/VM:*:*) + echo s390-ibm-zvmoe + exit ;; + *:OS400:*:*) + echo powerpc-ibm-os400 + exit ;; + arm:RISC*:1.[012]*:*|arm:riscix:1.[012]*:*) + echo arm-acorn-riscix${UNAME_RELEASE} + exit ;; + arm*:riscos:*:*|arm*:RISCOS:*:*) + echo arm-unknown-riscos + exit ;; + SR2?01:HI-UX/MPP:*:* | SR8000:HI-UX/MPP:*:*) + echo hppa1.1-hitachi-hiuxmpp + exit ;; + Pyramid*:OSx*:*:* | MIS*:OSx*:*:* | MIS*:SMP_DC-OSx*:*:*) + # akee@wpdis03.wpafb.af.mil (Earle F. Ake) contributed MIS and NILE. + if test "`(/bin/universe) 2>/dev/null`" = att ; then + echo pyramid-pyramid-sysv3 + else + echo pyramid-pyramid-bsd + fi + exit ;; + NILE*:*:*:dcosx) + echo pyramid-pyramid-svr4 + exit ;; + DRS?6000:unix:4.0:6*) + echo sparc-icl-nx6 + exit ;; + DRS?6000:UNIX_SV:4.2*:7* | DRS?6000:isis:4.2*:7*) + case `/usr/bin/uname -p` in + sparc) echo sparc-icl-nx7; exit ;; + esac ;; + s390x:SunOS:*:*) + echo ${UNAME_MACHINE}-ibm-solaris2`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + sun4H:SunOS:5.*:*) + echo sparc-hal-solaris2`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + sun4*:SunOS:5.*:* | tadpole*:SunOS:5.*:*) + echo sparc-sun-solaris2`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + i86pc:AuroraUX:5.*:* | i86xen:AuroraUX:5.*:*) + echo i386-pc-auroraux${UNAME_RELEASE} + exit ;; + i86pc:SunOS:5.*:* | i86xen:SunOS:5.*:*) + eval $set_cc_for_build + SUN_ARCH=i386 + # If there is a compiler, see if it is configured for 64-bit objects. + # Note that the Sun cc does not turn __LP64__ into 1 like gcc does. + # This test works for both compilers. + if [ "$CC_FOR_BUILD" != no_compiler_found ]; then + if (echo '#ifdef __amd64'; echo IS_64BIT_ARCH; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_64BIT_ARCH >/dev/null + then + SUN_ARCH=x86_64 + fi + fi + echo ${SUN_ARCH}-pc-solaris2`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + sun4*:SunOS:6*:*) + # According to config.sub, this is the proper way to canonicalize + # SunOS6. Hard to guess exactly what SunOS6 will be like, but + # it's likely to be more like Solaris than SunOS4. + echo sparc-sun-solaris3`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + sun4*:SunOS:*:*) + case "`/usr/bin/arch -k`" in + Series*|S4*) + UNAME_RELEASE=`uname -v` + ;; + esac + # Japanese Language versions have a version number like `4.1.3-JL'. + echo sparc-sun-sunos`echo ${UNAME_RELEASE}|sed -e 's/-/_/'` + exit ;; + sun3*:SunOS:*:*) + echo m68k-sun-sunos${UNAME_RELEASE} + exit ;; + sun*:*:4.2BSD:*) + UNAME_RELEASE=`(sed 1q /etc/motd | awk '{print substr($5,1,3)}') 2>/dev/null` + test "x${UNAME_RELEASE}" = x && UNAME_RELEASE=3 + case "`/bin/arch`" in + sun3) + echo m68k-sun-sunos${UNAME_RELEASE} + ;; + sun4) + echo sparc-sun-sunos${UNAME_RELEASE} + ;; + esac + exit ;; + aushp:SunOS:*:*) + echo sparc-auspex-sunos${UNAME_RELEASE} + exit ;; + # The situation for MiNT is a little confusing. The machine name + # can be virtually everything (everything which is not + # "atarist" or "atariste" at least should have a processor + # > m68000). The system name ranges from "MiNT" over "FreeMiNT" + # to the lowercase version "mint" (or "freemint"). Finally + # the system name "TOS" denotes a system which is actually not + # MiNT. But MiNT is downward compatible to TOS, so this should + # be no problem. + atarist[e]:*MiNT:*:* | atarist[e]:*mint:*:* | atarist[e]:*TOS:*:*) + echo m68k-atari-mint${UNAME_RELEASE} + exit ;; + atari*:*MiNT:*:* | atari*:*mint:*:* | atarist[e]:*TOS:*:*) + echo m68k-atari-mint${UNAME_RELEASE} + exit ;; + *falcon*:*MiNT:*:* | *falcon*:*mint:*:* | *falcon*:*TOS:*:*) + echo m68k-atari-mint${UNAME_RELEASE} + exit ;; + milan*:*MiNT:*:* | milan*:*mint:*:* | *milan*:*TOS:*:*) + echo m68k-milan-mint${UNAME_RELEASE} + exit ;; + hades*:*MiNT:*:* | hades*:*mint:*:* | *hades*:*TOS:*:*) + echo m68k-hades-mint${UNAME_RELEASE} + exit ;; + *:*MiNT:*:* | *:*mint:*:* | *:*TOS:*:*) + echo m68k-unknown-mint${UNAME_RELEASE} + exit ;; + m68k:machten:*:*) + echo m68k-apple-machten${UNAME_RELEASE} + exit ;; + powerpc:machten:*:*) + echo powerpc-apple-machten${UNAME_RELEASE} + exit ;; + RISC*:Mach:*:*) + echo mips-dec-mach_bsd4.3 + exit ;; + RISC*:ULTRIX:*:*) + echo mips-dec-ultrix${UNAME_RELEASE} + exit ;; + VAX*:ULTRIX*:*:*) + echo vax-dec-ultrix${UNAME_RELEASE} + exit ;; + 2020:CLIX:*:* | 2430:CLIX:*:*) + echo clipper-intergraph-clix${UNAME_RELEASE} + exit ;; + mips:*:*:UMIPS | mips:*:*:RISCos) + eval $set_cc_for_build + sed 's/^ //' << EOF >$dummy.c +#ifdef __cplusplus +#include /* for printf() prototype */ + int main (int argc, char *argv[]) { +#else + int main (argc, argv) int argc; char *argv[]; { +#endif + #if defined (host_mips) && defined (MIPSEB) + #if defined (SYSTYPE_SYSV) + printf ("mips-mips-riscos%ssysv\n", argv[1]); exit (0); + #endif + #if defined (SYSTYPE_SVR4) + printf ("mips-mips-riscos%ssvr4\n", argv[1]); exit (0); + #endif + #if defined (SYSTYPE_BSD43) || defined(SYSTYPE_BSD) + printf ("mips-mips-riscos%sbsd\n", argv[1]); exit (0); + #endif + #endif + exit (-1); + } +EOF + $CC_FOR_BUILD -o $dummy $dummy.c && + dummyarg=`echo "${UNAME_RELEASE}" | sed -n 's/\([0-9]*\).*/\1/p'` && + SYSTEM_NAME=`$dummy $dummyarg` && + { echo "$SYSTEM_NAME"; exit; } + echo mips-mips-riscos${UNAME_RELEASE} + exit ;; + Motorola:PowerMAX_OS:*:*) + echo powerpc-motorola-powermax + exit ;; + Motorola:*:4.3:PL8-*) + echo powerpc-harris-powermax + exit ;; + Night_Hawk:*:*:PowerMAX_OS | Synergy:PowerMAX_OS:*:*) + echo powerpc-harris-powermax + exit ;; + Night_Hawk:Power_UNIX:*:*) + echo powerpc-harris-powerunix + exit ;; + m88k:CX/UX:7*:*) + echo m88k-harris-cxux7 + exit ;; + m88k:*:4*:R4*) + echo m88k-motorola-sysv4 + exit ;; + m88k:*:3*:R3*) + echo m88k-motorola-sysv3 + exit ;; + AViiON:dgux:*:*) + # DG/UX returns AViiON for all architectures + UNAME_PROCESSOR=`/usr/bin/uname -p` + if [ $UNAME_PROCESSOR = mc88100 ] || [ $UNAME_PROCESSOR = mc88110 ] + then + if [ ${TARGET_BINARY_INTERFACE}x = m88kdguxelfx ] || \ + [ ${TARGET_BINARY_INTERFACE}x = x ] + then + echo m88k-dg-dgux${UNAME_RELEASE} + else + echo m88k-dg-dguxbcs${UNAME_RELEASE} + fi + else + echo i586-dg-dgux${UNAME_RELEASE} + fi + exit ;; + M88*:DolphinOS:*:*) # DolphinOS (SVR3) + echo m88k-dolphin-sysv3 + exit ;; + M88*:*:R3*:*) + # Delta 88k system running SVR3 + echo m88k-motorola-sysv3 + exit ;; + XD88*:*:*:*) # Tektronix XD88 system running UTekV (SVR3) + echo m88k-tektronix-sysv3 + exit ;; + Tek43[0-9][0-9]:UTek:*:*) # Tektronix 4300 system running UTek (BSD) + echo m68k-tektronix-bsd + exit ;; + *:IRIX*:*:*) + echo mips-sgi-irix`echo ${UNAME_RELEASE}|sed -e 's/-/_/g'` + exit ;; + ????????:AIX?:[12].1:2) # AIX 2.2.1 or AIX 2.1.1 is RT/PC AIX. + echo romp-ibm-aix # uname -m gives an 8 hex-code CPU id + exit ;; # Note that: echo "'`uname -s`'" gives 'AIX ' + i*86:AIX:*:*) + echo i386-ibm-aix + exit ;; + ia64:AIX:*:*) + if [ -x /usr/bin/oslevel ] ; then + IBM_REV=`/usr/bin/oslevel` + else + IBM_REV=${UNAME_VERSION}.${UNAME_RELEASE} + fi + echo ${UNAME_MACHINE}-ibm-aix${IBM_REV} + exit ;; + *:AIX:2:3) + if grep bos325 /usr/include/stdio.h >/dev/null 2>&1; then + eval $set_cc_for_build + sed 's/^ //' << EOF >$dummy.c + #include + + main() + { + if (!__power_pc()) + exit(1); + puts("powerpc-ibm-aix3.2.5"); + exit(0); + } +EOF + if $CC_FOR_BUILD -o $dummy $dummy.c && SYSTEM_NAME=`$dummy` + then + echo "$SYSTEM_NAME" + else + echo rs6000-ibm-aix3.2.5 + fi + elif grep bos324 /usr/include/stdio.h >/dev/null 2>&1; then + echo rs6000-ibm-aix3.2.4 + else + echo rs6000-ibm-aix3.2 + fi + exit ;; + *:AIX:*:[4567]) + IBM_CPU_ID=`/usr/sbin/lsdev -C -c processor -S available | sed 1q | awk '{ print $1 }'` + if /usr/sbin/lsattr -El ${IBM_CPU_ID} | grep ' POWER' >/dev/null 2>&1; then + IBM_ARCH=rs6000 + else + IBM_ARCH=powerpc + fi + if [ -x /usr/bin/lslpp ] ; then + IBM_REV=`/usr/bin/lslpp -Lqc bos.rte.libc | + awk -F: '{ print $3 }' | sed s/[0-9]*$/0/` + else + IBM_REV=${UNAME_VERSION}.${UNAME_RELEASE} + fi + echo ${IBM_ARCH}-ibm-aix${IBM_REV} + exit ;; + *:AIX:*:*) + echo rs6000-ibm-aix + exit ;; + ibmrt:4.4BSD:*|romp-ibm:BSD:*) + echo romp-ibm-bsd4.4 + exit ;; + ibmrt:*BSD:*|romp-ibm:BSD:*) # covers RT/PC BSD and + echo romp-ibm-bsd${UNAME_RELEASE} # 4.3 with uname added to + exit ;; # report: romp-ibm BSD 4.3 + *:BOSX:*:*) + echo rs6000-bull-bosx + exit ;; + DPX/2?00:B.O.S.:*:*) + echo m68k-bull-sysv3 + exit ;; + 9000/[34]??:4.3bsd:1.*:*) + echo m68k-hp-bsd + exit ;; + hp300:4.4BSD:*:* | 9000/[34]??:4.3bsd:2.*:*) + echo m68k-hp-bsd4.4 + exit ;; + 9000/[34678]??:HP-UX:*:*) + HPUX_REV=`echo ${UNAME_RELEASE}|sed -e 's/[^.]*.[0B]*//'` + case "${UNAME_MACHINE}" in + 9000/31? ) HP_ARCH=m68000 ;; + 9000/[34]?? ) HP_ARCH=m68k ;; + 9000/[678][0-9][0-9]) + if [ -x /usr/bin/getconf ]; then + sc_cpu_version=`/usr/bin/getconf SC_CPU_VERSION 2>/dev/null` + sc_kernel_bits=`/usr/bin/getconf SC_KERNEL_BITS 2>/dev/null` + case "${sc_cpu_version}" in + 523) HP_ARCH=hppa1.0 ;; # CPU_PA_RISC1_0 + 528) HP_ARCH=hppa1.1 ;; # CPU_PA_RISC1_1 + 532) # CPU_PA_RISC2_0 + case "${sc_kernel_bits}" in + 32) HP_ARCH=hppa2.0n ;; + 64) HP_ARCH=hppa2.0w ;; + '') HP_ARCH=hppa2.0 ;; # HP-UX 10.20 + esac ;; + esac + fi + if [ "${HP_ARCH}" = "" ]; then + eval $set_cc_for_build + sed 's/^ //' << EOF >$dummy.c + + #define _HPUX_SOURCE + #include + #include + + int main () + { + #if defined(_SC_KERNEL_BITS) + long bits = sysconf(_SC_KERNEL_BITS); + #endif + long cpu = sysconf (_SC_CPU_VERSION); + + switch (cpu) + { + case CPU_PA_RISC1_0: puts ("hppa1.0"); break; + case CPU_PA_RISC1_1: puts ("hppa1.1"); break; + case CPU_PA_RISC2_0: + #if defined(_SC_KERNEL_BITS) + switch (bits) + { + case 64: puts ("hppa2.0w"); break; + case 32: puts ("hppa2.0n"); break; + default: puts ("hppa2.0"); break; + } break; + #else /* !defined(_SC_KERNEL_BITS) */ + puts ("hppa2.0"); break; + #endif + default: puts ("hppa1.0"); break; + } + exit (0); + } +EOF + (CCOPTS="" $CC_FOR_BUILD -o $dummy $dummy.c 2>/dev/null) && HP_ARCH=`$dummy` + test -z "$HP_ARCH" && HP_ARCH=hppa + fi ;; + esac + if [ ${HP_ARCH} = hppa2.0w ] + then + eval $set_cc_for_build + + # hppa2.0w-hp-hpux* has a 64-bit kernel and a compiler generating + # 32-bit code. hppa64-hp-hpux* has the same kernel and a compiler + # generating 64-bit code. GNU and HP use different nomenclature: + # + # $ CC_FOR_BUILD=cc ./config.guess + # => hppa2.0w-hp-hpux11.23 + # $ CC_FOR_BUILD="cc +DA2.0w" ./config.guess + # => hppa64-hp-hpux11.23 + + if echo __LP64__ | (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | + grep -q __LP64__ + then + HP_ARCH=hppa2.0w + else + HP_ARCH=hppa64 + fi + fi + echo ${HP_ARCH}-hp-hpux${HPUX_REV} + exit ;; + ia64:HP-UX:*:*) + HPUX_REV=`echo ${UNAME_RELEASE}|sed -e 's/[^.]*.[0B]*//'` + echo ia64-hp-hpux${HPUX_REV} + exit ;; + 3050*:HI-UX:*:*) + eval $set_cc_for_build + sed 's/^ //' << EOF >$dummy.c + #include + int + main () + { + long cpu = sysconf (_SC_CPU_VERSION); + /* The order matters, because CPU_IS_HP_MC68K erroneously returns + true for CPU_PA_RISC1_0. CPU_IS_PA_RISC returns correct + results, however. */ + if (CPU_IS_PA_RISC (cpu)) + { + switch (cpu) + { + case CPU_PA_RISC1_0: puts ("hppa1.0-hitachi-hiuxwe2"); break; + case CPU_PA_RISC1_1: puts ("hppa1.1-hitachi-hiuxwe2"); break; + case CPU_PA_RISC2_0: puts ("hppa2.0-hitachi-hiuxwe2"); break; + default: puts ("hppa-hitachi-hiuxwe2"); break; + } + } + else if (CPU_IS_HP_MC68K (cpu)) + puts ("m68k-hitachi-hiuxwe2"); + else puts ("unknown-hitachi-hiuxwe2"); + exit (0); + } +EOF + $CC_FOR_BUILD -o $dummy $dummy.c && SYSTEM_NAME=`$dummy` && + { echo "$SYSTEM_NAME"; exit; } + echo unknown-hitachi-hiuxwe2 + exit ;; + 9000/7??:4.3bsd:*:* | 9000/8?[79]:4.3bsd:*:* ) + echo hppa1.1-hp-bsd + exit ;; + 9000/8??:4.3bsd:*:*) + echo hppa1.0-hp-bsd + exit ;; + *9??*:MPE/iX:*:* | *3000*:MPE/iX:*:*) + echo hppa1.0-hp-mpeix + exit ;; + hp7??:OSF1:*:* | hp8?[79]:OSF1:*:* ) + echo hppa1.1-hp-osf + exit ;; + hp8??:OSF1:*:*) + echo hppa1.0-hp-osf + exit ;; + i*86:OSF1:*:*) + if [ -x /usr/sbin/sysversion ] ; then + echo ${UNAME_MACHINE}-unknown-osf1mk + else + echo ${UNAME_MACHINE}-unknown-osf1 + fi + exit ;; + parisc*:Lites*:*:*) + echo hppa1.1-hp-lites + exit ;; + C1*:ConvexOS:*:* | convex:ConvexOS:C1*:*) + echo c1-convex-bsd + exit ;; + C2*:ConvexOS:*:* | convex:ConvexOS:C2*:*) + if getsysinfo -f scalar_acc + then echo c32-convex-bsd + else echo c2-convex-bsd + fi + exit ;; + C34*:ConvexOS:*:* | convex:ConvexOS:C34*:*) + echo c34-convex-bsd + exit ;; + C38*:ConvexOS:*:* | convex:ConvexOS:C38*:*) + echo c38-convex-bsd + exit ;; + C4*:ConvexOS:*:* | convex:ConvexOS:C4*:*) + echo c4-convex-bsd + exit ;; + CRAY*Y-MP:*:*:*) + echo ymp-cray-unicos${UNAME_RELEASE} | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*[A-Z]90:*:*:*) + echo ${UNAME_MACHINE}-cray-unicos${UNAME_RELEASE} \ + | sed -e 's/CRAY.*\([A-Z]90\)/\1/' \ + -e y/ABCDEFGHIJKLMNOPQRSTUVWXYZ/abcdefghijklmnopqrstuvwxyz/ \ + -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*TS:*:*:*) + echo t90-cray-unicos${UNAME_RELEASE} | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*T3E:*:*:*) + echo alphaev5-cray-unicosmk${UNAME_RELEASE} | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*SV1:*:*:*) + echo sv1-cray-unicos${UNAME_RELEASE} | sed -e 's/\.[^.]*$/.X/' + exit ;; + *:UNICOS/mp:*:*) + echo craynv-cray-unicosmp${UNAME_RELEASE} | sed -e 's/\.[^.]*$/.X/' + exit ;; + F30[01]:UNIX_System_V:*:* | F700:UNIX_System_V:*:*) + FUJITSU_PROC=`uname -m | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz` + FUJITSU_SYS=`uname -p | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/\///'` + FUJITSU_REL=`echo ${UNAME_RELEASE} | sed -e 's/ /_/'` + echo "${FUJITSU_PROC}-fujitsu-${FUJITSU_SYS}${FUJITSU_REL}" + exit ;; + 5000:UNIX_System_V:4.*:*) + FUJITSU_SYS=`uname -p | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/\///'` + FUJITSU_REL=`echo ${UNAME_RELEASE} | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/ /_/'` + echo "sparc-fujitsu-${FUJITSU_SYS}${FUJITSU_REL}" + exit ;; + i*86:BSD/386:*:* | i*86:BSD/OS:*:* | *:Ascend\ Embedded/OS:*:*) + echo ${UNAME_MACHINE}-pc-bsdi${UNAME_RELEASE} + exit ;; + sparc*:BSD/OS:*:*) + echo sparc-unknown-bsdi${UNAME_RELEASE} + exit ;; + *:BSD/OS:*:*) + echo ${UNAME_MACHINE}-unknown-bsdi${UNAME_RELEASE} + exit ;; + *:FreeBSD:*:*) + UNAME_PROCESSOR=`/usr/bin/uname -p` + case ${UNAME_PROCESSOR} in + amd64) + UNAME_PROCESSOR=x86_64 ;; + i386) + UNAME_PROCESSOR=i586 ;; + esac + echo ${UNAME_PROCESSOR}-unknown-freebsd`echo ${UNAME_RELEASE}|sed -e 's/[-(].*//'` + exit ;; + i*:CYGWIN*:*) + echo ${UNAME_MACHINE}-pc-cygwin + exit ;; + *:MINGW64*:*) + echo ${UNAME_MACHINE}-pc-mingw64 + exit ;; + *:MINGW*:*) + echo ${UNAME_MACHINE}-pc-mingw32 + exit ;; + *:MSYS*:*) + echo ${UNAME_MACHINE}-pc-msys + exit ;; + i*:windows32*:*) + # uname -m includes "-pc" on this system. + echo ${UNAME_MACHINE}-mingw32 + exit ;; + i*:PW*:*) + echo ${UNAME_MACHINE}-pc-pw32 + exit ;; + *:Interix*:*) + case ${UNAME_MACHINE} in + x86) + echo i586-pc-interix${UNAME_RELEASE} + exit ;; + authenticamd | genuineintel | EM64T) + echo x86_64-unknown-interix${UNAME_RELEASE} + exit ;; + IA64) + echo ia64-unknown-interix${UNAME_RELEASE} + exit ;; + esac ;; + [345]86:Windows_95:* | [345]86:Windows_98:* | [345]86:Windows_NT:*) + echo i${UNAME_MACHINE}-pc-mks + exit ;; + 8664:Windows_NT:*) + echo x86_64-pc-mks + exit ;; + i*:Windows_NT*:* | Pentium*:Windows_NT*:*) + # How do we know it's Interix rather than the generic POSIX subsystem? + # It also conflicts with pre-2.0 versions of AT&T UWIN. Should we + # UNAME_MACHINE based on the output of uname instead of i386? + echo i586-pc-interix + exit ;; + i*:UWIN*:*) + echo ${UNAME_MACHINE}-pc-uwin + exit ;; + amd64:CYGWIN*:*:* | x86_64:CYGWIN*:*:*) + echo x86_64-unknown-cygwin + exit ;; + p*:CYGWIN*:*) + echo powerpcle-unknown-cygwin + exit ;; + prep*:SunOS:5.*:*) + echo powerpcle-unknown-solaris2`echo ${UNAME_RELEASE}|sed -e 's/[^.]*//'` + exit ;; + *:GNU:*:*) + # the GNU system + echo `echo ${UNAME_MACHINE}|sed -e 's,[-/].*$,,'`-unknown-${LIBC}`echo ${UNAME_RELEASE}|sed -e 's,/.*$,,'` + exit ;; + *:GNU/*:*:*) + # other systems with GNU libc and userland + echo ${UNAME_MACHINE}-unknown-`echo ${UNAME_SYSTEM} | sed 's,^[^/]*/,,' | tr "[:upper:]" "[:lower:]"``echo ${UNAME_RELEASE}|sed -e 's/[-(].*//'`-${LIBC} + exit ;; + i*86:Minix:*:*) + echo ${UNAME_MACHINE}-pc-minix + exit ;; + aarch64:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + aarch64_be:Linux:*:*) + UNAME_MACHINE=aarch64_be + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + alpha:Linux:*:*) + case `sed -n '/^cpu model/s/^.*: \(.*\)/\1/p' < /proc/cpuinfo` in + EV5) UNAME_MACHINE=alphaev5 ;; + EV56) UNAME_MACHINE=alphaev56 ;; + PCA56) UNAME_MACHINE=alphapca56 ;; + PCA57) UNAME_MACHINE=alphapca56 ;; + EV6) UNAME_MACHINE=alphaev6 ;; + EV67) UNAME_MACHINE=alphaev67 ;; + EV68*) UNAME_MACHINE=alphaev68 ;; + esac + objdump --private-headers /bin/sh | grep -q ld.so.1 + if test "$?" = 0 ; then LIBC=gnulibc1 ; fi + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + arc:Linux:*:* | arceb:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + arm*:Linux:*:*) + eval $set_cc_for_build + if echo __ARM_EABI__ | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ARM_EABI__ + then + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + else + if echo __ARM_PCS_VFP | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ARM_PCS_VFP + then + echo ${UNAME_MACHINE}-unknown-linux-${LIBC}eabi + else + echo ${UNAME_MACHINE}-unknown-linux-${LIBC}eabihf + fi + fi + exit ;; + avr32*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + cris:Linux:*:*) + echo ${UNAME_MACHINE}-axis-linux-${LIBC} + exit ;; + crisv32:Linux:*:*) + echo ${UNAME_MACHINE}-axis-linux-${LIBC} + exit ;; + e2k:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + frv:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + hexagon:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + i*86:Linux:*:*) + echo ${UNAME_MACHINE}-pc-linux-${LIBC} + exit ;; + ia64:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + k1om:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + m32r*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + m68*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + mips:Linux:*:* | mips64:Linux:*:*) + eval $set_cc_for_build + sed 's/^ //' << EOF >$dummy.c + #undef CPU + #undef ${UNAME_MACHINE} + #undef ${UNAME_MACHINE}el + #if defined(__MIPSEL__) || defined(__MIPSEL) || defined(_MIPSEL) || defined(MIPSEL) + CPU=${UNAME_MACHINE}el + #else + #if defined(__MIPSEB__) || defined(__MIPSEB) || defined(_MIPSEB) || defined(MIPSEB) + CPU=${UNAME_MACHINE} + #else + CPU= + #endif + #endif +EOF + eval `$CC_FOR_BUILD -E $dummy.c 2>/dev/null | grep '^CPU'` + test x"${CPU}" != x && { echo "${CPU}-unknown-linux-${LIBC}"; exit; } + ;; + mips64el:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + openrisc*:Linux:*:*) + echo or1k-unknown-linux-${LIBC} + exit ;; + or32:Linux:*:* | or1k*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + padre:Linux:*:*) + echo sparc-unknown-linux-${LIBC} + exit ;; + parisc64:Linux:*:* | hppa64:Linux:*:*) + echo hppa64-unknown-linux-${LIBC} + exit ;; + parisc:Linux:*:* | hppa:Linux:*:*) + # Look for CPU level + case `grep '^cpu[^a-z]*:' /proc/cpuinfo 2>/dev/null | cut -d' ' -f2` in + PA7*) echo hppa1.1-unknown-linux-${LIBC} ;; + PA8*) echo hppa2.0-unknown-linux-${LIBC} ;; + *) echo hppa-unknown-linux-${LIBC} ;; + esac + exit ;; + ppc64:Linux:*:*) + echo powerpc64-unknown-linux-${LIBC} + exit ;; + ppc:Linux:*:*) + echo powerpc-unknown-linux-${LIBC} + exit ;; + ppc64le:Linux:*:*) + echo powerpc64le-unknown-linux-${LIBC} + exit ;; + ppcle:Linux:*:*) + echo powerpcle-unknown-linux-${LIBC} + exit ;; + riscv32:Linux:*:* | riscv64:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + s390:Linux:*:* | s390x:Linux:*:*) + echo ${UNAME_MACHINE}-ibm-linux-${LIBC} + exit ;; + sh64*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + sh*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + sparc:Linux:*:* | sparc64:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + tile*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + vax:Linux:*:*) + echo ${UNAME_MACHINE}-dec-linux-${LIBC} + exit ;; + x86_64:Linux:*:*) + echo ${UNAME_MACHINE}-pc-linux-${LIBC} + exit ;; + xtensa*:Linux:*:*) + echo ${UNAME_MACHINE}-unknown-linux-${LIBC} + exit ;; + i*86:DYNIX/ptx:4*:*) + # ptx 4.0 does uname -s correctly, with DYNIX/ptx in there. + # earlier versions are messed up and put the nodename in both + # sysname and nodename. + echo i386-sequent-sysv4 + exit ;; + i*86:UNIX_SV:4.2MP:2.*) + # Unixware is an offshoot of SVR4, but it has its own version + # number series starting with 2... + # I am not positive that other SVR4 systems won't match this, + # I just have to hope. -- rms. + # Use sysv4.2uw... so that sysv4* matches it. + echo ${UNAME_MACHINE}-pc-sysv4.2uw${UNAME_VERSION} + exit ;; + i*86:OS/2:*:*) + # If we were able to find `uname', then EMX Unix compatibility + # is probably installed. + echo ${UNAME_MACHINE}-pc-os2-emx + exit ;; + i*86:XTS-300:*:STOP) + echo ${UNAME_MACHINE}-unknown-stop + exit ;; + i*86:atheos:*:*) + echo ${UNAME_MACHINE}-unknown-atheos + exit ;; + i*86:syllable:*:*) + echo ${UNAME_MACHINE}-pc-syllable + exit ;; + i*86:LynxOS:2.*:* | i*86:LynxOS:3.[01]*:* | i*86:LynxOS:4.[02]*:*) + echo i386-unknown-lynxos${UNAME_RELEASE} + exit ;; + i*86:*DOS:*:*) + echo ${UNAME_MACHINE}-pc-msdosdjgpp + exit ;; + i*86:*:4.*:* | i*86:SYSTEM_V:4.*:*) + UNAME_REL=`echo ${UNAME_RELEASE} | sed 's/\/MP$//'` + if grep Novell /usr/include/link.h >/dev/null 2>/dev/null; then + echo ${UNAME_MACHINE}-univel-sysv${UNAME_REL} + else + echo ${UNAME_MACHINE}-pc-sysv${UNAME_REL} + fi + exit ;; + i*86:*:5:[678]*) + # UnixWare 7.x, OpenUNIX and OpenServer 6. + case `/bin/uname -X | grep "^Machine"` in + *486*) UNAME_MACHINE=i486 ;; + *Pentium) UNAME_MACHINE=i586 ;; + *Pent*|*Celeron) UNAME_MACHINE=i686 ;; + esac + echo ${UNAME_MACHINE}-unknown-sysv${UNAME_RELEASE}${UNAME_SYSTEM}${UNAME_VERSION} + exit ;; + i*86:*:3.2:*) + if test -f /usr/options/cb.name; then + UNAME_REL=`sed -n 's/.*Version //p' /dev/null >/dev/null ; then + UNAME_REL=`(/bin/uname -X|grep Release|sed -e 's/.*= //')` + (/bin/uname -X|grep i80486 >/dev/null) && UNAME_MACHINE=i486 + (/bin/uname -X|grep '^Machine.*Pentium' >/dev/null) \ + && UNAME_MACHINE=i586 + (/bin/uname -X|grep '^Machine.*Pent *II' >/dev/null) \ + && UNAME_MACHINE=i686 + (/bin/uname -X|grep '^Machine.*Pentium Pro' >/dev/null) \ + && UNAME_MACHINE=i686 + echo ${UNAME_MACHINE}-pc-sco$UNAME_REL + else + echo ${UNAME_MACHINE}-pc-sysv32 + fi + exit ;; + pc:*:*:*) + # Left here for compatibility: + # uname -m prints for DJGPP always 'pc', but it prints nothing about + # the processor, so we play safe by assuming i586. + # Note: whatever this is, it MUST be the same as what config.sub + # prints for the "djgpp" host, or else GDB configure will decide that + # this is a cross-build. + echo i586-pc-msdosdjgpp + exit ;; + Intel:Mach:3*:*) + echo i386-pc-mach3 + exit ;; + paragon:*:*:*) + echo i860-intel-osf1 + exit ;; + i860:*:4.*:*) # i860-SVR4 + if grep Stardent /usr/include/sys/uadmin.h >/dev/null 2>&1 ; then + echo i860-stardent-sysv${UNAME_RELEASE} # Stardent Vistra i860-SVR4 + else # Add other i860-SVR4 vendors below as they are discovered. + echo i860-unknown-sysv${UNAME_RELEASE} # Unknown i860-SVR4 + fi + exit ;; + mini*:CTIX:SYS*5:*) + # "miniframe" + echo m68010-convergent-sysv + exit ;; + mc68k:UNIX:SYSTEM5:3.51m) + echo m68k-convergent-sysv + exit ;; + M680?0:D-NIX:5.3:*) + echo m68k-diab-dnix + exit ;; + M68*:*:R3V[5678]*:*) + test -r /sysV68 && { echo 'm68k-motorola-sysv'; exit; } ;; + 3[345]??:*:4.0:3.0 | 3[34]??A:*:4.0:3.0 | 3[34]??,*:*:4.0:3.0 | 3[34]??/*:*:4.0:3.0 | 4400:*:4.0:3.0 | 4850:*:4.0:3.0 | SKA40:*:4.0:3.0 | SDS2:*:4.0:3.0 | SHG2:*:4.0:3.0 | S7501*:*:4.0:3.0) + OS_REL='' + test -r /etc/.relid \ + && OS_REL=.`sed -n 's/[^ ]* [^ ]* \([0-9][0-9]\).*/\1/p' < /etc/.relid` + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4.3${OS_REL}; exit; } + /bin/uname -p 2>/dev/null | /bin/grep entium >/dev/null \ + && { echo i586-ncr-sysv4.3${OS_REL}; exit; } ;; + 3[34]??:*:4.0:* | 3[34]??,*:*:4.0:*) + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4; exit; } ;; + NCR*:*:4.2:* | MPRAS*:*:4.2:*) + OS_REL='.3' + test -r /etc/.relid \ + && OS_REL=.`sed -n 's/[^ ]* [^ ]* \([0-9][0-9]\).*/\1/p' < /etc/.relid` + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4.3${OS_REL}; exit; } + /bin/uname -p 2>/dev/null | /bin/grep entium >/dev/null \ + && { echo i586-ncr-sysv4.3${OS_REL}; exit; } + /bin/uname -p 2>/dev/null | /bin/grep pteron >/dev/null \ + && { echo i586-ncr-sysv4.3${OS_REL}; exit; } ;; + m68*:LynxOS:2.*:* | m68*:LynxOS:3.0*:*) + echo m68k-unknown-lynxos${UNAME_RELEASE} + exit ;; + mc68030:UNIX_System_V:4.*:*) + echo m68k-atari-sysv4 + exit ;; + TSUNAMI:LynxOS:2.*:*) + echo sparc-unknown-lynxos${UNAME_RELEASE} + exit ;; + rs6000:LynxOS:2.*:*) + echo rs6000-unknown-lynxos${UNAME_RELEASE} + exit ;; + PowerPC:LynxOS:2.*:* | PowerPC:LynxOS:3.[01]*:* | PowerPC:LynxOS:4.[02]*:*) + echo powerpc-unknown-lynxos${UNAME_RELEASE} + exit ;; + SM[BE]S:UNIX_SV:*:*) + echo mips-dde-sysv${UNAME_RELEASE} + exit ;; + RM*:ReliantUNIX-*:*:*) + echo mips-sni-sysv4 + exit ;; + RM*:SINIX-*:*:*) + echo mips-sni-sysv4 + exit ;; + *:SINIX-*:*:*) + if uname -p 2>/dev/null >/dev/null ; then + UNAME_MACHINE=`(uname -p) 2>/dev/null` + echo ${UNAME_MACHINE}-sni-sysv4 + else + echo ns32k-sni-sysv + fi + exit ;; + PENTIUM:*:4.0*:*) # Unisys `ClearPath HMP IX 4000' SVR4/MP effort + # says + echo i586-unisys-sysv4 + exit ;; + *:UNIX_System_V:4*:FTX*) + # From Gerald Hewes . + # How about differentiating between stratus architectures? -djm + echo hppa1.1-stratus-sysv4 + exit ;; + *:*:*:FTX*) + # From seanf@swdc.stratus.com. + echo i860-stratus-sysv4 + exit ;; + i*86:VOS:*:*) + # From Paul.Green@stratus.com. + echo ${UNAME_MACHINE}-stratus-vos + exit ;; + *:VOS:*:*) + # From Paul.Green@stratus.com. + echo hppa1.1-stratus-vos + exit ;; + mc68*:A/UX:*:*) + echo m68k-apple-aux${UNAME_RELEASE} + exit ;; + news*:NEWS-OS:6*:*) + echo mips-sony-newsos6 + exit ;; + R[34]000:*System_V*:*:* | R4000:UNIX_SYSV:*:* | R*000:UNIX_SV:*:*) + if [ -d /usr/nec ]; then + echo mips-nec-sysv${UNAME_RELEASE} + else + echo mips-unknown-sysv${UNAME_RELEASE} + fi + exit ;; + BeBox:BeOS:*:*) # BeOS running on hardware made by Be, PPC only. + echo powerpc-be-beos + exit ;; + BeMac:BeOS:*:*) # BeOS running on Mac or Mac clone, PPC only. + echo powerpc-apple-beos + exit ;; + BePC:BeOS:*:*) # BeOS running on Intel PC compatible. + echo i586-pc-beos + exit ;; + BePC:Haiku:*:*) # Haiku running on Intel PC compatible. + echo i586-pc-haiku + exit ;; + x86_64:Haiku:*:*) + echo x86_64-unknown-haiku + exit ;; + SX-4:SUPER-UX:*:*) + echo sx4-nec-superux${UNAME_RELEASE} + exit ;; + SX-5:SUPER-UX:*:*) + echo sx5-nec-superux${UNAME_RELEASE} + exit ;; + SX-6:SUPER-UX:*:*) + echo sx6-nec-superux${UNAME_RELEASE} + exit ;; + SX-7:SUPER-UX:*:*) + echo sx7-nec-superux${UNAME_RELEASE} + exit ;; + SX-8:SUPER-UX:*:*) + echo sx8-nec-superux${UNAME_RELEASE} + exit ;; + SX-8R:SUPER-UX:*:*) + echo sx8r-nec-superux${UNAME_RELEASE} + exit ;; + SX-ACE:SUPER-UX:*:*) + echo sxace-nec-superux${UNAME_RELEASE} + exit ;; + Power*:Rhapsody:*:*) + echo powerpc-apple-rhapsody${UNAME_RELEASE} + exit ;; + *:Rhapsody:*:*) + echo ${UNAME_MACHINE}-apple-rhapsody${UNAME_RELEASE} + exit ;; + *:Darwin:*:*) + UNAME_PROCESSOR=`uname -p` || UNAME_PROCESSOR=unknown + eval $set_cc_for_build + if test "$UNAME_PROCESSOR" = unknown ; then + UNAME_PROCESSOR=powerpc + fi + if test `echo "$UNAME_RELEASE" | sed -e 's/\..*//'` -le 10 ; then + if [ "$CC_FOR_BUILD" != no_compiler_found ]; then + if (echo '#ifdef __LP64__'; echo IS_64BIT_ARCH; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_64BIT_ARCH >/dev/null + then + case $UNAME_PROCESSOR in + i386) UNAME_PROCESSOR=x86_64 ;; + powerpc) UNAME_PROCESSOR=powerpc64 ;; + esac + fi + # On 10.4-10.6 one might compile for PowerPC via gcc -arch ppc + if (echo '#ifdef __POWERPC__'; echo IS_PPC; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_PPC >/dev/null + then + UNAME_PROCESSOR=powerpc + fi + fi + elif test "$UNAME_PROCESSOR" = i386 ; then + # Avoid executing cc on OS X 10.9, as it ships with a stub + # that puts up a graphical alert prompting to install + # developer tools. Any system running Mac OS X 10.7 or + # later (Darwin 11 and later) is required to have a 64-bit + # processor. This is not true of the ARM version of Darwin + # that Apple uses in portable devices. + UNAME_PROCESSOR=x86_64 + fi + echo ${UNAME_PROCESSOR}-apple-darwin${UNAME_RELEASE} + exit ;; + *:procnto*:*:* | *:QNX:[0123456789]*:*) + UNAME_PROCESSOR=`uname -p` + if test "$UNAME_PROCESSOR" = x86; then + UNAME_PROCESSOR=i386 + UNAME_MACHINE=pc + fi + echo ${UNAME_PROCESSOR}-${UNAME_MACHINE}-nto-qnx${UNAME_RELEASE} + exit ;; + *:QNX:*:4*) + echo i386-pc-qnx + exit ;; + NEO-*:NONSTOP_KERNEL:*:*) + echo neo-tandem-nsk${UNAME_RELEASE} + exit ;; + NSE-*:NONSTOP_KERNEL:*:*) + echo nse-tandem-nsk${UNAME_RELEASE} + exit ;; + NSR-*:NONSTOP_KERNEL:*:*) + echo nsr-tandem-nsk${UNAME_RELEASE} + exit ;; + NSX-*:NONSTOP_KERNEL:*:*) + echo nsx-tandem-nsk${UNAME_RELEASE} + exit ;; + *:NonStop-UX:*:*) + echo mips-compaq-nonstopux + exit ;; + BS2000:POSIX*:*:*) + echo bs2000-siemens-sysv + exit ;; + DS/*:UNIX_System_V:*:*) + echo ${UNAME_MACHINE}-${UNAME_SYSTEM}-${UNAME_RELEASE} + exit ;; + *:Plan9:*:*) + # "uname -m" is not consistent, so use $cputype instead. 386 + # is converted to i386 for consistency with other x86 + # operating systems. + if test "$cputype" = 386; then + UNAME_MACHINE=i386 + else + UNAME_MACHINE="$cputype" + fi + echo ${UNAME_MACHINE}-unknown-plan9 + exit ;; + *:TOPS-10:*:*) + echo pdp10-unknown-tops10 + exit ;; + *:TENEX:*:*) + echo pdp10-unknown-tenex + exit ;; + KS10:TOPS-20:*:* | KL10:TOPS-20:*:* | TYPE4:TOPS-20:*:*) + echo pdp10-dec-tops20 + exit ;; + XKL-1:TOPS-20:*:* | TYPE5:TOPS-20:*:*) + echo pdp10-xkl-tops20 + exit ;; + *:TOPS-20:*:*) + echo pdp10-unknown-tops20 + exit ;; + *:ITS:*:*) + echo pdp10-unknown-its + exit ;; + SEI:*:*:SEIUX) + echo mips-sei-seiux${UNAME_RELEASE} + exit ;; + *:DragonFly:*:*) + echo ${UNAME_MACHINE}-unknown-dragonfly`echo ${UNAME_RELEASE}|sed -e 's/[-(].*//'` + exit ;; + *:*VMS:*:*) + UNAME_MACHINE=`(uname -p) 2>/dev/null` + case "${UNAME_MACHINE}" in + A*) echo alpha-dec-vms ; exit ;; + I*) echo ia64-dec-vms ; exit ;; + V*) echo vax-dec-vms ; exit ;; + esac ;; + *:XENIX:*:SysV) + echo i386-pc-xenix + exit ;; + i*86:skyos:*:*) + echo ${UNAME_MACHINE}-pc-skyos`echo ${UNAME_RELEASE} | sed -e 's/ .*$//'` + exit ;; + i*86:rdos:*:*) + echo ${UNAME_MACHINE}-pc-rdos + exit ;; + i*86:AROS:*:*) + echo ${UNAME_MACHINE}-pc-aros + exit ;; + x86_64:VMkernel:*:*) + echo ${UNAME_MACHINE}-unknown-esx + exit ;; + amd64:Isilon\ OneFS:*:*) + echo x86_64-unknown-onefs + exit ;; +esac + +cat >&2 </dev/null || echo unknown` +uname -r = `(uname -r) 2>/dev/null || echo unknown` +uname -s = `(uname -s) 2>/dev/null || echo unknown` +uname -v = `(uname -v) 2>/dev/null || echo unknown` + +/usr/bin/uname -p = `(/usr/bin/uname -p) 2>/dev/null` +/bin/uname -X = `(/bin/uname -X) 2>/dev/null` + +hostinfo = `(hostinfo) 2>/dev/null` +/bin/universe = `(/bin/universe) 2>/dev/null` +/usr/bin/arch -k = `(/usr/bin/arch -k) 2>/dev/null` +/bin/arch = `(/bin/arch) 2>/dev/null` +/usr/bin/oslevel = `(/usr/bin/oslevel) 2>/dev/null` +/usr/convex/getsysinfo = `(/usr/convex/getsysinfo) 2>/dev/null` + +UNAME_MACHINE = ${UNAME_MACHINE} +UNAME_RELEASE = ${UNAME_RELEASE} +UNAME_SYSTEM = ${UNAME_SYSTEM} +UNAME_VERSION = ${UNAME_VERSION} +EOF + +exit 1 + +# Local variables: +# eval: (add-hook 'write-file-hooks 'time-stamp) +# time-stamp-start: "timestamp='" +# time-stamp-format: "%:y-%02m-%02d" +# time-stamp-end: "'" +# End: diff --git a/core/src/index/thirdparty/faiss/build-aux/config.sub b/core/src/index/thirdparty/faiss/build-aux/config.sub new file mode 100755 index 0000000000..40ea5dfe11 --- /dev/null +++ b/core/src/index/thirdparty/faiss/build-aux/config.sub @@ -0,0 +1,1836 @@ +#! /bin/sh +# Configuration validation subroutine script. +# Copyright 1992-2017 Free Software Foundation, Inc. + +timestamp='2017-04-02' + +# This file is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . +# +# As a special exception to the GNU General Public License, if you +# distribute this file as part of a program that contains a +# configuration script generated by Autoconf, you may include it under +# the same distribution terms that you use for the rest of that +# program. This Exception is an additional permission under section 7 +# of the GNU General Public License, version 3 ("GPLv3"). + + +# Please send patches to . +# +# Configuration subroutine to validate and canonicalize a configuration type. +# Supply the specified configuration type as an argument. +# If it is invalid, we print an error message on stderr and exit with code 1. +# Otherwise, we print the canonical config type on stdout and succeed. + +# You can get the latest version of this script from: +# http://git.savannah.gnu.org/gitweb/?p=config.git;a=blob_plain;f=config.sub + +# This file is supposed to be the same for all GNU packages +# and recognize all the CPU types, system types and aliases +# that are meaningful with *any* GNU software. +# Each package is responsible for reporting which valid configurations +# it does not support. The user should be able to distinguish +# a failure to support a valid configuration from a meaningless +# configuration. + +# The goal of this file is to map all the various variations of a given +# machine specification into a single specification in the form: +# CPU_TYPE-MANUFACTURER-OPERATING_SYSTEM +# or in some cases, the newer four-part form: +# CPU_TYPE-MANUFACTURER-KERNEL-OPERATING_SYSTEM +# It is wrong to echo any other type of specification. + +me=`echo "$0" | sed -e 's,.*/,,'` + +usage="\ +Usage: $0 [OPTION] CPU-MFR-OPSYS or ALIAS + +Canonicalize a configuration name. + +Operation modes: + -h, --help print this help, then exit + -t, --time-stamp print date of last modification, then exit + -v, --version print version number, then exit + +Report bugs and patches to ." + +version="\ +GNU config.sub ($timestamp) + +Copyright 1992-2017 Free Software Foundation, Inc. + +This is free software; see the source for copying conditions. There is NO +warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE." + +help=" +Try \`$me --help' for more information." + +# Parse command line +while test $# -gt 0 ; do + case $1 in + --time-stamp | --time* | -t ) + echo "$timestamp" ; exit ;; + --version | -v ) + echo "$version" ; exit ;; + --help | --h* | -h ) + echo "$usage"; exit ;; + -- ) # Stop option processing + shift; break ;; + - ) # Use stdin as input. + break ;; + -* ) + echo "$me: invalid option $1$help" + exit 1 ;; + + *local*) + # First pass through any local machine types. + echo $1 + exit ;; + + * ) + break ;; + esac +done + +case $# in + 0) echo "$me: missing argument$help" >&2 + exit 1;; + 1) ;; + *) echo "$me: too many arguments$help" >&2 + exit 1;; +esac + +# Separate what the user gave into CPU-COMPANY and OS or KERNEL-OS (if any). +# Here we must recognize all the valid KERNEL-OS combinations. +maybe_os=`echo $1 | sed 's/^\(.*\)-\([^-]*-[^-]*\)$/\2/'` +case $maybe_os in + nto-qnx* | linux-gnu* | linux-android* | linux-dietlibc | linux-newlib* | \ + linux-musl* | linux-uclibc* | uclinux-uclibc* | uclinux-gnu* | kfreebsd*-gnu* | \ + knetbsd*-gnu* | netbsd*-gnu* | netbsd*-eabi* | \ + kopensolaris*-gnu* | cloudabi*-eabi* | \ + storm-chaos* | os2-emx* | rtmk-nova*) + os=-$maybe_os + basic_machine=`echo $1 | sed 's/^\(.*\)-\([^-]*-[^-]*\)$/\1/'` + ;; + android-linux) + os=-linux-android + basic_machine=`echo $1 | sed 's/^\(.*\)-\([^-]*-[^-]*\)$/\1/'`-unknown + ;; + *) + basic_machine=`echo $1 | sed 's/-[^-]*$//'` + if [ $basic_machine != $1 ] + then os=`echo $1 | sed 's/.*-/-/'` + else os=; fi + ;; +esac + +### Let's recognize common machines as not being operating systems so +### that things like config.sub decstation-3100 work. We also +### recognize some manufacturers as not being operating systems, so we +### can provide default operating systems below. +case $os in + -sun*os*) + # Prevent following clause from handling this invalid input. + ;; + -dec* | -mips* | -sequent* | -encore* | -pc532* | -sgi* | -sony* | \ + -att* | -7300* | -3300* | -delta* | -motorola* | -sun[234]* | \ + -unicom* | -ibm* | -next | -hp | -isi* | -apollo | -altos* | \ + -convergent* | -ncr* | -news | -32* | -3600* | -3100* | -hitachi* |\ + -c[123]* | -convex* | -sun | -crds | -omron* | -dg | -ultra | -tti* | \ + -harris | -dolphin | -highlevel | -gould | -cbm | -ns | -masscomp | \ + -apple | -axis | -knuth | -cray | -microblaze*) + os= + basic_machine=$1 + ;; + -bluegene*) + os=-cnk + ;; + -sim | -cisco | -oki | -wec | -winbond) + os= + basic_machine=$1 + ;; + -scout) + ;; + -wrs) + os=-vxworks + basic_machine=$1 + ;; + -chorusos*) + os=-chorusos + basic_machine=$1 + ;; + -chorusrdb) + os=-chorusrdb + basic_machine=$1 + ;; + -hiux*) + os=-hiuxwe2 + ;; + -sco6) + os=-sco5v6 + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco5) + os=-sco3.2v5 + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco4) + os=-sco3.2v4 + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco3.2.[4-9]*) + os=`echo $os | sed -e 's/sco3.2./sco3.2v/'` + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco3.2v[4-9]*) + # Don't forget version if it is 3.2v4 or newer. + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco5v6*) + # Don't forget version if it is 3.2v4 or newer. + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -sco*) + os=-sco3.2v2 + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -udk*) + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -isc) + os=-isc2.2 + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -clix*) + basic_machine=clipper-intergraph + ;; + -isc*) + basic_machine=`echo $1 | sed -e 's/86-.*/86-pc/'` + ;; + -lynx*178) + os=-lynxos178 + ;; + -lynx*5) + os=-lynxos5 + ;; + -lynx*) + os=-lynxos + ;; + -ptx*) + basic_machine=`echo $1 | sed -e 's/86-.*/86-sequent/'` + ;; + -windowsnt*) + os=`echo $os | sed -e 's/windowsnt/winnt/'` + ;; + -psos*) + os=-psos + ;; + -mint | -mint[0-9]*) + basic_machine=m68k-atari + os=-mint + ;; +esac + +# Decode aliases for certain CPU-COMPANY combinations. +case $basic_machine in + # Recognize the basic CPU types without company name. + # Some are omitted here because they have special meanings below. + 1750a | 580 \ + | a29k \ + | aarch64 | aarch64_be \ + | alpha | alphaev[4-8] | alphaev56 | alphaev6[78] | alphapca5[67] \ + | alpha64 | alpha64ev[4-8] | alpha64ev56 | alpha64ev6[78] | alpha64pca5[67] \ + | am33_2.0 \ + | arc | arceb \ + | arm | arm[bl]e | arme[lb] | armv[2-8] | armv[3-8][lb] | armv7[arm] \ + | avr | avr32 \ + | ba \ + | be32 | be64 \ + | bfin \ + | c4x | c8051 | clipper \ + | d10v | d30v | dlx | dsp16xx \ + | e2k | epiphany \ + | fido | fr30 | frv | ft32 \ + | h8300 | h8500 | hppa | hppa1.[01] | hppa2.0 | hppa2.0[nw] | hppa64 \ + | hexagon \ + | i370 | i860 | i960 | ia16 | ia64 \ + | ip2k | iq2000 \ + | k1om \ + | le32 | le64 \ + | lm32 \ + | m32c | m32r | m32rle | m68000 | m68k | m88k \ + | maxq | mb | microblaze | microblazeel | mcore | mep | metag \ + | mips | mipsbe | mipseb | mipsel | mipsle \ + | mips16 \ + | mips64 | mips64el \ + | mips64octeon | mips64octeonel \ + | mips64orion | mips64orionel \ + | mips64r5900 | mips64r5900el \ + | mips64vr | mips64vrel \ + | mips64vr4100 | mips64vr4100el \ + | mips64vr4300 | mips64vr4300el \ + | mips64vr5000 | mips64vr5000el \ + | mips64vr5900 | mips64vr5900el \ + | mipsisa32 | mipsisa32el \ + | mipsisa32r2 | mipsisa32r2el \ + | mipsisa32r6 | mipsisa32r6el \ + | mipsisa64 | mipsisa64el \ + | mipsisa64r2 | mipsisa64r2el \ + | mipsisa64r6 | mipsisa64r6el \ + | mipsisa64sb1 | mipsisa64sb1el \ + | mipsisa64sr71k | mipsisa64sr71kel \ + | mipsr5900 | mipsr5900el \ + | mipstx39 | mipstx39el \ + | mn10200 | mn10300 \ + | moxie \ + | mt \ + | msp430 \ + | nds32 | nds32le | nds32be \ + | nios | nios2 | nios2eb | nios2el \ + | ns16k | ns32k \ + | open8 | or1k | or1knd | or32 \ + | pdp10 | pdp11 | pj | pjl \ + | powerpc | powerpc64 | powerpc64le | powerpcle \ + | pru \ + | pyramid \ + | riscv32 | riscv64 \ + | rl78 | rx \ + | score \ + | sh | sh[1234] | sh[24]a | sh[24]aeb | sh[23]e | sh[234]eb | sheb | shbe | shle | sh[1234]le | sh3ele \ + | sh64 | sh64le \ + | sparc | sparc64 | sparc64b | sparc64v | sparc86x | sparclet | sparclite \ + | sparcv8 | sparcv9 | sparcv9b | sparcv9v \ + | spu \ + | tahoe | tic4x | tic54x | tic55x | tic6x | tic80 | tron \ + | ubicom32 \ + | v850 | v850e | v850e1 | v850e2 | v850es | v850e2v3 \ + | visium \ + | wasm32 \ + | we32k \ + | x86 | xc16x | xstormy16 | xtensa \ + | z8k | z80) + basic_machine=$basic_machine-unknown + ;; + c54x) + basic_machine=tic54x-unknown + ;; + c55x) + basic_machine=tic55x-unknown + ;; + c6x) + basic_machine=tic6x-unknown + ;; + leon|leon[3-9]) + basic_machine=sparc-$basic_machine + ;; + m6811 | m68hc11 | m6812 | m68hc12 | m68hcs12x | nvptx | picochip) + basic_machine=$basic_machine-unknown + os=-none + ;; + m88110 | m680[12346]0 | m683?2 | m68360 | m5200 | v70 | w65 | z8k) + ;; + ms1) + basic_machine=mt-unknown + ;; + + strongarm | thumb | xscale) + basic_machine=arm-unknown + ;; + xgate) + basic_machine=$basic_machine-unknown + os=-none + ;; + xscaleeb) + basic_machine=armeb-unknown + ;; + + xscaleel) + basic_machine=armel-unknown + ;; + + # We use `pc' rather than `unknown' + # because (1) that's what they normally are, and + # (2) the word "unknown" tends to confuse beginning users. + i*86 | x86_64) + basic_machine=$basic_machine-pc + ;; + # Object if more than one company name word. + *-*-*) + echo Invalid configuration \`$1\': machine \`$basic_machine\' not recognized 1>&2 + exit 1 + ;; + # Recognize the basic CPU types with company name. + 580-* \ + | a29k-* \ + | aarch64-* | aarch64_be-* \ + | alpha-* | alphaev[4-8]-* | alphaev56-* | alphaev6[78]-* \ + | alpha64-* | alpha64ev[4-8]-* | alpha64ev56-* | alpha64ev6[78]-* \ + | alphapca5[67]-* | alpha64pca5[67]-* | arc-* | arceb-* \ + | arm-* | armbe-* | armle-* | armeb-* | armv*-* \ + | avr-* | avr32-* \ + | ba-* \ + | be32-* | be64-* \ + | bfin-* | bs2000-* \ + | c[123]* | c30-* | [cjt]90-* | c4x-* \ + | c8051-* | clipper-* | craynv-* | cydra-* \ + | d10v-* | d30v-* | dlx-* \ + | e2k-* | elxsi-* \ + | f30[01]-* | f700-* | fido-* | fr30-* | frv-* | fx80-* \ + | h8300-* | h8500-* \ + | hppa-* | hppa1.[01]-* | hppa2.0-* | hppa2.0[nw]-* | hppa64-* \ + | hexagon-* \ + | i*86-* | i860-* | i960-* | ia16-* | ia64-* \ + | ip2k-* | iq2000-* \ + | k1om-* \ + | le32-* | le64-* \ + | lm32-* \ + | m32c-* | m32r-* | m32rle-* \ + | m68000-* | m680[012346]0-* | m68360-* | m683?2-* | m68k-* \ + | m88110-* | m88k-* | maxq-* | mcore-* | metag-* \ + | microblaze-* | microblazeel-* \ + | mips-* | mipsbe-* | mipseb-* | mipsel-* | mipsle-* \ + | mips16-* \ + | mips64-* | mips64el-* \ + | mips64octeon-* | mips64octeonel-* \ + | mips64orion-* | mips64orionel-* \ + | mips64r5900-* | mips64r5900el-* \ + | mips64vr-* | mips64vrel-* \ + | mips64vr4100-* | mips64vr4100el-* \ + | mips64vr4300-* | mips64vr4300el-* \ + | mips64vr5000-* | mips64vr5000el-* \ + | mips64vr5900-* | mips64vr5900el-* \ + | mipsisa32-* | mipsisa32el-* \ + | mipsisa32r2-* | mipsisa32r2el-* \ + | mipsisa32r6-* | mipsisa32r6el-* \ + | mipsisa64-* | mipsisa64el-* \ + | mipsisa64r2-* | mipsisa64r2el-* \ + | mipsisa64r6-* | mipsisa64r6el-* \ + | mipsisa64sb1-* | mipsisa64sb1el-* \ + | mipsisa64sr71k-* | mipsisa64sr71kel-* \ + | mipsr5900-* | mipsr5900el-* \ + | mipstx39-* | mipstx39el-* \ + | mmix-* \ + | mt-* \ + | msp430-* \ + | nds32-* | nds32le-* | nds32be-* \ + | nios-* | nios2-* | nios2eb-* | nios2el-* \ + | none-* | np1-* | ns16k-* | ns32k-* \ + | open8-* \ + | or1k*-* \ + | orion-* \ + | pdp10-* | pdp11-* | pj-* | pjl-* | pn-* | power-* \ + | powerpc-* | powerpc64-* | powerpc64le-* | powerpcle-* \ + | pru-* \ + | pyramid-* \ + | riscv32-* | riscv64-* \ + | rl78-* | romp-* | rs6000-* | rx-* \ + | sh-* | sh[1234]-* | sh[24]a-* | sh[24]aeb-* | sh[23]e-* | sh[34]eb-* | sheb-* | shbe-* \ + | shle-* | sh[1234]le-* | sh3ele-* | sh64-* | sh64le-* \ + | sparc-* | sparc64-* | sparc64b-* | sparc64v-* | sparc86x-* | sparclet-* \ + | sparclite-* \ + | sparcv8-* | sparcv9-* | sparcv9b-* | sparcv9v-* | sv1-* | sx*-* \ + | tahoe-* \ + | tic30-* | tic4x-* | tic54x-* | tic55x-* | tic6x-* | tic80-* \ + | tile*-* \ + | tron-* \ + | ubicom32-* \ + | v850-* | v850e-* | v850e1-* | v850es-* | v850e2-* | v850e2v3-* \ + | vax-* \ + | visium-* \ + | wasm32-* \ + | we32k-* \ + | x86-* | x86_64-* | xc16x-* | xps100-* \ + | xstormy16-* | xtensa*-* \ + | ymp-* \ + | z8k-* | z80-*) + ;; + # Recognize the basic CPU types without company name, with glob match. + xtensa*) + basic_machine=$basic_machine-unknown + ;; + # Recognize the various machine names and aliases which stand + # for a CPU type and a company and sometimes even an OS. + 386bsd) + basic_machine=i386-unknown + os=-bsd + ;; + 3b1 | 7300 | 7300-att | att-7300 | pc7300 | safari | unixpc) + basic_machine=m68000-att + ;; + 3b*) + basic_machine=we32k-att + ;; + a29khif) + basic_machine=a29k-amd + os=-udi + ;; + abacus) + basic_machine=abacus-unknown + ;; + adobe68k) + basic_machine=m68010-adobe + os=-scout + ;; + alliant | fx80) + basic_machine=fx80-alliant + ;; + altos | altos3068) + basic_machine=m68k-altos + ;; + am29k) + basic_machine=a29k-none + os=-bsd + ;; + amd64) + basic_machine=x86_64-pc + ;; + amd64-*) + basic_machine=x86_64-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + amdahl) + basic_machine=580-amdahl + os=-sysv + ;; + amiga | amiga-*) + basic_machine=m68k-unknown + ;; + amigaos | amigados) + basic_machine=m68k-unknown + os=-amigaos + ;; + amigaunix | amix) + basic_machine=m68k-unknown + os=-sysv4 + ;; + apollo68) + basic_machine=m68k-apollo + os=-sysv + ;; + apollo68bsd) + basic_machine=m68k-apollo + os=-bsd + ;; + aros) + basic_machine=i386-pc + os=-aros + ;; + asmjs) + basic_machine=asmjs-unknown + ;; + aux) + basic_machine=m68k-apple + os=-aux + ;; + balance) + basic_machine=ns32k-sequent + os=-dynix + ;; + blackfin) + basic_machine=bfin-unknown + os=-linux + ;; + blackfin-*) + basic_machine=bfin-`echo $basic_machine | sed 's/^[^-]*-//'` + os=-linux + ;; + bluegene*) + basic_machine=powerpc-ibm + os=-cnk + ;; + c54x-*) + basic_machine=tic54x-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + c55x-*) + basic_machine=tic55x-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + c6x-*) + basic_machine=tic6x-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + c90) + basic_machine=c90-cray + os=-unicos + ;; + cegcc) + basic_machine=arm-unknown + os=-cegcc + ;; + convex-c1) + basic_machine=c1-convex + os=-bsd + ;; + convex-c2) + basic_machine=c2-convex + os=-bsd + ;; + convex-c32) + basic_machine=c32-convex + os=-bsd + ;; + convex-c34) + basic_machine=c34-convex + os=-bsd + ;; + convex-c38) + basic_machine=c38-convex + os=-bsd + ;; + cray | j90) + basic_machine=j90-cray + os=-unicos + ;; + craynv) + basic_machine=craynv-cray + os=-unicosmp + ;; + cr16 | cr16-*) + basic_machine=cr16-unknown + os=-elf + ;; + crds | unos) + basic_machine=m68k-crds + ;; + crisv32 | crisv32-* | etraxfs*) + basic_machine=crisv32-axis + ;; + cris | cris-* | etrax*) + basic_machine=cris-axis + ;; + crx) + basic_machine=crx-unknown + os=-elf + ;; + da30 | da30-*) + basic_machine=m68k-da30 + ;; + decstation | decstation-3100 | pmax | pmax-* | pmin | dec3100 | decstatn) + basic_machine=mips-dec + ;; + decsystem10* | dec10*) + basic_machine=pdp10-dec + os=-tops10 + ;; + decsystem20* | dec20*) + basic_machine=pdp10-dec + os=-tops20 + ;; + delta | 3300 | motorola-3300 | motorola-delta \ + | 3300-motorola | delta-motorola) + basic_machine=m68k-motorola + ;; + delta88) + basic_machine=m88k-motorola + os=-sysv3 + ;; + dicos) + basic_machine=i686-pc + os=-dicos + ;; + djgpp) + basic_machine=i586-pc + os=-msdosdjgpp + ;; + dpx20 | dpx20-*) + basic_machine=rs6000-bull + os=-bosx + ;; + dpx2* | dpx2*-bull) + basic_machine=m68k-bull + os=-sysv3 + ;; + e500v[12]) + basic_machine=powerpc-unknown + os=$os"spe" + ;; + e500v[12]-*) + basic_machine=powerpc-`echo $basic_machine | sed 's/^[^-]*-//'` + os=$os"spe" + ;; + ebmon29k) + basic_machine=a29k-amd + os=-ebmon + ;; + elxsi) + basic_machine=elxsi-elxsi + os=-bsd + ;; + encore | umax | mmax) + basic_machine=ns32k-encore + ;; + es1800 | OSE68k | ose68k | ose | OSE) + basic_machine=m68k-ericsson + os=-ose + ;; + fx2800) + basic_machine=i860-alliant + ;; + genix) + basic_machine=ns32k-ns + ;; + gmicro) + basic_machine=tron-gmicro + os=-sysv + ;; + go32) + basic_machine=i386-pc + os=-go32 + ;; + h3050r* | hiux*) + basic_machine=hppa1.1-hitachi + os=-hiuxwe2 + ;; + h8300hms) + basic_machine=h8300-hitachi + os=-hms + ;; + h8300xray) + basic_machine=h8300-hitachi + os=-xray + ;; + h8500hms) + basic_machine=h8500-hitachi + os=-hms + ;; + harris) + basic_machine=m88k-harris + os=-sysv3 + ;; + hp300-*) + basic_machine=m68k-hp + ;; + hp300bsd) + basic_machine=m68k-hp + os=-bsd + ;; + hp300hpux) + basic_machine=m68k-hp + os=-hpux + ;; + hp3k9[0-9][0-9] | hp9[0-9][0-9]) + basic_machine=hppa1.0-hp + ;; + hp9k2[0-9][0-9] | hp9k31[0-9]) + basic_machine=m68000-hp + ;; + hp9k3[2-9][0-9]) + basic_machine=m68k-hp + ;; + hp9k6[0-9][0-9] | hp6[0-9][0-9]) + basic_machine=hppa1.0-hp + ;; + hp9k7[0-79][0-9] | hp7[0-79][0-9]) + basic_machine=hppa1.1-hp + ;; + hp9k78[0-9] | hp78[0-9]) + # FIXME: really hppa2.0-hp + basic_machine=hppa1.1-hp + ;; + hp9k8[67]1 | hp8[67]1 | hp9k80[24] | hp80[24] | hp9k8[78]9 | hp8[78]9 | hp9k893 | hp893) + # FIXME: really hppa2.0-hp + basic_machine=hppa1.1-hp + ;; + hp9k8[0-9][13679] | hp8[0-9][13679]) + basic_machine=hppa1.1-hp + ;; + hp9k8[0-9][0-9] | hp8[0-9][0-9]) + basic_machine=hppa1.0-hp + ;; + hppa-next) + os=-nextstep3 + ;; + hppaosf) + basic_machine=hppa1.1-hp + os=-osf + ;; + hppro) + basic_machine=hppa1.1-hp + os=-proelf + ;; + i370-ibm* | ibm*) + basic_machine=i370-ibm + ;; + i*86v32) + basic_machine=`echo $1 | sed -e 's/86.*/86-pc/'` + os=-sysv32 + ;; + i*86v4*) + basic_machine=`echo $1 | sed -e 's/86.*/86-pc/'` + os=-sysv4 + ;; + i*86v) + basic_machine=`echo $1 | sed -e 's/86.*/86-pc/'` + os=-sysv + ;; + i*86sol2) + basic_machine=`echo $1 | sed -e 's/86.*/86-pc/'` + os=-solaris2 + ;; + i386mach) + basic_machine=i386-mach + os=-mach + ;; + i386-vsta | vsta) + basic_machine=i386-unknown + os=-vsta + ;; + iris | iris4d) + basic_machine=mips-sgi + case $os in + -irix*) + ;; + *) + os=-irix4 + ;; + esac + ;; + isi68 | isi) + basic_machine=m68k-isi + os=-sysv + ;; + leon-*|leon[3-9]-*) + basic_machine=sparc-`echo $basic_machine | sed 's/-.*//'` + ;; + m68knommu) + basic_machine=m68k-unknown + os=-linux + ;; + m68knommu-*) + basic_machine=m68k-`echo $basic_machine | sed 's/^[^-]*-//'` + os=-linux + ;; + m88k-omron*) + basic_machine=m88k-omron + ;; + magnum | m3230) + basic_machine=mips-mips + os=-sysv + ;; + merlin) + basic_machine=ns32k-utek + os=-sysv + ;; + microblaze*) + basic_machine=microblaze-xilinx + ;; + mingw64) + basic_machine=x86_64-pc + os=-mingw64 + ;; + mingw32) + basic_machine=i686-pc + os=-mingw32 + ;; + mingw32ce) + basic_machine=arm-unknown + os=-mingw32ce + ;; + miniframe) + basic_machine=m68000-convergent + ;; + *mint | -mint[0-9]* | *MiNT | *MiNT[0-9]*) + basic_machine=m68k-atari + os=-mint + ;; + mips3*-*) + basic_machine=`echo $basic_machine | sed -e 's/mips3/mips64/'` + ;; + mips3*) + basic_machine=`echo $basic_machine | sed -e 's/mips3/mips64/'`-unknown + ;; + monitor) + basic_machine=m68k-rom68k + os=-coff + ;; + morphos) + basic_machine=powerpc-unknown + os=-morphos + ;; + moxiebox) + basic_machine=moxie-unknown + os=-moxiebox + ;; + msdos) + basic_machine=i386-pc + os=-msdos + ;; + ms1-*) + basic_machine=`echo $basic_machine | sed -e 's/ms1-/mt-/'` + ;; + msys) + basic_machine=i686-pc + os=-msys + ;; + mvs) + basic_machine=i370-ibm + os=-mvs + ;; + nacl) + basic_machine=le32-unknown + os=-nacl + ;; + ncr3000) + basic_machine=i486-ncr + os=-sysv4 + ;; + netbsd386) + basic_machine=i386-unknown + os=-netbsd + ;; + netwinder) + basic_machine=armv4l-rebel + os=-linux + ;; + news | news700 | news800 | news900) + basic_machine=m68k-sony + os=-newsos + ;; + news1000) + basic_machine=m68030-sony + os=-newsos + ;; + news-3600 | risc-news) + basic_machine=mips-sony + os=-newsos + ;; + necv70) + basic_machine=v70-nec + os=-sysv + ;; + next | m*-next ) + basic_machine=m68k-next + case $os in + -nextstep* ) + ;; + -ns2*) + os=-nextstep2 + ;; + *) + os=-nextstep3 + ;; + esac + ;; + nh3000) + basic_machine=m68k-harris + os=-cxux + ;; + nh[45]000) + basic_machine=m88k-harris + os=-cxux + ;; + nindy960) + basic_machine=i960-intel + os=-nindy + ;; + mon960) + basic_machine=i960-intel + os=-mon960 + ;; + nonstopux) + basic_machine=mips-compaq + os=-nonstopux + ;; + np1) + basic_machine=np1-gould + ;; + neo-tandem) + basic_machine=neo-tandem + ;; + nse-tandem) + basic_machine=nse-tandem + ;; + nsr-tandem) + basic_machine=nsr-tandem + ;; + nsx-tandem) + basic_machine=nsx-tandem + ;; + op50n-* | op60c-*) + basic_machine=hppa1.1-oki + os=-proelf + ;; + openrisc | openrisc-*) + basic_machine=or32-unknown + ;; + os400) + basic_machine=powerpc-ibm + os=-os400 + ;; + OSE68000 | ose68000) + basic_machine=m68000-ericsson + os=-ose + ;; + os68k) + basic_machine=m68k-none + os=-os68k + ;; + pa-hitachi) + basic_machine=hppa1.1-hitachi + os=-hiuxwe2 + ;; + paragon) + basic_machine=i860-intel + os=-osf + ;; + parisc) + basic_machine=hppa-unknown + os=-linux + ;; + parisc-*) + basic_machine=hppa-`echo $basic_machine | sed 's/^[^-]*-//'` + os=-linux + ;; + pbd) + basic_machine=sparc-tti + ;; + pbb) + basic_machine=m68k-tti + ;; + pc532 | pc532-*) + basic_machine=ns32k-pc532 + ;; + pc98) + basic_machine=i386-pc + ;; + pc98-*) + basic_machine=i386-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + pentium | p5 | k5 | k6 | nexgen | viac3) + basic_machine=i586-pc + ;; + pentiumpro | p6 | 6x86 | athlon | athlon_*) + basic_machine=i686-pc + ;; + pentiumii | pentium2 | pentiumiii | pentium3) + basic_machine=i686-pc + ;; + pentium4) + basic_machine=i786-pc + ;; + pentium-* | p5-* | k5-* | k6-* | nexgen-* | viac3-*) + basic_machine=i586-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + pentiumpro-* | p6-* | 6x86-* | athlon-*) + basic_machine=i686-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + pentiumii-* | pentium2-* | pentiumiii-* | pentium3-*) + basic_machine=i686-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + pentium4-*) + basic_machine=i786-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + pn) + basic_machine=pn-gould + ;; + power) basic_machine=power-ibm + ;; + ppc | ppcbe) basic_machine=powerpc-unknown + ;; + ppc-* | ppcbe-*) + basic_machine=powerpc-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + ppcle | powerpclittle) + basic_machine=powerpcle-unknown + ;; + ppcle-* | powerpclittle-*) + basic_machine=powerpcle-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + ppc64) basic_machine=powerpc64-unknown + ;; + ppc64-*) basic_machine=powerpc64-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + ppc64le | powerpc64little) + basic_machine=powerpc64le-unknown + ;; + ppc64le-* | powerpc64little-*) + basic_machine=powerpc64le-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + ps2) + basic_machine=i386-ibm + ;; + pw32) + basic_machine=i586-unknown + os=-pw32 + ;; + rdos | rdos64) + basic_machine=x86_64-pc + os=-rdos + ;; + rdos32) + basic_machine=i386-pc + os=-rdos + ;; + rom68k) + basic_machine=m68k-rom68k + os=-coff + ;; + rm[46]00) + basic_machine=mips-siemens + ;; + rtpc | rtpc-*) + basic_machine=romp-ibm + ;; + s390 | s390-*) + basic_machine=s390-ibm + ;; + s390x | s390x-*) + basic_machine=s390x-ibm + ;; + sa29200) + basic_machine=a29k-amd + os=-udi + ;; + sb1) + basic_machine=mipsisa64sb1-unknown + ;; + sb1el) + basic_machine=mipsisa64sb1el-unknown + ;; + sde) + basic_machine=mipsisa32-sde + os=-elf + ;; + sei) + basic_machine=mips-sei + os=-seiux + ;; + sequent) + basic_machine=i386-sequent + ;; + sh) + basic_machine=sh-hitachi + os=-hms + ;; + sh5el) + basic_machine=sh5le-unknown + ;; + sh64) + basic_machine=sh64-unknown + ;; + sparclite-wrs | simso-wrs) + basic_machine=sparclite-wrs + os=-vxworks + ;; + sps7) + basic_machine=m68k-bull + os=-sysv2 + ;; + spur) + basic_machine=spur-unknown + ;; + st2000) + basic_machine=m68k-tandem + ;; + stratus) + basic_machine=i860-stratus + os=-sysv4 + ;; + strongarm-* | thumb-*) + basic_machine=arm-`echo $basic_machine | sed 's/^[^-]*-//'` + ;; + sun2) + basic_machine=m68000-sun + ;; + sun2os3) + basic_machine=m68000-sun + os=-sunos3 + ;; + sun2os4) + basic_machine=m68000-sun + os=-sunos4 + ;; + sun3os3) + basic_machine=m68k-sun + os=-sunos3 + ;; + sun3os4) + basic_machine=m68k-sun + os=-sunos4 + ;; + sun4os3) + basic_machine=sparc-sun + os=-sunos3 + ;; + sun4os4) + basic_machine=sparc-sun + os=-sunos4 + ;; + sun4sol2) + basic_machine=sparc-sun + os=-solaris2 + ;; + sun3 | sun3-*) + basic_machine=m68k-sun + ;; + sun4) + basic_machine=sparc-sun + ;; + sun386 | sun386i | roadrunner) + basic_machine=i386-sun + ;; + sv1) + basic_machine=sv1-cray + os=-unicos + ;; + symmetry) + basic_machine=i386-sequent + os=-dynix + ;; + t3e) + basic_machine=alphaev5-cray + os=-unicos + ;; + t90) + basic_machine=t90-cray + os=-unicos + ;; + tile*) + basic_machine=$basic_machine-unknown + os=-linux-gnu + ;; + tx39) + basic_machine=mipstx39-unknown + ;; + tx39el) + basic_machine=mipstx39el-unknown + ;; + toad1) + basic_machine=pdp10-xkl + os=-tops20 + ;; + tower | tower-32) + basic_machine=m68k-ncr + ;; + tpf) + basic_machine=s390x-ibm + os=-tpf + ;; + udi29k) + basic_machine=a29k-amd + os=-udi + ;; + ultra3) + basic_machine=a29k-nyu + os=-sym1 + ;; + v810 | necv810) + basic_machine=v810-nec + os=-none + ;; + vaxv) + basic_machine=vax-dec + os=-sysv + ;; + vms) + basic_machine=vax-dec + os=-vms + ;; + vpp*|vx|vx-*) + basic_machine=f301-fujitsu + ;; + vxworks960) + basic_machine=i960-wrs + os=-vxworks + ;; + vxworks68) + basic_machine=m68k-wrs + os=-vxworks + ;; + vxworks29k) + basic_machine=a29k-wrs + os=-vxworks + ;; + wasm32) + basic_machine=wasm32-unknown + ;; + w65*) + basic_machine=w65-wdc + os=-none + ;; + w89k-*) + basic_machine=hppa1.1-winbond + os=-proelf + ;; + xbox) + basic_machine=i686-pc + os=-mingw32 + ;; + xps | xps100) + basic_machine=xps100-honeywell + ;; + xscale-* | xscalee[bl]-*) + basic_machine=`echo $basic_machine | sed 's/^xscale/arm/'` + ;; + ymp) + basic_machine=ymp-cray + os=-unicos + ;; + z8k-*-coff) + basic_machine=z8k-unknown + os=-sim + ;; + z80-*-coff) + basic_machine=z80-unknown + os=-sim + ;; + none) + basic_machine=none-none + os=-none + ;; + +# Here we handle the default manufacturer of certain CPU types. It is in +# some cases the only manufacturer, in others, it is the most popular. + w89k) + basic_machine=hppa1.1-winbond + ;; + op50n) + basic_machine=hppa1.1-oki + ;; + op60c) + basic_machine=hppa1.1-oki + ;; + romp) + basic_machine=romp-ibm + ;; + mmix) + basic_machine=mmix-knuth + ;; + rs6000) + basic_machine=rs6000-ibm + ;; + vax) + basic_machine=vax-dec + ;; + pdp10) + # there are many clones, so DEC is not a safe bet + basic_machine=pdp10-unknown + ;; + pdp11) + basic_machine=pdp11-dec + ;; + we32k) + basic_machine=we32k-att + ;; + sh[1234] | sh[24]a | sh[24]aeb | sh[34]eb | sh[1234]le | sh[23]ele) + basic_machine=sh-unknown + ;; + sparc | sparcv8 | sparcv9 | sparcv9b | sparcv9v) + basic_machine=sparc-sun + ;; + cydra) + basic_machine=cydra-cydrome + ;; + orion) + basic_machine=orion-highlevel + ;; + orion105) + basic_machine=clipper-highlevel + ;; + mac | mpw | mac-mpw) + basic_machine=m68k-apple + ;; + pmac | pmac-mpw) + basic_machine=powerpc-apple + ;; + *-unknown) + # Make sure to match an already-canonicalized machine name. + ;; + *) + echo Invalid configuration \`$1\': machine \`$basic_machine\' not recognized 1>&2 + exit 1 + ;; +esac + +# Here we canonicalize certain aliases for manufacturers. +case $basic_machine in + *-digital*) + basic_machine=`echo $basic_machine | sed 's/digital.*/dec/'` + ;; + *-commodore*) + basic_machine=`echo $basic_machine | sed 's/commodore.*/cbm/'` + ;; + *) + ;; +esac + +# Decode manufacturer-specific aliases for certain operating systems. + +if [ x"$os" != x"" ] +then +case $os in + # First match some system type aliases + # that might get confused with valid system types. + # -solaris* is a basic system type, with this one exception. + -auroraux) + os=-auroraux + ;; + -solaris1 | -solaris1.*) + os=`echo $os | sed -e 's|solaris1|sunos4|'` + ;; + -solaris) + os=-solaris2 + ;; + -svr4*) + os=-sysv4 + ;; + -unixware*) + os=-sysv4.2uw + ;; + -gnu/linux*) + os=`echo $os | sed -e 's|gnu/linux|linux-gnu|'` + ;; + # First accept the basic system types. + # The portable systems comes first. + # Each alternative MUST END IN A *, to match a version number. + # -sysv* is not here because it comes later, after sysvr4. + -gnu* | -bsd* | -mach* | -minix* | -genix* | -ultrix* | -irix* \ + | -*vms* | -sco* | -esix* | -isc* | -aix* | -cnk* | -sunos | -sunos[34]*\ + | -hpux* | -unos* | -osf* | -luna* | -dgux* | -auroraux* | -solaris* \ + | -sym* | -kopensolaris* | -plan9* \ + | -amigaos* | -amigados* | -msdos* | -newsos* | -unicos* | -aof* \ + | -aos* | -aros* | -cloudabi* | -sortix* \ + | -nindy* | -vxsim* | -vxworks* | -ebmon* | -hms* | -mvs* \ + | -clix* | -riscos* | -uniplus* | -iris* | -rtu* | -xenix* \ + | -hiux* | -386bsd* | -knetbsd* | -mirbsd* | -netbsd* \ + | -bitrig* | -openbsd* | -solidbsd* | -libertybsd* \ + | -ekkobsd* | -kfreebsd* | -freebsd* | -riscix* | -lynxos* \ + | -bosx* | -nextstep* | -cxux* | -aout* | -elf* | -oabi* \ + | -ptx* | -coff* | -ecoff* | -winnt* | -domain* | -vsta* \ + | -udi* | -eabi* | -lites* | -ieee* | -go32* | -aux* \ + | -chorusos* | -chorusrdb* | -cegcc* | -glidix* \ + | -cygwin* | -msys* | -pe* | -psos* | -moss* | -proelf* | -rtems* \ + | -midipix* | -mingw32* | -mingw64* | -linux-gnu* | -linux-android* \ + | -linux-newlib* | -linux-musl* | -linux-uclibc* \ + | -uxpv* | -beos* | -mpeix* | -udk* | -moxiebox* \ + | -interix* | -uwin* | -mks* | -rhapsody* | -darwin* | -opened* \ + | -openstep* | -oskit* | -conix* | -pw32* | -nonstopux* \ + | -storm-chaos* | -tops10* | -tenex* | -tops20* | -its* \ + | -os2* | -vos* | -palmos* | -uclinux* | -nucleus* \ + | -morphos* | -superux* | -rtmk* | -rtmk-nova* | -windiss* \ + | -powermax* | -dnix* | -nx6 | -nx7 | -sei* | -dragonfly* \ + | -skyos* | -haiku* | -rdos* | -toppers* | -drops* | -es* \ + | -onefs* | -tirtos* | -phoenix* | -fuchsia* | -redox*) + # Remember, each alternative MUST END IN *, to match a version number. + ;; + -qnx*) + case $basic_machine in + x86-* | i*86-*) + ;; + *) + os=-nto$os + ;; + esac + ;; + -nto-qnx*) + ;; + -nto*) + os=`echo $os | sed -e 's|nto|nto-qnx|'` + ;; + -sim | -es1800* | -hms* | -xray | -os68k* | -none* | -v88r* \ + | -windows* | -osx | -abug | -netware* | -os9* | -beos* | -haiku* \ + | -macos* | -mpw* | -magic* | -mmixware* | -mon960* | -lnews*) + ;; + -mac*) + os=`echo $os | sed -e 's|mac|macos|'` + ;; + -linux-dietlibc) + os=-linux-dietlibc + ;; + -linux*) + os=`echo $os | sed -e 's|linux|linux-gnu|'` + ;; + -sunos5*) + os=`echo $os | sed -e 's|sunos5|solaris2|'` + ;; + -sunos6*) + os=`echo $os | sed -e 's|sunos6|solaris3|'` + ;; + -opened*) + os=-openedition + ;; + -os400*) + os=-os400 + ;; + -wince*) + os=-wince + ;; + -osfrose*) + os=-osfrose + ;; + -osf*) + os=-osf + ;; + -utek*) + os=-bsd + ;; + -dynix*) + os=-bsd + ;; + -acis*) + os=-aos + ;; + -atheos*) + os=-atheos + ;; + -syllable*) + os=-syllable + ;; + -386bsd) + os=-bsd + ;; + -ctix* | -uts*) + os=-sysv + ;; + -nova*) + os=-rtmk-nova + ;; + -ns2 ) + os=-nextstep2 + ;; + -nsk*) + os=-nsk + ;; + # Preserve the version number of sinix5. + -sinix5.*) + os=`echo $os | sed -e 's|sinix|sysv|'` + ;; + -sinix*) + os=-sysv4 + ;; + -tpf*) + os=-tpf + ;; + -triton*) + os=-sysv3 + ;; + -oss*) + os=-sysv3 + ;; + -svr4) + os=-sysv4 + ;; + -svr3) + os=-sysv3 + ;; + -sysvr4) + os=-sysv4 + ;; + # This must come after -sysvr4. + -sysv*) + ;; + -ose*) + os=-ose + ;; + -es1800*) + os=-ose + ;; + -xenix) + os=-xenix + ;; + -*mint | -mint[0-9]* | -*MiNT | -MiNT[0-9]*) + os=-mint + ;; + -aros*) + os=-aros + ;; + -zvmoe) + os=-zvmoe + ;; + -dicos*) + os=-dicos + ;; + -nacl*) + ;; + -ios) + ;; + -none) + ;; + *) + # Get rid of the `-' at the beginning of $os. + os=`echo $os | sed 's/[^-]*-//'` + echo Invalid configuration \`$1\': system \`$os\' not recognized 1>&2 + exit 1 + ;; +esac +else + +# Here we handle the default operating systems that come with various machines. +# The value should be what the vendor currently ships out the door with their +# machine or put another way, the most popular os provided with the machine. + +# Note that if you're going to try to match "-MANUFACTURER" here (say, +# "-sun"), then you have to tell the case statement up towards the top +# that MANUFACTURER isn't an operating system. Otherwise, code above +# will signal an error saying that MANUFACTURER isn't an operating +# system, and we'll never get to this point. + +case $basic_machine in + score-*) + os=-elf + ;; + spu-*) + os=-elf + ;; + *-acorn) + os=-riscix1.2 + ;; + arm*-rebel) + os=-linux + ;; + arm*-semi) + os=-aout + ;; + c4x-* | tic4x-*) + os=-coff + ;; + c8051-*) + os=-elf + ;; + hexagon-*) + os=-elf + ;; + tic54x-*) + os=-coff + ;; + tic55x-*) + os=-coff + ;; + tic6x-*) + os=-coff + ;; + # This must come before the *-dec entry. + pdp10-*) + os=-tops20 + ;; + pdp11-*) + os=-none + ;; + *-dec | vax-*) + os=-ultrix4.2 + ;; + m68*-apollo) + os=-domain + ;; + i386-sun) + os=-sunos4.0.2 + ;; + m68000-sun) + os=-sunos3 + ;; + m68*-cisco) + os=-aout + ;; + mep-*) + os=-elf + ;; + mips*-cisco) + os=-elf + ;; + mips*-*) + os=-elf + ;; + or32-*) + os=-coff + ;; + *-tti) # must be before sparc entry or we get the wrong os. + os=-sysv3 + ;; + sparc-* | *-sun) + os=-sunos4.1.1 + ;; + pru-*) + os=-elf + ;; + *-be) + os=-beos + ;; + *-haiku) + os=-haiku + ;; + *-ibm) + os=-aix + ;; + *-knuth) + os=-mmixware + ;; + *-wec) + os=-proelf + ;; + *-winbond) + os=-proelf + ;; + *-oki) + os=-proelf + ;; + *-hp) + os=-hpux + ;; + *-hitachi) + os=-hiux + ;; + i860-* | *-att | *-ncr | *-altos | *-motorola | *-convergent) + os=-sysv + ;; + *-cbm) + os=-amigaos + ;; + *-dg) + os=-dgux + ;; + *-dolphin) + os=-sysv3 + ;; + m68k-ccur) + os=-rtu + ;; + m88k-omron*) + os=-luna + ;; + *-next ) + os=-nextstep + ;; + *-sequent) + os=-ptx + ;; + *-crds) + os=-unos + ;; + *-ns) + os=-genix + ;; + i370-*) + os=-mvs + ;; + *-next) + os=-nextstep3 + ;; + *-gould) + os=-sysv + ;; + *-highlevel) + os=-bsd + ;; + *-encore) + os=-bsd + ;; + *-sgi) + os=-irix + ;; + *-siemens) + os=-sysv4 + ;; + *-masscomp) + os=-rtu + ;; + f30[01]-fujitsu | f700-fujitsu) + os=-uxpv + ;; + *-rom68k) + os=-coff + ;; + *-*bug) + os=-coff + ;; + *-apple) + os=-macos + ;; + *-atari*) + os=-mint + ;; + *) + os=-none + ;; +esac +fi + +# Here we handle the case where we know the os, and the CPU type, but not the +# manufacturer. We pick the logical manufacturer. +vendor=unknown +case $basic_machine in + *-unknown) + case $os in + -riscix*) + vendor=acorn + ;; + -sunos*) + vendor=sun + ;; + -cnk*|-aix*) + vendor=ibm + ;; + -beos*) + vendor=be + ;; + -hpux*) + vendor=hp + ;; + -mpeix*) + vendor=hp + ;; + -hiux*) + vendor=hitachi + ;; + -unos*) + vendor=crds + ;; + -dgux*) + vendor=dg + ;; + -luna*) + vendor=omron + ;; + -genix*) + vendor=ns + ;; + -mvs* | -opened*) + vendor=ibm + ;; + -os400*) + vendor=ibm + ;; + -ptx*) + vendor=sequent + ;; + -tpf*) + vendor=ibm + ;; + -vxsim* | -vxworks* | -windiss*) + vendor=wrs + ;; + -aux*) + vendor=apple + ;; + -hms*) + vendor=hitachi + ;; + -mpw* | -macos*) + vendor=apple + ;; + -*mint | -mint[0-9]* | -*MiNT | -MiNT[0-9]*) + vendor=atari + ;; + -vos*) + vendor=stratus + ;; + esac + basic_machine=`echo $basic_machine | sed "s/unknown/$vendor/"` + ;; +esac + +echo $basic_machine$os +exit + +# Local variables: +# eval: (add-hook 'write-file-hooks 'time-stamp) +# time-stamp-start: "timestamp='" +# time-stamp-format: "%:y-%02m-%02d" +# time-stamp-end: "'" +# End: diff --git a/core/src/index/thirdparty/faiss/build-aux/install-sh b/core/src/index/thirdparty/faiss/build-aux/install-sh new file mode 100755 index 0000000000..0360b79e7d --- /dev/null +++ b/core/src/index/thirdparty/faiss/build-aux/install-sh @@ -0,0 +1,501 @@ +#!/bin/sh +# install - install a program, script, or datafile + +scriptversion=2016-01-11.22; # UTC + +# This originates from X11R5 (mit/util/scripts/install.sh), which was +# later released in X11R6 (xc/config/util/install.sh) with the +# following copyright and license. +# +# Copyright (C) 1994 X Consortium +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# X CONSORTIUM BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN +# AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNEC- +# TION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# Except as contained in this notice, the name of the X Consortium shall not +# be used in advertising or otherwise to promote the sale, use or other deal- +# ings in this Software without prior written authorization from the X Consor- +# tium. +# +# +# FSF changes to this file are in the public domain. +# +# Calling this script install-sh is preferred over install.sh, to prevent +# 'make' implicit rules from creating a file called install from it +# when there is no Makefile. +# +# This script is compatible with the BSD install script, but was written +# from scratch. + +tab=' ' +nl=' +' +IFS=" $tab$nl" + +# Set DOITPROG to "echo" to test this script. + +doit=${DOITPROG-} +doit_exec=${doit:-exec} + +# Put in absolute file names if you don't have them in your path; +# or use environment vars. + +chgrpprog=${CHGRPPROG-chgrp} +chmodprog=${CHMODPROG-chmod} +chownprog=${CHOWNPROG-chown} +cmpprog=${CMPPROG-cmp} +cpprog=${CPPROG-cp} +mkdirprog=${MKDIRPROG-mkdir} +mvprog=${MVPROG-mv} +rmprog=${RMPROG-rm} +stripprog=${STRIPPROG-strip} + +posix_mkdir= + +# Desired mode of installed file. +mode=0755 + +chgrpcmd= +chmodcmd=$chmodprog +chowncmd= +mvcmd=$mvprog +rmcmd="$rmprog -f" +stripcmd= + +src= +dst= +dir_arg= +dst_arg= + +copy_on_change=false +is_target_a_directory=possibly + +usage="\ +Usage: $0 [OPTION]... [-T] SRCFILE DSTFILE + or: $0 [OPTION]... SRCFILES... DIRECTORY + or: $0 [OPTION]... -t DIRECTORY SRCFILES... + or: $0 [OPTION]... -d DIRECTORIES... + +In the 1st form, copy SRCFILE to DSTFILE. +In the 2nd and 3rd, copy all SRCFILES to DIRECTORY. +In the 4th, create DIRECTORIES. + +Options: + --help display this help and exit. + --version display version info and exit. + + -c (ignored) + -C install only if different (preserve the last data modification time) + -d create directories instead of installing files. + -g GROUP $chgrpprog installed files to GROUP. + -m MODE $chmodprog installed files to MODE. + -o USER $chownprog installed files to USER. + -s $stripprog installed files. + -t DIRECTORY install into DIRECTORY. + -T report an error if DSTFILE is a directory. + +Environment variables override the default commands: + CHGRPPROG CHMODPROG CHOWNPROG CMPPROG CPPROG MKDIRPROG MVPROG + RMPROG STRIPPROG +" + +while test $# -ne 0; do + case $1 in + -c) ;; + + -C) copy_on_change=true;; + + -d) dir_arg=true;; + + -g) chgrpcmd="$chgrpprog $2" + shift;; + + --help) echo "$usage"; exit $?;; + + -m) mode=$2 + case $mode in + *' '* | *"$tab"* | *"$nl"* | *'*'* | *'?'* | *'['*) + echo "$0: invalid mode: $mode" >&2 + exit 1;; + esac + shift;; + + -o) chowncmd="$chownprog $2" + shift;; + + -s) stripcmd=$stripprog;; + + -t) + is_target_a_directory=always + dst_arg=$2 + # Protect names problematic for 'test' and other utilities. + case $dst_arg in + -* | [=\(\)!]) dst_arg=./$dst_arg;; + esac + shift;; + + -T) is_target_a_directory=never;; + + --version) echo "$0 $scriptversion"; exit $?;; + + --) shift + break;; + + -*) echo "$0: invalid option: $1" >&2 + exit 1;; + + *) break;; + esac + shift +done + +# We allow the use of options -d and -T together, by making -d +# take the precedence; this is for compatibility with GNU install. + +if test -n "$dir_arg"; then + if test -n "$dst_arg"; then + echo "$0: target directory not allowed when installing a directory." >&2 + exit 1 + fi +fi + +if test $# -ne 0 && test -z "$dir_arg$dst_arg"; then + # When -d is used, all remaining arguments are directories to create. + # When -t is used, the destination is already specified. + # Otherwise, the last argument is the destination. Remove it from $@. + for arg + do + if test -n "$dst_arg"; then + # $@ is not empty: it contains at least $arg. + set fnord "$@" "$dst_arg" + shift # fnord + fi + shift # arg + dst_arg=$arg + # Protect names problematic for 'test' and other utilities. + case $dst_arg in + -* | [=\(\)!]) dst_arg=./$dst_arg;; + esac + done +fi + +if test $# -eq 0; then + if test -z "$dir_arg"; then + echo "$0: no input file specified." >&2 + exit 1 + fi + # It's OK to call 'install-sh -d' without argument. + # This can happen when creating conditional directories. + exit 0 +fi + +if test -z "$dir_arg"; then + if test $# -gt 1 || test "$is_target_a_directory" = always; then + if test ! -d "$dst_arg"; then + echo "$0: $dst_arg: Is not a directory." >&2 + exit 1 + fi + fi +fi + +if test -z "$dir_arg"; then + do_exit='(exit $ret); exit $ret' + trap "ret=129; $do_exit" 1 + trap "ret=130; $do_exit" 2 + trap "ret=141; $do_exit" 13 + trap "ret=143; $do_exit" 15 + + # Set umask so as not to create temps with too-generous modes. + # However, 'strip' requires both read and write access to temps. + case $mode in + # Optimize common cases. + *644) cp_umask=133;; + *755) cp_umask=22;; + + *[0-7]) + if test -z "$stripcmd"; then + u_plus_rw= + else + u_plus_rw='% 200' + fi + cp_umask=`expr '(' 777 - $mode % 1000 ')' $u_plus_rw`;; + *) + if test -z "$stripcmd"; then + u_plus_rw= + else + u_plus_rw=,u+rw + fi + cp_umask=$mode$u_plus_rw;; + esac +fi + +for src +do + # Protect names problematic for 'test' and other utilities. + case $src in + -* | [=\(\)!]) src=./$src;; + esac + + if test -n "$dir_arg"; then + dst=$src + dstdir=$dst + test -d "$dstdir" + dstdir_status=$? + else + + # Waiting for this to be detected by the "$cpprog $src $dsttmp" command + # might cause directories to be created, which would be especially bad + # if $src (and thus $dsttmp) contains '*'. + if test ! -f "$src" && test ! -d "$src"; then + echo "$0: $src does not exist." >&2 + exit 1 + fi + + if test -z "$dst_arg"; then + echo "$0: no destination specified." >&2 + exit 1 + fi + dst=$dst_arg + + # If destination is a directory, append the input filename; won't work + # if double slashes aren't ignored. + if test -d "$dst"; then + if test "$is_target_a_directory" = never; then + echo "$0: $dst_arg: Is a directory" >&2 + exit 1 + fi + dstdir=$dst + dst=$dstdir/`basename "$src"` + dstdir_status=0 + else + dstdir=`dirname "$dst"` + test -d "$dstdir" + dstdir_status=$? + fi + fi + + obsolete_mkdir_used=false + + if test $dstdir_status != 0; then + case $posix_mkdir in + '') + # Create intermediate dirs using mode 755 as modified by the umask. + # This is like FreeBSD 'install' as of 1997-10-28. + umask=`umask` + case $stripcmd.$umask in + # Optimize common cases. + *[2367][2367]) mkdir_umask=$umask;; + .*0[02][02] | .[02][02] | .[02]) mkdir_umask=22;; + + *[0-7]) + mkdir_umask=`expr $umask + 22 \ + - $umask % 100 % 40 + $umask % 20 \ + - $umask % 10 % 4 + $umask % 2 + `;; + *) mkdir_umask=$umask,go-w;; + esac + + # With -d, create the new directory with the user-specified mode. + # Otherwise, rely on $mkdir_umask. + if test -n "$dir_arg"; then + mkdir_mode=-m$mode + else + mkdir_mode= + fi + + posix_mkdir=false + case $umask in + *[123567][0-7][0-7]) + # POSIX mkdir -p sets u+wx bits regardless of umask, which + # is incompatible with FreeBSD 'install' when (umask & 300) != 0. + ;; + *) + tmpdir=${TMPDIR-/tmp}/ins$RANDOM-$$ + trap 'ret=$?; rmdir "$tmpdir/d" "$tmpdir" 2>/dev/null; exit $ret' 0 + + if (umask $mkdir_umask && + exec $mkdirprog $mkdir_mode -p -- "$tmpdir/d") >/dev/null 2>&1 + then + if test -z "$dir_arg" || { + # Check for POSIX incompatibilities with -m. + # HP-UX 11.23 and IRIX 6.5 mkdir -m -p sets group- or + # other-writable bit of parent directory when it shouldn't. + # FreeBSD 6.1 mkdir -m -p sets mode of existing directory. + ls_ld_tmpdir=`ls -ld "$tmpdir"` + case $ls_ld_tmpdir in + d????-?r-*) different_mode=700;; + d????-?--*) different_mode=755;; + *) false;; + esac && + $mkdirprog -m$different_mode -p -- "$tmpdir" && { + ls_ld_tmpdir_1=`ls -ld "$tmpdir"` + test "$ls_ld_tmpdir" = "$ls_ld_tmpdir_1" + } + } + then posix_mkdir=: + fi + rmdir "$tmpdir/d" "$tmpdir" + else + # Remove any dirs left behind by ancient mkdir implementations. + rmdir ./$mkdir_mode ./-p ./-- 2>/dev/null + fi + trap '' 0;; + esac;; + esac + + if + $posix_mkdir && ( + umask $mkdir_umask && + $doit_exec $mkdirprog $mkdir_mode -p -- "$dstdir" + ) + then : + else + + # The umask is ridiculous, or mkdir does not conform to POSIX, + # or it failed possibly due to a race condition. Create the + # directory the slow way, step by step, checking for races as we go. + + case $dstdir in + /*) prefix='/';; + [-=\(\)!]*) prefix='./';; + *) prefix='';; + esac + + oIFS=$IFS + IFS=/ + set -f + set fnord $dstdir + shift + set +f + IFS=$oIFS + + prefixes= + + for d + do + test X"$d" = X && continue + + prefix=$prefix$d + if test -d "$prefix"; then + prefixes= + else + if $posix_mkdir; then + (umask=$mkdir_umask && + $doit_exec $mkdirprog $mkdir_mode -p -- "$dstdir") && break + # Don't fail if two instances are running concurrently. + test -d "$prefix" || exit 1 + else + case $prefix in + *\'*) qprefix=`echo "$prefix" | sed "s/'/'\\\\\\\\''/g"`;; + *) qprefix=$prefix;; + esac + prefixes="$prefixes '$qprefix'" + fi + fi + prefix=$prefix/ + done + + if test -n "$prefixes"; then + # Don't fail if two instances are running concurrently. + (umask $mkdir_umask && + eval "\$doit_exec \$mkdirprog $prefixes") || + test -d "$dstdir" || exit 1 + obsolete_mkdir_used=true + fi + fi + fi + + if test -n "$dir_arg"; then + { test -z "$chowncmd" || $doit $chowncmd "$dst"; } && + { test -z "$chgrpcmd" || $doit $chgrpcmd "$dst"; } && + { test "$obsolete_mkdir_used$chowncmd$chgrpcmd" = false || + test -z "$chmodcmd" || $doit $chmodcmd $mode "$dst"; } || exit 1 + else + + # Make a couple of temp file names in the proper directory. + dsttmp=$dstdir/_inst.$$_ + rmtmp=$dstdir/_rm.$$_ + + # Trap to clean up those temp files at exit. + trap 'ret=$?; rm -f "$dsttmp" "$rmtmp" && exit $ret' 0 + + # Copy the file name to the temp name. + (umask $cp_umask && $doit_exec $cpprog "$src" "$dsttmp") && + + # and set any options; do chmod last to preserve setuid bits. + # + # If any of these fail, we abort the whole thing. If we want to + # ignore errors from any of these, just make sure not to ignore + # errors from the above "$doit $cpprog $src $dsttmp" command. + # + { test -z "$chowncmd" || $doit $chowncmd "$dsttmp"; } && + { test -z "$chgrpcmd" || $doit $chgrpcmd "$dsttmp"; } && + { test -z "$stripcmd" || $doit $stripcmd "$dsttmp"; } && + { test -z "$chmodcmd" || $doit $chmodcmd $mode "$dsttmp"; } && + + # If -C, don't bother to copy if it wouldn't change the file. + if $copy_on_change && + old=`LC_ALL=C ls -dlL "$dst" 2>/dev/null` && + new=`LC_ALL=C ls -dlL "$dsttmp" 2>/dev/null` && + set -f && + set X $old && old=:$2:$4:$5:$6 && + set X $new && new=:$2:$4:$5:$6 && + set +f && + test "$old" = "$new" && + $cmpprog "$dst" "$dsttmp" >/dev/null 2>&1 + then + rm -f "$dsttmp" + else + # Rename the file to the real destination. + $doit $mvcmd -f "$dsttmp" "$dst" 2>/dev/null || + + # The rename failed, perhaps because mv can't rename something else + # to itself, or perhaps because mv is so ancient that it does not + # support -f. + { + # Now remove or move aside any old file at destination location. + # We try this two ways since rm can't unlink itself on some + # systems and the destination file might be busy for other + # reasons. In this case, the final cleanup might fail but the new + # file should still install successfully. + { + test ! -f "$dst" || + $doit $rmcmd -f "$dst" 2>/dev/null || + { $doit $mvcmd -f "$dst" "$rmtmp" 2>/dev/null && + { $doit $rmcmd -f "$rmtmp" 2>/dev/null; :; } + } || + { echo "$0: cannot unlink or rename $dst" >&2 + (exit 1); exit 1 + } + } && + + # Now rename the file to the real destination. + $doit $mvcmd "$dsttmp" "$dst" + } + fi || exit 1 + + trap '' 0 + fi +done + +# Local variables: +# eval: (add-hook 'write-file-hooks 'time-stamp) +# time-stamp-start: "scriptversion=" +# time-stamp-format: "%:y-%02m-%02d.%02H" +# time-stamp-time-zone: "UTC0" +# time-stamp-end: "; # UTC" +# End: diff --git a/core/src/index/thirdparty/faiss/build.sh b/core/src/index/thirdparty/faiss/build.sh new file mode 100755 index 0000000000..a58a6e6134 --- /dev/null +++ b/core/src/index/thirdparty/faiss/build.sh @@ -0,0 +1,3 @@ +#./configure CPUFLAGS='-mavx -mf16c -msse4 -mpopcnt' CXXFLAGS='-O0 -g -fPIC -m64 -Wno-sign-compare -Wall -Wextra' --prefix=$PWD --with-cuda-arch=-gencode=arch=compute_75,code=sm_75 --with-cuda=/usr/local/cuda +./configure --prefix=$PWD CFLAGS='-g -fPIC' CXXFLAGS='-O3 -g -fPIC -DELPP_THREAD_SAFE -fopenmp -g -fPIC -mf16c -O3 -DNDEBUG' --without-python --with-cuda=/usr/local/cuda --with-cuda-arch='-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75' +make install -j8 diff --git a/core/src/index/thirdparty/faiss/c_api/AutoTune_c.cpp b/core/src/index/thirdparty/faiss/c_api/AutoTune_c.cpp new file mode 100644 index 0000000000..2f412d6aaa --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/AutoTune_c.cpp @@ -0,0 +1,83 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include +#include "AutoTune.h" +#include "AutoTune_c.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::ParameterRange; +using faiss::ParameterSpace; + +const char* faiss_ParameterRange_name(const FaissParameterRange* range) { + return reinterpret_cast(range)->name.c_str(); +} + +void faiss_ParameterRange_values(FaissParameterRange* range, double** p_values, size_t* p_size) { + auto& values = reinterpret_cast(range)->values; + *p_values = values.data(); + *p_size = values.size(); +} + +int faiss_ParameterSpace_new(FaissParameterSpace** space) { + try { + auto new_space = new ParameterSpace(); + *space = reinterpret_cast(new_space); + } CATCH_AND_HANDLE +} + +DEFINE_DESTRUCTOR(ParameterSpace) + +size_t faiss_ParameterSpace_n_combinations(const FaissParameterSpace* space) { + return reinterpret_cast(space)->n_combinations(); +} + +int faiss_ParameterSpace_combination_name(const FaissParameterSpace* space, size_t cno, char* char_buffer, size_t size) { + try { + auto rep = reinterpret_cast(space)->combination_name(cno); + strncpy(char_buffer, rep.c_str(), size); + } CATCH_AND_HANDLE +} + +int faiss_ParameterSpace_set_index_parameters(const FaissParameterSpace* space, FaissIndex* cindex, const char* param_string) { + try { + auto index = reinterpret_cast(cindex); + reinterpret_cast(space)->set_index_parameters(index, param_string); + } CATCH_AND_HANDLE +} + +/// set a combination of parameters on an index +int faiss_ParameterSpace_set_index_parameters_cno(const FaissParameterSpace* space, FaissIndex* cindex, size_t cno) { + try { + auto index = reinterpret_cast(cindex); + reinterpret_cast(space)->set_index_parameters(index, cno); + } CATCH_AND_HANDLE +} + +int faiss_ParameterSpace_set_index_parameter(const FaissParameterSpace* space, FaissIndex* cindex, const char * name, double value) { + try { + auto index = reinterpret_cast(cindex); + reinterpret_cast(space)->set_index_parameter(index, name, value); + } CATCH_AND_HANDLE +} + +void faiss_ParameterSpace_display(const FaissParameterSpace* space) { + reinterpret_cast(space)->display(); +} + +int faiss_ParameterSpace_add_range(FaissParameterSpace* space, const char* name, FaissParameterRange** p_range) { + try { + ParameterRange& range = reinterpret_cast(space)->add_range(name); + if (p_range) { + *p_range = reinterpret_cast(&range); + } + } CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/AutoTune_c.h b/core/src/index/thirdparty/faiss/c_api/AutoTune_c.h new file mode 100644 index 0000000000..d870921c04 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/AutoTune_c.h @@ -0,0 +1,64 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_AUTO_TUNE_C_H +#define FAISS_AUTO_TUNE_C_H + +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// possible values of a parameter, sorted from least to most expensive/accurate +FAISS_DECLARE_CLASS(ParameterRange) + +FAISS_DECLARE_GETTER(ParameterRange, const char*, name) + +/// Getter for the values in the range. The output values are invalidated +/// upon any other modification of the range. +void faiss_ParameterRange_values(FaissParameterRange*, double**, size_t*); + +/** Uses a-priori knowledge on the Faiss indexes to extract tunable parameters. + */ +FAISS_DECLARE_CLASS(ParameterSpace) + +/// Parameter space default constructor +int faiss_ParameterSpace_new(FaissParameterSpace** space); + +/// nb of combinations, = product of values sizes +size_t faiss_ParameterSpace_n_combinations(const FaissParameterSpace*); + +/// get string representation of the combination +/// by writing it to the given character buffer. +/// A buffer size of 1000 ensures that the full name is collected. +int faiss_ParameterSpace_combination_name(const FaissParameterSpace*, size_t, char*, size_t); + +/// set a combination of parameters described by a string +int faiss_ParameterSpace_set_index_parameters(const FaissParameterSpace*, FaissIndex*, const char *); + +/// set a combination of parameters on an index +int faiss_ParameterSpace_set_index_parameters_cno(const FaissParameterSpace*, FaissIndex*, size_t); + +/// set one of the parameters +int faiss_ParameterSpace_set_index_parameter(const FaissParameterSpace*, FaissIndex*, const char *, double); + +/// print a description on stdout +void faiss_ParameterSpace_display(const FaissParameterSpace*); + +/// add a new parameter (or return it if it exists) +int faiss_ParameterSpace_add_range(FaissParameterSpace*, const char*, FaissParameterRange**); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/Clustering_c.cpp b/core/src/index/thirdparty/faiss/c_api/Clustering_c.cpp new file mode 100644 index 0000000000..e4541458c0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/Clustering_c.cpp @@ -0,0 +1,145 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "Clustering_c.h" +#include "Clustering.h" +#include "Index.h" +#include +#include "macros_impl.h" + +extern "C" { + +using faiss::Clustering; +using faiss::ClusteringParameters; +using faiss::Index; +using faiss::ClusteringIterationStats; + +DEFINE_GETTER(Clustering, int, niter) +DEFINE_GETTER(Clustering, int, nredo) +DEFINE_GETTER(Clustering, int, verbose) +DEFINE_GETTER(Clustering, int, spherical) +DEFINE_GETTER(Clustering, int, update_index) +DEFINE_GETTER(Clustering, int, frozen_centroids) + +DEFINE_GETTER(Clustering, int, min_points_per_centroid) +DEFINE_GETTER(Clustering, int, max_points_per_centroid) + +DEFINE_GETTER(Clustering, int, seed) + +/// getter for d +DEFINE_GETTER(Clustering, size_t, d) + +/// getter for k +DEFINE_GETTER(Clustering, size_t, k) + +DEFINE_GETTER(ClusteringIterationStats, float, obj) +DEFINE_GETTER(ClusteringIterationStats, double, time) +DEFINE_GETTER(ClusteringIterationStats, double, time_search) +DEFINE_GETTER(ClusteringIterationStats, double, imbalance_factor) +DEFINE_GETTER(ClusteringIterationStats, int, nsplit) + +void faiss_ClusteringParameters_init(FaissClusteringParameters* params) { + ClusteringParameters d; + params->frozen_centroids = d.frozen_centroids; + params->max_points_per_centroid = d.max_points_per_centroid; + params->min_points_per_centroid = d.min_points_per_centroid; + params->niter = d.niter; + params->nredo = d.nredo; + params->seed = d.seed; + params->spherical = d.spherical; + params->update_index = d.update_index; + params->verbose = d.verbose; +} + +// This conversion is required because the two types are not memory-compatible +inline ClusteringParameters from_faiss_c(const FaissClusteringParameters* params) { + ClusteringParameters o; + o.frozen_centroids = params->frozen_centroids; + o.max_points_per_centroid = params->max_points_per_centroid; + o.min_points_per_centroid = params->min_points_per_centroid; + o.niter = params->niter; + o.nredo = params->nredo; + o.seed = params->seed; + o.spherical = params->spherical; + o.update_index = params->update_index; + o.verbose = params->verbose; + return o; +} + +/// getter for centroids (size = k * d) +void faiss_Clustering_centroids( + FaissClustering* clustering, float** centroids, size_t* size) { + std::vector& v = reinterpret_cast(clustering)->centroids; + if (centroids) { + *centroids = v.data(); + } + if (size) { + *size = v.size(); + } +} + +/// getter for iteration stats +void faiss_Clustering_iteration_stats( + FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size) { + std::vector& v = reinterpret_cast(clustering)->iteration_stats; + if (iteration_stats) { + *iteration_stats = reinterpret_cast(v.data()); + } + if (size) { + *size = v.size(); + } +} + +/// the only mandatory parameters are k and d +int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k) { + try { + Clustering* c = new Clustering(d, k); + *p_clustering = reinterpret_cast(c); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_Clustering_new_with_params( + FaissClustering** p_clustering, int d, int k, const FaissClusteringParameters* cp) { + try { + Clustering* c = new Clustering(d, k, from_faiss_c(cp)); + *p_clustering = reinterpret_cast(c); + return 0; + } CATCH_AND_HANDLE +} + +/// Index is used during the assignment stage +int faiss_Clustering_train( + FaissClustering* clustering, idx_t n, const float* x, FaissIndex* index) { + try { + reinterpret_cast(clustering)->train( + n, x, *reinterpret_cast(index)); + return 0; + } CATCH_AND_HANDLE +} + +void faiss_Clustering_free(FaissClustering* clustering) { + delete reinterpret_cast(clustering); +} + +int faiss_kmeans_clustering (size_t d, size_t n, size_t k, + const float *x, + float *centroids, + float *q_error) { + try { + float out = faiss::kmeans_clustering(d, n, k, x, centroids); + if (q_error) { + *q_error = out; + } + return 0; + } CATCH_AND_HANDLE +} + +} diff --git a/core/src/index/thirdparty/faiss/c_api/Clustering_c.h b/core/src/index/thirdparty/faiss/c_api/Clustering_c.h new file mode 100644 index 0000000000..af82152e60 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/Clustering_c.h @@ -0,0 +1,123 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c -*- + +#ifndef FAISS_CLUSTERING_C_H +#define FAISS_CLUSTERING_C_H + +#include "Index_c.h" +#include "faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Class for the clustering parameters. Can be passed to the + * constructor of the Clustering object. + */ +typedef struct FaissClusteringParameters { + int niter; ///< clustering iterations + int nredo; ///< redo clustering this many times and keep best + + int verbose; ///< (bool) + int spherical; ///< (bool) do we want normalized centroids? + int update_index; ///< (bool) update index after each iteration? + int frozen_centroids; ///< (bool) use the centroids provided as input and do not change them during iterations + + int min_points_per_centroid; ///< otherwise you get a warning + int max_points_per_centroid; ///< to limit size of dataset + + int seed; ///< seed for the random number generator +} FaissClusteringParameters; + + +/// Sets the ClusteringParameters object with reasonable defaults +void faiss_ClusteringParameters_init(FaissClusteringParameters* params); + + +/** clustering based on assignment - centroid update iterations + * + * The clustering is based on an Index object that assigns training + * points to the centroids. Therefore, at each iteration the centroids + * are added to the index. + * + * On output, the centroids table is set to the latest version + * of the centroids and they are also added to the index. If the + * centroids table it is not empty on input, it is also used for + * initialization. + * + * To do several clusterings, just call train() several times on + * different training sets, clearing the centroid table in between. + */ +FAISS_DECLARE_CLASS(Clustering) + +FAISS_DECLARE_GETTER(Clustering, int, niter) +FAISS_DECLARE_GETTER(Clustering, int, nredo) +FAISS_DECLARE_GETTER(Clustering, int, verbose) +FAISS_DECLARE_GETTER(Clustering, int, spherical) +FAISS_DECLARE_GETTER(Clustering, int, update_index) +FAISS_DECLARE_GETTER(Clustering, int, frozen_centroids) + +FAISS_DECLARE_GETTER(Clustering, int, min_points_per_centroid) +FAISS_DECLARE_GETTER(Clustering, int, max_points_per_centroid) + +FAISS_DECLARE_GETTER(Clustering, int, seed) + +/// getter for d +FAISS_DECLARE_GETTER(Clustering, size_t, d) + +/// getter for k +FAISS_DECLARE_GETTER(Clustering, size_t, k) + +FAISS_DECLARE_CLASS(ClusteringIterationStats) +FAISS_DECLARE_GETTER(ClusteringIterationStats, float, obj) +FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time) +FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time_search) +FAISS_DECLARE_GETTER(ClusteringIterationStats, double, imbalance_factor) +FAISS_DECLARE_GETTER(ClusteringIterationStats, int, nsplit) + +/// getter for centroids (size = k * d) +void faiss_Clustering_centroids( + FaissClustering* clustering, float** centroids, size_t* size); + +/// getter for iteration stats +void faiss_Clustering_iteration_stats( + FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size); + +/// the only mandatory parameters are k and d +int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k); + +int faiss_Clustering_new_with_params( + FaissClustering** p_clustering, int d, int k, const FaissClusteringParameters* cp); + +int faiss_Clustering_train( + FaissClustering* clustering, idx_t n, const float* x, FaissIndex* index); + +void faiss_Clustering_free(FaissClustering* clustering); + +/** simplified interface + * + * @param d dimension of the data + * @param n nb of training vectors + * @param k nb of output centroids + * @param x training set (size n * d) + * @param centroids output centroids (size k * d) + * @param q_error final quantization error + * @return error code + */ +int faiss_kmeans_clustering (size_t d, size_t n, size_t k, + const float *x, + float *centroids, + float *q_error); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/INSTALL.md b/core/src/index/thirdparty/faiss/c_api/INSTALL.md new file mode 100644 index 0000000000..b640d7db73 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/INSTALL.md @@ -0,0 +1,100 @@ +Faiss C API +=========== + +Faiss provides a pure C interface, which can subsequently be used either in pure C programs or to produce bindings for programming languages with Foreign Function Interface (FFI) support. Although this is not required for the Python interface, some other programming languages (e.g. Rust and Julia) do not have SWIG support. + +Compilation instructions +------------------------ + +The full contents of the pure C API are in the ["c_api"](c_api/) folder. +Please be sure to follow the instructions on [building the main C++ library](../INSTALL.md#step-1-compiling-the-c-faiss) first. +Then, enter the [c_api](c_api/) directory and run + + `make` + +This builds the dynamic library "faiss_c", containing the full implementation of Faiss and the necessary wrappers for the C interface. It does not depend on libfaiss.a or the C++ standard library. It will also build an example program `bin/example_c`. + +Using the API +------------- + +The C API is composed of: + +- A set of C header files comprising the main Faiss interfaces, converted for use in C. Each file follows the format `«name»_c.h`, where `«name»` is the respective name from the C++ API. For example, the file [Index_c.h](./Index_c.h) file corresponds to the base `Index` API. Functions are declared with the `faiss_` prefix (e.g. `faiss_IndexFlat_new`), whereas new types have the `Faiss` prefix (e.g. `FaissIndex`, `FaissMetricType`, ...). +- A dynamic library, compiled from the sources in the same folder, encloses the implementation of the library and wrapper functions. + +The index factory is available via the `faiss_index_factory` function in `AutoTune_c.h`: + +```c +FaissIndex* index = NULL; +int c = faiss_index_factory(&index, 64, "Flat", METRIC_L2); +if (c) { + // operation failed +} +``` + +Most operations that you would find as member functions are available with the format `faiss_«classname»_«member»`. + +```c +idx_t ntotal = faiss_Index_ntotal(index); +``` + +Since this is C, the index needs to be freed manually in the end: + +```c +faiss_Index_free(index); +``` + +Error handling is done by examining the error code returned by operations with recoverable errors. +The code identifies the type of exception that rose from the implementation. Fetching the +corresponding error message can be done by calling the function `faiss_get_last_error()` from +`error_c.h`. Getter functions and `free` functions do not return an error code. + +```c +int c = faiss_Index_add(index, nb, xb); +if (c) { + printf("%s", faiss_get_last_error()); + exit(-1); +} +``` + +An example is included, which is built automatically for the target `all`. It can also be built separately: + + `make bin/example_c` + +Building with GPU support +------------------------- + +For GPU support, a separate dynamic library in the "c_api/gpu" directory needs to be built. + + `make` + +The "gpufaiss_c" dynamic library contains the GPU and CPU implementations of Faiss, which means that +it can be used in place of "faiss_c". The same library will dynamically link with the CUDA runtime +and cuBLAS. + +Using the GPU with the C API +---------------------------- + +A standard GPU resurces object can be obtained by the name `FaissStandardGpuResources`: + +```c +FaissStandardGpuResources* gpu_res = NULL; +int c = faiss_StandardGpuResources_new(&gpu_res); +if (c) { + printf("%s", faiss_get_last_error()); + exit(-1); +} +``` + +Similarly to the C++ API, a CPU index can be converted to a GPU index: + +```c +FaissIndex* cpu_index = NULL; +int c = faiss_index_factory(&cpu_index, d, "Flat", METRIC_L2); +if (c) { /* ... */ } +FaissGpuIndex* gpu_index = NULL; +c = faiss_index_cpu_to_gpu(gpu_res, 0, cpu_index, &gpu_index); +if (c) { /* ... */ } +``` + +A more complete example is available by the name `bin/example_gpu_c`. diff --git a/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.cpp new file mode 100644 index 0000000000..4b741922e8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.cpp @@ -0,0 +1,140 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "IndexFlat_c.h" +#include "IndexFlat.h" +#include "Index.h" +#include "macros_impl.h" + +extern "C" { + +using faiss::Index; +using faiss::IndexFlat; +using faiss::IndexFlatIP; +using faiss::IndexFlatL2; +using faiss::IndexFlatL2BaseShift; +using faiss::IndexRefineFlat; +using faiss::IndexFlat1D; + +DEFINE_DESTRUCTOR(IndexFlat) +DEFINE_INDEX_DOWNCAST(IndexFlat) + +int faiss_IndexFlat_new(FaissIndexFlat** p_index) { + try { + *p_index = reinterpret_cast(new IndexFlat()); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlat_new_with(FaissIndexFlat** p_index, idx_t d, FaissMetricType metric) { + try { + IndexFlat* index = new IndexFlat(d, static_cast(metric)); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +void faiss_IndexFlat_xb(FaissIndexFlat* index, float** p_xb, size_t* p_size) { + auto& xb = reinterpret_cast(index)->xb; + *p_xb = xb.data(); + if (p_size) { + *p_size = xb.size(); + } +} + +int faiss_IndexFlat_compute_distance_subset( + FaissIndex* index, + idx_t n, + const float *x, + idx_t k, + float *distances, + const idx_t *labels) { + try { + reinterpret_cast(index)->compute_distance_subset( + n, x, k, distances, labels); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlatIP_new(FaissIndexFlatIP** p_index) { + try { + IndexFlatIP* index = new IndexFlatIP(); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlatIP_new_with(FaissIndexFlatIP** p_index, idx_t d) { + try { + IndexFlatIP* index = new IndexFlatIP(d); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlatL2_new(FaissIndexFlatL2** p_index) { + try { + IndexFlatL2* index = new IndexFlatL2(); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlatL2_new_with(FaissIndexFlatL2** p_index, idx_t d) { + try { + IndexFlatL2* index = new IndexFlatL2(d); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlatL2BaseShift_new(FaissIndexFlatL2BaseShift** p_index, idx_t d, size_t nshift, const float *shift) { + try { + IndexFlatL2BaseShift* index = new IndexFlatL2BaseShift(d, nshift, shift); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexRefineFlat_new(FaissIndexRefineFlat** p_index, FaissIndex* base_index) { + try { + IndexRefineFlat* index = new IndexRefineFlat( + reinterpret_cast(base_index)); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_DESTRUCTOR(IndexRefineFlat) + +int faiss_IndexFlat1D_new(FaissIndexFlat1D** p_index) { + try { + IndexFlat1D* index = new IndexFlat1D(); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlat1D_new_with(FaissIndexFlat1D** p_index, int continuous_update) { + try { + IndexFlat1D* index = new IndexFlat1D(static_cast(continuous_update)); + *p_index = reinterpret_cast(index); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_IndexFlat1D_update_permutation(FaissIndexFlat1D* index) { + try { + reinterpret_cast(index)->update_permutation(); + return 0; + } CATCH_AND_HANDLE +} + +} diff --git a/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.h b/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.h new file mode 100644 index 0000000000..072ba7dcf3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexFlat_c.h @@ -0,0 +1,115 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c -*- + +#ifndef FAISS_INDEX_FLAT_C_H +#define FAISS_INDEX_FLAT_C_H + +#include "Index_c.h" +#include "faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// forward declaration +typedef enum FaissMetricType FaissMetricType; + +/** Opaque type for IndexFlat */ +FAISS_DECLARE_CLASS_INHERITED(IndexFlat, Index) + +int faiss_IndexFlat_new(FaissIndexFlat** p_index); + +int faiss_IndexFlat_new_with(FaissIndexFlat** p_index, idx_t d, FaissMetricType metric); + +/** get a pointer to the index's internal data (the `xb` field). The outputs + * become invalid after any data addition or removal operation. + * + * @param index opaque pointer to index object + * @param p_xb output, the pointer to the beginning of `xb`. + * @param p_size output, the current size of `sb` in number of float values. + */ +void faiss_IndexFlat_xb(FaissIndexFlat* index, float** p_xb, size_t* p_size); + +/** attempt a dynamic cast to a flat index, thus checking + * check whether the underlying index type is `IndexFlat`. + * + * @param index opaque pointer to index object + * @return the same pointer if the index is a flat index, NULL otherwise + */ +FAISS_DECLARE_INDEX_DOWNCAST(IndexFlat) + +FAISS_DECLARE_DESTRUCTOR(IndexFlat) + +/** compute distance with a subset of vectors + * + * @param index opaque pointer to index object + * @param x query vectors, size n * d + * @param labels indices of the vectors that should be compared + * for each query vector, size n * k + * @param distances + * corresponding output distances, size n * k + */ +int faiss_IndexFlat_compute_distance_subset( + FaissIndex *index, + idx_t n, + const float *x, + idx_t k, + float *distances, + const idx_t *labels); + +/** Opaque type for IndexFlatIP */ +FAISS_DECLARE_CLASS_INHERITED(IndexFlatIP, Index) + +int faiss_IndexFlatIP_new(FaissIndexFlatIP** p_index); + +int faiss_IndexFlatIP_new_with(FaissIndexFlatIP** p_index, idx_t d); + +/** Opaque type for IndexFlatL2 */ +FAISS_DECLARE_CLASS_INHERITED(IndexFlatL2, Index) + +int faiss_IndexFlatL2_new(FaissIndexFlatL2** p_index); + +int faiss_IndexFlatL2_new_with(FaissIndexFlatL2** p_index, idx_t d); + +/** Opaque type for IndexFlatL2BaseShift + * + * same as an IndexFlatL2 but a value is subtracted from each distance + */ +FAISS_DECLARE_CLASS_INHERITED(IndexFlatL2BaseShift, Index) + +int faiss_IndexFlatL2BaseShift_new(FaissIndexFlatL2BaseShift** p_index, idx_t d, size_t nshift, const float *shift); + +/** Opaque type for IndexRefineFlat + * + * Index that queries in a base_index (a fast one) and refines the + * results with an exact search, hopefully improving the results. + */ +FAISS_DECLARE_CLASS_INHERITED(IndexRefineFlat, Index) + +int faiss_IndexRefineFlat_new(FaissIndexRefineFlat** p_index, FaissIndex* base_index); + +FAISS_DECLARE_DESTRUCTOR(IndexRefineFlat) + +/** Opaque type for IndexFlat1D + * + * optimized version for 1D "vectors" + */ +FAISS_DECLARE_CLASS_INHERITED(IndexFlat1D, Index) + +int faiss_IndexFlat1D_new(FaissIndexFlat1D** p_index); +int faiss_IndexFlat1D_new_with(FaissIndexFlat1D** p_index, int continuous_update); + +int faiss_IndexFlat1D_update_permutation(FaissIndexFlat1D* index); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.cpp new file mode 100644 index 0000000000..410e39a6c5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.cpp @@ -0,0 +1,64 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "Index_c.h" +#include "Clustering_c.h" +#include "IndexIVFFlat_c.h" +#include "IndexIVFFlat.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::IndexIVFFlat; +using faiss::MetricType; + +DEFINE_DESTRUCTOR(IndexIVFFlat) +DEFINE_INDEX_DOWNCAST(IndexIVFFlat) + +int faiss_IndexIVFFlat_new(FaissIndexIVFFlat** p_index) { + try { + *p_index = reinterpret_cast(new IndexIVFFlat()); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVFFlat_new_with(FaissIndexIVFFlat** p_index, + FaissIndex* quantizer, size_t d, size_t nlist) +{ + try { + auto q = reinterpret_cast(quantizer); + *p_index = reinterpret_cast(new IndexIVFFlat(q, d, nlist)); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVFFlat_new_with_metric( + FaissIndexIVFFlat** p_index, FaissIndex* quantizer, size_t d, size_t nlist, + FaissMetricType metric) +{ + try { + auto q = reinterpret_cast(quantizer); + auto m = static_cast(metric); + *p_index = reinterpret_cast(new IndexIVFFlat(q, d, nlist, m)); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVFFlat_add_core(FaissIndexIVFFlat* index, idx_t n, + const float * x, const idx_t *xids, const int64_t *precomputed_idx) +{ + try { + reinterpret_cast(index)->add_core(n, x, xids, precomputed_idx); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVFFlat_update_vectors(FaissIndexIVFFlat* index, int nv, + idx_t *idx, const float *v) +{ + try { + reinterpret_cast(index)->update_vectors(nv, idx, v); + } CATCH_AND_HANDLE +} diff --git a/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.h b/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.h new file mode 100644 index 0000000000..4c5f3ec25b --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexIVFFlat_c.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_INDEX_IVF_FLAT_C_H +#define FAISS_INDEX_IVF_FLAT_C_H + +#include "faiss_c.h" +#include "Index_c.h" +#include "Clustering_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Inverted file with stored vectors. Here the inverted file + * pre-selects the vectors to be searched, but they are not otherwise + * encoded, the code array just contains the raw float entries. + */ +FAISS_DECLARE_CLASS(IndexIVFFlat) +FAISS_DECLARE_DESTRUCTOR(IndexIVFFlat) +FAISS_DECLARE_INDEX_DOWNCAST(IndexIVFFlat) + +int faiss_IndexIVFFlat_new(FaissIndexIVFFlat** p_index); + +int faiss_IndexIVFFlat_new_with(FaissIndexIVFFlat** p_index, + FaissIndex* quantizer, size_t d, size_t nlist); + +int faiss_IndexIVFFlat_new_with_metric( + FaissIndexIVFFlat** p_index, FaissIndex* quantizer, size_t d, size_t nlist, + FaissMetricType metric); + +int faiss_IndexIVFFlat_add_core(FaissIndexIVFFlat* index, idx_t n, + const float * x, const idx_t *xids, const int64_t *precomputed_idx); + +/** Update a subset of vectors. + * + * The index must have a direct_map + * + * @param nv nb of vectors to update + * @param idx vector indices to update, size nv + * @param v vectors of new values, size nv*d + */ +int faiss_IndexIVFFlat_update_vectors(FaissIndexIVFFlat* index, int nv, + idx_t *idx, const float *v); + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.cpp new file mode 100644 index 0000000000..4f7983723b --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.cpp @@ -0,0 +1,99 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "Index_c.h" +#include "Clustering_c.h" +#include "IndexIVF_c.h" +#include "IndexIVF.h" +#include "macros_impl.h" + +using faiss::IndexIVF; +using faiss::IndexIVFStats; + +DEFINE_DESTRUCTOR(IndexIVF) +DEFINE_INDEX_DOWNCAST(IndexIVF) + +/// number of possible key values +DEFINE_GETTER(IndexIVF, size_t, nlist) +/// number of probes at query time +DEFINE_GETTER(IndexIVF, size_t, nprobe) +/// quantizer that maps vectors to inverted lists +DEFINE_GETTER_PERMISSIVE(IndexIVF, FaissIndex*, quantizer) + +/** + * = 0: use the quantizer as index in a kmeans training + * = 1: just pass on the training set to the train() of the quantizer + * = 2: kmeans training on a flat index + add the centroids to the quantizer + */ +DEFINE_GETTER(IndexIVF, char, quantizer_trains_alone) + +/// whether object owns the quantizer +DEFINE_GETTER(IndexIVF, int, own_fields) + +using faiss::IndexIVF; + +int faiss_IndexIVF_merge_from( + FaissIndexIVF* index, FaissIndexIVF* other, idx_t add_id) { + try { + reinterpret_cast(index)->merge_from( + *reinterpret_cast(other), add_id); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVF_copy_subset_to( + const FaissIndexIVF* index, FaissIndexIVF* other, int subset_type, idx_t a1, + idx_t a2) { + try { + reinterpret_cast(index)->copy_subset_to( + *reinterpret_cast(other), subset_type, a1, a2); + } CATCH_AND_HANDLE +} + +int faiss_IndexIVF_search_preassigned (const FaissIndexIVF* index, + idx_t n, const float *x, idx_t k, const idx_t *assign, + const float *centroid_dis, float *distances, idx_t *labels, + int store_pairs) { + try { + reinterpret_cast(index)->search_preassigned( + n, x, k, assign, centroid_dis, distances, labels, store_pairs); + } CATCH_AND_HANDLE +} + +size_t faiss_IndexIVF_get_list_size(const FaissIndexIVF* index, size_t list_no) { + return reinterpret_cast(index)->get_list_size(list_no); +} + +int faiss_IndexIVF_make_direct_map(FaissIndexIVF* index, + int new_maintain_direct_map) { + try { + reinterpret_cast(index)->make_direct_map( + static_cast(new_maintain_direct_map)); + } CATCH_AND_HANDLE +} + +double faiss_IndexIVF_imbalance_factor (const FaissIndexIVF* index) { + return reinterpret_cast(index)->invlists->imbalance_factor(); +} + +/// display some stats about the inverted lists +void faiss_IndexIVF_print_stats (const FaissIndexIVF* index) { + reinterpret_cast(index)->invlists->print_stats(); +} + +/// get inverted lists ids +void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist) { + const idx_t* list = reinterpret_cast(index)->invlists->get_ids(list_no); + size_t list_size = reinterpret_cast(index)->get_list_size(list_no); + memcpy(invlist, list, list_size*sizeof(idx_t)); +} + +void faiss_IndexIVFStats_reset(FaissIndexIVFStats* stats) { + reinterpret_cast(stats)->reset(); +} diff --git a/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.h b/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.h new file mode 100644 index 0000000000..5aa907c8c2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexIVF_c.h @@ -0,0 +1,142 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_INDEX_IVF_C_H +#define FAISS_INDEX_IVF_C_H + +#include "faiss_c.h" +#include "Index_c.h" +#include "Clustering_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Index based on a inverted file (IVF) + * + * In the inverted file, the quantizer (an Index instance) provides a + * quantization index for each vector to be added. The quantization + * index maps to a list (aka inverted list or posting list), where the + * id of the vector is then stored. + * + * At search time, the vector to be searched is also quantized, and + * only the list corresponding to the quantization index is + * searched. This speeds up the search by making it + * non-exhaustive. This can be relaxed using multi-probe search: a few + * (nprobe) quantization indices are selected and several inverted + * lists are visited. + * + * Sub-classes implement a post-filtering of the index that refines + * the distance estimation from the query to databse vectors. + */ +FAISS_DECLARE_CLASS_INHERITED(IndexIVF, Index) +FAISS_DECLARE_DESTRUCTOR(IndexIVF) +FAISS_DECLARE_INDEX_DOWNCAST(IndexIVF) + +/// number of possible key values +FAISS_DECLARE_GETTER(IndexIVF, size_t, nlist) +/// number of probes at query time +FAISS_DECLARE_GETTER(IndexIVF, size_t, nprobe) +/// quantizer that maps vectors to inverted lists +FAISS_DECLARE_GETTER(IndexIVF, FaissIndex*, quantizer) +/** + * = 0: use the quantizer as index in a kmeans training + * = 1: just pass on the training set to the train() of the quantizer + * = 2: kmeans training on a flat index + add the centroids to the quantizer + */ +FAISS_DECLARE_GETTER(IndexIVF, char, quantizer_trains_alone) + +/// whether object owns the quantizer +FAISS_DECLARE_GETTER(IndexIVF, int, own_fields) + +/** moves the entries from another dataset to self. On output, + * other is empty. add_id is added to all moved ids (for + * sequential ids, this would be this->ntotal */ +int faiss_IndexIVF_merge_from( + FaissIndexIVF* index, FaissIndexIVF* other, idx_t add_id); + +/** copy a subset of the entries index to the other index + * + * if subset_type == 0: copies ids in [a1, a2) + * if subset_type == 1: copies ids if id % a1 == a2 + * if subset_type == 2: copies inverted lists such that a1 + * elements are left before and a2 elements are after + */ +int faiss_IndexIVF_copy_subset_to( + const FaissIndexIVF* index, FaissIndexIVF* other, int subset_type, idx_t a1, + idx_t a2); + +/** search a set of vectors, that are pre-quantized by the IVF + * quantizer. Fill in the corresponding heaps with the query + * results. search() calls this. + * + * @param n nb of vectors to query + * @param x query vectors, size nx * d + * @param assign coarse quantization indices, size nx * nprobe + * @param centroid_dis + * distances to coarse centroids, size nx * nprobe + * @param distance + * output distances, size n * k + * @param labels output labels, size n * k + * @param store_pairs store inv list index + inv list offset + * instead in upper/lower 32 bit of result, + * instead of ids (used for reranking). + */ +int faiss_IndexIVF_search_preassigned (const FaissIndexIVF* index, + idx_t n, const float *x, idx_t k, const idx_t *assign, + const float *centroid_dis, float *distances, idx_t *labels, + int store_pairs); + +size_t faiss_IndexIVF_get_list_size(const FaissIndexIVF* index, + size_t list_no); + +/** intialize a direct map + * + * @param new_maintain_direct_map if true, create a direct map, + * else clear it + */ +int faiss_IndexIVF_make_direct_map(FaissIndexIVF* index, + int new_maintain_direct_map); + +/** Check the inverted lists' imbalance factor. + * + * 1= perfectly balanced, >1: imbalanced + */ +double faiss_IndexIVF_imbalance_factor (const FaissIndexIVF* index); + +/// display some stats about the inverted lists of the index +void faiss_IndexIVF_print_stats (const FaissIndexIVF* index); + +/// Get the IDs in an inverted list. IDs are written to `invlist`, which must be large enough +//// to accommodate the full list. +/// +/// @param list_no the list ID +/// @param invlist output pointer to a slice of memory, at least as long as the list's size +/// @see faiss_IndexIVF_get_list_size(size_t) +void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist); + +typedef struct FaissIndexIVFStats { + size_t nq; // nb of queries run + size_t nlist; // nb of inverted lists scanned + size_t ndis; // nb of distancs computed +} FaissIndexIVFStats; + +void faiss_IndexIVFStats_reset(FaissIndexIVFStats* stats); + +inline void faiss_IndexIVFStats_init(FaissIndexIVFStats* stats) { + faiss_IndexIVFStats_reset(stats); +} + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.cpp new file mode 100644 index 0000000000..39a348f807 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.cpp @@ -0,0 +1,37 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "IndexLSH_c.h" +#include "IndexLSH.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::IndexLSH; + +DEFINE_DESTRUCTOR(IndexLSH) +DEFINE_INDEX_DOWNCAST(IndexLSH) + +DEFINE_GETTER(IndexLSH, int, nbits) +DEFINE_GETTER(IndexLSH, int, bytes_per_vec) +DEFINE_GETTER_PERMISSIVE(IndexLSH, int, rotate_data) +DEFINE_GETTER_PERMISSIVE(IndexLSH, int, train_thresholds) + +int faiss_IndexLSH_new(FaissIndexLSH** p_index, idx_t d, int nbits) { + try { + *p_index = reinterpret_cast(new IndexLSH(d, nbits)); + } CATCH_AND_HANDLE +} + +int faiss_IndexLSH_new_with_options(FaissIndexLSH** p_index, idx_t d, int nbits, int rotate_data, int train_thresholds) { + try { + *p_index = reinterpret_cast( + new IndexLSH(d, nbits, static_cast(rotate_data), static_cast(train_thresholds))); + } CATCH_AND_HANDLE +} diff --git a/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.h b/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.h new file mode 100644 index 0000000000..4a3dab418d --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexLSH_c.h @@ -0,0 +1,40 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#ifndef INDEX_LSH_C_H +#define INDEX_LSH_C_H + +#include "faiss_c.h" +#include "Index_c.h" +#include "Clustering_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** The sign of each vector component is put in a binary signature */ +FAISS_DECLARE_CLASS_INHERITED(IndexLSH, Index) +FAISS_DECLARE_DESTRUCTOR(IndexLSH) +FAISS_DECLARE_INDEX_DOWNCAST(IndexLSH) + +FAISS_DECLARE_GETTER(IndexLSH, int, nbits) +FAISS_DECLARE_GETTER(IndexLSH, int, bytes_per_vec) +FAISS_DECLARE_GETTER(IndexLSH, int, rotate_data) +FAISS_DECLARE_GETTER(IndexLSH, int, train_thresholds) + +int faiss_IndexLSH_new(FaissIndexLSH** p_index, idx_t d, int nbits); + +int faiss_IndexLSH_new_with_options(FaissIndexLSH** p_index, idx_t d, int nbits, int rotate_data, int train_thresholds); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.cpp new file mode 100644 index 0000000000..7d99602edd --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.cpp @@ -0,0 +1,21 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "IndexPreTransform_c.h" +#include "IndexPreTransform.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::IndexPreTransform; + +DEFINE_DESTRUCTOR(IndexPreTransform) +DEFINE_INDEX_DOWNCAST(IndexPreTransform) + +DEFINE_GETTER_PERMISSIVE(IndexPreTransform, FaissIndex*, index) diff --git a/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.h b/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.h new file mode 100644 index 0000000000..c6d34b23c7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexPreTransform_c.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_INDEX_PRETRANSFORM_C_H +#define FAISS_INDEX_PRETRANSFORM_C_H + +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +FAISS_DECLARE_CLASS(IndexPreTransform) +FAISS_DECLARE_DESTRUCTOR(IndexPreTransform) +FAISS_DECLARE_INDEX_DOWNCAST(IndexPreTransform) + +FAISS_DECLARE_GETTER(IndexPreTransform, FaissIndex*, index) + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/IndexShards_c.cpp b/core/src/index/thirdparty/faiss/c_api/IndexShards_c.cpp new file mode 100644 index 0000000000..e66aeb7ed0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexShards_c.cpp @@ -0,0 +1,44 @@ +#include "IndexShards_c.h" +#include "IndexShards.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::IndexShards; + +DEFINE_GETTER(IndexShards, int, own_fields) +DEFINE_SETTER(IndexShards, int, own_fields) + +DEFINE_GETTER(IndexShards, int, successive_ids) +DEFINE_SETTER(IndexShards, int, successive_ids) + +int faiss_IndexShards_new(FaissIndexShards** p_index, idx_t d) { + try { + auto out = new IndexShards(d); + *p_index = reinterpret_cast(out); + } CATCH_AND_HANDLE +} + +int faiss_IndexShards_new_with_options(FaissIndexShards** p_index, idx_t d, int threaded, int successive_ids) { + try { + auto out = new IndexShards(d, static_cast(threaded), static_cast(successive_ids)); + *p_index = reinterpret_cast(out); + } CATCH_AND_HANDLE +} + +int faiss_IndexShards_add_shard(FaissIndexShards* index, FaissIndex* shard) { + try { + reinterpret_cast(index)->add_shard( + reinterpret_cast(shard)); + } CATCH_AND_HANDLE +} + +int faiss_IndexShards_sync_with_shard_indexes(FaissIndexShards* index) { + try { + reinterpret_cast(index)->sync_with_shard_indexes(); + } CATCH_AND_HANDLE +} + +FaissIndex* faiss_IndexShards_at(FaissIndexShards* index, int i) { + auto shard = reinterpret_cast(index)->at(i); + return reinterpret_cast(shard); +} diff --git a/core/src/index/thirdparty/faiss/c_api/IndexShards_c.h b/core/src/index/thirdparty/faiss/c_api/IndexShards_c.h new file mode 100644 index 0000000000..7e6a30b2a9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/IndexShards_c.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#ifndef INDEXSHARDS_C_H +#define INDEXSHARDS_C_H + +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Index that concatenates the results from several sub-indexes + */ +FAISS_DECLARE_CLASS_INHERITED(IndexShards, Index) + +FAISS_DECLARE_GETTER_SETTER(IndexShards, int, own_fields) +FAISS_DECLARE_GETTER_SETTER(IndexShards, int, successive_ids) + +int faiss_IndexShards_new(FaissIndexShards** p_index, idx_t d); + +int faiss_IndexShards_new_with_options(FaissIndexShards** p_index, idx_t d, int threaded, int successive_ids); + +int faiss_IndexShards_add_shard(FaissIndexShards* index, FaissIndex* shard); + +/// update metric_type and ntotal +int faiss_IndexShards_sync_with_shard_indexes(FaissIndexShards* index); + +FaissIndex* faiss_IndexShards_at(FaissIndexShards* index, int i); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/Index_c.cpp b/core/src/index/thirdparty/faiss/c_api/Index_c.cpp new file mode 100644 index 0000000000..38263f4333 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/Index_c.cpp @@ -0,0 +1,105 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "Index_c.h" +#include "Index.h" +#include "macros_impl.h" + +extern "C" { + +DEFINE_DESTRUCTOR(Index) + +DEFINE_GETTER(Index, int, d) + +DEFINE_GETTER(Index, int, is_trained) + +DEFINE_GETTER(Index, idx_t, ntotal) + +DEFINE_GETTER(Index, FaissMetricType, metric_type) + +int faiss_Index_train(FaissIndex* index, idx_t n, const float* x) { + try { + reinterpret_cast(index)->train(n, x); + } CATCH_AND_HANDLE +} + +int faiss_Index_add(FaissIndex* index, idx_t n, const float* x) { + try { + reinterpret_cast(index)->add(n, x); + } CATCH_AND_HANDLE +} + +int faiss_Index_add_with_ids(FaissIndex* index, idx_t n, const float* x, const idx_t* xids) { + try { + reinterpret_cast(index)->add_with_ids(n, x, xids); + } CATCH_AND_HANDLE +} + +int faiss_Index_search(const FaissIndex* index, idx_t n, const float* x, idx_t k, + float* distances, idx_t* labels) { + try { + reinterpret_cast(index)->search(n, x, k, distances, labels); + } CATCH_AND_HANDLE +} + +int faiss_Index_range_search(const FaissIndex* index, idx_t n, const float* x, float radius, + FaissRangeSearchResult* result) { + try { + reinterpret_cast(index)->range_search( + n, x, radius, reinterpret_cast(result)); + } CATCH_AND_HANDLE +} + +int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels) { + try { + reinterpret_cast(index)->assign(n, x, labels); + } CATCH_AND_HANDLE +} + +int faiss_Index_reset(FaissIndex* index) { + try { + reinterpret_cast(index)->reset(); + } CATCH_AND_HANDLE +} + +int faiss_Index_remove_ids(FaissIndex* index, const FaissIDSelector* sel, size_t* n_removed) { + try { + size_t n {reinterpret_cast(index)->remove_ids( + *reinterpret_cast(sel))}; + if (n_removed) { + *n_removed = n; + } + } CATCH_AND_HANDLE +} + +int faiss_Index_reconstruct(const FaissIndex* index, idx_t key, float* recons) { + try { + reinterpret_cast(index)->reconstruct(key, recons); + } CATCH_AND_HANDLE +} + +int faiss_Index_reconstruct_n (const FaissIndex* index, idx_t i0, idx_t ni, float* recons) { + try { + reinterpret_cast(index)->reconstruct_n(i0, ni, recons); + } CATCH_AND_HANDLE +} + +int faiss_Index_compute_residual(const FaissIndex* index, const float* x, float* residual, idx_t key) { + try { + reinterpret_cast(index)->compute_residual(x, residual, key); + } CATCH_AND_HANDLE +} + +int faiss_Index_compute_residual_n(const FaissIndex* index, idx_t n, const float* x, float* residuals, const idx_t* keys) { + try { + reinterpret_cast(index)->compute_residual_n(n, x, residuals, keys); + } CATCH_AND_HANDLE +} +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/Index_c.h b/core/src/index/thirdparty/faiss/c_api/Index_c.h new file mode 100644 index 0000000000..4b6a30c7cd --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/Index_c.h @@ -0,0 +1,183 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c -*- + +#ifndef FAISS_INDEX_C_H +#define FAISS_INDEX_C_H + +#include +#include "faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// forward declaration required here +FAISS_DECLARE_CLASS(RangeSearchResult) + +//typedef struct FaissRangeSearchResult_H FaissRangeSearchResult; +typedef struct FaissIDSelector_H FaissIDSelector; + +/// Some algorithms support both an inner product version and a L2 search version. +typedef enum FaissMetricType { + METRIC_INNER_PRODUCT = 0, ///< maximum inner product search + METRIC_L2 = 1, ///< squared L2 search + METRIC_L1, ///< L1 (aka cityblock) + METRIC_Linf, ///< infinity distance + METRIC_Lp, ///< L_p distance, p is given by metric_arg + + /// some additional metrics defined in scipy.spatial.distance + METRIC_Canberra = 20, + METRIC_BrayCurtis, + METRIC_JensenShannon, +} FaissMetricType; + +/// Opaque type for referencing to an index object +FAISS_DECLARE_CLASS(Index) +FAISS_DECLARE_DESTRUCTOR(Index) + +/// Getter for d +FAISS_DECLARE_GETTER(Index, int, d) + +/// Getter for is_trained +FAISS_DECLARE_GETTER(Index, int, is_trained) + +/// Getter for ntotal +FAISS_DECLARE_GETTER(Index, idx_t, ntotal) + +/// Getter for metric_type +FAISS_DECLARE_GETTER(Index, FaissMetricType, metric_type) + +/** Perform training on a representative set of vectors + * + * @param index opaque pointer to index object + * @param n nb of training vectors + * @param x training vecors, size n * d + */ +int faiss_Index_train(FaissIndex* index, idx_t n, const float* x); + +/** Add n vectors of dimension d to the index. + * + * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 + * This function slices the input vectors in chuncks smaller than + * blocksize_add and calls add_core. + * @param index opaque pointer to index object + * @param x input matrix, size n * d + */ +int faiss_Index_add(FaissIndex* index, idx_t n, const float* x); + +/** Same as add, but stores xids instead of sequential ids. + * + * The default implementation fails with an assertion, as it is + * not supported by all indexes. + * + * @param index opaque pointer to index object + * @param xids if non-null, ids to store for the vectors (size n) + */ +int faiss_Index_add_with_ids(FaissIndex* index, idx_t n, const float* x, const idx_t* xids); + +/** query n vectors of dimension d to the index. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param index opaque pointer to index object + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + */ +int faiss_Index_search(const FaissIndex* index, idx_t n, const float* x, idx_t k, + float* distances, idx_t* labels); + +/** query n vectors of dimension d to the index. + * + * return all vectors with distance < radius. Note that many + * indexes do not implement the range_search (only the k-NN search + * is mandatory). + * + * @param index opaque pointer to index object + * @param x input vectors to search, size n * d + * @param radius search radius + * @param result result table + */ +int faiss_Index_range_search(const FaissIndex* index, idx_t n, const float* x, + float radius, FaissRangeSearchResult* result); + +/** return the indexes of the k vectors closest to the query x. + * + * This function is identical as search but only return labels of neighbors. + * @param index opaque pointer to index object + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n + */ +int faiss_Index_assign(FaissIndex* index, idx_t n, const float * x, idx_t * labels); + +/** removes all elements from the database. + * @param index opaque pointer to index object + */ +int faiss_Index_reset(FaissIndex* index); + +/** removes IDs from the index. Not supported by all indexes + * @param index opaque pointer to index object + * @param nremove output for the number of IDs removed + */ +int faiss_Index_remove_ids(FaissIndex* index, const FaissIDSelector* sel, size_t* n_removed); + +/** Reconstruct a stored vector (or an approximation if lossy coding) + * + * this function may not be defined for some indexes + * @param index opaque pointer to index object + * @param key id of the vector to reconstruct + * @param recons reconstucted vector (size d) + */ +int faiss_Index_reconstruct(const FaissIndex* index, idx_t key, float* recons); + +/** Reconstruct vectors i0 to i0 + ni - 1 + * + * this function may not be defined for some indexes + * @param index opaque pointer to index object + * @param recons reconstucted vector (size ni * d) + */ +int faiss_Index_reconstruct_n (const FaissIndex* index, idx_t i0, idx_t ni, float* recons); + +/** Computes a residual vector after indexing encoding. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param index opaque pointer to index object + * @param x input vector, size d + * @param residual output residual vector, size d + * @param key encoded index, as returned by search and assign + */ +int faiss_Index_compute_residual(const FaissIndex* index, const float* x, float* residual, idx_t key); + +/** Computes a residual vector after indexing encoding. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param index opaque pointer to index object + * @param n number of vectors + * @param x input vector, size (n x d) + * @param residuals output residual vectors, size (n x d) + * @param keys encoded index, as returned by search and assign + */ +int faiss_Index_compute_residual_n(const FaissIndex* index, idx_t n, const float* x, float* residuals, const idx_t* keys); + + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/Makefile b/core/src/index/thirdparty/faiss/c_api/Makefile new file mode 100644 index 0000000000..c47c465f00 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/Makefile @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +.SUFFIXES: .cpp .o + +# C API + +include ../makefile.inc +DEBUGFLAG=-DNDEBUG # no debugging + +LIBNAME=libfaiss +CLIBNAME=libfaiss_c +LIBCOBJ=error_impl.o Index_c.o IndexFlat_c.o Clustering_c.o AutoTune_c.o \ + impl/AuxIndexStructures_c.o IndexIVF_c.o IndexIVFFlat_c.o IndexLSH_c.o \ + index_io_c.o MetaIndexes_c.o IndexShards_c.o index_factory_c.o \ + clone_index_c.o IndexPreTransform_c.o +CFLAGS=-fPIC -m64 -Wno-sign-compare -g -O3 -Wall -Wextra + +# Build static and shared object files by default +all: $(CLIBNAME).a $(CLIBNAME).$(SHAREDEXT) + +# Build static object file containing the wrapper implementation only. +# Consumers are required to link with libfaiss.a and libstdc++. +$(CLIBNAME).a: $(LIBCOBJ) + ar r $@ $^ + +# Build dynamic library (independent object) +$(CLIBNAME).$(SHAREDEXT): $(LIBCOBJ) ../$(LIBNAME).a + $(CXX) $(LDFLAGS) $(SHAREDFLAGS) -o $@ \ + -Wl,--whole-archive $^ -Wl,--no-whole-archive $(LIBS) -static-libstdc++ + +bin/example_c: example_c.c $(CLIBNAME).$(SHAREDEXT) + $(CC) $(CFLAGS) -std=c99 -I. -I.. -L. -o $@ example_c.c \ + $(LDFLAGS) -lm -lfaiss_c + +clean: + rm -f $(CLIBNAME).a $(CLIBNAME).$(SHAREDEXT)* *.o bin/example_c + +%.o: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -c $< -o $@ + +# Dependencies + +error_impl.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +error_impl.o: error_impl.cpp error_c.h error_impl.h macros_impl.h + +index_io_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +index_io_c.o: index_io_c.cpp error_impl.cpp ../index_io.h macros_impl.h + +index_factory_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +index_factory_c.o: index_factory_c.cpp error_impl.cpp ../index_io.h macros_impl.h + +clone_index_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +clone_index_c.o: index_factory_c.cpp error_impl.cpp ../index_io.h macros_impl.h + +Index_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +Index_c.o: Index_c.cpp Index_c.h ../Index.h macros_impl.h + +IndexFlat_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexFlat_c.o: IndexFlat_c.cpp IndexFlat_c.h ../IndexFlat.h macros_impl.h + +IndexIVF_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexIVF_c.o: IndexIVF_c.cpp IndexIVF_c.h ../IndexIVF.h macros_impl.h + +IndexIVFFlat_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexIVFFlat_c.o: IndexIVFFlat_c.cpp IndexIVFFlat_c.h ../IndexIVFFlat.h macros_impl.h + +IndexLSH_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexLSH_c.o: IndexLSH_c.cpp IndexLSH_c.h ../IndexLSH.h macros_impl.h + +IndexShards_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexShards_c.o: IndexShards_c.cpp IndexShards_c.h ../Index.h ../IndexShards.h macros_impl.h + +Clustering_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +Clustering_c.o: Clustering_c.cpp Clustering_c.h ../Clustering.h macros_impl.h + +AutoTune_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +AutoTune_c.o: AutoTune_c.cpp AutoTune_c.h ../AutoTune.h macros_impl.h + +impl/AuxIndexStructures_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +impl/AuxIndexStructures_c.o: impl/AuxIndexStructures_c.cpp impl/AuxIndexStructures_c.h ../impl/AuxIndexStructures.h macros_impl.h + +MetaIndexes_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +MetaIndexes_c.o: MetaIndexes_c.cpp MetaIndexes_c.h ../MetaIndexes.h macros_impl.h + +IndexPreTransform_c.o: CXXFLAGS += -I.. -I ../impl $(DEBUGFLAG) +IndexPreTransform_c.o: IndexPreTransform_c.cpp IndexPreTransform_c.h ../IndexPreTransform.h macros_impl.h diff --git a/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.cpp b/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.cpp new file mode 100644 index 0000000000..72abd9e793 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.cpp @@ -0,0 +1,49 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "MetaIndexes_c.h" +#include "MetaIndexes.h" +#include "macros_impl.h" + +using faiss::Index; +using faiss::IndexIDMap; +using faiss::IndexIDMap2; + +DEFINE_GETTER(IndexIDMap, int, own_fields) +DEFINE_SETTER(IndexIDMap, int, own_fields) + +int faiss_IndexIDMap_new(FaissIndexIDMap** p_index, FaissIndex* index) { + try { + auto out = new IndexIDMap(reinterpret_cast(index)); + *p_index = reinterpret_cast(out); + } CATCH_AND_HANDLE +} + +void faiss_IndexIDMap_id_map(FaissIndexIDMap* index, idx_t** p_id_map, size_t* p_size) { + auto idx = reinterpret_cast(index); + if (p_id_map) + *p_id_map = idx->id_map.data(); + if (p_size) + *p_size = idx->id_map.size(); +} + +int faiss_IndexIDMap2_new(FaissIndexIDMap2** p_index, FaissIndex* index) { + try { + auto out = new IndexIDMap2(reinterpret_cast(index)); + *p_index = reinterpret_cast(out); + } CATCH_AND_HANDLE +} + +int faiss_IndexIDMap2_construct_rev_map(FaissIndexIDMap2* index) { + try { + reinterpret_cast(index)->construct_rev_map(); + } CATCH_AND_HANDLE +} + diff --git a/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.h b/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.h new file mode 100644 index 0000000000..940394f92f --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/MetaIndexes_c.h @@ -0,0 +1,49 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#ifndef METAINDEXES_C_H +#define METAINDEXES_C_H + +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Index that translates search results to ids */ +FAISS_DECLARE_CLASS_INHERITED(IndexIDMap, Index) + +FAISS_DECLARE_GETTER_SETTER(IndexIDMap, int, own_fields) + +int faiss_IndexIDMap_new(FaissIndexIDMap** p_index, FaissIndex* index); + +/** get a pointer to the index map's internal ID vector (the `id_map` field). The + * outputs of this function become invalid after any operation that can modify the index. + * + * @param index opaque pointer to index object + * @param p_id_map output, the pointer to the beginning of `id_map`. + * @param p_size output, the current length of `id_map`. + */ +void faiss_IndexIDMap_id_map(FaissIndexIDMap* index, idx_t** p_id_map, size_t* p_size); + +/** same as IndexIDMap but also provides an efficient reconstruction + implementation via a 2-way index */ +FAISS_DECLARE_CLASS_INHERITED(IndexIDMap2, IndexIDMap) + +int faiss_IndexIDMap2_new(FaissIndexIDMap2** p_index, FaissIndex* index); + +/// make the rev_map from scratch +int faiss_IndexIDMap2_construct_rev_map(FaissIndexIDMap2* index); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/clone_index_c.cpp b/core/src/index/thirdparty/faiss/c_api/clone_index_c.cpp new file mode 100644 index 0000000000..999b139a7c --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/clone_index_c.cpp @@ -0,0 +1,23 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- +// I/O code for indexes + +#include "clone_index_c.h" +#include "clone_index.h" +#include "macros_impl.h" + +using faiss::Index; + +int faiss_clone_index (const FaissIndex *idx, FaissIndex **p_out) { + try { + auto out = faiss::clone_index(reinterpret_cast(idx)); + *p_out = reinterpret_cast(out); + } CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/clone_index_c.h b/core/src/index/thirdparty/faiss/c_api/clone_index_c.h new file mode 100644 index 0000000000..3cf7e1a658 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/clone_index_c.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- +// I/O code for indexes + + +#ifndef FAISS_CLONE_INDEX_C_H +#define FAISS_CLONE_INDEX_C_H + +#include +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* cloning functions */ + +/** Clone an index. This is equivalent to `faiss::clone_index` */ +int faiss_clone_index (const FaissIndex *, FaissIndex ** p_out); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/error_c.h b/core/src/index/thirdparty/faiss/c_api/error_c.h new file mode 100644 index 0000000000..5aa5664feb --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/error_c.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_ERROR_C_H +#define FAISS_ERROR_C_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// An error code which depends on the exception thrown from the previous +/// operation. See `faiss_get_last_error` to retrieve the error message. +typedef enum FaissErrorCode { + /// No error + OK = 0, + /// Any exception other than Faiss or standard C++ library exceptions + UNKNOWN_EXCEPT = -1, + /// Faiss library exception + FAISS_EXCEPT = -2, + /// Standard C++ library exception + STD_EXCEPT = -4 +} FaissErrorCode; + +/** + * Get the error message of the last failed operation performed by Faiss. + * The given pointer is only invalid until another Faiss function is + * called. + */ +const char* faiss_get_last_error(); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/error_impl.cpp b/core/src/index/thirdparty/faiss/c_api/error_impl.cpp new file mode 100644 index 0000000000..25793eb0e8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/error_impl.cpp @@ -0,0 +1,27 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "error_c.h" +#include "error_impl.h" +#include "FaissException.h" +#include + +thread_local std::exception_ptr faiss_last_exception; + +const char* faiss_get_last_error() { + if (faiss_last_exception) { + try { + std::rethrow_exception(faiss_last_exception); + } catch (std::exception& e) { + return e.what(); + } + } + return nullptr; +} diff --git a/core/src/index/thirdparty/faiss/c_api/error_impl.h b/core/src/index/thirdparty/faiss/c_api/error_impl.h new file mode 100644 index 0000000000..b44254ad94 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/error_impl.h @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include + +/** global variable for holding the last exception thrown by + * calls to Faiss functions through the C API + */ +extern thread_local std::exception_ptr faiss_last_exception; diff --git a/core/src/index/thirdparty/faiss/c_api/example_c.c b/core/src/index/thirdparty/faiss/c_api/example_c.c new file mode 100644 index 0000000000..2e9a78a1ad --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/example_c.c @@ -0,0 +1,97 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#include +#include +#include + +#include "error_c.h" +#include "index_io_c.h" +#include "Index_c.h" +#include "IndexFlat_c.h" +#include "AutoTune_c.h" +#include "clone_index_c.h" + +#define FAISS_TRY(C) \ + { \ + if (C) { \ + fprintf(stderr, "%s", faiss_get_last_error()); \ + exit(-1); \ + } \ + } + +double drand() { + return (double)rand() / (double)RAND_MAX; +} + +int main() { + time_t seed = time(NULL); + srand(seed); + printf("Generating some data...\n"); + int d = 128; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + float *xb = malloc(d * nb * sizeof(float)); + float *xq = malloc(d * nq * sizeof(float)); + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) xb[d * i + j] = drand(); + xb[d * i] += i / 1000.; + } + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) xq[d * i + j] = drand(); + xq[d * i] += i / 1000.; + } + + printf("Building an index...\n"); + + FaissIndex* index = NULL; + FAISS_TRY(faiss_index_factory(&index, d, "Flat", METRIC_L2)); // use factory to create index + printf("is_trained = %s\n", faiss_Index_is_trained(index) ? "true" : "false"); + FAISS_TRY(faiss_Index_add(index, nb, xb)); // add vectors to the index + printf("ntotal = %ld\n", faiss_Index_ntotal(index)); + + printf("Searching...\n"); + int k = 5; + + { // sanity check: search 5 first vectors of xb + idx_t *I = malloc(k * 5 * sizeof(idx_t)); + float *D = malloc(k * 5 * sizeof(float)); + FAISS_TRY(faiss_Index_search(index, 5, xb, k, D, I)); + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) printf("%5ld (d=%2.3f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + free(I); + free(D); + } + { // search xq + idx_t *I = malloc(k * nq * sizeof(idx_t)); + float *D = malloc(k * nq * sizeof(float)); + FAISS_TRY(faiss_Index_search(index, 5, xb, k, D, I)); + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) printf("%5ld (d=%2.3f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + free(I); + free(D); + } + + printf("Saving index to disk...\n"); + FAISS_TRY(faiss_write_index_fname(index, "example.index")); + + printf("Freeing index...\n"); + faiss_Index_free(index); + printf("Done.\n"); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/c_api/faiss_c.h b/core/src/index/thirdparty/faiss/c_api/faiss_c.h new file mode 100644 index 0000000000..2357f71327 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/faiss_c.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +/// Macros and typedefs for C wrapper API declarations + +#ifndef FAISS_C_H +#define FAISS_C_H + +#include + +typedef int64_t faiss_idx_t; ///< all indices are this type +typedef faiss_idx_t idx_t; +typedef float faiss_component_t; ///< all vector components are this type +typedef float faiss_distance_t; ///< all distances between vectors are this type + +/// Declare an opaque type for a class type `clazz`. +#define FAISS_DECLARE_CLASS(clazz) \ + typedef struct Faiss ## clazz ## _H Faiss ## clazz; + +/// Declare an opaque type for a class type `clazz`, while +/// actually aliasing it to an existing parent class type `parent`. +#define FAISS_DECLARE_CLASS_INHERITED(clazz, parent) \ + typedef struct Faiss ## parent ## _H Faiss ## clazz; + +/// Declare a dynamic downcast operation from a base `FaissIndex*` pointer +/// type to a more specific index type. The function returns the same pointer +/// if the downcast is valid, and `NULL` otherwise. +#define FAISS_DECLARE_INDEX_DOWNCAST(clazz) \ + Faiss ## clazz * faiss_ ## clazz ## _cast (FaissIndex*); + +/// Declare a getter for the field `name` in class `clazz`, +/// of return type `ty` +#define FAISS_DECLARE_GETTER(clazz, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *); + +/// Declare a setter for the field `name` in class `clazz`, +/// in which the user provides a value of type `ty` +#define FAISS_DECLARE_SETTER(clazz, ty, name) \ + void faiss_ ## clazz ## _set_ ## name (Faiss ## clazz *, ty); + +/// Declare a getter and setter for the field `name` in class `clazz`. +#define FAISS_DECLARE_GETTER_SETTER(clazz, ty, name) \ + FAISS_DECLARE_GETTER(clazz, ty, name) \ + FAISS_DECLARE_SETTER(clazz, ty, name) + +/// Declare a destructor function which frees an object of +/// type `clazz`. +#define FAISS_DECLARE_DESTRUCTOR(clazz) \ + void faiss_ ## clazz ## _free (Faiss ## clazz *obj); + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.cpp b/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.cpp new file mode 100644 index 0000000000..7336d5d7d3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.cpp @@ -0,0 +1,96 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "GpuAutoTune_c.h" +#include "GpuClonerOptions_c.h" +#include "macros_impl.h" +#include "Index.h" +#include "gpu/GpuAutoTune.h" +#include "gpu/GpuClonerOptions.h" +#include + +using faiss::Index; +using faiss::gpu::GpuResources; +using faiss::gpu::GpuClonerOptions; +using faiss::gpu::GpuMultipleClonerOptions; + +int faiss_index_gpu_to_cpu(const FaissIndex* gpu_index, FaissIndex** p_out) { + try { + auto cpu_index = faiss::gpu::index_gpu_to_cpu( + reinterpret_cast(gpu_index) + ); + *p_out = reinterpret_cast(cpu_index); + } CATCH_AND_HANDLE +} + +/// converts any CPU index that can be converted to GPU +int faiss_index_cpu_to_gpu(FaissGpuResources* resources, int device, const FaissIndex *index, FaissGpuIndex** p_out) { + try { + auto res = reinterpret_cast(resources); + auto gpu_index = faiss::gpu::index_cpu_to_gpu( + res, device, reinterpret_cast(index) + ); + *p_out = reinterpret_cast(gpu_index); + } CATCH_AND_HANDLE +} + +int faiss_index_cpu_to_gpu_with_options( + FaissGpuResources* resources, int device, + const FaissIndex *index, const FaissGpuClonerOptions* options, + FaissGpuIndex** p_out) +{ + try { + auto res = reinterpret_cast(resources); + auto gpu_index = faiss::gpu::index_cpu_to_gpu( + res, device, reinterpret_cast(index), + reinterpret_cast(options)); + *p_out = reinterpret_cast(gpu_index); + } CATCH_AND_HANDLE +} + +int faiss_index_cpu_to_gpu_multiple( + FaissGpuResources* const* resources_vec, + const int* devices, size_t devices_size, + const FaissIndex* index, FaissGpuIndex** p_out) +{ + try { + std::vector res(devices_size); + for (auto i = 0u; i < devices_size; ++i) { + res[i] = reinterpret_cast(resources_vec[i]); + } + + std::vector dev(devices, devices + devices_size); + + auto gpu_index = faiss::gpu::index_cpu_to_gpu_multiple( + res, dev, reinterpret_cast(index)); + *p_out = reinterpret_cast(gpu_index); + } CATCH_AND_HANDLE +} + +int faiss_index_cpu_to_gpu_multiple_with_options( + FaissGpuResources** resources_vec, size_t resources_vec_size, + int* devices, size_t devices_size, + const FaissIndex* index, const FaissGpuMultipleClonerOptions* options, + FaissGpuIndex** p_out) +{ + try { + std::vector res(resources_vec_size); + for (auto i = 0u; i < resources_vec_size; ++i) { + res[i] = reinterpret_cast(resources_vec[i]); + } + + std::vector dev(devices, devices + devices_size); + + auto gpu_index = faiss::gpu::index_cpu_to_gpu_multiple( + res, dev, reinterpret_cast(index), + reinterpret_cast(options)); + *p_out = reinterpret_cast(gpu_index); + } CATCH_AND_HANDLE +} diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.h new file mode 100644 index 0000000000..5dbd15c977 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuAutoTune_c.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_GPU_AUTO_TUNE_C_H +#define FAISS_GPU_AUTO_TUNE_C_H + +#include +#include "faiss_c.h" +#include "GpuClonerOptions_c.h" +#include "GpuResources_c.h" +#include "GpuIndex_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// converts any GPU index inside gpu_index to a CPU index +int faiss_index_gpu_to_cpu(const FaissIndex* gpu_index, FaissIndex** p_out); + +/// converts any CPU index that can be converted to GPU +int faiss_index_cpu_to_gpu( + FaissGpuResources* resources, int device, + const FaissIndex *index, FaissGpuIndex** p_out); + +/// converts any CPU index that can be converted to GPU +int faiss_index_cpu_to_gpu_with_options( + FaissGpuResources* resources, int device, + const FaissIndex *index, const FaissGpuClonerOptions* options, + FaissGpuIndex** p_out); + +/// converts any CPU index that can be converted to GPU +int faiss_index_cpu_to_gpu_multiple( + FaissGpuResources* const* resources_vec, const int* devices, size_t devices_size, + const FaissIndex* index, FaissGpuIndex** p_out); + +/// converts any CPU index that can be converted to GPU +int faiss_index_cpu_to_gpu_multiple_with_options( + FaissGpuResources* const* resources_vec, const int* devices, size_t devices_size, + const FaissIndex* index, const FaissGpuMultipleClonerOptions* options, + FaissGpuIndex** p_out); + +/// parameter space and setters for GPU indexes +FAISS_DECLARE_CLASS_INHERITED(GpuParameterSpace, ParameterSpace) + +#ifdef __cplusplus +} +#endif +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.cpp b/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.cpp new file mode 100644 index 0000000000..c61fc5e34c --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.cpp @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "GpuClonerOptions_c.h" +#include "gpu/GpuClonerOptions.h" +#include "macros_impl.h" + +using faiss::gpu::IndicesOptions; +using faiss::gpu::GpuClonerOptions; +using faiss::gpu::GpuMultipleClonerOptions; + +int faiss_GpuClonerOptions_new(FaissGpuClonerOptions** p) { + try { + *p = reinterpret_cast(new GpuClonerOptions()); + } CATCH_AND_HANDLE +} + +int faiss_GpuMultipleClonerOptions_new(FaissGpuMultipleClonerOptions** p) { + try { + *p = reinterpret_cast(new GpuMultipleClonerOptions()); + } CATCH_AND_HANDLE +} + +DEFINE_DESTRUCTOR(GpuClonerOptions) +DEFINE_DESTRUCTOR(GpuMultipleClonerOptions) + +DEFINE_GETTER(GpuClonerOptions, FaissIndicesOptions, indicesOptions) +DEFINE_GETTER(GpuClonerOptions, int, useFloat16CoarseQuantizer) +DEFINE_GETTER(GpuClonerOptions, int, useFloat16) +DEFINE_GETTER(GpuClonerOptions, int, usePrecomputed) +DEFINE_GETTER(GpuClonerOptions, long, reserveVecs) +DEFINE_GETTER(GpuClonerOptions, int, storeTransposed) +DEFINE_GETTER(GpuClonerOptions, int, verbose) +DEFINE_GETTER(GpuMultipleClonerOptions, int, shard) +DEFINE_GETTER(GpuMultipleClonerOptions, int, shard_type) + +DEFINE_SETTER_STATIC(GpuClonerOptions, IndicesOptions, FaissIndicesOptions, indicesOptions) +DEFINE_SETTER_STATIC(GpuClonerOptions, bool, int, useFloat16CoarseQuantizer) +DEFINE_SETTER_STATIC(GpuClonerOptions, bool, int, useFloat16) +DEFINE_SETTER_STATIC(GpuClonerOptions, bool, int, usePrecomputed) +DEFINE_SETTER(GpuClonerOptions, long, reserveVecs) +DEFINE_SETTER_STATIC(GpuClonerOptions, bool, int, storeTransposed) +DEFINE_SETTER_STATIC(GpuClonerOptions, bool, int, verbose) +DEFINE_SETTER_STATIC(GpuMultipleClonerOptions, bool, int, shard) +DEFINE_SETTER(GpuMultipleClonerOptions, int, shard_type) diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.h new file mode 100644 index 0000000000..94ff403e7a --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuClonerOptions_c.h @@ -0,0 +1,68 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_GPU_CLONER_OPTIONS_C_H +#define FAISS_GPU_CLONER_OPTIONS_C_H + +#include "faiss_c.h" +#include "GpuIndicesOptions_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +FAISS_DECLARE_CLASS(GpuClonerOptions) + +FAISS_DECLARE_DESTRUCTOR(GpuClonerOptions) + +/// Default constructor for GpuClonerOptions +int faiss_GpuClonerOptions_new(FaissGpuClonerOptions**); + +/// how should indices be stored on index types that support indices +/// (anything but GpuIndexFlat*)? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, FaissIndicesOptions, indicesOptions) + +/// (boolean) is the coarse quantizer in float16? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, int, useFloat16CoarseQuantizer) + +/// (boolean) for GpuIndexIVFFlat, is storage in float16? +/// for GpuIndexIVFPQ, are intermediate calculations in float16? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, int, useFloat16) + +/// (boolean) use precomputed tables? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, int, usePrecomputed) + +/// reserve vectors in the invfiles? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, long, reserveVecs) + +/// (boolean) For GpuIndexFlat, store data in transposed layout? +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, int, storeTransposed) + +/// (boolean) Set verbose options on the index +FAISS_DECLARE_GETTER_SETTER(GpuClonerOptions, int, verbose) + +FAISS_DECLARE_CLASS_INHERITED(GpuMultipleClonerOptions, GpuClonerOptions) + +FAISS_DECLARE_DESTRUCTOR(GpuMultipleClonerOptions) + +/// Default constructor for GpuMultipleClonerOptions +int faiss_GpuMultipleClonerOptions_new(FaissGpuMultipleClonerOptions**); + +/// (boolean) Whether to shard the index across GPUs, versus replication +/// across GPUs +FAISS_DECLARE_GETTER_SETTER(GpuMultipleClonerOptions, int, shard) + +/// IndexIVF::copy_subset_to subset type +FAISS_DECLARE_GETTER_SETTER(GpuMultipleClonerOptions, int, shard_type) + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.cpp b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.cpp new file mode 100644 index 0000000000..bdef82766e --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.cpp @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "gpu/GpuIndex.h" +#include "GpuIndex_c.h" +#include "macros_impl.h" + +using faiss::gpu::GpuIndexConfig; + +DEFINE_GETTER(GpuIndexConfig, int, device) diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.h new file mode 100644 index 0000000000..664c76101f --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndex_c.h @@ -0,0 +1,30 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_GPU_INDEX_C_H +#define FAISS_GPU_INDEX_C_H + +#include "faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +FAISS_DECLARE_CLASS(GpuIndexConfig) + +FAISS_DECLARE_GETTER(GpuIndexConfig, int, device) + +FAISS_DECLARE_CLASS_INHERITED(GpuIndex, Index) + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndicesOptions_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndicesOptions_c.h new file mode 100644 index 0000000000..6a49773bc6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuIndicesOptions_c.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_GPU_INDICES_OPTIONS_C_H +#define FAISS_GPU_INDICES_OPTIONS_C_H + +#ifdef __cplusplus +extern "C" { +#endif + +/// How user vector index data is stored on the GPU +typedef enum FaissIndicesOptions { + /// The user indices are only stored on the CPU; the GPU returns + /// (inverted list, offset) to the CPU which is then translated to + /// the real user index. + INDICES_CPU = 0, + /// The indices are not stored at all, on either the CPU or + /// GPU. Only (inverted list, offset) is returned to the user as the + /// index. + INDICES_IVF = 1, + /// Indices are stored as 32 bit integers on the GPU, but returned + /// as 64 bit integers + INDICES_32_BIT = 2, + /// Indices are stored as 64 bit integers on the GPU + INDICES_64_BIT = 3, +} FaissIndicesOptions; + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.cpp b/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.cpp new file mode 100644 index 0000000000..3f6525125d --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.cpp @@ -0,0 +1,86 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "gpu/GpuResources_c.h" +#include "gpu/GpuResources.h" +#include "macros_impl.h" + +using faiss::gpu::GpuResources; + +DEFINE_DESTRUCTOR(GpuResources) + +int faiss_GpuResources_initializeForDevice(FaissGpuResources* res, int device) { + try { + reinterpret_cast(res)->initializeForDevice(device); + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getBlasHandle(FaissGpuResources* res, int device, cublasHandle_t* out) { + try { + auto o = reinterpret_cast(res)->getBlasHandle(device); + *out = o; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getDefaultStream(FaissGpuResources* res, int device, cudaStream_t* out) { + try { + auto o = reinterpret_cast(res)->getDefaultStream(device); + *out = o; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getPinnedMemory(FaissGpuResources* res, void** p_buffer, size_t* p_size) { + try { + auto o = reinterpret_cast(res)->getPinnedMemory(); + *p_buffer = o.first; + *p_size = o.second; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getAsyncCopyStream(FaissGpuResources* res, int device, cudaStream_t* out) { + try { + auto o = reinterpret_cast(res)->getAsyncCopyStream(device); + *out = o; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getBlasHandleCurrentDevice(FaissGpuResources* res, cublasHandle_t* out) { + try { + auto o = reinterpret_cast(res)->getBlasHandleCurrentDevice(); + *out = o; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getDefaultStreamCurrentDevice(FaissGpuResources* res, cudaStream_t* out) { + try { + auto o = reinterpret_cast(res)->getDefaultStreamCurrentDevice(); + *out = o; + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_syncDefaultStream(FaissGpuResources* res, int device) { + try { + reinterpret_cast(res)->syncDefaultStream(device); + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_syncDefaultStreamCurrentDevice(FaissGpuResources* res) { + try { + reinterpret_cast(res)->syncDefaultStreamCurrentDevice(); + } CATCH_AND_HANDLE +} + +int faiss_GpuResources_getAsyncCopyStreamCurrentDevice(FaissGpuResources* res, cudaStream_t* out) { + try { + auto o = reinterpret_cast(res)->getAsyncCopyStreamCurrentDevice(); + *out = o; + } CATCH_AND_HANDLE +} + diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.h new file mode 100644 index 0000000000..bb9cefde36 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/GpuResources_c.h @@ -0,0 +1,66 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_GPU_RESOURCES_C_H +#define FAISS_GPU_RESOURCES_C_H + +#include +#include +#include "faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Base class of GPU-side resource provider; hides provision of +/// cuBLAS handles, CUDA streams and a temporary memory manager +FAISS_DECLARE_CLASS(GpuResources) + +FAISS_DECLARE_DESTRUCTOR(GpuResources) + +/// Call to pre-allocate resources for a particular device. If this is +/// not called, then resources will be allocated at the first time +/// of demand +int faiss_GpuResources_initializeForDevice(FaissGpuResources*, int); + +/// Returns the cuBLAS handle that we use for the given device +int faiss_GpuResources_getBlasHandle(FaissGpuResources*, int, cublasHandle_t*); + +/// Returns the stream that we order all computation on for the +/// given device +int faiss_GpuResources_getDefaultStream(FaissGpuResources*, int, cudaStream_t*); + +/// Returns the available CPU pinned memory buffer +int faiss_GpuResources_getPinnedMemory(FaissGpuResources*, void**, size_t*); + +/// Returns the stream on which we perform async CPU <-> GPU copies +int faiss_GpuResources_getAsyncCopyStream(FaissGpuResources*, int, cudaStream_t*); + +/// Calls getBlasHandle with the current device +int faiss_GpuResources_getBlasHandleCurrentDevice(FaissGpuResources*, cublasHandle_t*); + +/// Calls getDefaultStream with the current device +int faiss_GpuResources_getDefaultStreamCurrentDevice(FaissGpuResources*, cudaStream_t*); + +/// Synchronizes the CPU with respect to the default stream for the +/// given device +// equivalent to cudaDeviceSynchronize(getDefaultStream(device)) +int faiss_GpuResources_syncDefaultStream(FaissGpuResources*, int); + +/// Calls syncDefaultStream for the current device +int faiss_GpuResources_syncDefaultStreamCurrentDevice(FaissGpuResources*); + +/// Calls getAsyncCopyStream for the current device +int faiss_GpuResources_getAsyncCopyStreamCurrentDevice(FaissGpuResources*, cudaStream_t*); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/Makefile b/core/src/index/thirdparty/faiss/c_api/gpu/Makefile new file mode 100644 index 0000000000..ab1f707cee --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/Makefile @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +.SUFFIXES: .cpp .o + +# C API with GPU support + +include ../../makefile.inc +DEBUGFLAG=-DNDEBUG # no debugging + +LIBNAME=libgpufaiss +CLIBNAME=libgpufaiss_c +LIBGPUCOBJ=GpuAutoTune_c.o GpuClonerOptions_c.o GpuIndex_c.o GpuResources_c.o \ + StandardGpuResources_c.o +LIBCOBJ=../libfaiss_c.a +CFLAGS=-fPIC -m64 -Wno-sign-compare -g -O3 -Wall -Wextra +CUDACFLAGS=-I$(CUDA_ROOT)/include + +# Build shared object file by default +all: $(CLIBNAME).$(SHAREDEXT) + +# Build static object file containing the wrapper implementation only. +# Consumers are required to link with the C++ standard library and remaining +# portions of this library: libfaiss_c.a, libfaiss.a, libgpufaiss.a, and libstdc++. +$(CLIBNAME).a: $(LIBGPUCOBJ) ../../gpu/$(LIBNAME).a + ar r $@ $^ + +# Build dynamic library +$(CLIBNAME).$(SHAREDEXT): $(LIBCOBJ) $(LIBGPUCOBJ) ../../libfaiss.a ../../gpu/$(LIBNAME).a + $(CXX) $(LDFLAGS) $(SHAREDFLAGS) $(CUDACFLAGS) -o $@ \ + -Wl,--whole-archive $(LIBCOBJ) ../../libfaiss.a \ + -Wl,--no-whole-archive -static-libstdc++ $(LIBGPUCOBJ) $(LIBS) ../../gpu/$(LIBNAME).a \ + $(NVCCLDFLAGS) $(NVCCLIBS) + +# Build GPU example +bin/example_gpu_c: example_gpu_c.c $(CLIBNAME).$(SHAREDEXT) + $(CC) $(CFLAGS) $(CUDACFLAGS) $(NVCCLIBS) -std=c99 -I. -I.. -o $@ example_gpu_c.c \ + -L. -lgpufaiss_c + +clean: + rm -f $(CLIBNAME).a $(CLIBNAME).$(SHAREDEXT)* *.o bin/example_gpu_c + +%.o: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -c $< -o $@ + +# Dependencies + +GpuAutoTune_c.o: CXXFLAGS += -I.. -I../.. $(CUDACFLAGS) $(DEBUGFLAG) +GpuAutoTune_c.o: GpuAutoTune_c.cpp GpuAutoTune_c.h ../../gpu/GpuAutoTune.h ../Index_c.h ../macros_impl.h + +GpuClonerOptions_c.o: CXXFLAGS += -I.. -I../.. $(CUDACFLAGS) $(DEBUGFLAG) +GpuClonerOptions_c.o: GpuClonerOptions_c.cpp GpuClonerOptions_c.h GpuIndicesOptions_c.h ../../gpu/GpuClonerOptions.h ../macros_impl.h + +GpuIndex_c.o: CXXFLAGS += -I.. -I../.. $(CUDACFLAGS) $(DEBUGFLAG) +GpuIndex_c.o: GpuIndex_c.cpp GpuIndex_c.h ../../gpu/GpuIndex.h ../macros_impl.h + +GpuResources_c.o: CXXFLAGS += -I.. -I../.. $(CUDACFLAGS) $(DEBUGFLAG) +GpuResources_c.o: GpuResources_c.cpp GpuResources_c.h ../../gpu/GpuResources.h ../macros_impl.h + +StandardGpuResources_c.o: CXXFLAGS += -I.. -I../.. $(CUDACFLAGS) $(DEBUGFLAG) +StandardGpuResources_c.o: StandardGpuResources_c.cpp StandardGpuResources_c.h ../../gpu/StandardGpuResources.h ../macros_impl.h diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.cpp b/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.cpp new file mode 100644 index 0000000000..84afb027eb --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.cpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "gpu/StandardGpuResources_c.h" +#include "gpu/StandardGpuResources.h" +#include "macros_impl.h" + +using faiss::gpu::StandardGpuResources; + +DEFINE_DESTRUCTOR(StandardGpuResources) + +int faiss_StandardGpuResources_new(FaissStandardGpuResources** p_res) { + try { + auto p = new StandardGpuResources(); + *p_res = reinterpret_cast(p); + } CATCH_AND_HANDLE +} + +int faiss_StandardGpuResources_noTempMemory(FaissStandardGpuResources* res) { + try { + reinterpret_cast(res)->noTempMemory(); + } CATCH_AND_HANDLE +} + +int faiss_StandardGpuResources_setTempMemory(FaissStandardGpuResources* res, size_t size) { + try { + reinterpret_cast(res)->setTempMemory(size); + } CATCH_AND_HANDLE +} + +int faiss_StandardGpuResources_setPinnedMemory(FaissStandardGpuResources* res, size_t size) { + try { + reinterpret_cast(res)->setPinnedMemory(size); + } CATCH_AND_HANDLE +} + +int faiss_StandardGpuResources_setDefaultStream(FaissStandardGpuResources* res, int device, cudaStream_t stream) { + try { + reinterpret_cast(res)->setDefaultStream(device, stream); + } CATCH_AND_HANDLE +} + +int faiss_StandardGpuResources_setDefaultNullStreamAllDevices(FaissStandardGpuResources* res) { + try { + reinterpret_cast(res)->setDefaultNullStreamAllDevices(); + } CATCH_AND_HANDLE +} diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.h b/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.h new file mode 100644 index 0000000000..f9a3c854f0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/StandardGpuResources_c.h @@ -0,0 +1,53 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_STANDARD_GPURESOURCES_C_H +#define FAISS_STANDARD_GPURESOURCES_C_H + +#include +#include "faiss_c.h" +#include "gpu/GpuResources_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Default implementation of GpuResources that allocates a cuBLAS +/// stream and 2 streams for use, as well as temporary memory +FAISS_DECLARE_CLASS_INHERITED(StandardGpuResources, GpuResources) + +FAISS_DECLARE_DESTRUCTOR(StandardGpuResources) + +/// Default constructor for StandardGpuResources +int faiss_StandardGpuResources_new(FaissStandardGpuResources**); + +/// Disable allocation of temporary memory; all temporary memory +/// requests will call cudaMalloc / cudaFree at the point of use +int faiss_StandardGpuResources_noTempMemory(FaissStandardGpuResources*); + +/// Specify that we wish to use a certain fixed size of memory on +/// all devices as temporary memory +int faiss_StandardGpuResources_setTempMemory(FaissStandardGpuResources*, size_t size); + +/// Set amount of pinned memory to allocate, for async GPU <-> CPU +/// transfers +int faiss_StandardGpuResources_setPinnedMemory(FaissStandardGpuResources*, size_t size); + +/// Called to change the stream for work ordering +int faiss_StandardGpuResources_setDefaultStream(FaissStandardGpuResources*, int device, cudaStream_t stream); + +/// Called to change the work ordering streams to the null stream +/// for all devices +int faiss_StandardGpuResources_setDefaultNullStreamAllDevices(FaissStandardGpuResources*); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/example_gpu_c.c b/core/src/index/thirdparty/faiss/c_api/gpu/example_gpu_c.c new file mode 100644 index 0000000000..c2a10a2e30 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/example_gpu_c.c @@ -0,0 +1,106 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#include +#include +#include + +#include "error_c.h" +#include "Index_c.h" +#include "AutoTune_c.h" +#include "GpuAutoTune_c.h" +#include "StandardGpuResources_c.h" + +#define FAISS_TRY(C) \ + { \ + if (C) { \ + fprintf(stderr, "%s", faiss_get_last_error()); \ + exit(-1); \ + } \ + } + +double drand() { + return (double)rand() / (double)RAND_MAX; +} + +int main() { + time_t seed = time(NULL); + srand(seed); + printf("Generating some data...\n"); + int d = 128; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + float *xb = malloc(d * nb * sizeof(float)); + float *xq = malloc(d * nq * sizeof(float)); + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) xb[d * i + j] = drand(); + xb[d * i] += i / 1000.; + } + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) xq[d * i + j] = drand(); + xq[d * i] += i / 1000.; + } + + printf("Loading standard GPU resources...\n"); + FaissStandardGpuResources* gpu_res = NULL; + FAISS_TRY(faiss_StandardGpuResources_new(&gpu_res)); + + printf("Building an index...\n"); + FaissIndex* cpu_index = NULL; + FAISS_TRY(faiss_index_factory(&cpu_index, d, "Flat", METRIC_L2)); // use factory to create index + + printf("Moving index to the GPU...\n"); + FaissGpuIndex* index = NULL; + FaissGpuClonerOptions* options = NULL; + FAISS_TRY(faiss_GpuClonerOptions_new(&options)); + FAISS_TRY(faiss_index_cpu_to_gpu_with_options(gpu_res, 0, cpu_index, options, &index)); + + printf("is_trained = %s\n", faiss_Index_is_trained(index) ? "true" : "false"); + FAISS_TRY(faiss_Index_add(index, nb, xb)); // add vectors to the index + printf("ntotal = %ld\n", faiss_Index_ntotal(index)); + + printf("Searching...\n"); + int k = 5; + + { // sanity check: search 5 first vectors of xb + idx_t *I = malloc(k * 5 * sizeof(idx_t)); + float *D = malloc(k * 5 * sizeof(float)); + FAISS_TRY(faiss_Index_search(index, 5, xb, k, D, I)); + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) printf("%5ld (d=%2.3f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + free(I); + free(D); + } + { // search xq + idx_t *I = malloc(k * nq * sizeof(idx_t)); + float *D = malloc(k * nq * sizeof(float)); + FAISS_TRY(faiss_Index_search(index, 5, xb, k, D, I)); + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) printf("%5ld (d=%2.3f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + free(I); + free(D); + } + + printf("Freeing index...\n"); + faiss_Index_free(index); + printf("Freeing GPU resources...\n"); + faiss_GpuResources_free(gpu_res); + faiss_GpuClonerOptions_free(options); + printf("Done.\n"); + + return 0; +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/gpu/macros_impl.h b/core/src/index/thirdparty/faiss/c_api/gpu/macros_impl.h new file mode 100644 index 0000000000..3f6ea5844a --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/gpu/macros_impl.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#ifndef GPU_MACROS_IMPL_H +#define GPU_MACROS_IMPL_H +#include "../macros_impl.h" + +#undef DEFINE_GETTER +#define DEFINE_GETTER(clazz, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *obj) { \ + return static_cast< ty >( \ + reinterpret_cast< const faiss::gpu::clazz *>(obj)-> name \ + ); \ + } + +#undef DEFINE_SETTER +#define DEFINE_SETTER(clazz, ty, name) \ + void faiss_ ## clazz ## _set_ ## name (Faiss ## clazz *obj, ty val) { \ + reinterpret_cast< faiss::gpu::clazz *>(obj)-> name = val; \ + } + +#undef DEFINE_SETTER_STATIC +#define DEFINE_SETTER_STATIC(clazz, ty_to, ty_from, name) \ + void faiss_ ## clazz ## _set_ ## name (Faiss ## clazz *obj, ty_from val) { \ + reinterpret_cast< faiss::gpu::clazz *>(obj)-> name = \ + static_cast< ty_to >(val); \ + } + +#undef DEFINE_DESTRUCTOR +#define DEFINE_DESTRUCTOR(clazz) \ + void faiss_ ## clazz ## _free (Faiss ## clazz *obj) { \ + delete reinterpret_cast(obj); \ + } + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.cpp b/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.cpp new file mode 100644 index 0000000000..4b3fec8fb7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.cpp @@ -0,0 +1,220 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include "AuxIndexStructures_c.h" +#include "../../impl/AuxIndexStructures.h" +#include "../macros_impl.h" +#include + +using faiss::BufferList; +using faiss::IDSelector; +using faiss::IDSelectorBatch; +using faiss::IDSelectorRange; +using faiss::RangeSearchResult; +using faiss::RangeSearchPartialResult; +using faiss::RangeQueryResult; +using faiss::DistanceComputer; + +DEFINE_GETTER(RangeSearchResult, size_t, nq) + +int faiss_RangeSearchResult_new(FaissRangeSearchResult** p_rsr, idx_t nq) { + try { + *p_rsr = reinterpret_cast( + new RangeSearchResult(nq)); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_RangeSearchResult_new_with(FaissRangeSearchResult** p_rsr, idx_t nq, int alloc_lims) { + try { + *p_rsr = reinterpret_cast( + new RangeSearchResult(nq, static_cast(alloc_lims))); + return 0; + } CATCH_AND_HANDLE +} + +/// called when lims contains the nb of elements result entries +/// for each query +int faiss_RangeSearchResult_do_allocation(FaissRangeSearchResult* rsr) { + try { + reinterpret_cast(rsr)->do_allocation(); + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_DESTRUCTOR(RangeSearchResult) + +/// getter for buffer_size +DEFINE_GETTER(RangeSearchResult, size_t, buffer_size) + +/// getter for lims: size (nq + 1) +void faiss_RangeSearchResult_lims(FaissRangeSearchResult* rsr, size_t** lims) { + *lims = reinterpret_cast(rsr)->lims; +} + +/// getter for labels and respective distances (not sorted): +/// result for query i is labels[lims[i]:lims[i+1]] +void faiss_RangeSearchResult_labels(FaissRangeSearchResult* rsr, idx_t** labels, float** distances) { + auto sr = reinterpret_cast(rsr); + *labels = sr->labels; + *distances = sr->distances; +} + +DEFINE_DESTRUCTOR(IDSelector) + +int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id) { + return reinterpret_cast(sel)->is_member(id); +} + +DEFINE_DESTRUCTOR(IDSelectorRange) + +DEFINE_GETTER(IDSelectorRange, idx_t, imin) +DEFINE_GETTER(IDSelectorRange, idx_t, imax) + +int faiss_IDSelectorRange_new(FaissIDSelectorRange** p_sel, idx_t imin, idx_t imax) { + try { + *p_sel = reinterpret_cast( + new IDSelectorRange(imin, imax) + ); + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_GETTER(IDSelectorBatch, int, nbits) +DEFINE_GETTER(IDSelectorBatch, idx_t, mask) + +int faiss_IDSelectorBatch_new(FaissIDSelectorBatch** p_sel, size_t n, const idx_t* indices) { + try { + *p_sel = reinterpret_cast( + new IDSelectorBatch(n, indices) + ); + return 0; + } CATCH_AND_HANDLE +} + +// Below are structures used only by Index implementations + +DEFINE_DESTRUCTOR(BufferList) + +DEFINE_GETTER(BufferList, size_t, buffer_size) +DEFINE_GETTER(BufferList, size_t, wp) + +int faiss_BufferList_append_buffer(FaissBufferList* bl) { + try { + reinterpret_cast(bl)->append_buffer(); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_BufferList_new(FaissBufferList** p_bl, size_t buffer_size) { + try { + *p_bl = reinterpret_cast( + new BufferList(buffer_size) + ); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_BufferList_add(FaissBufferList* bl, idx_t id, float dis) { + try { + reinterpret_cast(bl)->add(id, dis); + return 0; + } CATCH_AND_HANDLE +} + +/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to +/// tables dest_ids, dest_dis +int faiss_BufferList_copy_range( + FaissBufferList* bl, size_t ofs, size_t n, idx_t *dest_ids, float *dest_dis) { + try { + reinterpret_cast(bl)->copy_range(ofs, n, dest_ids, dest_dis); + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_GETTER(RangeQueryResult, idx_t, qno) +DEFINE_GETTER(RangeQueryResult, size_t, nres) +DEFINE_GETTER_PERMISSIVE(RangeQueryResult, FaissRangeSearchPartialResult*, pres) + +int faiss_RangeQueryResult_add(FaissRangeQueryResult* qr, float dis, idx_t id) { + try { + reinterpret_cast(qr)->add(dis, id); + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_GETTER_PERMISSIVE(RangeSearchPartialResult, FaissRangeSearchResult*, res) + +int faiss_RangeSearchPartialResult_new( + FaissRangeSearchPartialResult** p_res, FaissRangeSearchResult* res_in) { + try { + *p_res = reinterpret_cast( + new RangeSearchPartialResult( + reinterpret_cast(res_in)) + ); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_RangeSearchPartialResult_finalize( + FaissRangeSearchPartialResult* res) { + try { + reinterpret_cast(res)->finalize(); + return 0; + } CATCH_AND_HANDLE +} + +/// called by range_search before do_allocation +int faiss_RangeSearchPartialResult_set_lims( + FaissRangeSearchPartialResult* res) { + try { + reinterpret_cast(res)->set_lims(); + return 0; + } CATCH_AND_HANDLE +} + +int faiss_RangeSearchPartialResult_new_result( + FaissRangeSearchPartialResult* res, idx_t qno, FaissRangeQueryResult** qr) { + + try { + auto q = + &reinterpret_cast(res)->new_result(qno); + if (qr) { + *qr = reinterpret_cast(&q); + } + return 0; + } CATCH_AND_HANDLE +} + +DEFINE_DESTRUCTOR(DistanceComputer) + +int faiss_DistanceComputer_set_query(FaissDistanceComputer *dc, const float *x) { + try { + reinterpret_cast(dc)->set_query(x); + return 0; + } + CATCH_AND_HANDLE +} + +int faiss_DistanceComputer_vector_to_query_dis(FaissDistanceComputer *dc, idx_t i, float *qd) { + try { + *qd = reinterpret_cast(dc)->operator()(i); + return 0; + } + CATCH_AND_HANDLE +} + +int faiss_DistanceComputer_symmetric_dis(FaissDistanceComputer *dc, idx_t i, idx_t j, float *vd) { + try { + *vd = reinterpret_cast(dc)->symmetric_dis(i, j); + return 0; + } + CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.h b/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.h new file mode 100644 index 0000000000..1d66b0aac0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/impl/AuxIndexStructures_c.h @@ -0,0 +1,149 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_AUX_INDEX_STRUCTURES_C_H +#define FAISS_AUX_INDEX_STRUCTURES_C_H + +#include "../Index_c.h" +#include "../faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +FAISS_DECLARE_CLASS(RangeSearchResult) + +FAISS_DECLARE_GETTER(RangeSearchResult, size_t, nq) + +int faiss_RangeSearchResult_new(FaissRangeSearchResult** p_rsr, idx_t nq); + +int faiss_RangeSearchResult_new_with(FaissRangeSearchResult** p_rsr, idx_t nq, int alloc_lims); + +/// called when lims contains the nb of elements result entries +/// for each query +int faiss_RangeSearchResult_do_allocation(FaissRangeSearchResult* rsr); + +FAISS_DECLARE_DESTRUCTOR(RangeSearchResult) + +/// getter for buffer_size +FAISS_DECLARE_GETTER(RangeSearchResult, size_t, buffer_size) + +/// getter for lims: size (nq + 1) +void faiss_RangeSearchResult_lims( + FaissRangeSearchResult* rsr, size_t** lims); + +/// getter for labels and respective distances (not sorted): +/// result for query i is labels[lims[i]:lims[i+1]] +void faiss_RangeSearchResult_labels( + FaissRangeSearchResult* rsr, idx_t** labels, float** distances); + + +/** Encapsulates a set of ids to remove. */ +FAISS_DECLARE_CLASS(IDSelector) +FAISS_DECLARE_DESTRUCTOR(IDSelector) + +int faiss_IDSelector_is_member(const FaissIDSelector* sel, idx_t id); + +/** remove ids between [imni, imax) */ +FAISS_DECLARE_CLASS(IDSelectorRange) +FAISS_DECLARE_DESTRUCTOR(IDSelectorRange) + +FAISS_DECLARE_GETTER(IDSelectorRange, idx_t, imin) +FAISS_DECLARE_GETTER(IDSelectorRange, idx_t, imax) + +int faiss_IDSelectorRange_new(FaissIDSelectorRange** p_sel, idx_t imin, idx_t imax); + +/** Remove ids from a set. Repetitions of ids in the indices set + * passed to the constructor does not hurt performance. The hash + * function used for the bloom filter and GCC's implementation of + * unordered_set are just the least significant bits of the id. This + * works fine for random ids or ids in sequences but will produce many + * hash collisions if lsb's are always the same */ +FAISS_DECLARE_CLASS(IDSelectorBatch) + +FAISS_DECLARE_GETTER(IDSelectorBatch, int, nbits) +FAISS_DECLARE_GETTER(IDSelectorBatch, idx_t, mask) + +int faiss_IDSelectorBatch_new(FaissIDSelectorBatch** p_sel, size_t n, const idx_t* indices); + +// Below are structures used only by Index implementations + +/** List of temporary buffers used to store results before they are + * copied to the RangeSearchResult object. */ +FAISS_DECLARE_CLASS(BufferList) +FAISS_DECLARE_DESTRUCTOR(BufferList) + +FAISS_DECLARE_GETTER(BufferList, size_t, buffer_size) +FAISS_DECLARE_GETTER(BufferList, size_t, wp) + +typedef struct FaissBuffer { + idx_t *ids; + float *dis; +} FaissBuffer; + +int faiss_BufferList_append_buffer(FaissBufferList* bl); + +int faiss_BufferList_new(FaissBufferList** p_bl, size_t buffer_size); + +int faiss_BufferList_add(FaissBufferList* bl, idx_t id, float dis); + +/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to +/// tables dest_ids, dest_dis +int faiss_BufferList_copy_range( + FaissBufferList* bl, size_t ofs, size_t n, idx_t *dest_ids, float *dest_dis); + +/// the entries in the buffers are split per query +FAISS_DECLARE_CLASS(RangeSearchPartialResult) + +/// result structure for a single query +FAISS_DECLARE_CLASS(RangeQueryResult) +FAISS_DECLARE_GETTER(RangeQueryResult, idx_t, qno) +FAISS_DECLARE_GETTER(RangeQueryResult, size_t, nres) +FAISS_DECLARE_GETTER(RangeQueryResult, FaissRangeSearchPartialResult*, pres) + +int faiss_RangeQueryResult_add(FaissRangeQueryResult* qr, float dis, idx_t id); + + +FAISS_DECLARE_GETTER(RangeSearchPartialResult, FaissRangeSearchResult*, res) + +int faiss_RangeSearchPartialResult_new( + FaissRangeSearchPartialResult** p_res, FaissRangeSearchResult* res_in); + +int faiss_RangeSearchPartialResult_finalize( + FaissRangeSearchPartialResult* res); + +/// called by range_search before do_allocation +int faiss_RangeSearchPartialResult_set_lims( + FaissRangeSearchPartialResult* res); + +int faiss_RangeSearchPartialResult_new_result( + FaissRangeSearchPartialResult* res, idx_t qno, FaissRangeQueryResult** qr); + + +FAISS_DECLARE_CLASS(DistanceComputer) +/// called before computing distances +int faiss_DistanceComputer_set_query(FaissDistanceComputer *dc, const float *x); + +/** + * Compute distance of vector i to current query. + * This function corresponds to the function call operator: DistanceComputer::operator() + */ +int faiss_DistanceComputer_vector_to_query_dis( FaissDistanceComputer *dc, idx_t i, float *qd); +/// compute distance between two stored vectors +int faiss_DistanceComputer_symmetric_dis(FaissDistanceComputer *dc, idx_t i, idx_t j, float *vd); + +FAISS_DECLARE_DESTRUCTOR(DistanceComputer) + + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/core/src/index/thirdparty/faiss/c_api/index_factory_c.cpp b/core/src/index/thirdparty/faiss/c_api/index_factory_c.cpp new file mode 100644 index 0000000000..f7f00c4132 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/index_factory_c.cpp @@ -0,0 +1,26 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +#include +#include "index_factory.h" +#include "index_factory_c.h" +#include "macros_impl.h" + +using faiss::Index; + +/** Build and index with the sequence of processing steps described in + * the string. + */ +int faiss_index_factory(FaissIndex** p_index, int d, const char* description, FaissMetricType metric) { + try { + *p_index = reinterpret_cast(faiss::index_factory( + d, description, static_cast(metric))); + } CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/index_factory_c.h b/core/src/index/thirdparty/faiss/c_api/index_factory_c.h new file mode 100644 index 0000000000..4262fe09a2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/index_factory_c.h @@ -0,0 +1,30 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c -*- + +#ifndef FAISS_INDEX_FACTORY_C_H +#define FAISS_INDEX_FACTORY_C_H + +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Build and index with the sequence of processing steps described in + * the string. + */ +int faiss_index_factory(FaissIndex** p_index, int d, const char* description, FaissMetricType metric); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/index_io_c.cpp b/core/src/index/thirdparty/faiss/c_api/index_io_c.cpp new file mode 100644 index 0000000000..8c0ca4420e --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/index_io_c.cpp @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- +// I/O code for indexes + +#include "index_io_c.h" +#include "index_io.h" +#include "macros_impl.h" + +using faiss::Index; + +int faiss_write_index(const FaissIndex *idx, FILE *f) { + try { + faiss::write_index(reinterpret_cast(idx), f); + } CATCH_AND_HANDLE +} + +int faiss_write_index_fname(const FaissIndex *idx, const char *fname) { + try { + faiss::write_index(reinterpret_cast(idx), fname); + } CATCH_AND_HANDLE +} + +int faiss_read_index(FILE *f, int io_flags, FaissIndex **p_out) { + try { + auto out = faiss::read_index(f, io_flags); + *p_out = reinterpret_cast(out); + } CATCH_AND_HANDLE +} + +int faiss_read_index_fname(const char *fname, int io_flags, FaissIndex **p_out) { + try { + auto out = faiss::read_index(fname, io_flags); + *p_out = reinterpret_cast(out); + } CATCH_AND_HANDLE +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/index_io_c.h b/core/src/index/thirdparty/faiss/c_api/index_io_c.h new file mode 100644 index 0000000000..f703e491ca --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/index_io_c.h @@ -0,0 +1,50 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved +// -*- c++ -*- +// I/O code for indexes + + +#ifndef FAISS_INDEX_IO_C_H +#define FAISS_INDEX_IO_C_H + +#include +#include "faiss_c.h" +#include "Index_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Write index to a file. + * This is equivalent to `faiss::write_index` when a file descriptor is provided. + */ +int faiss_write_index(const FaissIndex *idx, FILE *f); + +/** Write index to a file. + * This is equivalent to `faiss::write_index` when a file path is provided. + */ +int faiss_write_index_fname(const FaissIndex *idx, const char *fname); + +#define FAISS_IO_FLAG_MMAP 1 +#define FAISS_IO_FLAG_READ_ONLY 2 + +/** Read index from a file. + * This is equivalent to `faiss:read_index` when a file descriptor is given. + */ +int faiss_read_index(FILE *f, int io_flags, FaissIndex **p_out); + +/** Read index from a file. + * This is equivalent to `faiss:read_index` when a file path is given. + */ +int faiss_read_index_fname(const char *fname, int io_flags, FaissIndex **p_out); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/c_api/macros_impl.h b/core/src/index/thirdparty/faiss/c_api/macros_impl.h new file mode 100644 index 0000000000..af07938018 --- /dev/null +++ b/core/src/index/thirdparty/faiss/c_api/macros_impl.h @@ -0,0 +1,110 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved. +// -*- c++ -*- + +/// Utility macros for the C wrapper implementation. + +#ifndef MACROS_IMPL_H +#define MACROS_IMPL_H + +#include "faiss_c.h" +#include "FaissException.h" +#include "error_impl.h" +#include +#include + +#ifdef NDEBUG +#define CATCH_AND_HANDLE \ + catch (faiss::FaissException& e) { \ + faiss_last_exception = \ + std::make_exception_ptr(e); \ + return -2; \ + } catch (std::exception& e) { \ + faiss_last_exception = \ + std::make_exception_ptr(e); \ + return -4; \ + } catch (...) { \ + faiss_last_exception = \ + std::make_exception_ptr( \ + std::runtime_error("Unknown error")); \ + return -1; \ + } return 0; +#else +#define CATCH_AND_HANDLE \ + catch (faiss::FaissException& e) { \ + std::cerr << e.what() << '\n'; \ + faiss_last_exception = \ + std::make_exception_ptr(e); \ + return -2; \ + } catch (std::exception& e) { \ + std::cerr << e.what() << '\n'; \ + faiss_last_exception = \ + std::make_exception_ptr(e); \ + return -4; \ + } catch (...) { \ + std::cerr << "Unrecognized exception!\n"; \ + faiss_last_exception = \ + std::make_exception_ptr( \ + std::runtime_error("Unknown error")); \ + return -1; \ + } return 0; +#endif + +#define DEFINE_GETTER(clazz, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *obj) { \ + return static_cast< ty >( \ + reinterpret_cast< const faiss::clazz *>(obj)-> name \ + ); \ + } + +#define DEFINE_GETTER_SUBCLASS(clazz, parent, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *obj) { \ + return static_cast< ty >( \ + reinterpret_cast(obj)-> name \ + ); \ + } + +#define DEFINE_GETTER_PERMISSIVE(clazz, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *obj) { \ + return ( ty ) ( \ + reinterpret_cast(obj)-> name \ + ); \ + } + +#define DEFINE_GETTER_SUBCLASS_PERMISSIVE(clazz, parent, ty, name) \ + ty faiss_ ## clazz ## _ ## name (const Faiss ## clazz *obj) { \ + return ( ty ) ( \ + reinterpret_cast(obj)-> name \ + ); \ + } + +#define DEFINE_SETTER(clazz, ty, name) \ + void faiss_ ## clazz ## _set_ ## name (Faiss ## clazz *obj, ty val) { \ + reinterpret_cast< faiss::clazz *>(obj)-> name = val; \ + } + +#define DEFINE_SETTER_STATIC(clazz, ty_to, ty_from, name) \ + void faiss_ ## clazz ## _set_ ## name (Faiss ## clazz *obj, ty_from val) { \ + reinterpret_cast< faiss::clazz *>(obj)-> name = \ + static_cast< ty_to >(val); \ + } + +#define DEFINE_DESTRUCTOR(clazz) \ + void faiss_ ## clazz ## _free (Faiss ## clazz *obj) { \ + delete reinterpret_cast(obj); \ + } + +#define DEFINE_INDEX_DOWNCAST(clazz) \ + Faiss ## clazz * faiss_ ## clazz ## _cast (FaissIndex* index) { \ + return reinterpret_cast( \ + dynamic_cast< faiss::clazz *>( \ + reinterpret_cast(index))); \ + } + +#endif diff --git a/core/src/index/thirdparty/faiss/clone_index.cpp b/core/src/index/thirdparty/faiss/clone_index.cpp new file mode 100644 index 0000000000..ca9809d284 --- /dev/null +++ b/core/src/index/thirdparty/faiss/clone_index.cpp @@ -0,0 +1,156 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +/************************************************************* + * cloning functions + **************************************************************/ + + + +Index * clone_index (const Index *index) +{ + Cloner cl; + return cl.clone_Index (index); +} + +// assumes there is a copy constructor ready. Always try from most +// specific to most general. Most indexes don't have complicated +// structs, the default copy constructor often just works. +#define TRYCLONE(classname, obj) \ + if (const classname *clo = dynamic_cast(obj)) { \ + return new classname(*clo); \ + } else + +VectorTransform *Cloner::clone_VectorTransform (const VectorTransform *vt) +{ + TRYCLONE (RemapDimensionsTransform, vt) + TRYCLONE (OPQMatrix, vt) + TRYCLONE (PCAMatrix, vt) + TRYCLONE (ITQMatrix, vt) + TRYCLONE (RandomRotationMatrix, vt) + TRYCLONE (LinearTransform, vt) + { + FAISS_THROW_MSG("clone not supported for this type of VectorTransform"); + } + return nullptr; +} + +IndexIVF * Cloner::clone_IndexIVF (const IndexIVF *ivf) +{ + TRYCLONE (IndexIVFPQR, ivf) + TRYCLONE (IndexIVFPQ, ivf) + TRYCLONE (IndexIVFFlat, ivf) + TRYCLONE (IndexIVFScalarQuantizer, ivf) + TRYCLONE (IndexIVFSQHybrid, ivf) + { + FAISS_THROW_MSG("clone not supported for this type of IndexIVF"); + } + return nullptr; +} + +Index *Cloner::clone_Index (IndexComposition* index_composition) { + FAISS_THROW_MSG( "Not implemented"); +} + +Index *Cloner::clone_Index (const Index *index) +{ + TRYCLONE (IndexPQ, index) + TRYCLONE (IndexLSH, index) + TRYCLONE (IndexFlatL2, index) + TRYCLONE (IndexFlatIP, index) + TRYCLONE (IndexFlat, index) + TRYCLONE (IndexLattice, index) + TRYCLONE (IndexScalarQuantizer, index) + TRYCLONE (MultiIndexQuantizer, index) + if (const IndexIVF * ivf = dynamic_cast(index)) { + IndexIVF *res = clone_IndexIVF (ivf); + if (ivf->invlists == nullptr) { + res->invlists = nullptr; + } else if (auto *ails = dynamic_cast + (ivf->invlists)) { + res->invlists = new ArrayInvertedLists(*ails); + res->own_invlists = true; + } else if (auto *ails = dynamic_cast(ivf->invlists)) { + res->invlists = new ReadOnlyArrayInvertedLists(*ails); + res->own_invlists = true; + } else { + FAISS_THROW_MSG( "clone not supported for this type of inverted lists"); + } + res->own_fields = true; + res->quantizer = clone_Index (ivf->quantizer); + return res; + } else if (const IndexPreTransform * ipt = + dynamic_cast (index)) { + IndexPreTransform *res = new IndexPreTransform (); + res->d = ipt->d; + res->ntotal = ipt->ntotal; + res->is_trained = ipt->is_trained; + res->metric_type = ipt->metric_type; + res->metric_arg = ipt->metric_arg; + + + res->index = clone_Index (ipt->index); + for (int i = 0; i < ipt->chain.size(); i++) + res->chain.push_back (clone_VectorTransform (ipt->chain[i])); + res->own_fields = true; + return res; + } else if (const IndexIDMap *idmap = + dynamic_cast (index)) { + IndexIDMap *res = new IndexIDMap (*idmap); + res->own_fields = true; + res->index = clone_Index (idmap->index); + return res; + } else if (const IndexHNSW *ihnsw = + dynamic_cast (index)) { + IndexHNSW *res = new IndexHNSW (*ihnsw); + res->own_fields = true; + res->storage = clone_Index (ihnsw->storage); + return res; + } else if (const Index2Layer *i2l = + dynamic_cast (index)) { + Index2Layer *res = new Index2Layer (*i2l); + res->q1.own_fields = true; + res->q1.quantizer = clone_Index (i2l->q1.quantizer); + return res; + } else { + FAISS_THROW_MSG( "clone not supported for this type of Index"); + } + return nullptr; +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/clone_index.h b/core/src/index/thirdparty/faiss/clone_index.h new file mode 100644 index 0000000000..45990c93f7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/clone_index.h @@ -0,0 +1,48 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +// I/O code for indexes + +#pragma once + + + +namespace faiss { + +struct Index; +struct IndexIVF; +struct VectorTransform; + +namespace gpu { +struct GpuIndexFlat; +} + +/* cloning functions */ +Index *clone_index (const Index *); + +struct IndexComposition { + Index *index = nullptr; + gpu::GpuIndexFlat *quantizer = nullptr; + long mode = 0; // 0: all data, 1: copy quantizer, 2: copy data +}; + +/** Cloner class, useful to override classes with other cloning + * functions. The cloning function above just calls + * Cloner::clone_Index. */ +struct Cloner { + virtual VectorTransform *clone_VectorTransform (const VectorTransform *); + virtual Index *clone_Index (const Index *); + virtual Index *clone_Index (IndexComposition* index_composition); + virtual IndexIVF *clone_IndexIVF (const IndexIVF *); + virtual ~Cloner() {} +}; + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/conda/Dockerfile b/core/src/index/thirdparty/faiss/conda/Dockerfile new file mode 100644 index 0000000000..9184e8fea3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/Dockerfile @@ -0,0 +1,33 @@ +FROM soumith/conda-cuda:latest + +COPY ./ faiss +WORKDIR /faiss/conda + +ENV FAISS_BUILD_VERSION 1.5.3 +ENV FAISS_BUILD_NUMBER 0 +RUN conda build faiss --no-anaconda-upload -c pytorch +RUN CUDA_ROOT=/usr/local/cuda-8.0 \ + CUDA_ARCH="-gencode=arch=compute_35,code=compute_35 \ + -gencode=arch=compute_52,code=compute_52 \ + -gencode=arch=compute_60,code=compute_60 \ + -gencode=arch=compute_61,code=compute_61" \ + conda build faiss-gpu --variants '{ "cudatoolkit": "8.0" }' \ + --no-anaconda-upload -c pytorch --no-test +RUN CUDA_ROOT=/usr/local/cuda-9.0 \ + CUDA_ARCH="-gencode=arch=compute_35,code=compute_35 \ + -gencode=arch=compute_52,code=compute_52 \ + -gencode=arch=compute_60,code=compute_60 \ + -gencode=arch=compute_61,code=compute_61 \ + -gencode=arch=compute_70,code=compute_70" \ + conda build faiss-gpu --variants '{ "cudatoolkit": "9.0" }' \ + --no-anaconda-upload -c pytorch --no-test +RUN CUDA_ROOT=/usr/local/cuda-10.0 \ + CUDA_ARCH="-gencode=arch=compute_35,code=compute_35 \ + -gencode=arch=compute_52,code=compute_52 \ + -gencode=arch=compute_60,code=compute_60 \ + -gencode=arch=compute_61,code=compute_61 \ + -gencode=arch=compute_70,code=compute_70 \ + -gencode=arch=compute_72,code=compute_72 \ + -gencode=arch=compute_75,code=compute_75" \ + conda build faiss-gpu --variants '{ "cudatoolkit": "10.0" }' \ + --no-anaconda-upload -c pytorch --no-test diff --git a/core/src/index/thirdparty/faiss/conda/conda_build_config.yaml b/core/src/index/thirdparty/faiss/conda/conda_build_config.yaml new file mode 100644 index 0000000000..e9f0a51d26 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/conda_build_config.yaml @@ -0,0 +1,7 @@ +CONDA_BUILD_SYSROOT: + - /opt/MacOSX10.9.sdk # [osx] +python: + - 2.7 + - 3.5 + - 3.6 + - 3.7 diff --git a/core/src/index/thirdparty/faiss/conda/faiss-gpu/build.sh b/core/src/index/thirdparty/faiss/conda/faiss-gpu/build.sh new file mode 100644 index 0000000000..25326c90d9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss-gpu/build.sh @@ -0,0 +1,16 @@ +# Build avx2 version +CXXFLAGS="-mavx2 -mf16c" ./configure --with-cuda=$CUDA_ROOT --with-cuda-arch="$CUDA_ARCH" +make -j $CPU_COUNT +make -C python _swigfaiss_avx2.so +make clean + +# Build vanilla version (no avx) +./configure --with-cuda=$CUDA_ROOT --with-cuda-arch="$CUDA_ARCH" +make -j $CPU_COUNT +make -C python _swigfaiss.so + +make -C python build + +cd python + +$PYTHON setup.py install --single-version-externally-managed --record=record.txt diff --git a/core/src/index/thirdparty/faiss/conda/faiss-gpu/conda_build_config.yaml b/core/src/index/thirdparty/faiss/conda/faiss-gpu/conda_build_config.yaml new file mode 100644 index 0000000000..da98e5d414 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss-gpu/conda_build_config.yaml @@ -0,0 +1,11 @@ +cxx_compiler_version: + - 5.4 +cudatoolkit: + - 8.0 + - 9.0 + - 9.2 + - 10.0 + - 10.1 +pin_run_as_build: + cudatoolkit: + max_pin: x.x diff --git a/core/src/index/thirdparty/faiss/conda/faiss-gpu/meta.yaml b/core/src/index/thirdparty/faiss/conda/faiss-gpu/meta.yaml new file mode 100644 index 0000000000..886531bafc --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss-gpu/meta.yaml @@ -0,0 +1,41 @@ +package: + name: faiss-gpu + version: "{{ FAISS_BUILD_VERSION }}" + +source: + git_url: ../../ + +requirements: + build: + - {{ compiler('cxx') }} + - llvm-openmp # [osx] + - setuptools + - swig + + host: + - python {{ python }} + - intel-openmp # [osx] + - numpy 1.11.* + - mkl >=2018 + - cudatoolkit {{ cudatoolkit }} + + run: + - python {{ python }} + - intel-openmp # [osx] + - numpy >=1.11 + - mkl >=2018 + - blas=*=mkl + - {{ pin_compatible('cudatoolkit') }} + +build: + number: {{ FAISS_BUILD_NUMBER }} + script_env: + - CUDA_ROOT + - CUDA_ARCH + +about: + home: https://github.com/facebookresearch/faiss + license: MIT + license_family: MIT + license_file: LICENSE + summary: A library for efficient similarity search and clustering of dense vectors. diff --git a/core/src/index/thirdparty/faiss/conda/faiss-gpu/run_test.py b/core/src/index/thirdparty/faiss/conda/faiss-gpu/run_test.py new file mode 100644 index 0000000000..68e0bbc3e3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss-gpu/run_test.py @@ -0,0 +1,16 @@ +import faiss +import numpy as np + +d = 128 +n = 100 + +rs = np.random.RandomState(1337) +x = rs.rand(n, d).astype(np.float32) + +index = faiss.IndexFlatL2(d) + +res = faiss.StandardGpuResources() +gpu_index = faiss.index_cpu_to_gpu(res, 0, index) +gpu_index.add(x) + +D, I = index.search(x, 10) diff --git a/core/src/index/thirdparty/faiss/conda/faiss/build.sh b/core/src/index/thirdparty/faiss/conda/faiss/build.sh new file mode 100644 index 0000000000..87ccb4cad0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss/build.sh @@ -0,0 +1,16 @@ +# Build avx2 version +CXXFLAGS="-mavx2 -mf16c" ./configure --without-cuda +make -j $CPU_COUNT +make -C python _swigfaiss_avx2.so +make clean + +# Build vanilla version (no avx) +./configure --without-cuda +make -j $CPU_COUNT +make -C python _swigfaiss.so + +make -C python build + +cd python + +$PYTHON setup.py install --single-version-externally-managed --record=record.txt diff --git a/core/src/index/thirdparty/faiss/conda/faiss/meta.yaml b/core/src/index/thirdparty/faiss/conda/faiss/meta.yaml new file mode 100644 index 0000000000..e765cf388d --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss/meta.yaml @@ -0,0 +1,36 @@ +package: + name: faiss-cpu + version: "{{ FAISS_BUILD_VERSION }}" + +source: + git_url: ../../ + +requirements: + build: + - {{ compiler('cxx') }} + - llvm-openmp # [osx] + - setuptools + - swig + + host: + - python {{ python }} + - intel-openmp # [osx] + - numpy 1.11.* + - mkl >=2018 + + run: + - python {{ python }} + - intel-openmp # [osx] + - numpy >=1.11 + - blas=*=mkl + - mkl >=2018 + +build: + number: {{ FAISS_BUILD_NUMBER }} + +about: + home: https://github.com/facebookresearch/faiss + license: MIT + license_family: MIT + license_file: LICENSE + summary: A library for efficient similarity search and clustering of dense vectors. diff --git a/core/src/index/thirdparty/faiss/conda/faiss/run_test.py b/core/src/index/thirdparty/faiss/conda/faiss/run_test.py new file mode 100644 index 0000000000..57e6d7d92c --- /dev/null +++ b/core/src/index/thirdparty/faiss/conda/faiss/run_test.py @@ -0,0 +1,14 @@ +import faiss +import numpy as np + +d = 128 +# NOTE: BLAS kicks in only when n > distance_compute_blas_threshold = 20 +n = 100 + +rs = np.random.RandomState(1337) +x = rs.rand(n, d).astype(np.float32) + +index = faiss.IndexFlatL2(d) +index.add(x) + +D, I = index.search(x, 10) diff --git a/core/src/index/thirdparty/faiss/configure b/core/src/index/thirdparty/faiss/configure new file mode 100755 index 0000000000..ed40daefd9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/configure @@ -0,0 +1,7998 @@ +#! /bin/sh +# Guess values for system-dependent variables and create Makefiles. +# Generated by GNU Autoconf 2.69 for faiss 1.0. +# +# +# Copyright (C) 1992-1996, 1998-2012 Free Software Foundation, Inc. +# +# +# This configure script is free software; the Free Software Foundation +# gives unlimited permission to copy, distribute and modify it. +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +## -------------------- ## +## M4sh Initialization. ## +## -------------------- ## + +# Be more Bourne compatible +DUALCASE=1; export DUALCASE # for MKS sh +if test -n "${ZSH_VERSION+set}" && (emulate sh) >/dev/null 2>&1; then : + emulate sh + NULLCMD=: + # Pre-4.2 versions of Zsh do word splitting on ${1+"$@"}, which + # is contrary to our usage. Disable this feature. + alias -g '${1+"$@"}'='"$@"' + setopt NO_GLOB_SUBST +else + case `(set -o) 2>/dev/null` in #( + *posix*) : + set -o posix ;; #( + *) : + ;; +esac +fi + + +as_nl=' +' +export as_nl +# Printing a long string crashes Solaris 7 /usr/bin/printf. +as_echo='\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\' +as_echo=$as_echo$as_echo$as_echo$as_echo$as_echo +as_echo=$as_echo$as_echo$as_echo$as_echo$as_echo$as_echo +# Prefer a ksh shell builtin over an external printf program on Solaris, +# but without wasting forks for bash or zsh. +if test -z "$BASH_VERSION$ZSH_VERSION" \ + && (test "X`print -r -- $as_echo`" = "X$as_echo") 2>/dev/null; then + as_echo='print -r --' + as_echo_n='print -rn --' +elif (test "X`printf %s $as_echo`" = "X$as_echo") 2>/dev/null; then + as_echo='printf %s\n' + as_echo_n='printf %s' +else + if test "X`(/usr/ucb/echo -n -n $as_echo) 2>/dev/null`" = "X-n $as_echo"; then + as_echo_body='eval /usr/ucb/echo -n "$1$as_nl"' + as_echo_n='/usr/ucb/echo -n' + else + as_echo_body='eval expr "X$1" : "X\\(.*\\)"' + as_echo_n_body='eval + arg=$1; + case $arg in #( + *"$as_nl"*) + expr "X$arg" : "X\\(.*\\)$as_nl"; + arg=`expr "X$arg" : ".*$as_nl\\(.*\\)"`;; + esac; + expr "X$arg" : "X\\(.*\\)" | tr -d "$as_nl" + ' + export as_echo_n_body + as_echo_n='sh -c $as_echo_n_body as_echo' + fi + export as_echo_body + as_echo='sh -c $as_echo_body as_echo' +fi + +# The user is always right. +if test "${PATH_SEPARATOR+set}" != set; then + PATH_SEPARATOR=: + (PATH='/bin;/bin'; FPATH=$PATH; sh -c :) >/dev/null 2>&1 && { + (PATH='/bin:/bin'; FPATH=$PATH; sh -c :) >/dev/null 2>&1 || + PATH_SEPARATOR=';' + } +fi + + +# IFS +# We need space, tab and new line, in precisely that order. Quoting is +# there to prevent editors from complaining about space-tab. +# (If _AS_PATH_WALK were called with IFS unset, it would disable word +# splitting by setting IFS to empty value.) +IFS=" "" $as_nl" + +# Find who we are. Look in the path if we contain no directory separator. +as_myself= +case $0 in #(( + *[\\/]* ) as_myself=$0 ;; + *) as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + test -r "$as_dir/$0" && as_myself=$as_dir/$0 && break + done +IFS=$as_save_IFS + + ;; +esac +# We did not find ourselves, most probably we were run as `sh COMMAND' +# in which case we are not to be found in the path. +if test "x$as_myself" = x; then + as_myself=$0 +fi +if test ! -f "$as_myself"; then + $as_echo "$as_myself: error: cannot find myself; rerun with an absolute file name" >&2 + exit 1 +fi + +# Unset variables that we do not need and which cause bugs (e.g. in +# pre-3.0 UWIN ksh). But do not cause bugs in bash 2.01; the "|| exit 1" +# suppresses any "Segmentation fault" message there. '((' could +# trigger a bug in pdksh 5.2.14. +for as_var in BASH_ENV ENV MAIL MAILPATH +do eval test x\${$as_var+set} = xset \ + && ( (unset $as_var) || exit 1) >/dev/null 2>&1 && unset $as_var || : +done +PS1='$ ' +PS2='> ' +PS4='+ ' + +# NLS nuisances. +LC_ALL=C +export LC_ALL +LANGUAGE=C +export LANGUAGE + +# CDPATH. +(unset CDPATH) >/dev/null 2>&1 && unset CDPATH + +# Use a proper internal environment variable to ensure we don't fall + # into an infinite loop, continuously re-executing ourselves. + if test x"${_as_can_reexec}" != xno && test "x$CONFIG_SHELL" != x; then + _as_can_reexec=no; export _as_can_reexec; + # We cannot yet assume a decent shell, so we have to provide a +# neutralization value for shells without unset; and this also +# works around shells that cannot unset nonexistent variables. +# Preserve -v and -x to the replacement shell. +BASH_ENV=/dev/null +ENV=/dev/null +(unset BASH_ENV) >/dev/null 2>&1 && unset BASH_ENV ENV +case $- in # (((( + *v*x* | *x*v* ) as_opts=-vx ;; + *v* ) as_opts=-v ;; + *x* ) as_opts=-x ;; + * ) as_opts= ;; +esac +exec $CONFIG_SHELL $as_opts "$as_myself" ${1+"$@"} +# Admittedly, this is quite paranoid, since all the known shells bail +# out after a failed `exec'. +$as_echo "$0: could not re-execute with $CONFIG_SHELL" >&2 +as_fn_exit 255 + fi + # We don't want this to propagate to other subprocesses. + { _as_can_reexec=; unset _as_can_reexec;} +if test "x$CONFIG_SHELL" = x; then + as_bourne_compatible="if test -n \"\${ZSH_VERSION+set}\" && (emulate sh) >/dev/null 2>&1; then : + emulate sh + NULLCMD=: + # Pre-4.2 versions of Zsh do word splitting on \${1+\"\$@\"}, which + # is contrary to our usage. Disable this feature. + alias -g '\${1+\"\$@\"}'='\"\$@\"' + setopt NO_GLOB_SUBST +else + case \`(set -o) 2>/dev/null\` in #( + *posix*) : + set -o posix ;; #( + *) : + ;; +esac +fi +" + as_required="as_fn_return () { (exit \$1); } +as_fn_success () { as_fn_return 0; } +as_fn_failure () { as_fn_return 1; } +as_fn_ret_success () { return 0; } +as_fn_ret_failure () { return 1; } + +exitcode=0 +as_fn_success || { exitcode=1; echo as_fn_success failed.; } +as_fn_failure && { exitcode=1; echo as_fn_failure succeeded.; } +as_fn_ret_success || { exitcode=1; echo as_fn_ret_success failed.; } +as_fn_ret_failure && { exitcode=1; echo as_fn_ret_failure succeeded.; } +if ( set x; as_fn_ret_success y && test x = \"\$1\" ); then : + +else + exitcode=1; echo positional parameters were not saved. +fi +test x\$exitcode = x0 || exit 1 +test -x / || exit 1" + as_suggested=" as_lineno_1=";as_suggested=$as_suggested$LINENO;as_suggested=$as_suggested" as_lineno_1a=\$LINENO + as_lineno_2=";as_suggested=$as_suggested$LINENO;as_suggested=$as_suggested" as_lineno_2a=\$LINENO + eval 'test \"x\$as_lineno_1'\$as_run'\" != \"x\$as_lineno_2'\$as_run'\" && + test \"x\`expr \$as_lineno_1'\$as_run' + 1\`\" = \"x\$as_lineno_2'\$as_run'\"' || exit 1 +test \$(( 1 + 1 )) = 2 || exit 1" + if (eval "$as_required") 2>/dev/null; then : + as_have_required=yes +else + as_have_required=no +fi + if test x$as_have_required = xyes && (eval "$as_suggested") 2>/dev/null; then : + +else + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +as_found=false +for as_dir in /bin$PATH_SEPARATOR/usr/bin$PATH_SEPARATOR$PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + as_found=: + case $as_dir in #( + /*) + for as_base in sh bash ksh sh5; do + # Try only shells that exist, to save several forks. + as_shell=$as_dir/$as_base + if { test -f "$as_shell" || test -f "$as_shell.exe"; } && + { $as_echo "$as_bourne_compatible""$as_required" | as_run=a "$as_shell"; } 2>/dev/null; then : + CONFIG_SHELL=$as_shell as_have_required=yes + if { $as_echo "$as_bourne_compatible""$as_suggested" | as_run=a "$as_shell"; } 2>/dev/null; then : + break 2 +fi +fi + done;; + esac + as_found=false +done +$as_found || { if { test -f "$SHELL" || test -f "$SHELL.exe"; } && + { $as_echo "$as_bourne_compatible""$as_required" | as_run=a "$SHELL"; } 2>/dev/null; then : + CONFIG_SHELL=$SHELL as_have_required=yes +fi; } +IFS=$as_save_IFS + + + if test "x$CONFIG_SHELL" != x; then : + export CONFIG_SHELL + # We cannot yet assume a decent shell, so we have to provide a +# neutralization value for shells without unset; and this also +# works around shells that cannot unset nonexistent variables. +# Preserve -v and -x to the replacement shell. +BASH_ENV=/dev/null +ENV=/dev/null +(unset BASH_ENV) >/dev/null 2>&1 && unset BASH_ENV ENV +case $- in # (((( + *v*x* | *x*v* ) as_opts=-vx ;; + *v* ) as_opts=-v ;; + *x* ) as_opts=-x ;; + * ) as_opts= ;; +esac +exec $CONFIG_SHELL $as_opts "$as_myself" ${1+"$@"} +# Admittedly, this is quite paranoid, since all the known shells bail +# out after a failed `exec'. +$as_echo "$0: could not re-execute with $CONFIG_SHELL" >&2 +exit 255 +fi + + if test x$as_have_required = xno; then : + $as_echo "$0: This script requires a shell more modern than all" + $as_echo "$0: the shells that I found on your system." + if test x${ZSH_VERSION+set} = xset ; then + $as_echo "$0: In particular, zsh $ZSH_VERSION has bugs and should" + $as_echo "$0: be upgraded to zsh 4.3.4 or later." + else + $as_echo "$0: Please tell bug-autoconf@gnu.org about your system, +$0: including any error possibly output before this +$0: message. Then install a modern shell, or manually run +$0: the script under such a shell if you do have one." + fi + exit 1 +fi +fi +fi +SHELL=${CONFIG_SHELL-/bin/sh} +export SHELL +# Unset more variables known to interfere with behavior of common tools. +CLICOLOR_FORCE= GREP_OPTIONS= +unset CLICOLOR_FORCE GREP_OPTIONS + +## --------------------- ## +## M4sh Shell Functions. ## +## --------------------- ## +# as_fn_unset VAR +# --------------- +# Portably unset VAR. +as_fn_unset () +{ + { eval $1=; unset $1;} +} +as_unset=as_fn_unset + +# as_fn_set_status STATUS +# ----------------------- +# Set $? to STATUS, without forking. +as_fn_set_status () +{ + return $1 +} # as_fn_set_status + +# as_fn_exit STATUS +# ----------------- +# Exit the shell with STATUS, even in a "trap 0" or "set -e" context. +as_fn_exit () +{ + set +e + as_fn_set_status $1 + exit $1 +} # as_fn_exit + +# as_fn_mkdir_p +# ------------- +# Create "$as_dir" as a directory, including parents if necessary. +as_fn_mkdir_p () +{ + + case $as_dir in #( + -*) as_dir=./$as_dir;; + esac + test -d "$as_dir" || eval $as_mkdir_p || { + as_dirs= + while :; do + case $as_dir in #( + *\'*) as_qdir=`$as_echo "$as_dir" | sed "s/'/'\\\\\\\\''/g"`;; #'( + *) as_qdir=$as_dir;; + esac + as_dirs="'$as_qdir' $as_dirs" + as_dir=`$as_dirname -- "$as_dir" || +$as_expr X"$as_dir" : 'X\(.*[^/]\)//*[^/][^/]*/*$' \| \ + X"$as_dir" : 'X\(//\)[^/]' \| \ + X"$as_dir" : 'X\(//\)$' \| \ + X"$as_dir" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X"$as_dir" | + sed '/^X\(.*[^/]\)\/\/*[^/][^/]*\/*$/{ + s//\1/ + q + } + /^X\(\/\/\)[^/].*/{ + s//\1/ + q + } + /^X\(\/\/\)$/{ + s//\1/ + q + } + /^X\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + test -d "$as_dir" && break + done + test -z "$as_dirs" || eval "mkdir $as_dirs" + } || test -d "$as_dir" || as_fn_error $? "cannot create directory $as_dir" + + +} # as_fn_mkdir_p + +# as_fn_executable_p FILE +# ----------------------- +# Test if FILE is an executable regular file. +as_fn_executable_p () +{ + test -f "$1" && test -x "$1" +} # as_fn_executable_p +# as_fn_append VAR VALUE +# ---------------------- +# Append the text in VALUE to the end of the definition contained in VAR. Take +# advantage of any shell optimizations that allow amortized linear growth over +# repeated appends, instead of the typical quadratic growth present in naive +# implementations. +if (eval "as_var=1; as_var+=2; test x\$as_var = x12") 2>/dev/null; then : + eval 'as_fn_append () + { + eval $1+=\$2 + }' +else + as_fn_append () + { + eval $1=\$$1\$2 + } +fi # as_fn_append + +# as_fn_arith ARG... +# ------------------ +# Perform arithmetic evaluation on the ARGs, and store the result in the +# global $as_val. Take advantage of shells that can avoid forks. The arguments +# must be portable across $(()) and expr. +if (eval "test \$(( 1 + 1 )) = 2") 2>/dev/null; then : + eval 'as_fn_arith () + { + as_val=$(( $* )) + }' +else + as_fn_arith () + { + as_val=`expr "$@" || test $? -eq 1` + } +fi # as_fn_arith + + +# as_fn_error STATUS ERROR [LINENO LOG_FD] +# ---------------------------------------- +# Output "`basename $0`: error: ERROR" to stderr. If LINENO and LOG_FD are +# provided, also output the error to LOG_FD, referencing LINENO. Then exit the +# script with STATUS, using 1 if that was 0. +as_fn_error () +{ + as_status=$1; test $as_status -eq 0 && as_status=1 + if test "$4"; then + as_lineno=${as_lineno-"$3"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + $as_echo "$as_me:${as_lineno-$LINENO}: error: $2" >&$4 + fi + $as_echo "$as_me: error: $2" >&2 + as_fn_exit $as_status +} # as_fn_error + +if expr a : '\(a\)' >/dev/null 2>&1 && + test "X`expr 00001 : '.*\(...\)'`" = X001; then + as_expr=expr +else + as_expr=false +fi + +if (basename -- /) >/dev/null 2>&1 && test "X`basename -- / 2>&1`" = "X/"; then + as_basename=basename +else + as_basename=false +fi + +if (as_dir=`dirname -- /` && test "X$as_dir" = X/) >/dev/null 2>&1; then + as_dirname=dirname +else + as_dirname=false +fi + +as_me=`$as_basename -- "$0" || +$as_expr X/"$0" : '.*/\([^/][^/]*\)/*$' \| \ + X"$0" : 'X\(//\)$' \| \ + X"$0" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X/"$0" | + sed '/^.*\/\([^/][^/]*\)\/*$/{ + s//\1/ + q + } + /^X\/\(\/\/\)$/{ + s//\1/ + q + } + /^X\/\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + +# Avoid depending upon Character Ranges. +as_cr_letters='abcdefghijklmnopqrstuvwxyz' +as_cr_LETTERS='ABCDEFGHIJKLMNOPQRSTUVWXYZ' +as_cr_Letters=$as_cr_letters$as_cr_LETTERS +as_cr_digits='0123456789' +as_cr_alnum=$as_cr_Letters$as_cr_digits + + + as_lineno_1=$LINENO as_lineno_1a=$LINENO + as_lineno_2=$LINENO as_lineno_2a=$LINENO + eval 'test "x$as_lineno_1'$as_run'" != "x$as_lineno_2'$as_run'" && + test "x`expr $as_lineno_1'$as_run' + 1`" = "x$as_lineno_2'$as_run'"' || { + # Blame Lee E. McMahon (1931-1989) for sed's syntax. :-) + sed -n ' + p + /[$]LINENO/= + ' <$as_myself | + sed ' + s/[$]LINENO.*/&-/ + t lineno + b + :lineno + N + :loop + s/[$]LINENO\([^'$as_cr_alnum'_].*\n\)\(.*\)/\2\1\2/ + t loop + s/-\n.*// + ' >$as_me.lineno && + chmod +x "$as_me.lineno" || + { $as_echo "$as_me: error: cannot create $as_me.lineno; rerun with a POSIX shell" >&2; as_fn_exit 1; } + + # If we had to re-execute with $CONFIG_SHELL, we're ensured to have + # already done that, so ensure we don't try to do so again and fall + # in an infinite loop. This has already happened in practice. + _as_can_reexec=no; export _as_can_reexec + # Don't try to exec as it changes $[0], causing all sort of problems + # (the dirname of $[0] is not the place where we might find the + # original and so on. Autoconf is especially sensitive to this). + . "./$as_me.lineno" + # Exit status is that of the last command. + exit +} + +ECHO_C= ECHO_N= ECHO_T= +case `echo -n x` in #((((( +-n*) + case `echo 'xy\c'` in + *c*) ECHO_T=' ';; # ECHO_T is single tab character. + xy) ECHO_C='\c';; + *) echo `echo ksh88 bug on AIX 6.1` > /dev/null + ECHO_T=' ';; + esac;; +*) + ECHO_N='-n';; +esac + +rm -f conf$$ conf$$.exe conf$$.file +if test -d conf$$.dir; then + rm -f conf$$.dir/conf$$.file +else + rm -f conf$$.dir + mkdir conf$$.dir 2>/dev/null +fi +if (echo >conf$$.file) 2>/dev/null; then + if ln -s conf$$.file conf$$ 2>/dev/null; then + as_ln_s='ln -s' + # ... but there are two gotchas: + # 1) On MSYS, both `ln -s file dir' and `ln file dir' fail. + # 2) DJGPP < 2.04 has no symlinks; `ln -s' creates a wrapper executable. + # In both cases, we have to default to `cp -pR'. + ln -s conf$$.file conf$$.dir 2>/dev/null && test ! -f conf$$.exe || + as_ln_s='cp -pR' + elif ln conf$$.file conf$$ 2>/dev/null; then + as_ln_s=ln + else + as_ln_s='cp -pR' + fi +else + as_ln_s='cp -pR' +fi +rm -f conf$$ conf$$.exe conf$$.dir/conf$$.file conf$$.file +rmdir conf$$.dir 2>/dev/null + +if mkdir -p . 2>/dev/null; then + as_mkdir_p='mkdir -p "$as_dir"' +else + test -d ./-p && rmdir ./-p + as_mkdir_p=false +fi + +as_test_x='test -x' +as_executable_p=as_fn_executable_p + +# Sed expression to map a string onto a valid CPP name. +as_tr_cpp="eval sed 'y%*$as_cr_letters%P$as_cr_LETTERS%;s%[^_$as_cr_alnum]%_%g'" + +# Sed expression to map a string onto a valid variable name. +as_tr_sh="eval sed 'y%*+%pp%;s%[^_$as_cr_alnum]%_%g'" + + +test -n "$DJDIR" || exec 7<&0 &1 + +# Name of the host. +# hostname on some systems (SVR3.2, old GNU/Linux) returns a bogus exit status, +# so uname gets run too. +ac_hostname=`(hostname || uname -n) 2>/dev/null | sed 1q` + +# +# Initializations. +# +ac_default_prefix=/usr/local +ac_clean_files= +ac_config_libobj_dir=. +LIBOBJS= +cross_compiling=no +subdirs= +MFLAGS= +MAKEFLAGS= + +# Identity of this package. +PACKAGE_NAME='faiss' +PACKAGE_TARNAME='faiss' +PACKAGE_VERSION='1.0' +PACKAGE_STRING='faiss 1.0' +PACKAGE_BUGREPORT='' +PACKAGE_URL='' + +ac_unique_file="Index.h" +# Factoring default headers for most tests. +ac_includes_default="\ +#include +#ifdef HAVE_SYS_TYPES_H +# include +#endif +#ifdef HAVE_SYS_STAT_H +# include +#endif +#ifdef STDC_HEADERS +# include +# include +#else +# ifdef HAVE_STDLIB_H +# include +# endif +#endif +#ifdef HAVE_STRING_H +# if !defined STDC_HEADERS && defined HAVE_MEMORY_H +# include +# endif +# include +#endif +#ifdef HAVE_STRINGS_H +# include +#endif +#ifdef HAVE_INTTYPES_H +# include +#endif +#ifdef HAVE_STDINT_H +# include +#endif +#ifdef HAVE_UNISTD_H +# include +#endif" + +ac_header_list= +ac_subst_vars='LTLIBOBJS +ARCH_CXXFLAGS +ARCH_CPUFLAGS +target_os +target_vendor +target_cpu +target +LAPACK_LIBS +OPENMP_LDFLAGS +BLAS_LIBS +host_os +host_vendor +host_cpu +host +build_os +build_vendor +build_cpu +build +OPENMP_CXXFLAGS +LIBOBJS +CUDA_ARCH +CUDA_PREFIX +NVCC_LIBS +NVCC_LDFLAGS +NVCC_CPPFLAGS +EGREP +GREP +CXXCPP +NVCC +SWIG +NUMPY_INCLUDE +PYTHON_CFLAGS +PYTHON +MKDIR_P +SET_MAKE +CPP +ac_ct_CC +CFLAGS +CC +HAVE_CXX11 +OBJEXT +EXEEXT +ac_ct_CXX +CPPFLAGS +LDFLAGS +CXXFLAGS +CXX +target_alias +host_alias +build_alias +LIBS +ECHO_T +ECHO_N +ECHO_C +DEFS +mandir +localedir +libdir +psdir +pdfdir +dvidir +htmldir +infodir +docdir +oldincludedir +includedir +localstatedir +sharedstatedir +sysconfdir +datadir +datarootdir +libexecdir +sbindir +bindir +program_transform_name +prefix +exec_prefix +PACKAGE_URL +PACKAGE_BUGREPORT +PACKAGE_STRING +PACKAGE_VERSION +PACKAGE_TARNAME +PACKAGE_NAME +PATH_SEPARATOR +SHELL' +ac_subst_files='' +ac_user_opts=' +enable_option_checking +with_python +with_swig +with_cuda +with_cuda_arch +enable_openmp +with_blas +with_lapack +' + ac_precious_vars='build_alias +host_alias +target_alias +CXX +CXXFLAGS +LDFLAGS +LIBS +CPPFLAGS +CCC +CC +CFLAGS +CPP +CXXCPP' + + +# Initialize some variables set by options. +ac_init_help= +ac_init_version=false +ac_unrecognized_opts= +ac_unrecognized_sep= +# The variables have the same names as the options, with +# dashes changed to underlines. +cache_file=/dev/null +exec_prefix=NONE +no_create= +no_recursion= +prefix=NONE +program_prefix=NONE +program_suffix=NONE +program_transform_name=s,x,x, +silent= +site= +srcdir= +verbose= +x_includes=NONE +x_libraries=NONE + +# Installation directory options. +# These are left unexpanded so users can "make install exec_prefix=/foo" +# and all the variables that are supposed to be based on exec_prefix +# by default will actually change. +# Use braces instead of parens because sh, perl, etc. also accept them. +# (The list follows the same order as the GNU Coding Standards.) +bindir='${exec_prefix}/bin' +sbindir='${exec_prefix}/sbin' +libexecdir='${exec_prefix}/libexec' +datarootdir='${prefix}/share' +datadir='${datarootdir}' +sysconfdir='${prefix}/etc' +sharedstatedir='${prefix}/com' +localstatedir='${prefix}/var' +includedir='${prefix}/include' +oldincludedir='/usr/include' +docdir='${datarootdir}/doc/${PACKAGE_TARNAME}' +infodir='${datarootdir}/info' +htmldir='${docdir}' +dvidir='${docdir}' +pdfdir='${docdir}' +psdir='${docdir}' +libdir='${exec_prefix}/lib' +localedir='${datarootdir}/locale' +mandir='${datarootdir}/man' + +ac_prev= +ac_dashdash= +for ac_option +do + # If the previous option needs an argument, assign it. + if test -n "$ac_prev"; then + eval $ac_prev=\$ac_option + ac_prev= + continue + fi + + case $ac_option in + *=?*) ac_optarg=`expr "X$ac_option" : '[^=]*=\(.*\)'` ;; + *=) ac_optarg= ;; + *) ac_optarg=yes ;; + esac + + # Accept the important Cygnus configure options, so we can diagnose typos. + + case $ac_dashdash$ac_option in + --) + ac_dashdash=yes ;; + + -bindir | --bindir | --bindi | --bind | --bin | --bi) + ac_prev=bindir ;; + -bindir=* | --bindir=* | --bindi=* | --bind=* | --bin=* | --bi=*) + bindir=$ac_optarg ;; + + -build | --build | --buil | --bui | --bu) + ac_prev=build_alias ;; + -build=* | --build=* | --buil=* | --bui=* | --bu=*) + build_alias=$ac_optarg ;; + + -cache-file | --cache-file | --cache-fil | --cache-fi \ + | --cache-f | --cache- | --cache | --cach | --cac | --ca | --c) + ac_prev=cache_file ;; + -cache-file=* | --cache-file=* | --cache-fil=* | --cache-fi=* \ + | --cache-f=* | --cache-=* | --cache=* | --cach=* | --cac=* | --ca=* | --c=*) + cache_file=$ac_optarg ;; + + --config-cache | -C) + cache_file=config.cache ;; + + -datadir | --datadir | --datadi | --datad) + ac_prev=datadir ;; + -datadir=* | --datadir=* | --datadi=* | --datad=*) + datadir=$ac_optarg ;; + + -datarootdir | --datarootdir | --datarootdi | --datarootd | --dataroot \ + | --dataroo | --dataro | --datar) + ac_prev=datarootdir ;; + -datarootdir=* | --datarootdir=* | --datarootdi=* | --datarootd=* \ + | --dataroot=* | --dataroo=* | --dataro=* | --datar=*) + datarootdir=$ac_optarg ;; + + -disable-* | --disable-*) + ac_useropt=`expr "x$ac_option" : 'x-*disable-\(.*\)'` + # Reject names that are not valid shell variable names. + expr "x$ac_useropt" : ".*[^-+._$as_cr_alnum]" >/dev/null && + as_fn_error $? "invalid feature name: $ac_useropt" + ac_useropt_orig=$ac_useropt + ac_useropt=`$as_echo "$ac_useropt" | sed 's/[-+.]/_/g'` + case $ac_user_opts in + *" +"enable_$ac_useropt" +"*) ;; + *) ac_unrecognized_opts="$ac_unrecognized_opts$ac_unrecognized_sep--disable-$ac_useropt_orig" + ac_unrecognized_sep=', ';; + esac + eval enable_$ac_useropt=no ;; + + -docdir | --docdir | --docdi | --doc | --do) + ac_prev=docdir ;; + -docdir=* | --docdir=* | --docdi=* | --doc=* | --do=*) + docdir=$ac_optarg ;; + + -dvidir | --dvidir | --dvidi | --dvid | --dvi | --dv) + ac_prev=dvidir ;; + -dvidir=* | --dvidir=* | --dvidi=* | --dvid=* | --dvi=* | --dv=*) + dvidir=$ac_optarg ;; + + -enable-* | --enable-*) + ac_useropt=`expr "x$ac_option" : 'x-*enable-\([^=]*\)'` + # Reject names that are not valid shell variable names. + expr "x$ac_useropt" : ".*[^-+._$as_cr_alnum]" >/dev/null && + as_fn_error $? "invalid feature name: $ac_useropt" + ac_useropt_orig=$ac_useropt + ac_useropt=`$as_echo "$ac_useropt" | sed 's/[-+.]/_/g'` + case $ac_user_opts in + *" +"enable_$ac_useropt" +"*) ;; + *) ac_unrecognized_opts="$ac_unrecognized_opts$ac_unrecognized_sep--enable-$ac_useropt_orig" + ac_unrecognized_sep=', ';; + esac + eval enable_$ac_useropt=\$ac_optarg ;; + + -exec-prefix | --exec_prefix | --exec-prefix | --exec-prefi \ + | --exec-pref | --exec-pre | --exec-pr | --exec-p | --exec- \ + | --exec | --exe | --ex) + ac_prev=exec_prefix ;; + -exec-prefix=* | --exec_prefix=* | --exec-prefix=* | --exec-prefi=* \ + | --exec-pref=* | --exec-pre=* | --exec-pr=* | --exec-p=* | --exec-=* \ + | --exec=* | --exe=* | --ex=*) + exec_prefix=$ac_optarg ;; + + -gas | --gas | --ga | --g) + # Obsolete; use --with-gas. + with_gas=yes ;; + + -help | --help | --hel | --he | -h) + ac_init_help=long ;; + -help=r* | --help=r* | --hel=r* | --he=r* | -hr*) + ac_init_help=recursive ;; + -help=s* | --help=s* | --hel=s* | --he=s* | -hs*) + ac_init_help=short ;; + + -host | --host | --hos | --ho) + ac_prev=host_alias ;; + -host=* | --host=* | --hos=* | --ho=*) + host_alias=$ac_optarg ;; + + -htmldir | --htmldir | --htmldi | --htmld | --html | --htm | --ht) + ac_prev=htmldir ;; + -htmldir=* | --htmldir=* | --htmldi=* | --htmld=* | --html=* | --htm=* \ + | --ht=*) + htmldir=$ac_optarg ;; + + -includedir | --includedir | --includedi | --included | --include \ + | --includ | --inclu | --incl | --inc) + ac_prev=includedir ;; + -includedir=* | --includedir=* | --includedi=* | --included=* | --include=* \ + | --includ=* | --inclu=* | --incl=* | --inc=*) + includedir=$ac_optarg ;; + + -infodir | --infodir | --infodi | --infod | --info | --inf) + ac_prev=infodir ;; + -infodir=* | --infodir=* | --infodi=* | --infod=* | --info=* | --inf=*) + infodir=$ac_optarg ;; + + -libdir | --libdir | --libdi | --libd) + ac_prev=libdir ;; + -libdir=* | --libdir=* | --libdi=* | --libd=*) + libdir=$ac_optarg ;; + + -libexecdir | --libexecdir | --libexecdi | --libexecd | --libexec \ + | --libexe | --libex | --libe) + ac_prev=libexecdir ;; + -libexecdir=* | --libexecdir=* | --libexecdi=* | --libexecd=* | --libexec=* \ + | --libexe=* | --libex=* | --libe=*) + libexecdir=$ac_optarg ;; + + -localedir | --localedir | --localedi | --localed | --locale) + ac_prev=localedir ;; + -localedir=* | --localedir=* | --localedi=* | --localed=* | --locale=*) + localedir=$ac_optarg ;; + + -localstatedir | --localstatedir | --localstatedi | --localstated \ + | --localstate | --localstat | --localsta | --localst | --locals) + ac_prev=localstatedir ;; + -localstatedir=* | --localstatedir=* | --localstatedi=* | --localstated=* \ + | --localstate=* | --localstat=* | --localsta=* | --localst=* | --locals=*) + localstatedir=$ac_optarg ;; + + -mandir | --mandir | --mandi | --mand | --man | --ma | --m) + ac_prev=mandir ;; + -mandir=* | --mandir=* | --mandi=* | --mand=* | --man=* | --ma=* | --m=*) + mandir=$ac_optarg ;; + + -nfp | --nfp | --nf) + # Obsolete; use --without-fp. + with_fp=no ;; + + -no-create | --no-create | --no-creat | --no-crea | --no-cre \ + | --no-cr | --no-c | -n) + no_create=yes ;; + + -no-recursion | --no-recursion | --no-recursio | --no-recursi \ + | --no-recurs | --no-recur | --no-recu | --no-rec | --no-re | --no-r) + no_recursion=yes ;; + + -oldincludedir | --oldincludedir | --oldincludedi | --oldincluded \ + | --oldinclude | --oldinclud | --oldinclu | --oldincl | --oldinc \ + | --oldin | --oldi | --old | --ol | --o) + ac_prev=oldincludedir ;; + -oldincludedir=* | --oldincludedir=* | --oldincludedi=* | --oldincluded=* \ + | --oldinclude=* | --oldinclud=* | --oldinclu=* | --oldincl=* | --oldinc=* \ + | --oldin=* | --oldi=* | --old=* | --ol=* | --o=*) + oldincludedir=$ac_optarg ;; + + -prefix | --prefix | --prefi | --pref | --pre | --pr | --p) + ac_prev=prefix ;; + -prefix=* | --prefix=* | --prefi=* | --pref=* | --pre=* | --pr=* | --p=*) + prefix=$ac_optarg ;; + + -program-prefix | --program-prefix | --program-prefi | --program-pref \ + | --program-pre | --program-pr | --program-p) + ac_prev=program_prefix ;; + -program-prefix=* | --program-prefix=* | --program-prefi=* \ + | --program-pref=* | --program-pre=* | --program-pr=* | --program-p=*) + program_prefix=$ac_optarg ;; + + -program-suffix | --program-suffix | --program-suffi | --program-suff \ + | --program-suf | --program-su | --program-s) + ac_prev=program_suffix ;; + -program-suffix=* | --program-suffix=* | --program-suffi=* \ + | --program-suff=* | --program-suf=* | --program-su=* | --program-s=*) + program_suffix=$ac_optarg ;; + + -program-transform-name | --program-transform-name \ + | --program-transform-nam | --program-transform-na \ + | --program-transform-n | --program-transform- \ + | --program-transform | --program-transfor \ + | --program-transfo | --program-transf \ + | --program-trans | --program-tran \ + | --progr-tra | --program-tr | --program-t) + ac_prev=program_transform_name ;; + -program-transform-name=* | --program-transform-name=* \ + | --program-transform-nam=* | --program-transform-na=* \ + | --program-transform-n=* | --program-transform-=* \ + | --program-transform=* | --program-transfor=* \ + | --program-transfo=* | --program-transf=* \ + | --program-trans=* | --program-tran=* \ + | --progr-tra=* | --program-tr=* | --program-t=*) + program_transform_name=$ac_optarg ;; + + -pdfdir | --pdfdir | --pdfdi | --pdfd | --pdf | --pd) + ac_prev=pdfdir ;; + -pdfdir=* | --pdfdir=* | --pdfdi=* | --pdfd=* | --pdf=* | --pd=*) + pdfdir=$ac_optarg ;; + + -psdir | --psdir | --psdi | --psd | --ps) + ac_prev=psdir ;; + -psdir=* | --psdir=* | --psdi=* | --psd=* | --ps=*) + psdir=$ac_optarg ;; + + -q | -quiet | --quiet | --quie | --qui | --qu | --q \ + | -silent | --silent | --silen | --sile | --sil) + silent=yes ;; + + -sbindir | --sbindir | --sbindi | --sbind | --sbin | --sbi | --sb) + ac_prev=sbindir ;; + -sbindir=* | --sbindir=* | --sbindi=* | --sbind=* | --sbin=* \ + | --sbi=* | --sb=*) + sbindir=$ac_optarg ;; + + -sharedstatedir | --sharedstatedir | --sharedstatedi \ + | --sharedstated | --sharedstate | --sharedstat | --sharedsta \ + | --sharedst | --shareds | --shared | --share | --shar \ + | --sha | --sh) + ac_prev=sharedstatedir ;; + -sharedstatedir=* | --sharedstatedir=* | --sharedstatedi=* \ + | --sharedstated=* | --sharedstate=* | --sharedstat=* | --sharedsta=* \ + | --sharedst=* | --shareds=* | --shared=* | --share=* | --shar=* \ + | --sha=* | --sh=*) + sharedstatedir=$ac_optarg ;; + + -site | --site | --sit) + ac_prev=site ;; + -site=* | --site=* | --sit=*) + site=$ac_optarg ;; + + -srcdir | --srcdir | --srcdi | --srcd | --src | --sr) + ac_prev=srcdir ;; + -srcdir=* | --srcdir=* | --srcdi=* | --srcd=* | --src=* | --sr=*) + srcdir=$ac_optarg ;; + + -sysconfdir | --sysconfdir | --sysconfdi | --sysconfd | --sysconf \ + | --syscon | --sysco | --sysc | --sys | --sy) + ac_prev=sysconfdir ;; + -sysconfdir=* | --sysconfdir=* | --sysconfdi=* | --sysconfd=* | --sysconf=* \ + | --syscon=* | --sysco=* | --sysc=* | --sys=* | --sy=*) + sysconfdir=$ac_optarg ;; + + -target | --target | --targe | --targ | --tar | --ta | --t) + ac_prev=target_alias ;; + -target=* | --target=* | --targe=* | --targ=* | --tar=* | --ta=* | --t=*) + target_alias=$ac_optarg ;; + + -v | -verbose | --verbose | --verbos | --verbo | --verb) + verbose=yes ;; + + -version | --version | --versio | --versi | --vers | -V) + ac_init_version=: ;; + + -with-* | --with-*) + ac_useropt=`expr "x$ac_option" : 'x-*with-\([^=]*\)'` + # Reject names that are not valid shell variable names. + expr "x$ac_useropt" : ".*[^-+._$as_cr_alnum]" >/dev/null && + as_fn_error $? "invalid package name: $ac_useropt" + ac_useropt_orig=$ac_useropt + ac_useropt=`$as_echo "$ac_useropt" | sed 's/[-+.]/_/g'` + case $ac_user_opts in + *" +"with_$ac_useropt" +"*) ;; + *) ac_unrecognized_opts="$ac_unrecognized_opts$ac_unrecognized_sep--with-$ac_useropt_orig" + ac_unrecognized_sep=', ';; + esac + eval with_$ac_useropt=\$ac_optarg ;; + + -without-* | --without-*) + ac_useropt=`expr "x$ac_option" : 'x-*without-\(.*\)'` + # Reject names that are not valid shell variable names. + expr "x$ac_useropt" : ".*[^-+._$as_cr_alnum]" >/dev/null && + as_fn_error $? "invalid package name: $ac_useropt" + ac_useropt_orig=$ac_useropt + ac_useropt=`$as_echo "$ac_useropt" | sed 's/[-+.]/_/g'` + case $ac_user_opts in + *" +"with_$ac_useropt" +"*) ;; + *) ac_unrecognized_opts="$ac_unrecognized_opts$ac_unrecognized_sep--without-$ac_useropt_orig" + ac_unrecognized_sep=', ';; + esac + eval with_$ac_useropt=no ;; + + --x) + # Obsolete; use --with-x. + with_x=yes ;; + + -x-includes | --x-includes | --x-include | --x-includ | --x-inclu \ + | --x-incl | --x-inc | --x-in | --x-i) + ac_prev=x_includes ;; + -x-includes=* | --x-includes=* | --x-include=* | --x-includ=* | --x-inclu=* \ + | --x-incl=* | --x-inc=* | --x-in=* | --x-i=*) + x_includes=$ac_optarg ;; + + -x-libraries | --x-libraries | --x-librarie | --x-librari \ + | --x-librar | --x-libra | --x-libr | --x-lib | --x-li | --x-l) + ac_prev=x_libraries ;; + -x-libraries=* | --x-libraries=* | --x-librarie=* | --x-librari=* \ + | --x-librar=* | --x-libra=* | --x-libr=* | --x-lib=* | --x-li=* | --x-l=*) + x_libraries=$ac_optarg ;; + + -*) as_fn_error $? "unrecognized option: \`$ac_option' +Try \`$0 --help' for more information" + ;; + + *=*) + ac_envvar=`expr "x$ac_option" : 'x\([^=]*\)='` + # Reject names that are not valid shell variable names. + case $ac_envvar in #( + '' | [0-9]* | *[!_$as_cr_alnum]* ) + as_fn_error $? "invalid variable name: \`$ac_envvar'" ;; + esac + eval $ac_envvar=\$ac_optarg + export $ac_envvar ;; + + *) + # FIXME: should be removed in autoconf 3.0. + $as_echo "$as_me: WARNING: you should use --build, --host, --target" >&2 + expr "x$ac_option" : ".*[^-._$as_cr_alnum]" >/dev/null && + $as_echo "$as_me: WARNING: invalid host type: $ac_option" >&2 + : "${build_alias=$ac_option} ${host_alias=$ac_option} ${target_alias=$ac_option}" + ;; + + esac +done + +if test -n "$ac_prev"; then + ac_option=--`echo $ac_prev | sed 's/_/-/g'` + as_fn_error $? "missing argument to $ac_option" +fi + +if test -n "$ac_unrecognized_opts"; then + case $enable_option_checking in + no) ;; + fatal) as_fn_error $? "unrecognized options: $ac_unrecognized_opts" ;; + *) $as_echo "$as_me: WARNING: unrecognized options: $ac_unrecognized_opts" >&2 ;; + esac +fi + +# Check all directory arguments for consistency. +for ac_var in exec_prefix prefix bindir sbindir libexecdir datarootdir \ + datadir sysconfdir sharedstatedir localstatedir includedir \ + oldincludedir docdir infodir htmldir dvidir pdfdir psdir \ + libdir localedir mandir +do + eval ac_val=\$$ac_var + # Remove trailing slashes. + case $ac_val in + */ ) + ac_val=`expr "X$ac_val" : 'X\(.*[^/]\)' \| "X$ac_val" : 'X\(.*\)'` + eval $ac_var=\$ac_val;; + esac + # Be sure to have absolute directory names. + case $ac_val in + [\\/$]* | ?:[\\/]* ) continue;; + NONE | '' ) case $ac_var in *prefix ) continue;; esac;; + esac + as_fn_error $? "expected an absolute directory name for --$ac_var: $ac_val" +done + +# There might be people who depend on the old broken behavior: `$host' +# used to hold the argument of --host etc. +# FIXME: To remove some day. +build=$build_alias +host=$host_alias +target=$target_alias + +# FIXME: To remove some day. +if test "x$host_alias" != x; then + if test "x$build_alias" = x; then + cross_compiling=maybe + elif test "x$build_alias" != "x$host_alias"; then + cross_compiling=yes + fi +fi + +ac_tool_prefix= +test -n "$host_alias" && ac_tool_prefix=$host_alias- + +test "$silent" = yes && exec 6>/dev/null + + +ac_pwd=`pwd` && test -n "$ac_pwd" && +ac_ls_di=`ls -di .` && +ac_pwd_ls_di=`cd "$ac_pwd" && ls -di .` || + as_fn_error $? "working directory cannot be determined" +test "X$ac_ls_di" = "X$ac_pwd_ls_di" || + as_fn_error $? "pwd does not report name of working directory" + + +# Find the source files, if location was not specified. +if test -z "$srcdir"; then + ac_srcdir_defaulted=yes + # Try the directory containing this script, then the parent directory. + ac_confdir=`$as_dirname -- "$as_myself" || +$as_expr X"$as_myself" : 'X\(.*[^/]\)//*[^/][^/]*/*$' \| \ + X"$as_myself" : 'X\(//\)[^/]' \| \ + X"$as_myself" : 'X\(//\)$' \| \ + X"$as_myself" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X"$as_myself" | + sed '/^X\(.*[^/]\)\/\/*[^/][^/]*\/*$/{ + s//\1/ + q + } + /^X\(\/\/\)[^/].*/{ + s//\1/ + q + } + /^X\(\/\/\)$/{ + s//\1/ + q + } + /^X\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + srcdir=$ac_confdir + if test ! -r "$srcdir/$ac_unique_file"; then + srcdir=.. + fi +else + ac_srcdir_defaulted=no +fi +if test ! -r "$srcdir/$ac_unique_file"; then + test "$ac_srcdir_defaulted" = yes && srcdir="$ac_confdir or .." + as_fn_error $? "cannot find sources ($ac_unique_file) in $srcdir" +fi +ac_msg="sources are in $srcdir, but \`cd $srcdir' does not work" +ac_abs_confdir=`( + cd "$srcdir" && test -r "./$ac_unique_file" || as_fn_error $? "$ac_msg" + pwd)` +# When building in place, set srcdir=. +if test "$ac_abs_confdir" = "$ac_pwd"; then + srcdir=. +fi +# Remove unnecessary trailing slashes from srcdir. +# Double slashes in file names in object file debugging info +# mess up M-x gdb in Emacs. +case $srcdir in +*/) srcdir=`expr "X$srcdir" : 'X\(.*[^/]\)' \| "X$srcdir" : 'X\(.*\)'`;; +esac +for ac_var in $ac_precious_vars; do + eval ac_env_${ac_var}_set=\${${ac_var}+set} + eval ac_env_${ac_var}_value=\$${ac_var} + eval ac_cv_env_${ac_var}_set=\${${ac_var}+set} + eval ac_cv_env_${ac_var}_value=\$${ac_var} +done + +# +# Report the --help message. +# +if test "$ac_init_help" = "long"; then + # Omit some internal or obsolete options to make the list less imposing. + # This message is too long to be a string in the A/UX 3.1 sh. + cat <<_ACEOF +\`configure' configures faiss 1.0 to adapt to many kinds of systems. + +Usage: $0 [OPTION]... [VAR=VALUE]... + +To assign environment variables (e.g., CC, CFLAGS...), specify them as +VAR=VALUE. See below for descriptions of some of the useful variables. + +Defaults for the options are specified in brackets. + +Configuration: + -h, --help display this help and exit + --help=short display options specific to this package + --help=recursive display the short help of all the included packages + -V, --version display version information and exit + -q, --quiet, --silent do not print \`checking ...' messages + --cache-file=FILE cache test results in FILE [disabled] + -C, --config-cache alias for \`--cache-file=config.cache' + -n, --no-create do not create output files + --srcdir=DIR find the sources in DIR [configure dir or \`..'] + +Installation directories: + --prefix=PREFIX install architecture-independent files in PREFIX + [$ac_default_prefix] + --exec-prefix=EPREFIX install architecture-dependent files in EPREFIX + [PREFIX] + +By default, \`make install' will install all the files in +\`$ac_default_prefix/bin', \`$ac_default_prefix/lib' etc. You can specify +an installation prefix other than \`$ac_default_prefix' using \`--prefix', +for instance \`--prefix=\$HOME'. + +For better control, use the options below. + +Fine tuning of the installation directories: + --bindir=DIR user executables [EPREFIX/bin] + --sbindir=DIR system admin executables [EPREFIX/sbin] + --libexecdir=DIR program executables [EPREFIX/libexec] + --sysconfdir=DIR read-only single-machine data [PREFIX/etc] + --sharedstatedir=DIR modifiable architecture-independent data [PREFIX/com] + --localstatedir=DIR modifiable single-machine data [PREFIX/var] + --libdir=DIR object code libraries [EPREFIX/lib] + --includedir=DIR C header files [PREFIX/include] + --oldincludedir=DIR C header files for non-gcc [/usr/include] + --datarootdir=DIR read-only arch.-independent data root [PREFIX/share] + --datadir=DIR read-only architecture-independent data [DATAROOTDIR] + --infodir=DIR info documentation [DATAROOTDIR/info] + --localedir=DIR locale-dependent data [DATAROOTDIR/locale] + --mandir=DIR man documentation [DATAROOTDIR/man] + --docdir=DIR documentation root [DATAROOTDIR/doc/faiss] + --htmldir=DIR html documentation [DOCDIR] + --dvidir=DIR dvi documentation [DOCDIR] + --pdfdir=DIR pdf documentation [DOCDIR] + --psdir=DIR ps documentation [DOCDIR] +_ACEOF + + cat <<\_ACEOF + +System types: + --build=BUILD configure for building on BUILD [guessed] + --host=HOST cross-compile to build programs to run on HOST [BUILD] + --target=TARGET configure for building compilers for TARGET [HOST] +_ACEOF +fi + +if test -n "$ac_init_help"; then + case $ac_init_help in + short | recursive ) echo "Configuration of faiss 1.0:";; + esac + cat <<\_ACEOF + +Optional Features: + --disable-option-checking ignore unrecognized --enable/--with options + --disable-FEATURE do not include FEATURE (same as --enable-FEATURE=no) + --enable-FEATURE[=ARG] include FEATURE [ARG=yes] + --disable-openmp do not use OpenMP + +Optional Packages: + --with-PACKAGE[=ARG] use PACKAGE [ARG=yes] + --without-PACKAGE do not use PACKAGE (same as --with-PACKAGE=no) + --with-python= use Python binary + --with-swig= use SWIG binary + --with-cuda= prefix of the CUDA installation + --with-cuda-arch= + device specific -gencode flags + --with-blas= use BLAS library + --with-lapack= use LAPACK library + +Some influential environment variables: + CXX C++ compiler command + CXXFLAGS C++ compiler flags + LDFLAGS linker flags, e.g. -L if you have libraries in a + nonstandard directory + LIBS libraries to pass to the linker, e.g. -l + CPPFLAGS (Objective) C/C++ preprocessor flags, e.g. -I if + you have headers in a nonstandard directory + CC C compiler command + CFLAGS C compiler flags + CPP C preprocessor + CXXCPP C++ preprocessor + +Use these variables to override the choices made by `configure' or to help +it to find libraries and programs with nonstandard names/locations. + +Report bugs to the package provider. +_ACEOF +ac_status=$? +fi + +if test "$ac_init_help" = "recursive"; then + # If there are subdirs, report their specific --help. + for ac_dir in : $ac_subdirs_all; do test "x$ac_dir" = x: && continue + test -d "$ac_dir" || + { cd "$srcdir" && ac_pwd=`pwd` && srcdir=. && test -d "$ac_dir"; } || + continue + ac_builddir=. + +case "$ac_dir" in +.) ac_dir_suffix= ac_top_builddir_sub=. ac_top_build_prefix= ;; +*) + ac_dir_suffix=/`$as_echo "$ac_dir" | sed 's|^\.[\\/]||'` + # A ".." for each directory in $ac_dir_suffix. + ac_top_builddir_sub=`$as_echo "$ac_dir_suffix" | sed 's|/[^\\/]*|/..|g;s|/||'` + case $ac_top_builddir_sub in + "") ac_top_builddir_sub=. ac_top_build_prefix= ;; + *) ac_top_build_prefix=$ac_top_builddir_sub/ ;; + esac ;; +esac +ac_abs_top_builddir=$ac_pwd +ac_abs_builddir=$ac_pwd$ac_dir_suffix +# for backward compatibility: +ac_top_builddir=$ac_top_build_prefix + +case $srcdir in + .) # We are building in place. + ac_srcdir=. + ac_top_srcdir=$ac_top_builddir_sub + ac_abs_top_srcdir=$ac_pwd ;; + [\\/]* | ?:[\\/]* ) # Absolute name. + ac_srcdir=$srcdir$ac_dir_suffix; + ac_top_srcdir=$srcdir + ac_abs_top_srcdir=$srcdir ;; + *) # Relative name. + ac_srcdir=$ac_top_build_prefix$srcdir$ac_dir_suffix + ac_top_srcdir=$ac_top_build_prefix$srcdir + ac_abs_top_srcdir=$ac_pwd/$srcdir ;; +esac +ac_abs_srcdir=$ac_abs_top_srcdir$ac_dir_suffix + + cd "$ac_dir" || { ac_status=$?; continue; } + # Check for guested configure. + if test -f "$ac_srcdir/configure.gnu"; then + echo && + $SHELL "$ac_srcdir/configure.gnu" --help=recursive + elif test -f "$ac_srcdir/configure"; then + echo && + $SHELL "$ac_srcdir/configure" --help=recursive + else + $as_echo "$as_me: WARNING: no configuration information is in $ac_dir" >&2 + fi || ac_status=$? + cd "$ac_pwd" || { ac_status=$?; break; } + done +fi + +test -n "$ac_init_help" && exit $ac_status +if $ac_init_version; then + cat <<\_ACEOF +faiss configure 1.0 +generated by GNU Autoconf 2.69 + +Copyright (C) 2012 Free Software Foundation, Inc. +This configure script is free software; the Free Software Foundation +gives unlimited permission to copy, distribute and modify it. + +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +_ACEOF + exit +fi + +## ------------------------ ## +## Autoconf initialization. ## +## ------------------------ ## + +# ac_fn_cxx_try_compile LINENO +# ---------------------------- +# Try to compile conftest.$ac_ext, and return whether this succeeded. +ac_fn_cxx_try_compile () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + rm -f conftest.$ac_objext + if { { ac_try="$ac_compile" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_compile") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + grep -v '^ *+' conftest.err >conftest.er1 + cat conftest.er1 >&5 + mv -f conftest.er1 conftest.err + fi + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } && { + test -z "$ac_cxx_werror_flag" || + test ! -s conftest.err + } && test -s conftest.$ac_objext; then : + ac_retval=0 +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=1 +fi + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_cxx_try_compile + +# ac_fn_c_try_compile LINENO +# -------------------------- +# Try to compile conftest.$ac_ext, and return whether this succeeded. +ac_fn_c_try_compile () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + rm -f conftest.$ac_objext + if { { ac_try="$ac_compile" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_compile") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + grep -v '^ *+' conftest.err >conftest.er1 + cat conftest.er1 >&5 + mv -f conftest.er1 conftest.err + fi + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } && { + test -z "$ac_c_werror_flag" || + test ! -s conftest.err + } && test -s conftest.$ac_objext; then : + ac_retval=0 +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=1 +fi + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_c_try_compile + +# ac_fn_c_try_cpp LINENO +# ---------------------- +# Try to preprocess conftest.$ac_ext, and return whether this succeeded. +ac_fn_c_try_cpp () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + if { { ac_try="$ac_cpp conftest.$ac_ext" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_cpp conftest.$ac_ext") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + grep -v '^ *+' conftest.err >conftest.er1 + cat conftest.er1 >&5 + mv -f conftest.er1 conftest.err + fi + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } > conftest.i && { + test -z "$ac_c_preproc_warn_flag$ac_c_werror_flag" || + test ! -s conftest.err + }; then : + ac_retval=0 +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=1 +fi + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_c_try_cpp + +# ac_fn_cxx_try_cpp LINENO +# ------------------------ +# Try to preprocess conftest.$ac_ext, and return whether this succeeded. +ac_fn_cxx_try_cpp () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + if { { ac_try="$ac_cpp conftest.$ac_ext" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_cpp conftest.$ac_ext") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + grep -v '^ *+' conftest.err >conftest.er1 + cat conftest.er1 >&5 + mv -f conftest.er1 conftest.err + fi + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } > conftest.i && { + test -z "$ac_cxx_preproc_warn_flag$ac_cxx_werror_flag" || + test ! -s conftest.err + }; then : + ac_retval=0 +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=1 +fi + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_cxx_try_cpp + +# ac_fn_cxx_check_header_mongrel LINENO HEADER VAR INCLUDES +# --------------------------------------------------------- +# Tests whether HEADER exists, giving a warning if it cannot be compiled using +# the include files in INCLUDES and setting the cache variable VAR +# accordingly. +ac_fn_cxx_check_header_mongrel () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + if eval \${$3+:} false; then : + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $2" >&5 +$as_echo_n "checking for $2... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +else + # Is the header compilable? +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking $2 usability" >&5 +$as_echo_n "checking $2 usability... " >&6; } +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$4 +#include <$2> +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_header_compiler=yes +else + ac_header_compiler=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_header_compiler" >&5 +$as_echo "$ac_header_compiler" >&6; } + +# Is the header present? +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking $2 presence" >&5 +$as_echo_n "checking $2 presence... " >&6; } +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include <$2> +_ACEOF +if ac_fn_cxx_try_cpp "$LINENO"; then : + ac_header_preproc=yes +else + ac_header_preproc=no +fi +rm -f conftest.err conftest.i conftest.$ac_ext +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_header_preproc" >&5 +$as_echo "$ac_header_preproc" >&6; } + +# So? What about this header? +case $ac_header_compiler:$ac_header_preproc:$ac_cxx_preproc_warn_flag in #(( + yes:no: ) + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: accepted by the compiler, rejected by the preprocessor!" >&5 +$as_echo "$as_me: WARNING: $2: accepted by the compiler, rejected by the preprocessor!" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: proceeding with the compiler's result" >&5 +$as_echo "$as_me: WARNING: $2: proceeding with the compiler's result" >&2;} + ;; + no:yes:* ) + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: present but cannot be compiled" >&5 +$as_echo "$as_me: WARNING: $2: present but cannot be compiled" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: check for missing prerequisite headers?" >&5 +$as_echo "$as_me: WARNING: $2: check for missing prerequisite headers?" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: see the Autoconf documentation" >&5 +$as_echo "$as_me: WARNING: $2: see the Autoconf documentation" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: section \"Present But Cannot Be Compiled\"" >&5 +$as_echo "$as_me: WARNING: $2: section \"Present But Cannot Be Compiled\"" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $2: proceeding with the compiler's result" >&5 +$as_echo "$as_me: WARNING: $2: proceeding with the compiler's result" >&2;} + ;; +esac + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $2" >&5 +$as_echo_n "checking for $2... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + eval "$3=\$ac_header_compiler" +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +fi + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_cxx_check_header_mongrel + +# ac_fn_cxx_try_run LINENO +# ------------------------ +# Try to link conftest.$ac_ext, and return whether this succeeded. Assumes +# that executables *can* be run. +ac_fn_cxx_try_run () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + if { { ac_try="$ac_link" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_link") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } && { ac_try='./conftest$ac_exeext' + { { case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_try") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; }; then : + ac_retval=0 +else + $as_echo "$as_me: program exited with status $ac_status" >&5 + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=$ac_status +fi + rm -rf conftest.dSYM conftest_ipa8_conftest.oo + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_cxx_try_run + +# ac_fn_cxx_check_header_compile LINENO HEADER VAR INCLUDES +# --------------------------------------------------------- +# Tests whether HEADER exists and can be compiled using the include files in +# INCLUDES, setting the cache variable VAR accordingly. +ac_fn_cxx_check_header_compile () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $2" >&5 +$as_echo_n "checking for $2... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$4 +#include <$2> +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + eval "$3=yes" +else + eval "$3=no" +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_cxx_check_header_compile + +# ac_fn_cxx_try_link LINENO +# ------------------------- +# Try to link conftest.$ac_ext, and return whether this succeeded. +ac_fn_cxx_try_link () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + rm -f conftest.$ac_objext conftest$ac_exeext + if { { ac_try="$ac_link" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_link") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + grep -v '^ *+' conftest.err >conftest.er1 + cat conftest.er1 >&5 + mv -f conftest.er1 conftest.err + fi + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } && { + test -z "$ac_cxx_werror_flag" || + test ! -s conftest.err + } && test -s conftest$ac_exeext && { + test "$cross_compiling" = yes || + test -x conftest$ac_exeext + }; then : + ac_retval=0 +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + + ac_retval=1 +fi + # Delete the IPA/IPO (Inter Procedural Analysis/Optimization) information + # created by the PGI compiler (conftest_ipa8_conftest.oo), as it would + # interfere with the next link command; also delete a directory that is + # left behind by Apple's compiler. We do this before executing the actions. + rm -rf conftest.dSYM conftest_ipa8_conftest.oo + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + as_fn_set_status $ac_retval + +} # ac_fn_cxx_try_link + +# ac_fn_cxx_check_type LINENO TYPE VAR INCLUDES +# --------------------------------------------- +# Tests whether TYPE exists after having included INCLUDES, setting cache +# variable VAR accordingly. +ac_fn_cxx_check_type () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $2" >&5 +$as_echo_n "checking for $2... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + eval "$3=no" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$4 +int +main () +{ +if (sizeof ($2)) + return 0; + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$4 +int +main () +{ +if (sizeof (($2))) + return 0; + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + +else + eval "$3=yes" +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_cxx_check_type + +# ac_fn_c_find_intX_t LINENO BITS VAR +# ----------------------------------- +# Finds a signed integer type with width BITS, setting cache variable VAR +# accordingly. +ac_fn_c_find_intX_t () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for int$2_t" >&5 +$as_echo_n "checking for int$2_t... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + eval "$3=no" + # Order is important - never check a type that is potentially smaller + # than half of the expected target width. + for ac_type in int$2_t 'int' 'long int' \ + 'long long int' 'short int' 'signed char'; do + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$ac_includes_default + enum { N = $2 / 2 - 1 }; +int +main () +{ +static int test_array [1 - 2 * !(0 < ($ac_type) ((((($ac_type) 1 << N) << N) - 1) * 2 + 1))]; +test_array [0] = 0; +return test_array [0]; + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$ac_includes_default + enum { N = $2 / 2 - 1 }; +int +main () +{ +static int test_array [1 - 2 * !(($ac_type) ((((($ac_type) 1 << N) << N) - 1) * 2 + 1) + < ($ac_type) ((((($ac_type) 1 << N) << N) - 1) * 2 + 2))]; +test_array [0] = 0; +return test_array [0]; + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + +else + case $ac_type in #( + int$2_t) : + eval "$3=yes" ;; #( + *) : + eval "$3=\$ac_type" ;; +esac +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + if eval test \"x\$"$3"\" = x"no"; then : + +else + break +fi + done +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_c_find_intX_t + +# ac_fn_c_find_uintX_t LINENO BITS VAR +# ------------------------------------ +# Finds an unsigned integer type with width BITS, setting cache variable VAR +# accordingly. +ac_fn_c_find_uintX_t () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for uint$2_t" >&5 +$as_echo_n "checking for uint$2_t... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + eval "$3=no" + # Order is important - never check a type that is potentially smaller + # than half of the expected target width. + for ac_type in uint$2_t 'unsigned int' 'unsigned long int' \ + 'unsigned long long int' 'unsigned short int' 'unsigned char'; do + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$ac_includes_default +int +main () +{ +static int test_array [1 - 2 * !((($ac_type) -1 >> ($2 / 2 - 1)) >> ($2 / 2 - 1) == 3)]; +test_array [0] = 0; +return test_array [0]; + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + case $ac_type in #( + uint$2_t) : + eval "$3=yes" ;; #( + *) : + eval "$3=\$ac_type" ;; +esac +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + if eval test \"x\$"$3"\" = x"no"; then : + +else + break +fi + done +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_c_find_uintX_t + +# ac_fn_cxx_check_func LINENO FUNC VAR +# ------------------------------------ +# Tests whether FUNC exists, setting the cache variable VAR accordingly +ac_fn_cxx_check_func () +{ + as_lineno=${as_lineno-"$1"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $2" >&5 +$as_echo_n "checking for $2... " >&6; } +if eval \${$3+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +/* Define $2 to an innocuous variant, in case declares $2. + For example, HP-UX 11i declares gettimeofday. */ +#define $2 innocuous_$2 + +/* System header to define __stub macros and hopefully few prototypes, + which can conflict with char $2 (); below. + Prefer to if __STDC__ is defined, since + exists even on freestanding compilers. */ + +#ifdef __STDC__ +# include +#else +# include +#endif + +#undef $2 + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $2 (); +/* The GNU C library defines this for functions which it implements + to always fail with ENOSYS. Some functions are actually named + something starting with __ and the normal name is an alias. */ +#if defined __stub_$2 || defined __stub___$2 +choke me +#endif + +int +main () +{ +return $2 (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$3=yes" +else + eval "$3=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +fi +eval ac_res=\$$3 + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + eval $as_lineno_stack; ${as_lineno_stack:+:} unset as_lineno + +} # ac_fn_cxx_check_func +cat >config.log <<_ACEOF +This file contains any messages produced by compilers while +running configure, to aid debugging if configure makes a mistake. + +It was created by faiss $as_me 1.0, which was +generated by GNU Autoconf 2.69. Invocation command line was + + $ $0 $@ + +_ACEOF +exec 5>>config.log +{ +cat <<_ASUNAME +## --------- ## +## Platform. ## +## --------- ## + +hostname = `(hostname || uname -n) 2>/dev/null | sed 1q` +uname -m = `(uname -m) 2>/dev/null || echo unknown` +uname -r = `(uname -r) 2>/dev/null || echo unknown` +uname -s = `(uname -s) 2>/dev/null || echo unknown` +uname -v = `(uname -v) 2>/dev/null || echo unknown` + +/usr/bin/uname -p = `(/usr/bin/uname -p) 2>/dev/null || echo unknown` +/bin/uname -X = `(/bin/uname -X) 2>/dev/null || echo unknown` + +/bin/arch = `(/bin/arch) 2>/dev/null || echo unknown` +/usr/bin/arch -k = `(/usr/bin/arch -k) 2>/dev/null || echo unknown` +/usr/convex/getsysinfo = `(/usr/convex/getsysinfo) 2>/dev/null || echo unknown` +/usr/bin/hostinfo = `(/usr/bin/hostinfo) 2>/dev/null || echo unknown` +/bin/machine = `(/bin/machine) 2>/dev/null || echo unknown` +/usr/bin/oslevel = `(/usr/bin/oslevel) 2>/dev/null || echo unknown` +/bin/universe = `(/bin/universe) 2>/dev/null || echo unknown` + +_ASUNAME + +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + $as_echo "PATH: $as_dir" + done +IFS=$as_save_IFS + +} >&5 + +cat >&5 <<_ACEOF + + +## ----------- ## +## Core tests. ## +## ----------- ## + +_ACEOF + + +# Keep a trace of the command line. +# Strip out --no-create and --no-recursion so they do not pile up. +# Strip out --silent because we don't want to record it for future runs. +# Also quote any args containing shell meta-characters. +# Make two passes to allow for proper duplicate-argument suppression. +ac_configure_args= +ac_configure_args0= +ac_configure_args1= +ac_must_keep_next=false +for ac_pass in 1 2 +do + for ac_arg + do + case $ac_arg in + -no-create | --no-c* | -n | -no-recursion | --no-r*) continue ;; + -q | -quiet | --quiet | --quie | --qui | --qu | --q \ + | -silent | --silent | --silen | --sile | --sil) + continue ;; + *\'*) + ac_arg=`$as_echo "$ac_arg" | sed "s/'/'\\\\\\\\''/g"` ;; + esac + case $ac_pass in + 1) as_fn_append ac_configure_args0 " '$ac_arg'" ;; + 2) + as_fn_append ac_configure_args1 " '$ac_arg'" + if test $ac_must_keep_next = true; then + ac_must_keep_next=false # Got value, back to normal. + else + case $ac_arg in + *=* | --config-cache | -C | -disable-* | --disable-* \ + | -enable-* | --enable-* | -gas | --g* | -nfp | --nf* \ + | -q | -quiet | --q* | -silent | --sil* | -v | -verb* \ + | -with-* | --with-* | -without-* | --without-* | --x) + case "$ac_configure_args0 " in + "$ac_configure_args1"*" '$ac_arg' "* ) continue ;; + esac + ;; + -* ) ac_must_keep_next=true ;; + esac + fi + as_fn_append ac_configure_args " '$ac_arg'" + ;; + esac + done +done +{ ac_configure_args0=; unset ac_configure_args0;} +{ ac_configure_args1=; unset ac_configure_args1;} + +# When interrupted or exit'd, cleanup temporary files, and complete +# config.log. We remove comments because anyway the quotes in there +# would cause problems or look ugly. +# WARNING: Use '\'' to represent an apostrophe within the trap. +# WARNING: Do not start the trap code with a newline, due to a FreeBSD 4.0 bug. +trap 'exit_status=$? + # Save into config.log some information that might help in debugging. + { + echo + + $as_echo "## ---------------- ## +## Cache variables. ## +## ---------------- ##" + echo + # The following way of writing the cache mishandles newlines in values, +( + for ac_var in `(set) 2>&1 | sed -n '\''s/^\([a-zA-Z_][a-zA-Z0-9_]*\)=.*/\1/p'\''`; do + eval ac_val=\$$ac_var + case $ac_val in #( + *${as_nl}*) + case $ac_var in #( + *_cv_*) { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: cache variable $ac_var contains a newline" >&5 +$as_echo "$as_me: WARNING: cache variable $ac_var contains a newline" >&2;} ;; + esac + case $ac_var in #( + _ | IFS | as_nl) ;; #( + BASH_ARGV | BASH_SOURCE) eval $ac_var= ;; #( + *) { eval $ac_var=; unset $ac_var;} ;; + esac ;; + esac + done + (set) 2>&1 | + case $as_nl`(ac_space='\'' '\''; set) 2>&1` in #( + *${as_nl}ac_space=\ *) + sed -n \ + "s/'\''/'\''\\\\'\'''\''/g; + s/^\\([_$as_cr_alnum]*_cv_[_$as_cr_alnum]*\\)=\\(.*\\)/\\1='\''\\2'\''/p" + ;; #( + *) + sed -n "/^[_$as_cr_alnum]*_cv_[_$as_cr_alnum]*=/p" + ;; + esac | + sort +) + echo + + $as_echo "## ----------------- ## +## Output variables. ## +## ----------------- ##" + echo + for ac_var in $ac_subst_vars + do + eval ac_val=\$$ac_var + case $ac_val in + *\'\''*) ac_val=`$as_echo "$ac_val" | sed "s/'\''/'\''\\\\\\\\'\'''\''/g"`;; + esac + $as_echo "$ac_var='\''$ac_val'\''" + done | sort + echo + + if test -n "$ac_subst_files"; then + $as_echo "## ------------------- ## +## File substitutions. ## +## ------------------- ##" + echo + for ac_var in $ac_subst_files + do + eval ac_val=\$$ac_var + case $ac_val in + *\'\''*) ac_val=`$as_echo "$ac_val" | sed "s/'\''/'\''\\\\\\\\'\'''\''/g"`;; + esac + $as_echo "$ac_var='\''$ac_val'\''" + done | sort + echo + fi + + if test -s confdefs.h; then + $as_echo "## ----------- ## +## confdefs.h. ## +## ----------- ##" + echo + cat confdefs.h + echo + fi + test "$ac_signal" != 0 && + $as_echo "$as_me: caught signal $ac_signal" + $as_echo "$as_me: exit $exit_status" + } >&5 + rm -f core *.core core.conftest.* && + rm -f -r conftest* confdefs* conf$$* $ac_clean_files && + exit $exit_status +' 0 +for ac_signal in 1 2 13 15; do + trap 'ac_signal='$ac_signal'; as_fn_exit 1' $ac_signal +done +ac_signal=0 + +# confdefs.h avoids OS command line length limits that DEFS can exceed. +rm -f -r conftest* confdefs.h + +$as_echo "/* confdefs.h */" > confdefs.h + +# Predefined preprocessor variables. + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_NAME "$PACKAGE_NAME" +_ACEOF + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_TARNAME "$PACKAGE_TARNAME" +_ACEOF + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_VERSION "$PACKAGE_VERSION" +_ACEOF + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_STRING "$PACKAGE_STRING" +_ACEOF + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_BUGREPORT "$PACKAGE_BUGREPORT" +_ACEOF + +cat >>confdefs.h <<_ACEOF +#define PACKAGE_URL "$PACKAGE_URL" +_ACEOF + + +# Let the site file select an alternate cache file if it wants to. +# Prefer an explicitly selected file to automatically selected ones. +ac_site_file1=NONE +ac_site_file2=NONE +if test -n "$CONFIG_SITE"; then + # We do not want a PATH search for config.site. + case $CONFIG_SITE in #(( + -*) ac_site_file1=./$CONFIG_SITE;; + */*) ac_site_file1=$CONFIG_SITE;; + *) ac_site_file1=./$CONFIG_SITE;; + esac +elif test "x$prefix" != xNONE; then + ac_site_file1=$prefix/share/config.site + ac_site_file2=$prefix/etc/config.site +else + ac_site_file1=$ac_default_prefix/share/config.site + ac_site_file2=$ac_default_prefix/etc/config.site +fi +for ac_site_file in "$ac_site_file1" "$ac_site_file2" +do + test "x$ac_site_file" = xNONE && continue + if test /dev/null != "$ac_site_file" && test -r "$ac_site_file"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: loading site script $ac_site_file" >&5 +$as_echo "$as_me: loading site script $ac_site_file" >&6;} + sed 's/^/| /' "$ac_site_file" >&5 + . "$ac_site_file" \ + || { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "failed to load site script $ac_site_file +See \`config.log' for more details" "$LINENO" 5; } + fi +done + +if test -r "$cache_file"; then + # Some versions of bash will fail to source /dev/null (special files + # actually), so we avoid doing that. DJGPP emulates it as a regular file. + if test /dev/null != "$cache_file" && test -f "$cache_file"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: loading cache $cache_file" >&5 +$as_echo "$as_me: loading cache $cache_file" >&6;} + case $cache_file in + [\\/]* | ?:[\\/]* ) . "$cache_file";; + *) . "./$cache_file";; + esac + fi +else + { $as_echo "$as_me:${as_lineno-$LINENO}: creating cache $cache_file" >&5 +$as_echo "$as_me: creating cache $cache_file" >&6;} + >$cache_file +fi + +as_fn_append ac_header_list " stdlib.h" +as_fn_append ac_header_list " unistd.h" +as_fn_append ac_header_list " sys/param.h" +# Check that the precious variables saved in the cache have kept the same +# value. +ac_cache_corrupted=false +for ac_var in $ac_precious_vars; do + eval ac_old_set=\$ac_cv_env_${ac_var}_set + eval ac_new_set=\$ac_env_${ac_var}_set + eval ac_old_val=\$ac_cv_env_${ac_var}_value + eval ac_new_val=\$ac_env_${ac_var}_value + case $ac_old_set,$ac_new_set in + set,) + { $as_echo "$as_me:${as_lineno-$LINENO}: error: \`$ac_var' was set to \`$ac_old_val' in the previous run" >&5 +$as_echo "$as_me: error: \`$ac_var' was set to \`$ac_old_val' in the previous run" >&2;} + ac_cache_corrupted=: ;; + ,set) + { $as_echo "$as_me:${as_lineno-$LINENO}: error: \`$ac_var' was not set in the previous run" >&5 +$as_echo "$as_me: error: \`$ac_var' was not set in the previous run" >&2;} + ac_cache_corrupted=: ;; + ,);; + *) + if test "x$ac_old_val" != "x$ac_new_val"; then + # differences in whitespace do not lead to failure. + ac_old_val_w=`echo x $ac_old_val` + ac_new_val_w=`echo x $ac_new_val` + if test "$ac_old_val_w" != "$ac_new_val_w"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: error: \`$ac_var' has changed since the previous run:" >&5 +$as_echo "$as_me: error: \`$ac_var' has changed since the previous run:" >&2;} + ac_cache_corrupted=: + else + { $as_echo "$as_me:${as_lineno-$LINENO}: warning: ignoring whitespace changes in \`$ac_var' since the previous run:" >&5 +$as_echo "$as_me: warning: ignoring whitespace changes in \`$ac_var' since the previous run:" >&2;} + eval $ac_var=\$ac_old_val + fi + { $as_echo "$as_me:${as_lineno-$LINENO}: former value: \`$ac_old_val'" >&5 +$as_echo "$as_me: former value: \`$ac_old_val'" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: current value: \`$ac_new_val'" >&5 +$as_echo "$as_me: current value: \`$ac_new_val'" >&2;} + fi;; + esac + # Pass precious variables to config.status. + if test "$ac_new_set" = set; then + case $ac_new_val in + *\'*) ac_arg=$ac_var=`$as_echo "$ac_new_val" | sed "s/'/'\\\\\\\\''/g"` ;; + *) ac_arg=$ac_var=$ac_new_val ;; + esac + case " $ac_configure_args " in + *" '$ac_arg' "*) ;; # Avoid dups. Use of quotes ensures accuracy. + *) as_fn_append ac_configure_args " '$ac_arg'" ;; + esac + fi +done +if $ac_cache_corrupted; then + { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} + { $as_echo "$as_me:${as_lineno-$LINENO}: error: changes in the environment can compromise the build" >&5 +$as_echo "$as_me: error: changes in the environment can compromise the build" >&2;} + as_fn_error $? "run \`make distclean' and/or \`rm $cache_file' and start over" "$LINENO" 5 +fi +## -------------------- ## +## Main body of script. ## +## -------------------- ## + +ac_ext=c +ac_cpp='$CPP $CPPFLAGS' +ac_compile='$CC -c $CFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CC -o conftest$ac_exeext $CFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_c_compiler_gnu + + + + +ac_aux_dir= +for ac_dir in build-aux "$srcdir"/build-aux; do + if test -f "$ac_dir/install-sh"; then + ac_aux_dir=$ac_dir + ac_install_sh="$ac_aux_dir/install-sh -c" + break + elif test -f "$ac_dir/install.sh"; then + ac_aux_dir=$ac_dir + ac_install_sh="$ac_aux_dir/install.sh -c" + break + elif test -f "$ac_dir/shtool"; then + ac_aux_dir=$ac_dir + ac_install_sh="$ac_aux_dir/shtool install -c" + break + fi +done +if test -z "$ac_aux_dir"; then + as_fn_error $? "cannot find install-sh, install.sh, or shtool in build-aux \"$srcdir\"/build-aux" "$LINENO" 5 +fi + +# These three variables are undocumented and unsupported, +# and are intended to be withdrawn in a future Autoconf release. +# They can cause serious problems if a builder's source tree is in a directory +# whose full name contains unusual characters. +ac_config_guess="$SHELL $ac_aux_dir/config.guess" # Please don't use this var. +ac_config_sub="$SHELL $ac_aux_dir/config.sub" # Please don't use this var. +ac_configure="$SHELL $ac_aux_dir/configure" # Please don't use this var. + + + + +: ${CXXFLAGS="-g -O3 -Wall -Wextra"} + +# Checks for programs. +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu +if test -z "$CXX"; then + if test -n "$CCC"; then + CXX=$CCC + else + if test -n "$ac_tool_prefix"; then + for ac_prog in g++ c++ gpp aCC CC cxx cc++ cl.exe FCC KCC RCC xlC_r xlC + do + # Extract the first word of "$ac_tool_prefix$ac_prog", so it can be a program name with args. +set dummy $ac_tool_prefix$ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_CXX+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$CXX"; then + ac_cv_prog_CXX="$CXX" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_CXX="$ac_tool_prefix$ac_prog" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +CXX=$ac_cv_prog_CXX +if test -n "$CXX"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $CXX" >&5 +$as_echo "$CXX" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$CXX" && break + done +fi +if test -z "$CXX"; then + ac_ct_CXX=$CXX + for ac_prog in g++ c++ gpp aCC CC cxx cc++ cl.exe FCC KCC RCC xlC_r xlC +do + # Extract the first word of "$ac_prog", so it can be a program name with args. +set dummy $ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_ac_ct_CXX+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$ac_ct_CXX"; then + ac_cv_prog_ac_ct_CXX="$ac_ct_CXX" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_ac_ct_CXX="$ac_prog" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +ac_ct_CXX=$ac_cv_prog_ac_ct_CXX +if test -n "$ac_ct_CXX"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_ct_CXX" >&5 +$as_echo "$ac_ct_CXX" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$ac_ct_CXX" && break +done + + if test "x$ac_ct_CXX" = x; then + CXX="g++" + else + case $cross_compiling:$ac_tool_warned in +yes:) +{ $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: using cross tools not prefixed with host triplet" >&5 +$as_echo "$as_me: WARNING: using cross tools not prefixed with host triplet" >&2;} +ac_tool_warned=yes ;; +esac + CXX=$ac_ct_CXX + fi +fi + + fi +fi +# Provide some information about the compiler. +$as_echo "$as_me:${as_lineno-$LINENO}: checking for C++ compiler version" >&5 +set X $ac_compile +ac_compiler=$2 +for ac_option in --version -v -V -qversion; do + { { ac_try="$ac_compiler $ac_option >&5" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_compiler $ac_option >&5") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + sed '10a\ +... rest of stderr output deleted ... + 10q' conftest.err >conftest.er1 + cat conftest.er1 >&5 + fi + rm -f conftest.er1 conftest.err + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } +done + +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +ac_clean_files_save=$ac_clean_files +ac_clean_files="$ac_clean_files a.out a.out.dSYM a.exe b.out" +# Try to create an executable without -o first, disregard a.out. +# It will help us diagnose broken compilers, and finding out an intuition +# of exeext. +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether the C++ compiler works" >&5 +$as_echo_n "checking whether the C++ compiler works... " >&6; } +ac_link_default=`$as_echo "$ac_link" | sed 's/ -o *conftest[^ ]*//'` + +# The possible output files: +ac_files="a.out conftest.exe conftest a.exe a_out.exe b.out conftest.*" + +ac_rmfiles= +for ac_file in $ac_files +do + case $ac_file in + *.$ac_ext | *.xcoff | *.tds | *.d | *.pdb | *.xSYM | *.bb | *.bbg | *.map | *.inf | *.dSYM | *.o | *.obj ) ;; + * ) ac_rmfiles="$ac_rmfiles $ac_file";; + esac +done +rm -f $ac_rmfiles + +if { { ac_try="$ac_link_default" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_link_default") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; then : + # Autoconf-2.13 could set the ac_cv_exeext variable to `no'. +# So ignore a value of `no', otherwise this would lead to `EXEEXT = no' +# in a Makefile. We should not override ac_cv_exeext if it was cached, +# so that the user can short-circuit this test for compilers unknown to +# Autoconf. +for ac_file in $ac_files '' +do + test -f "$ac_file" || continue + case $ac_file in + *.$ac_ext | *.xcoff | *.tds | *.d | *.pdb | *.xSYM | *.bb | *.bbg | *.map | *.inf | *.dSYM | *.o | *.obj ) + ;; + [ab].out ) + # We found the default executable, but exeext='' is most + # certainly right. + break;; + *.* ) + if test "${ac_cv_exeext+set}" = set && test "$ac_cv_exeext" != no; + then :; else + ac_cv_exeext=`expr "$ac_file" : '[^.]*\(\..*\)'` + fi + # We set ac_cv_exeext here because the later test for it is not + # safe: cross compilers may not add the suffix if given an `-o' + # argument, so we may need to know it at that point already. + # Even if this section looks crufty: it has the advantage of + # actually working. + break;; + * ) + break;; + esac +done +test "$ac_cv_exeext" = no && ac_cv_exeext= + +else + ac_file='' +fi +if test -z "$ac_file"; then : + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +$as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + +{ { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error 77 "C++ compiler cannot create executables +See \`config.log' for more details" "$LINENO" 5; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: yes" >&5 +$as_echo "yes" >&6; } +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for C++ compiler default output file name" >&5 +$as_echo_n "checking for C++ compiler default output file name... " >&6; } +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_file" >&5 +$as_echo "$ac_file" >&6; } +ac_exeext=$ac_cv_exeext + +rm -f -r a.out a.out.dSYM a.exe conftest$ac_cv_exeext b.out +ac_clean_files=$ac_clean_files_save +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for suffix of executables" >&5 +$as_echo_n "checking for suffix of executables... " >&6; } +if { { ac_try="$ac_link" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_link") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; then : + # If both `conftest.exe' and `conftest' are `present' (well, observable) +# catch `conftest.exe'. For instance with Cygwin, `ls conftest' will +# work properly (i.e., refer to `conftest.exe'), while it won't with +# `rm'. +for ac_file in conftest.exe conftest conftest.*; do + test -f "$ac_file" || continue + case $ac_file in + *.$ac_ext | *.xcoff | *.tds | *.d | *.pdb | *.xSYM | *.bb | *.bbg | *.map | *.inf | *.dSYM | *.o | *.obj ) ;; + *.* ) ac_cv_exeext=`expr "$ac_file" : '[^.]*\(\..*\)'` + break;; + * ) break;; + esac +done +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "cannot compute suffix of executables: cannot compile and link +See \`config.log' for more details" "$LINENO" 5; } +fi +rm -f conftest conftest$ac_cv_exeext +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_exeext" >&5 +$as_echo "$ac_cv_exeext" >&6; } + +rm -f conftest.$ac_ext +EXEEXT=$ac_cv_exeext +ac_exeext=$EXEEXT +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +int +main () +{ +FILE *f = fopen ("conftest.out", "w"); + return ferror (f) || fclose (f) != 0; + + ; + return 0; +} +_ACEOF +ac_clean_files="$ac_clean_files conftest.out" +# Check that the compiler produces executables we can run. If not, either +# the compiler is broken, or we cross compile. +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether we are cross compiling" >&5 +$as_echo_n "checking whether we are cross compiling... " >&6; } +if test "$cross_compiling" != yes; then + { { ac_try="$ac_link" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_link") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } + if { ac_try='./conftest$ac_cv_exeext' + { { case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_try") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; }; then + cross_compiling=no + else + if test "$cross_compiling" = maybe; then + cross_compiling=yes + else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "cannot run C++ compiled programs. +If you meant to cross compile, use \`--host'. +See \`config.log' for more details" "$LINENO" 5; } + fi + fi +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $cross_compiling" >&5 +$as_echo "$cross_compiling" >&6; } + +rm -f conftest.$ac_ext conftest$ac_cv_exeext conftest.out +ac_clean_files=$ac_clean_files_save +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for suffix of object files" >&5 +$as_echo_n "checking for suffix of object files... " >&6; } +if ${ac_cv_objext+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +rm -f conftest.o conftest.obj +if { { ac_try="$ac_compile" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_compile") 2>&5 + ac_status=$? + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; }; then : + for ac_file in conftest.o conftest.obj conftest.*; do + test -f "$ac_file" || continue; + case $ac_file in + *.$ac_ext | *.xcoff | *.tds | *.d | *.pdb | *.xSYM | *.bb | *.bbg | *.map | *.inf | *.dSYM ) ;; + *) ac_cv_objext=`expr "$ac_file" : '.*\.\(.*\)'` + break;; + esac +done +else + $as_echo "$as_me: failed program was:" >&5 +sed 's/^/| /' conftest.$ac_ext >&5 + +{ { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "cannot compute suffix of object files: cannot compile +See \`config.log' for more details" "$LINENO" 5; } +fi +rm -f conftest.$ac_cv_objext conftest.$ac_ext +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_objext" >&5 +$as_echo "$ac_cv_objext" >&6; } +OBJEXT=$ac_cv_objext +ac_objext=$OBJEXT +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether we are using the GNU C++ compiler" >&5 +$as_echo_n "checking whether we are using the GNU C++ compiler... " >&6; } +if ${ac_cv_cxx_compiler_gnu+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ +#ifndef __GNUC__ + choke me +#endif + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_compiler_gnu=yes +else + ac_compiler_gnu=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +ac_cv_cxx_compiler_gnu=$ac_compiler_gnu + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_cxx_compiler_gnu" >&5 +$as_echo "$ac_cv_cxx_compiler_gnu" >&6; } +if test $ac_compiler_gnu = yes; then + GXX=yes +else + GXX= +fi +ac_test_CXXFLAGS=${CXXFLAGS+set} +ac_save_CXXFLAGS=$CXXFLAGS +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether $CXX accepts -g" >&5 +$as_echo_n "checking whether $CXX accepts -g... " >&6; } +if ${ac_cv_prog_cxx_g+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_save_cxx_werror_flag=$ac_cxx_werror_flag + ac_cxx_werror_flag=yes + ac_cv_prog_cxx_g=no + CXXFLAGS="-g" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_prog_cxx_g=yes +else + CXXFLAGS="" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + +else + ac_cxx_werror_flag=$ac_save_cxx_werror_flag + CXXFLAGS="-g" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_prog_cxx_g=yes +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + ac_cxx_werror_flag=$ac_save_cxx_werror_flag +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_prog_cxx_g" >&5 +$as_echo "$ac_cv_prog_cxx_g" >&6; } +if test "$ac_test_CXXFLAGS" = set; then + CXXFLAGS=$ac_save_CXXFLAGS +elif test $ac_cv_prog_cxx_g = yes; then + if test "$GXX" = yes; then + CXXFLAGS="-g -O2" + else + CXXFLAGS="-g" + fi +else + if test "$GXX" = yes; then + CXXFLAGS="-O2" + else + CXXFLAGS= + fi +fi +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + + + ax_cxx_compile_alternatives="11 0x" ax_cxx_compile_cxx11_required=true + ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + ac_success=no + + + + if test x$ac_success = xno; then + for alternative in ${ax_cxx_compile_alternatives}; do + for switch in -std=c++${alternative} +std=c++${alternative} "-h std=c++${alternative}"; do + cachevar=`$as_echo "ax_cv_cxx_compile_cxx11_$switch" | $as_tr_sh` + { $as_echo "$as_me:${as_lineno-$LINENO}: checking whether $CXX supports C++11 features with $switch" >&5 +$as_echo_n "checking whether $CXX supports C++11 features with $switch... " >&6; } +if eval \${$cachevar+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_save_CXX="$CXX" + CXX="$CXX $switch" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + + +// If the compiler admits that it is not ready for C++11, why torture it? +// Hopefully, this will speed up the test. + +#ifndef __cplusplus + +#error "This is not a C++ compiler" + +#elif __cplusplus < 201103L + +#error "This is not a C++11 compiler" + +#else + +namespace cxx11 +{ + + namespace test_static_assert + { + + template + struct check + { + static_assert(sizeof(int) <= sizeof(T), "not big enough"); + }; + + } + + namespace test_final_override + { + + struct Base + { + virtual void f() {} + }; + + struct Derived : public Base + { + virtual void f() override {} + }; + + } + + namespace test_double_right_angle_brackets + { + + template < typename T > + struct check {}; + + typedef check single_type; + typedef check> double_type; + typedef check>> triple_type; + typedef check>>> quadruple_type; + + } + + namespace test_decltype + { + + int + f() + { + int a = 1; + decltype(a) b = 2; + return a + b; + } + + } + + namespace test_type_deduction + { + + template < typename T1, typename T2 > + struct is_same + { + static const bool value = false; + }; + + template < typename T > + struct is_same + { + static const bool value = true; + }; + + template < typename T1, typename T2 > + auto + add(T1 a1, T2 a2) -> decltype(a1 + a2) + { + return a1 + a2; + } + + int + test(const int c, volatile int v) + { + static_assert(is_same::value == true, ""); + static_assert(is_same::value == false, ""); + static_assert(is_same::value == false, ""); + auto ac = c; + auto av = v; + auto sumi = ac + av + 'x'; + auto sumf = ac + av + 1.0; + static_assert(is_same::value == true, ""); + static_assert(is_same::value == true, ""); + static_assert(is_same::value == true, ""); + static_assert(is_same::value == false, ""); + static_assert(is_same::value == true, ""); + return (sumf > 0.0) ? sumi : add(c, v); + } + + } + + namespace test_noexcept + { + + int f() { return 0; } + int g() noexcept { return 0; } + + static_assert(noexcept(f()) == false, ""); + static_assert(noexcept(g()) == true, ""); + + } + + namespace test_constexpr + { + + template < typename CharT > + unsigned long constexpr + strlen_c_r(const CharT *const s, const unsigned long acc) noexcept + { + return *s ? strlen_c_r(s + 1, acc + 1) : acc; + } + + template < typename CharT > + unsigned long constexpr + strlen_c(const CharT *const s) noexcept + { + return strlen_c_r(s, 0UL); + } + + static_assert(strlen_c("") == 0UL, ""); + static_assert(strlen_c("1") == 1UL, ""); + static_assert(strlen_c("example") == 7UL, ""); + static_assert(strlen_c("another\0example") == 7UL, ""); + + } + + namespace test_rvalue_references + { + + template < int N > + struct answer + { + static constexpr int value = N; + }; + + answer<1> f(int&) { return answer<1>(); } + answer<2> f(const int&) { return answer<2>(); } + answer<3> f(int&&) { return answer<3>(); } + + void + test() + { + int i = 0; + const int c = 0; + static_assert(decltype(f(i))::value == 1, ""); + static_assert(decltype(f(c))::value == 2, ""); + static_assert(decltype(f(0))::value == 3, ""); + } + + } + + namespace test_uniform_initialization + { + + struct test + { + static const int zero {}; + static const int one {1}; + }; + + static_assert(test::zero == 0, ""); + static_assert(test::one == 1, ""); + + } + + namespace test_lambdas + { + + void + test1() + { + auto lambda1 = [](){}; + auto lambda2 = lambda1; + lambda1(); + lambda2(); + } + + int + test2() + { + auto a = [](int i, int j){ return i + j; }(1, 2); + auto b = []() -> int { return '0'; }(); + auto c = [=](){ return a + b; }(); + auto d = [&](){ return c; }(); + auto e = [a, &b](int x) mutable { + const auto identity = [](int y){ return y; }; + for (auto i = 0; i < a; ++i) + a += b--; + return x + identity(a + b); + }(0); + return a + b + c + d + e; + } + + int + test3() + { + const auto nullary = [](){ return 0; }; + const auto unary = [](int x){ return x; }; + using nullary_t = decltype(nullary); + using unary_t = decltype(unary); + const auto higher1st = [](nullary_t f){ return f(); }; + const auto higher2nd = [unary](nullary_t f1){ + return [unary, f1](unary_t f2){ return f2(unary(f1())); }; + }; + return higher1st(nullary) + higher2nd(nullary)(unary); + } + + } + + namespace test_variadic_templates + { + + template + struct sum; + + template + struct sum + { + static constexpr auto value = N0 + sum::value; + }; + + template <> + struct sum<> + { + static constexpr auto value = 0; + }; + + static_assert(sum<>::value == 0, ""); + static_assert(sum<1>::value == 1, ""); + static_assert(sum<23>::value == 23, ""); + static_assert(sum<1, 2>::value == 3, ""); + static_assert(sum<5, 5, 11>::value == 21, ""); + static_assert(sum<2, 3, 5, 7, 11, 13>::value == 41, ""); + + } + + // http://stackoverflow.com/questions/13728184/template-aliases-and-sfinae + // Clang 3.1 fails with headers of libstd++ 4.8.3 when using std::function + // because of this. + namespace test_template_alias_sfinae + { + + struct foo {}; + + template + using member = typename T::member_type; + + template + void func(...) {} + + template + void func(member*) {} + + void test(); + + void test() { func(0); } + + } + +} // namespace cxx11 + +#endif // __cplusplus >= 201103L + + + +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + eval $cachevar=yes +else + eval $cachevar=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + CXX="$ac_save_CXX" +fi +eval ac_res=\$$cachevar + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } + if eval test x\$$cachevar = xyes; then + CXX="$CXX $switch" + if test -n "$CXXCPP" ; then + CXXCPP="$CXXCPP $switch" + fi + ac_success=yes + break + fi + done + if test x$ac_success = xyes; then + break + fi + done + fi + ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + + if test x$ax_cxx_compile_cxx11_required = xtrue; then + if test x$ac_success = xno; then + as_fn_error $? "*** A compiler with support for C++11 language features is required." "$LINENO" 5 + fi + fi + if test x$ac_success = xno; then + HAVE_CXX11=0 + { $as_echo "$as_me:${as_lineno-$LINENO}: No compiler with C++11 support was found" >&5 +$as_echo "$as_me: No compiler with C++11 support was found" >&6;} + else + HAVE_CXX11=1 + +$as_echo "#define HAVE_CXX11 1" >>confdefs.h + + fi + + +ac_ext=c +ac_cpp='$CPP $CPPFLAGS' +ac_compile='$CC -c $CFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CC -o conftest$ac_exeext $CFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_c_compiler_gnu +if test -n "$ac_tool_prefix"; then + # Extract the first word of "${ac_tool_prefix}gcc", so it can be a program name with args. +set dummy ${ac_tool_prefix}gcc; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$CC"; then + ac_cv_prog_CC="$CC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_CC="${ac_tool_prefix}gcc" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +CC=$ac_cv_prog_CC +if test -n "$CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $CC" >&5 +$as_echo "$CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + +fi +if test -z "$ac_cv_prog_CC"; then + ac_ct_CC=$CC + # Extract the first word of "gcc", so it can be a program name with args. +set dummy gcc; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_ac_ct_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$ac_ct_CC"; then + ac_cv_prog_ac_ct_CC="$ac_ct_CC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_ac_ct_CC="gcc" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +ac_ct_CC=$ac_cv_prog_ac_ct_CC +if test -n "$ac_ct_CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_ct_CC" >&5 +$as_echo "$ac_ct_CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + if test "x$ac_ct_CC" = x; then + CC="" + else + case $cross_compiling:$ac_tool_warned in +yes:) +{ $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: using cross tools not prefixed with host triplet" >&5 +$as_echo "$as_me: WARNING: using cross tools not prefixed with host triplet" >&2;} +ac_tool_warned=yes ;; +esac + CC=$ac_ct_CC + fi +else + CC="$ac_cv_prog_CC" +fi + +if test -z "$CC"; then + if test -n "$ac_tool_prefix"; then + # Extract the first word of "${ac_tool_prefix}cc", so it can be a program name with args. +set dummy ${ac_tool_prefix}cc; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$CC"; then + ac_cv_prog_CC="$CC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_CC="${ac_tool_prefix}cc" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +CC=$ac_cv_prog_CC +if test -n "$CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $CC" >&5 +$as_echo "$CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + fi +fi +if test -z "$CC"; then + # Extract the first word of "cc", so it can be a program name with args. +set dummy cc; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$CC"; then + ac_cv_prog_CC="$CC" # Let the user override the test. +else + ac_prog_rejected=no +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + if test "$as_dir/$ac_word$ac_exec_ext" = "/usr/ucb/cc"; then + ac_prog_rejected=yes + continue + fi + ac_cv_prog_CC="cc" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +if test $ac_prog_rejected = yes; then + # We found a bogon in the path, so make sure we never use it. + set dummy $ac_cv_prog_CC + shift + if test $# != 0; then + # We chose a different compiler from the bogus one. + # However, it has the same basename, so the bogon will be chosen + # first if we set CC to just the basename; use the full file name. + shift + ac_cv_prog_CC="$as_dir/$ac_word${1+' '}$@" + fi +fi +fi +fi +CC=$ac_cv_prog_CC +if test -n "$CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $CC" >&5 +$as_echo "$CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + +fi +if test -z "$CC"; then + if test -n "$ac_tool_prefix"; then + for ac_prog in cl.exe + do + # Extract the first word of "$ac_tool_prefix$ac_prog", so it can be a program name with args. +set dummy $ac_tool_prefix$ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$CC"; then + ac_cv_prog_CC="$CC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_CC="$ac_tool_prefix$ac_prog" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +CC=$ac_cv_prog_CC +if test -n "$CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $CC" >&5 +$as_echo "$CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$CC" && break + done +fi +if test -z "$CC"; then + ac_ct_CC=$CC + for ac_prog in cl.exe +do + # Extract the first word of "$ac_prog", so it can be a program name with args. +set dummy $ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_ac_ct_CC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$ac_ct_CC"; then + ac_cv_prog_ac_ct_CC="$ac_ct_CC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_ac_ct_CC="$ac_prog" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +ac_ct_CC=$ac_cv_prog_ac_ct_CC +if test -n "$ac_ct_CC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_ct_CC" >&5 +$as_echo "$ac_ct_CC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$ac_ct_CC" && break +done + + if test "x$ac_ct_CC" = x; then + CC="" + else + case $cross_compiling:$ac_tool_warned in +yes:) +{ $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: using cross tools not prefixed with host triplet" >&5 +$as_echo "$as_me: WARNING: using cross tools not prefixed with host triplet" >&2;} +ac_tool_warned=yes ;; +esac + CC=$ac_ct_CC + fi +fi + +fi + + +test -z "$CC" && { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "no acceptable C compiler found in \$PATH +See \`config.log' for more details" "$LINENO" 5; } + +# Provide some information about the compiler. +$as_echo "$as_me:${as_lineno-$LINENO}: checking for C compiler version" >&5 +set X $ac_compile +ac_compiler=$2 +for ac_option in --version -v -V -qversion; do + { { ac_try="$ac_compiler $ac_option >&5" +case "(($ac_try" in + *\"* | *\`* | *\\*) ac_try_echo=\$ac_try;; + *) ac_try_echo=$ac_try;; +esac +eval ac_try_echo="\"\$as_me:${as_lineno-$LINENO}: $ac_try_echo\"" +$as_echo "$ac_try_echo"; } >&5 + (eval "$ac_compiler $ac_option >&5") 2>conftest.err + ac_status=$? + if test -s conftest.err; then + sed '10a\ +... rest of stderr output deleted ... + 10q' conftest.err >conftest.er1 + cat conftest.er1 >&5 + fi + rm -f conftest.er1 conftest.err + $as_echo "$as_me:${as_lineno-$LINENO}: \$? = $ac_status" >&5 + test $ac_status = 0; } +done + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether we are using the GNU C compiler" >&5 +$as_echo_n "checking whether we are using the GNU C compiler... " >&6; } +if ${ac_cv_c_compiler_gnu+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ +#ifndef __GNUC__ + choke me +#endif + + ; + return 0; +} +_ACEOF +if ac_fn_c_try_compile "$LINENO"; then : + ac_compiler_gnu=yes +else + ac_compiler_gnu=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +ac_cv_c_compiler_gnu=$ac_compiler_gnu + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_c_compiler_gnu" >&5 +$as_echo "$ac_cv_c_compiler_gnu" >&6; } +if test $ac_compiler_gnu = yes; then + GCC=yes +else + GCC= +fi +ac_test_CFLAGS=${CFLAGS+set} +ac_save_CFLAGS=$CFLAGS +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether $CC accepts -g" >&5 +$as_echo_n "checking whether $CC accepts -g... " >&6; } +if ${ac_cv_prog_cc_g+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_save_c_werror_flag=$ac_c_werror_flag + ac_c_werror_flag=yes + ac_cv_prog_cc_g=no + CFLAGS="-g" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_c_try_compile "$LINENO"; then : + ac_cv_prog_cc_g=yes +else + CFLAGS="" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_c_try_compile "$LINENO"; then : + +else + ac_c_werror_flag=$ac_save_c_werror_flag + CFLAGS="-g" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_c_try_compile "$LINENO"; then : + ac_cv_prog_cc_g=yes +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + ac_c_werror_flag=$ac_save_c_werror_flag +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_prog_cc_g" >&5 +$as_echo "$ac_cv_prog_cc_g" >&6; } +if test "$ac_test_CFLAGS" = set; then + CFLAGS=$ac_save_CFLAGS +elif test $ac_cv_prog_cc_g = yes; then + if test "$GCC" = yes; then + CFLAGS="-g -O2" + else + CFLAGS="-g" + fi +else + if test "$GCC" = yes; then + CFLAGS="-O2" + else + CFLAGS= + fi +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $CC option to accept ISO C89" >&5 +$as_echo_n "checking for $CC option to accept ISO C89... " >&6; } +if ${ac_cv_prog_cc_c89+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_cv_prog_cc_c89=no +ac_save_CC=$CC +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +#include +struct stat; +/* Most of the following tests are stolen from RCS 5.7's src/conf.sh. */ +struct buf { int x; }; +FILE * (*rcsopen) (struct buf *, struct stat *, int); +static char *e (p, i) + char **p; + int i; +{ + return p[i]; +} +static char *f (char * (*g) (char **, int), char **p, ...) +{ + char *s; + va_list v; + va_start (v,p); + s = g (p, va_arg (v,int)); + va_end (v); + return s; +} + +/* OSF 4.0 Compaq cc is some sort of almost-ANSI by default. It has + function prototypes and stuff, but not '\xHH' hex character constants. + These don't provoke an error unfortunately, instead are silently treated + as 'x'. The following induces an error, until -std is added to get + proper ANSI mode. Curiously '\x00'!='x' always comes out true, for an + array size at least. It's necessary to write '\x00'==0 to get something + that's true only with -std. */ +int osf4_cc_array ['\x00' == 0 ? 1 : -1]; + +/* IBM C 6 for AIX is almost-ANSI by default, but it replaces macro parameters + inside strings and character constants. */ +#define FOO(x) 'x' +int xlc6_cc_array[FOO(a) == 'x' ? 1 : -1]; + +int test (int i, double x); +struct s1 {int (*f) (int a);}; +struct s2 {int (*f) (double a);}; +int pairnames (int, char **, FILE *(*)(struct buf *, struct stat *, int), int, int); +int argc; +char **argv; +int +main () +{ +return f (e, argv, 0) != argv[0] || f (e, argv, 1) != argv[1]; + ; + return 0; +} +_ACEOF +for ac_arg in '' -qlanglvl=extc89 -qlanglvl=ansi -std \ + -Ae "-Aa -D_HPUX_SOURCE" "-Xc -D__EXTENSIONS__" +do + CC="$ac_save_CC $ac_arg" + if ac_fn_c_try_compile "$LINENO"; then : + ac_cv_prog_cc_c89=$ac_arg +fi +rm -f core conftest.err conftest.$ac_objext + test "x$ac_cv_prog_cc_c89" != "xno" && break +done +rm -f conftest.$ac_ext +CC=$ac_save_CC + +fi +# AC_CACHE_VAL +case "x$ac_cv_prog_cc_c89" in + x) + { $as_echo "$as_me:${as_lineno-$LINENO}: result: none needed" >&5 +$as_echo "none needed" >&6; } ;; + xno) + { $as_echo "$as_me:${as_lineno-$LINENO}: result: unsupported" >&5 +$as_echo "unsupported" >&6; } ;; + *) + CC="$CC $ac_cv_prog_cc_c89" + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_prog_cc_c89" >&5 +$as_echo "$ac_cv_prog_cc_c89" >&6; } ;; +esac +if test "x$ac_cv_prog_cc_c89" != xno; then : + +fi + +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + +ac_ext=c +ac_cpp='$CPP $CPPFLAGS' +ac_compile='$CC -c $CFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CC -o conftest$ac_exeext $CFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_c_compiler_gnu +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking how to run the C preprocessor" >&5 +$as_echo_n "checking how to run the C preprocessor... " >&6; } +# On Suns, sometimes $CPP names a directory. +if test -n "$CPP" && test -d "$CPP"; then + CPP= +fi +if test -z "$CPP"; then + if ${ac_cv_prog_CPP+:} false; then : + $as_echo_n "(cached) " >&6 +else + # Double quotes because CPP needs to be expanded + for CPP in "$CC -E" "$CC -E -traditional-cpp" "/lib/cpp" + do + ac_preproc_ok=false +for ac_c_preproc_warn_flag in '' yes +do + # Use a header file that comes with gcc, so configuring glibc + # with a fresh cross-compiler works. + # Prefer to if __STDC__ is defined, since + # exists even on freestanding compilers. + # On the NeXT, cc -E runs the code through the compiler's parser, + # not just through cpp. "Syntax error" is here to catch this case. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#ifdef __STDC__ +# include +#else +# include +#endif + Syntax error +_ACEOF +if ac_fn_c_try_cpp "$LINENO"; then : + +else + # Broken: fails on valid input. +continue +fi +rm -f conftest.err conftest.i conftest.$ac_ext + + # OK, works on sane cases. Now check whether nonexistent headers + # can be detected and how. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +_ACEOF +if ac_fn_c_try_cpp "$LINENO"; then : + # Broken: success on invalid input. +continue +else + # Passes both tests. +ac_preproc_ok=: +break +fi +rm -f conftest.err conftest.i conftest.$ac_ext + +done +# Because of `break', _AC_PREPROC_IFELSE's cleaning code was skipped. +rm -f conftest.i conftest.err conftest.$ac_ext +if $ac_preproc_ok; then : + break +fi + + done + ac_cv_prog_CPP=$CPP + +fi + CPP=$ac_cv_prog_CPP +else + ac_cv_prog_CPP=$CPP +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $CPP" >&5 +$as_echo "$CPP" >&6; } +ac_preproc_ok=false +for ac_c_preproc_warn_flag in '' yes +do + # Use a header file that comes with gcc, so configuring glibc + # with a fresh cross-compiler works. + # Prefer to if __STDC__ is defined, since + # exists even on freestanding compilers. + # On the NeXT, cc -E runs the code through the compiler's parser, + # not just through cpp. "Syntax error" is here to catch this case. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#ifdef __STDC__ +# include +#else +# include +#endif + Syntax error +_ACEOF +if ac_fn_c_try_cpp "$LINENO"; then : + +else + # Broken: fails on valid input. +continue +fi +rm -f conftest.err conftest.i conftest.$ac_ext + + # OK, works on sane cases. Now check whether nonexistent headers + # can be detected and how. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +_ACEOF +if ac_fn_c_try_cpp "$LINENO"; then : + # Broken: success on invalid input. +continue +else + # Passes both tests. +ac_preproc_ok=: +break +fi +rm -f conftest.err conftest.i conftest.$ac_ext + +done +# Because of `break', _AC_PREPROC_IFELSE's cleaning code was skipped. +rm -f conftest.i conftest.err conftest.$ac_ext +if $ac_preproc_ok; then : + +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "C preprocessor \"$CPP\" fails sanity check +See \`config.log' for more details" "$LINENO" 5; } +fi + +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking whether ${MAKE-make} sets \$(MAKE)" >&5 +$as_echo_n "checking whether ${MAKE-make} sets \$(MAKE)... " >&6; } +set x ${MAKE-make} +ac_make=`$as_echo "$2" | sed 's/+/p/g; s/[^a-zA-Z0-9_]/_/g'` +if eval \${ac_cv_prog_make_${ac_make}_set+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat >conftest.make <<\_ACEOF +SHELL = /bin/sh +all: + @echo '@@@%%%=$(MAKE)=@@@%%%' +_ACEOF +# GNU make sometimes prints "make[1]: Entering ...", which would confuse us. +case `${MAKE-make} -f conftest.make 2>/dev/null` in + *@@@%%%=?*=@@@%%%*) + eval ac_cv_prog_make_${ac_make}_set=yes;; + *) + eval ac_cv_prog_make_${ac_make}_set=no;; +esac +rm -f conftest.make +fi +if eval test \$ac_cv_prog_make_${ac_make}_set = yes; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: yes" >&5 +$as_echo "yes" >&6; } + SET_MAKE= +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } + SET_MAKE="MAKE=${MAKE-make}" +fi + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for a thread-safe mkdir -p" >&5 +$as_echo_n "checking for a thread-safe mkdir -p... " >&6; } +if test -z "$MKDIR_P"; then + if ${ac_cv_path_mkdir+:} false; then : + $as_echo_n "(cached) " >&6 +else + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH$PATH_SEPARATOR/opt/sfw/bin +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_prog in mkdir gmkdir; do + for ac_exec_ext in '' $ac_executable_extensions; do + as_fn_executable_p "$as_dir/$ac_prog$ac_exec_ext" || continue + case `"$as_dir/$ac_prog$ac_exec_ext" --version 2>&1` in #( + 'mkdir (GNU coreutils) '* | \ + 'mkdir (coreutils) '* | \ + 'mkdir (fileutils) '4.1*) + ac_cv_path_mkdir=$as_dir/$ac_prog$ac_exec_ext + break 3;; + esac + done + done + done +IFS=$as_save_IFS + +fi + + test -d ./--version && rmdir ./--version + if test "${ac_cv_path_mkdir+set}" = set; then + MKDIR_P="$ac_cv_path_mkdir -p" + else + # As a last resort, use the slow shell script. Don't cache a + # value for MKDIR_P within a source directory, because that will + # break other packages using the cache if that directory is + # removed, or if the value is a relative name. + MKDIR_P="$ac_install_sh -d" + fi +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $MKDIR_P" >&5 +$as_echo "$MKDIR_P" >&6; } + + + + + +# Check whether --with-python was given. +if test "${with_python+set}" = set; then : + withval=$with_python; +fi + +case $with_python in + "") PYTHON_BIN=python ;; + *) PYTHON_BIN="$with_python" +esac + +# Extract the first word of "$PYTHON_BIN", so it can be a program name with args. +set dummy $PYTHON_BIN; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_PYTHON+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$PYTHON"; then + ac_cv_prog_PYTHON="$PYTHON" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_PYTHON="$PYTHON_BIN" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +PYTHON=$ac_cv_prog_PYTHON +if test -n "$PYTHON"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $PYTHON" >&5 +$as_echo "$PYTHON" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + +fa_python_bin=$PYTHON + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for Python C flags" >&5 +$as_echo_n "checking for Python C flags... " >&6; } +fa_python_cflags=`$PYTHON -c " +import sysconfig +paths = ['-I' + sysconfig.get_path(p) for p in ['include', 'platinclude']] +print(' '.join(paths))"` +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $fa_python_cflags" >&5 +$as_echo "$fa_python_cflags" >&6; } +PYTHON_CFLAGS="$PYTHON_CFLAGS $fa_python_cflags" + + + + +if test x$PYTHON != x; then + + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for numpy headers path" >&5 +$as_echo_n "checking for numpy headers path... " >&6; } + +fa_numpy_headers=`$PYTHON -c "import numpy; print(numpy.get_include())"` + +if test $? == 0; then + if test x$fa_numpy_headers != x; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $fa_numpy_headers" >&5 +$as_echo "$fa_numpy_headers" >&6; } + NUMPY_INCLUDE=$fa_numpy_headers + + else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: not found" >&5 +$as_echo "not found" >&6; } + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: You won't be able to build the python interface." >&5 +$as_echo "$as_me: WARNING: You won't be able to build the python interface." >&2;} + fi +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: not found" >&5 +$as_echo "not found" >&6; } + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: You won't be able to build the python interface." >&5 +$as_echo "$as_me: WARNING: You won't be able to build the python interface." >&2;} +fi + +fi + + + + +# Check whether --with-swig was given. +if test "${with_swig+set}" = set; then : + withval=$with_swig; +fi + +case $with_swig in + "") # Extract the first word of "swig", so it can be a program name with args. +set dummy swig; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_SWIG+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$SWIG"; then + ac_cv_prog_SWIG="$SWIG" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_SWIG="swig" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +SWIG=$ac_cv_prog_SWIG +if test -n "$SWIG"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $SWIG" >&5 +$as_echo "$SWIG" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + +;; + *) SWIG="$with_swig" +esac + + + + +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking how to run the C++ preprocessor" >&5 +$as_echo_n "checking how to run the C++ preprocessor... " >&6; } +if test -z "$CXXCPP"; then + if ${ac_cv_prog_CXXCPP+:} false; then : + $as_echo_n "(cached) " >&6 +else + # Double quotes because CXXCPP needs to be expanded + for CXXCPP in "$CXX -E" "/lib/cpp" + do + ac_preproc_ok=false +for ac_cxx_preproc_warn_flag in '' yes +do + # Use a header file that comes with gcc, so configuring glibc + # with a fresh cross-compiler works. + # Prefer to if __STDC__ is defined, since + # exists even on freestanding compilers. + # On the NeXT, cc -E runs the code through the compiler's parser, + # not just through cpp. "Syntax error" is here to catch this case. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#ifdef __STDC__ +# include +#else +# include +#endif + Syntax error +_ACEOF +if ac_fn_cxx_try_cpp "$LINENO"; then : + +else + # Broken: fails on valid input. +continue +fi +rm -f conftest.err conftest.i conftest.$ac_ext + + # OK, works on sane cases. Now check whether nonexistent headers + # can be detected and how. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +_ACEOF +if ac_fn_cxx_try_cpp "$LINENO"; then : + # Broken: success on invalid input. +continue +else + # Passes both tests. +ac_preproc_ok=: +break +fi +rm -f conftest.err conftest.i conftest.$ac_ext + +done +# Because of `break', _AC_PREPROC_IFELSE's cleaning code was skipped. +rm -f conftest.i conftest.err conftest.$ac_ext +if $ac_preproc_ok; then : + break +fi + + done + ac_cv_prog_CXXCPP=$CXXCPP + +fi + CXXCPP=$ac_cv_prog_CXXCPP +else + ac_cv_prog_CXXCPP=$CXXCPP +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $CXXCPP" >&5 +$as_echo "$CXXCPP" >&6; } +ac_preproc_ok=false +for ac_cxx_preproc_warn_flag in '' yes +do + # Use a header file that comes with gcc, so configuring glibc + # with a fresh cross-compiler works. + # Prefer to if __STDC__ is defined, since + # exists even on freestanding compilers. + # On the NeXT, cc -E runs the code through the compiler's parser, + # not just through cpp. "Syntax error" is here to catch this case. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#ifdef __STDC__ +# include +#else +# include +#endif + Syntax error +_ACEOF +if ac_fn_cxx_try_cpp "$LINENO"; then : + +else + # Broken: fails on valid input. +continue +fi +rm -f conftest.err conftest.i conftest.$ac_ext + + # OK, works on sane cases. Now check whether nonexistent headers + # can be detected and how. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +_ACEOF +if ac_fn_cxx_try_cpp "$LINENO"; then : + # Broken: success on invalid input. +continue +else + # Passes both tests. +ac_preproc_ok=: +break +fi +rm -f conftest.err conftest.i conftest.$ac_ext + +done +# Because of `break', _AC_PREPROC_IFELSE's cleaning code was skipped. +rm -f conftest.i conftest.err conftest.$ac_ext +if $ac_preproc_ok; then : + +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "C++ preprocessor \"$CXXCPP\" fails sanity check +See \`config.log' for more details" "$LINENO" 5; } +fi + +ac_ext=cpp +ac_cpp='$CXXCPP $CPPFLAGS' +ac_compile='$CXX -c $CXXFLAGS $CPPFLAGS conftest.$ac_ext >&5' +ac_link='$CXX -o conftest$ac_exeext $CXXFLAGS $CPPFLAGS $LDFLAGS conftest.$ac_ext $LIBS >&5' +ac_compiler_gnu=$ac_cv_cxx_compiler_gnu + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for grep that handles long lines and -e" >&5 +$as_echo_n "checking for grep that handles long lines and -e... " >&6; } +if ${ac_cv_path_GREP+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -z "$GREP"; then + ac_path_GREP_found=false + # Loop through the user's path and test for each of PROGNAME-LIST + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH$PATH_SEPARATOR/usr/xpg4/bin +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_prog in grep ggrep; do + for ac_exec_ext in '' $ac_executable_extensions; do + ac_path_GREP="$as_dir/$ac_prog$ac_exec_ext" + as_fn_executable_p "$ac_path_GREP" || continue +# Check for GNU ac_path_GREP and select it if it is found. + # Check for GNU $ac_path_GREP +case `"$ac_path_GREP" --version 2>&1` in +*GNU*) + ac_cv_path_GREP="$ac_path_GREP" ac_path_GREP_found=:;; +*) + ac_count=0 + $as_echo_n 0123456789 >"conftest.in" + while : + do + cat "conftest.in" "conftest.in" >"conftest.tmp" + mv "conftest.tmp" "conftest.in" + cp "conftest.in" "conftest.nl" + $as_echo 'GREP' >> "conftest.nl" + "$ac_path_GREP" -e 'GREP$' -e '-(cannot match)-' < "conftest.nl" >"conftest.out" 2>/dev/null || break + diff "conftest.out" "conftest.nl" >/dev/null 2>&1 || break + as_fn_arith $ac_count + 1 && ac_count=$as_val + if test $ac_count -gt ${ac_path_GREP_max-0}; then + # Best one so far, save it but keep looking for a better one + ac_cv_path_GREP="$ac_path_GREP" + ac_path_GREP_max=$ac_count + fi + # 10*(2^10) chars as input seems more than enough + test $ac_count -gt 10 && break + done + rm -f conftest.in conftest.tmp conftest.nl conftest.out;; +esac + + $ac_path_GREP_found && break 3 + done + done + done +IFS=$as_save_IFS + if test -z "$ac_cv_path_GREP"; then + as_fn_error $? "no acceptable grep could be found in $PATH$PATH_SEPARATOR/usr/xpg4/bin" "$LINENO" 5 + fi +else + ac_cv_path_GREP=$GREP +fi + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_path_GREP" >&5 +$as_echo "$ac_cv_path_GREP" >&6; } + GREP="$ac_cv_path_GREP" + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for egrep" >&5 +$as_echo_n "checking for egrep... " >&6; } +if ${ac_cv_path_EGREP+:} false; then : + $as_echo_n "(cached) " >&6 +else + if echo a | $GREP -E '(a|b)' >/dev/null 2>&1 + then ac_cv_path_EGREP="$GREP -E" + else + if test -z "$EGREP"; then + ac_path_EGREP_found=false + # Loop through the user's path and test for each of PROGNAME-LIST + as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH$PATH_SEPARATOR/usr/xpg4/bin +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_prog in egrep; do + for ac_exec_ext in '' $ac_executable_extensions; do + ac_path_EGREP="$as_dir/$ac_prog$ac_exec_ext" + as_fn_executable_p "$ac_path_EGREP" || continue +# Check for GNU ac_path_EGREP and select it if it is found. + # Check for GNU $ac_path_EGREP +case `"$ac_path_EGREP" --version 2>&1` in +*GNU*) + ac_cv_path_EGREP="$ac_path_EGREP" ac_path_EGREP_found=:;; +*) + ac_count=0 + $as_echo_n 0123456789 >"conftest.in" + while : + do + cat "conftest.in" "conftest.in" >"conftest.tmp" + mv "conftest.tmp" "conftest.in" + cp "conftest.in" "conftest.nl" + $as_echo 'EGREP' >> "conftest.nl" + "$ac_path_EGREP" 'EGREP$' < "conftest.nl" >"conftest.out" 2>/dev/null || break + diff "conftest.out" "conftest.nl" >/dev/null 2>&1 || break + as_fn_arith $ac_count + 1 && ac_count=$as_val + if test $ac_count -gt ${ac_path_EGREP_max-0}; then + # Best one so far, save it but keep looking for a better one + ac_cv_path_EGREP="$ac_path_EGREP" + ac_path_EGREP_max=$ac_count + fi + # 10*(2^10) chars as input seems more than enough + test $ac_count -gt 10 && break + done + rm -f conftest.in conftest.tmp conftest.nl conftest.out;; +esac + + $ac_path_EGREP_found && break 3 + done + done + done +IFS=$as_save_IFS + if test -z "$ac_cv_path_EGREP"; then + as_fn_error $? "no acceptable egrep could be found in $PATH$PATH_SEPARATOR/usr/xpg4/bin" "$LINENO" 5 + fi +else + ac_cv_path_EGREP=$EGREP +fi + + fi +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_path_EGREP" >&5 +$as_echo "$ac_cv_path_EGREP" >&6; } + EGREP="$ac_cv_path_EGREP" + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for ANSI C header files" >&5 +$as_echo_n "checking for ANSI C header files... " >&6; } +if ${ac_cv_header_stdc+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +#include +#include +#include + +int +main () +{ + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_header_stdc=yes +else + ac_cv_header_stdc=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + +if test $ac_cv_header_stdc = yes; then + # SunOS 4.x string.h does not declare mem*, contrary to ANSI. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include + +_ACEOF +if (eval "$ac_cpp conftest.$ac_ext") 2>&5 | + $EGREP "memchr" >/dev/null 2>&1; then : + +else + ac_cv_header_stdc=no +fi +rm -f conftest* + +fi + +if test $ac_cv_header_stdc = yes; then + # ISC 2.0.2 stdlib.h does not declare free, contrary to ANSI. + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include + +_ACEOF +if (eval "$ac_cpp conftest.$ac_ext") 2>&5 | + $EGREP "free" >/dev/null 2>&1; then : + +else + ac_cv_header_stdc=no +fi +rm -f conftest* + +fi + +if test $ac_cv_header_stdc = yes; then + # /bin/cc in Irix-4.0.5 gets non-ANSI ctype macros unless using -ansi. + if test "$cross_compiling" = yes; then : + : +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#include +#include +#if ((' ' & 0x0FF) == 0x020) +# define ISLOWER(c) ('a' <= (c) && (c) <= 'z') +# define TOUPPER(c) (ISLOWER(c) ? 'A' + ((c) - 'a') : (c)) +#else +# define ISLOWER(c) \ + (('a' <= (c) && (c) <= 'i') \ + || ('j' <= (c) && (c) <= 'r') \ + || ('s' <= (c) && (c) <= 'z')) +# define TOUPPER(c) (ISLOWER(c) ? ((c) | 0x40) : (c)) +#endif + +#define XOR(e, f) (((e) && !(f)) || (!(e) && (f))) +int +main () +{ + int i; + for (i = 0; i < 256; i++) + if (XOR (islower (i), ISLOWER (i)) + || toupper (i) != TOUPPER (i)) + return 2; + return 0; +} +_ACEOF +if ac_fn_cxx_try_run "$LINENO"; then : + +else + ac_cv_header_stdc=no +fi +rm -f core *.core core.conftest.* gmon.out bb.out conftest$ac_exeext \ + conftest.$ac_objext conftest.beam conftest.$ac_ext +fi + +fi +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_header_stdc" >&5 +$as_echo "$ac_cv_header_stdc" >&6; } +if test $ac_cv_header_stdc = yes; then + +$as_echo "#define STDC_HEADERS 1" >>confdefs.h + +fi + +# On IRIX 5.3, sys/types and inttypes.h are conflicting. +for ac_header in sys/types.h sys/stat.h stdlib.h string.h memory.h strings.h \ + inttypes.h stdint.h unistd.h +do : + as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` +ac_fn_cxx_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default +" +if eval test \"x\$"$as_ac_Header"\" = x"yes"; then : + cat >>confdefs.h <<_ACEOF +#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1 +_ACEOF + +fi + +done + + + + + +# Check whether --with-cuda was given. +if test "${with_cuda+set}" = set; then : + withval=$with_cuda; +fi + + +# Check whether --with-cuda-arch was given. +if test "${with_cuda_arch+set}" = set; then : + withval=$with_cuda_arch; +else + with_cuda_arch=default +fi + + +if test x$with_cuda != xno; then + if test x$with_cuda != x; then + cuda_prefix=$with_cuda + # Extract the first word of "nvcc", so it can be a program name with args. +set dummy nvcc; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_NVCC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$NVCC"; then + ac_cv_prog_NVCC="$NVCC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $cuda_prefix/bin +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_NVCC="$cuda_prefix/bin/nvcc" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +NVCC=$ac_cv_prog_NVCC +if test -n "$NVCC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $NVCC" >&5 +$as_echo "$NVCC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + NVCC_CPPFLAGS="-I$cuda_prefix/include" + NVCC_LDFLAGS="-L$cuda_prefix/lib64" + else + for ac_prog in nvcc /usr/local/cuda/bin/nvcc +do + # Extract the first word of "$ac_prog", so it can be a program name with args. +set dummy $ac_prog; ac_word=$2 +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $ac_word" >&5 +$as_echo_n "checking for $ac_word... " >&6; } +if ${ac_cv_prog_NVCC+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test -n "$NVCC"; then + ac_cv_prog_NVCC="$NVCC" # Let the user override the test. +else +as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + for ac_exec_ext in '' $ac_executable_extensions; do + if as_fn_executable_p "$as_dir/$ac_word$ac_exec_ext"; then + ac_cv_prog_NVCC="$ac_prog" + $as_echo "$as_me:${as_lineno-$LINENO}: found $as_dir/$ac_word$ac_exec_ext" >&5 + break 2 + fi +done + done +IFS=$as_save_IFS + +fi +fi +NVCC=$ac_cv_prog_NVCC +if test -n "$NVCC"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $NVCC" >&5 +$as_echo "$NVCC" >&6; } +else + { $as_echo "$as_me:${as_lineno-$LINENO}: result: no" >&5 +$as_echo "no" >&6; } +fi + + + test -n "$NVCC" && break +done + + if test "x$NVCC" == "x/usr/local/cuda/bin/nvcc"; then + cuda_prefix="/usr/local/cuda" + NVCC_CPPFLAGS="-I$cuda_prefix/include" + NVCC_LDFLAGS="-L$cuda_prefix/lib64" + else + cuda_prefix="" + NVCC_CPPFLAGS="" + NVCC_LDFLAGS="" + fi + fi + + if test "x$NVCC" == x; then + as_fn_error $? "Couldn't find nvcc" "$LINENO" 5 + fi + + if test "x$with_cuda_arch" == xdefault; then + with_cuda_arch="-gencode=arch=compute_35,code=compute_35 \\ +-gencode=arch=compute_52,code=compute_52 \\ +-gencode=arch=compute_60,code=compute_60 \\ +-gencode=arch=compute_61,code=compute_61 \\ +-gencode=arch=compute_70,code=compute_70 \\ +-gencode=arch=compute_75,code=compute_75" + fi + + fa_save_CPPFLAGS="$CPPFLAGS" + fa_save_LDFLAGS="$LDFLAGS" + fa_save_LIBS="$LIBS" + + CPPFLAGS="$NVCC_CPPFLAGS $CPPFLAGS" + LDFLAGS="$NVCC_LDFLAGS $LDFLAGS" + + ac_fn_cxx_check_header_mongrel "$LINENO" "cuda.h" "ac_cv_header_cuda_h" "$ac_includes_default" +if test "x$ac_cv_header_cuda_h" = xyes; then : + +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "Couldn't find cuda.h +See \`config.log' for more details" "$LINENO" 5; } +fi + + + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for cublasAlloc in -lcublas" >&5 +$as_echo_n "checking for cublasAlloc in -lcublas... " >&6; } +if ${ac_cv_lib_cublas_cublasAlloc+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lcublas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char cublasAlloc (); +int +main () +{ +return cublasAlloc (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_lib_cublas_cublasAlloc=yes +else + ac_cv_lib_cublas_cublasAlloc=no +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_cublas_cublasAlloc" >&5 +$as_echo "$ac_cv_lib_cublas_cublasAlloc" >&6; } +if test "x$ac_cv_lib_cublas_cublasAlloc" = xyes; then : + cat >>confdefs.h <<_ACEOF +#define HAVE_LIBCUBLAS 1 +_ACEOF + + LIBS="-lcublas $LIBS" + +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "Couldn't find libcublas +See \`config.log' for more details" "$LINENO" 5; } +fi + + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for cudaSetDevice in -lcudart" >&5 +$as_echo_n "checking for cudaSetDevice in -lcudart... " >&6; } +if ${ac_cv_lib_cudart_cudaSetDevice+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lcudart $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char cudaSetDevice (); +int +main () +{ +return cudaSetDevice (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_lib_cudart_cudaSetDevice=yes +else + ac_cv_lib_cudart_cudaSetDevice=no +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_cudart_cudaSetDevice" >&5 +$as_echo "$ac_cv_lib_cudart_cudaSetDevice" >&6; } +if test "x$ac_cv_lib_cudart_cudaSetDevice" = xyes; then : + cat >>confdefs.h <<_ACEOF +#define HAVE_LIBCUDART 1 +_ACEOF + + LIBS="-lcudart $LIBS" + +else + { { $as_echo "$as_me:${as_lineno-$LINENO}: error: in \`$ac_pwd':" >&5 +$as_echo "$as_me: error: in \`$ac_pwd':" >&2;} +as_fn_error $? "Couldn't find libcudart +See \`config.log' for more details" "$LINENO" 5; } +fi + + + NVCC_LIBS="$LIBS" + NVCC_CPPFLAGS="$CPPFLAGS" + NVCC_LDFLAGS="$LDFLAGS" + CPPFLAGS="$fa_save_CPPFLAGS" + LDFLAGS="$fa_save_LDFLAGS" + LIBS="$fa_save_LIBS" +fi + + + + + +CUDA_PREFIX=$cuda_prefix + +CUDA_ARCH=$with_cuda_arch + + + + +# Checks for header files. +for ac_header in float.h limits.h stddef.h stdint.h stdlib.h string.h sys/time.h unistd.h +do : + as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` +ac_fn_cxx_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" +if eval test \"x\$"$as_ac_Header"\" = x"yes"; then : + cat >>confdefs.h <<_ACEOF +#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1 +_ACEOF + +fi + +done + + +# Checks for typedefs, structures, and compiler characteristics. +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for stdbool.h that conforms to C99" >&5 +$as_echo_n "checking for stdbool.h that conforms to C99... " >&6; } +if ${ac_cv_header_stdbool_h+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + + #include + #ifndef bool + "error: bool is not defined" + #endif + #ifndef false + "error: false is not defined" + #endif + #if false + "error: false is not 0" + #endif + #ifndef true + "error: true is not defined" + #endif + #if true != 1 + "error: true is not 1" + #endif + #ifndef __bool_true_false_are_defined + "error: __bool_true_false_are_defined is not defined" + #endif + + struct s { _Bool s: 1; _Bool t; } s; + + char a[true == 1 ? 1 : -1]; + char b[false == 0 ? 1 : -1]; + char c[__bool_true_false_are_defined == 1 ? 1 : -1]; + char d[(bool) 0.5 == true ? 1 : -1]; + /* See body of main program for 'e'. */ + char f[(_Bool) 0.0 == false ? 1 : -1]; + char g[true]; + char h[sizeof (_Bool)]; + char i[sizeof s.t]; + enum { j = false, k = true, l = false * true, m = true * 256 }; + /* The following fails for + HP aC++/ANSI C B3910B A.05.55 [Dec 04 2003]. */ + _Bool n[m]; + char o[sizeof n == m * sizeof n[0] ? 1 : -1]; + char p[-1 - (_Bool) 0 < 0 && -1 - (bool) 0 < 0 ? 1 : -1]; + /* Catch a bug in an HP-UX C compiler. See + http://gcc.gnu.org/ml/gcc-patches/2003-12/msg02303.html + http://lists.gnu.org/archive/html/bug-coreutils/2005-11/msg00161.html + */ + _Bool q = true; + _Bool *pq = &q; + +int +main () +{ + + bool e = &s; + *pq |= q; + *pq |= ! q; + /* Refer to every declared value, to avoid compiler optimizations. */ + return (!a + !b + !c + !d + !e + !f + !g + !h + !i + !!j + !k + !!l + + !m + !n + !o + !p + !q + !pq); + + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_header_stdbool_h=yes +else + ac_cv_header_stdbool_h=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_header_stdbool_h" >&5 +$as_echo "$ac_cv_header_stdbool_h" >&6; } + ac_fn_cxx_check_type "$LINENO" "_Bool" "ac_cv_type__Bool" "$ac_includes_default" +if test "x$ac_cv_type__Bool" = xyes; then : + +cat >>confdefs.h <<_ACEOF +#define HAVE__BOOL 1 +_ACEOF + + +fi + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for inline" >&5 +$as_echo_n "checking for inline... " >&6; } +if ${ac_cv_c_inline+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_cv_c_inline=no +for ac_kw in inline __inline__ __inline; do + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#ifndef __cplusplus +typedef int foo_t; +static $ac_kw foo_t static_foo () {return 0; } +$ac_kw foo_t foo () {return 0; } +#endif + +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_c_inline=$ac_kw +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + test "$ac_cv_c_inline" != no && break +done + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_c_inline" >&5 +$as_echo "$ac_cv_c_inline" >&6; } + +case $ac_cv_c_inline in + inline | yes) ;; + *) + case $ac_cv_c_inline in + no) ac_val=;; + *) ac_val=$ac_cv_c_inline;; + esac + cat >>confdefs.h <<_ACEOF +#ifndef __cplusplus +#define inline $ac_val +#endif +_ACEOF + ;; +esac + +ac_fn_c_find_intX_t "$LINENO" "32" "ac_cv_c_int32_t" +case $ac_cv_c_int32_t in #( + no|yes) ;; #( + *) + +cat >>confdefs.h <<_ACEOF +#define int32_t $ac_cv_c_int32_t +_ACEOF +;; +esac + +ac_fn_c_find_intX_t "$LINENO" "64" "ac_cv_c_int64_t" +case $ac_cv_c_int64_t in #( + no|yes) ;; #( + *) + +cat >>confdefs.h <<_ACEOF +#define int64_t $ac_cv_c_int64_t +_ACEOF +;; +esac + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for C/C++ restrict keyword" >&5 +$as_echo_n "checking for C/C++ restrict keyword... " >&6; } +if ${ac_cv_c_restrict+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_cv_c_restrict=no + # The order here caters to the fact that C++ does not require restrict. + for ac_kw in __restrict __restrict__ _Restrict restrict; do + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +typedef int * int_ptr; + int foo (int_ptr $ac_kw ip) { + return ip[0]; + } +int +main () +{ +int s[1]; + int * $ac_kw t = s; + t[0] = 0; + return foo(t) + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_compile "$LINENO"; then : + ac_cv_c_restrict=$ac_kw +fi +rm -f core conftest.err conftest.$ac_objext conftest.$ac_ext + test "$ac_cv_c_restrict" != no && break + done + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_c_restrict" >&5 +$as_echo "$ac_cv_c_restrict" >&6; } + + case $ac_cv_c_restrict in + restrict) ;; + no) $as_echo "#define restrict /**/" >>confdefs.h + ;; + *) cat >>confdefs.h <<_ACEOF +#define restrict $ac_cv_c_restrict +_ACEOF + ;; + esac + +ac_fn_cxx_check_type "$LINENO" "size_t" "ac_cv_type_size_t" "$ac_includes_default" +if test "x$ac_cv_type_size_t" = xyes; then : + +else + +cat >>confdefs.h <<_ACEOF +#define size_t unsigned int +_ACEOF + +fi + +ac_fn_c_find_uintX_t "$LINENO" "16" "ac_cv_c_uint16_t" +case $ac_cv_c_uint16_t in #( + no|yes) ;; #( + *) + + +cat >>confdefs.h <<_ACEOF +#define uint16_t $ac_cv_c_uint16_t +_ACEOF +;; + esac + +ac_fn_c_find_uintX_t "$LINENO" "32" "ac_cv_c_uint32_t" +case $ac_cv_c_uint32_t in #( + no|yes) ;; #( + *) + +$as_echo "#define _UINT32_T 1" >>confdefs.h + + +cat >>confdefs.h <<_ACEOF +#define uint32_t $ac_cv_c_uint32_t +_ACEOF +;; + esac + +ac_fn_c_find_uintX_t "$LINENO" "64" "ac_cv_c_uint64_t" +case $ac_cv_c_uint64_t in #( + no|yes) ;; #( + *) + +$as_echo "#define _UINT64_T 1" >>confdefs.h + + +cat >>confdefs.h <<_ACEOF +#define uint64_t $ac_cv_c_uint64_t +_ACEOF +;; + esac + +ac_fn_c_find_uintX_t "$LINENO" "8" "ac_cv_c_uint8_t" +case $ac_cv_c_uint8_t in #( + no|yes) ;; #( + *) + +$as_echo "#define _UINT8_T 1" >>confdefs.h + + +cat >>confdefs.h <<_ACEOF +#define uint8_t $ac_cv_c_uint8_t +_ACEOF +;; + esac + + +# Checks for library functions. +for ac_header in stdlib.h +do : + ac_fn_cxx_check_header_mongrel "$LINENO" "stdlib.h" "ac_cv_header_stdlib_h" "$ac_includes_default" +if test "x$ac_cv_header_stdlib_h" = xyes; then : + cat >>confdefs.h <<_ACEOF +#define HAVE_STDLIB_H 1 +_ACEOF + +fi + +done + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for GNU libc compatible malloc" >&5 +$as_echo_n "checking for GNU libc compatible malloc... " >&6; } +if ${ac_cv_func_malloc_0_nonnull+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test "$cross_compiling" = yes; then : + ac_cv_func_malloc_0_nonnull=no +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +#if defined STDC_HEADERS || defined HAVE_STDLIB_H +# include +#else +char *malloc (); +#endif + +int +main () +{ +return ! malloc (0); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_run "$LINENO"; then : + ac_cv_func_malloc_0_nonnull=yes +else + ac_cv_func_malloc_0_nonnull=no +fi +rm -f core *.core core.conftest.* gmon.out bb.out conftest$ac_exeext \ + conftest.$ac_objext conftest.beam conftest.$ac_ext +fi + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_func_malloc_0_nonnull" >&5 +$as_echo "$ac_cv_func_malloc_0_nonnull" >&6; } +if test $ac_cv_func_malloc_0_nonnull = yes; then : + +$as_echo "#define HAVE_MALLOC 1" >>confdefs.h + +else + $as_echo "#define HAVE_MALLOC 0" >>confdefs.h + + case " $LIBOBJS " in + *" malloc.$ac_objext "* ) ;; + *) LIBOBJS="$LIBOBJS malloc.$ac_objext" + ;; +esac + + +$as_echo "#define malloc rpl_malloc" >>confdefs.h + +fi + + + + + + for ac_header in $ac_header_list +do : + as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` +ac_fn_cxx_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default +" +if eval test \"x\$"$as_ac_Header"\" = x"yes"; then : + cat >>confdefs.h <<_ACEOF +#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1 +_ACEOF + +fi + +done + + + + + + + + +for ac_func in getpagesize +do : + ac_fn_cxx_check_func "$LINENO" "getpagesize" "ac_cv_func_getpagesize" +if test "x$ac_cv_func_getpagesize" = xyes; then : + cat >>confdefs.h <<_ACEOF +#define HAVE_GETPAGESIZE 1 +_ACEOF + +fi +done + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for working mmap" >&5 +$as_echo_n "checking for working mmap... " >&6; } +if ${ac_cv_func_mmap_fixed_mapped+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test "$cross_compiling" = yes; then : + ac_cv_func_mmap_fixed_mapped=no +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ +$ac_includes_default +/* malloc might have been renamed as rpl_malloc. */ +#undef malloc + +/* Thanks to Mike Haertel and Jim Avera for this test. + Here is a matrix of mmap possibilities: + mmap private not fixed + mmap private fixed at somewhere currently unmapped + mmap private fixed at somewhere already mapped + mmap shared not fixed + mmap shared fixed at somewhere currently unmapped + mmap shared fixed at somewhere already mapped + For private mappings, we should verify that changes cannot be read() + back from the file, nor mmap's back from the file at a different + address. (There have been systems where private was not correctly + implemented like the infamous i386 svr4.0, and systems where the + VM page cache was not coherent with the file system buffer cache + like early versions of FreeBSD and possibly contemporary NetBSD.) + For shared mappings, we should conversely verify that changes get + propagated back to all the places they're supposed to be. + + Grep wants private fixed already mapped. + The main things grep needs to know about mmap are: + * does it exist and is it safe to write into the mmap'd area + * how to use it (BSD variants) */ + +#include +#include + +#if !defined STDC_HEADERS && !defined HAVE_STDLIB_H +char *malloc (); +#endif + +/* This mess was copied from the GNU getpagesize.h. */ +#ifndef HAVE_GETPAGESIZE +# ifdef _SC_PAGESIZE +# define getpagesize() sysconf(_SC_PAGESIZE) +# else /* no _SC_PAGESIZE */ +# ifdef HAVE_SYS_PARAM_H +# include +# ifdef EXEC_PAGESIZE +# define getpagesize() EXEC_PAGESIZE +# else /* no EXEC_PAGESIZE */ +# ifdef NBPG +# define getpagesize() NBPG * CLSIZE +# ifndef CLSIZE +# define CLSIZE 1 +# endif /* no CLSIZE */ +# else /* no NBPG */ +# ifdef NBPC +# define getpagesize() NBPC +# else /* no NBPC */ +# ifdef PAGESIZE +# define getpagesize() PAGESIZE +# endif /* PAGESIZE */ +# endif /* no NBPC */ +# endif /* no NBPG */ +# endif /* no EXEC_PAGESIZE */ +# else /* no HAVE_SYS_PARAM_H */ +# define getpagesize() 8192 /* punt totally */ +# endif /* no HAVE_SYS_PARAM_H */ +# endif /* no _SC_PAGESIZE */ + +#endif /* no HAVE_GETPAGESIZE */ + +int +main () +{ + char *data, *data2, *data3; + const char *cdata2; + int i, pagesize; + int fd, fd2; + + pagesize = getpagesize (); + + /* First, make a file with some known garbage in it. */ + data = (char *) malloc (pagesize); + if (!data) + return 1; + for (i = 0; i < pagesize; ++i) + *(data + i) = rand (); + umask (0); + fd = creat ("conftest.mmap", 0600); + if (fd < 0) + return 2; + if (write (fd, data, pagesize) != pagesize) + return 3; + close (fd); + + /* Next, check that the tail of a page is zero-filled. File must have + non-zero length, otherwise we risk SIGBUS for entire page. */ + fd2 = open ("conftest.txt", O_RDWR | O_CREAT | O_TRUNC, 0600); + if (fd2 < 0) + return 4; + cdata2 = ""; + if (write (fd2, cdata2, 1) != 1) + return 5; + data2 = (char *) mmap (0, pagesize, PROT_READ | PROT_WRITE, MAP_SHARED, fd2, 0L); + if (data2 == MAP_FAILED) + return 6; + for (i = 0; i < pagesize; ++i) + if (*(data2 + i)) + return 7; + close (fd2); + if (munmap (data2, pagesize)) + return 8; + + /* Next, try to mmap the file at a fixed address which already has + something else allocated at it. If we can, also make sure that + we see the same garbage. */ + fd = open ("conftest.mmap", O_RDWR); + if (fd < 0) + return 9; + if (data2 != mmap (data2, pagesize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_FIXED, fd, 0L)) + return 10; + for (i = 0; i < pagesize; ++i) + if (*(data + i) != *(data2 + i)) + return 11; + + /* Finally, make sure that changes to the mapped area do not + percolate back to the file as seen by read(). (This is a bug on + some variants of i386 svr4.0.) */ + for (i = 0; i < pagesize; ++i) + *(data2 + i) = *(data2 + i) + 1; + data3 = (char *) malloc (pagesize); + if (!data3) + return 12; + if (read (fd, data3, pagesize) != pagesize) + return 13; + for (i = 0; i < pagesize; ++i) + if (*(data + i) != *(data3 + i)) + return 14; + close (fd); + return 0; +} +_ACEOF +if ac_fn_cxx_try_run "$LINENO"; then : + ac_cv_func_mmap_fixed_mapped=yes +else + ac_cv_func_mmap_fixed_mapped=no +fi +rm -f core *.core core.conftest.* gmon.out bb.out conftest$ac_exeext \ + conftest.$ac_objext conftest.beam conftest.$ac_ext +fi + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_func_mmap_fixed_mapped" >&5 +$as_echo "$ac_cv_func_mmap_fixed_mapped" >&6; } +if test $ac_cv_func_mmap_fixed_mapped = yes; then + +$as_echo "#define HAVE_MMAP 1" >>confdefs.h + +fi +rm -f conftest.mmap conftest.txt + +for ac_func in clock_gettime floor gettimeofday memmove memset munmap pow sqrt strerror strstr +do : + as_ac_var=`$as_echo "ac_cv_func_$ac_func" | $as_tr_sh` +ac_fn_cxx_check_func "$LINENO" "$ac_func" "$as_ac_var" +if eval test \"x\$"$as_ac_var"\" = x"yes"; then : + cat >>confdefs.h <<_ACEOF +#define `$as_echo "HAVE_$ac_func" | $as_tr_cpp` 1 +_ACEOF + +fi +done + + + + OPENMP_CXXFLAGS= + # Check whether --enable-openmp was given. +if test "${enable_openmp+set}" = set; then : + enableval=$enable_openmp; +fi + + if test "$enable_openmp" != no; then + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $CXX option to support OpenMP" >&5 +$as_echo_n "checking for $CXX option to support OpenMP... " >&6; } +if ${ac_cv_prog_cxx_openmp+:} false; then : + $as_echo_n "(cached) " >&6 +else + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +#ifndef _OPENMP + choke me +#endif +#include +int main () { return omp_get_num_threads (); } + +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_prog_cxx_openmp='none needed' +else + ac_cv_prog_cxx_openmp='unsupported' + for ac_option in -fopenmp -xopenmp -openmp -mp -omp -qsmp=omp -homp \ + -Popenmp --openmp; do + ac_save_CXXFLAGS=$CXXFLAGS + CXXFLAGS="$CXXFLAGS $ac_option" + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +#ifndef _OPENMP + choke me +#endif +#include +int main () { return omp_get_num_threads (); } + +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_prog_cxx_openmp=$ac_option +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext + CXXFLAGS=$ac_save_CXXFLAGS + if test "$ac_cv_prog_cxx_openmp" != unsupported; then + break + fi + done +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_prog_cxx_openmp" >&5 +$as_echo "$ac_cv_prog_cxx_openmp" >&6; } + case $ac_cv_prog_cxx_openmp in #( + "none needed" | unsupported) + ;; #( + *) + OPENMP_CXXFLAGS=$ac_cv_prog_cxx_openmp ;; + esac + fi + + + +# Make sure we can run config.sub. +$SHELL "$ac_aux_dir/config.sub" sun4 >/dev/null 2>&1 || + as_fn_error $? "cannot run $SHELL $ac_aux_dir/config.sub" "$LINENO" 5 + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking build system type" >&5 +$as_echo_n "checking build system type... " >&6; } +if ${ac_cv_build+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_build_alias=$build_alias +test "x$ac_build_alias" = x && + ac_build_alias=`$SHELL "$ac_aux_dir/config.guess"` +test "x$ac_build_alias" = x && + as_fn_error $? "cannot guess build type; you must specify one" "$LINENO" 5 +ac_cv_build=`$SHELL "$ac_aux_dir/config.sub" $ac_build_alias` || + as_fn_error $? "$SHELL $ac_aux_dir/config.sub $ac_build_alias failed" "$LINENO" 5 + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_build" >&5 +$as_echo "$ac_cv_build" >&6; } +case $ac_cv_build in +*-*-*) ;; +*) as_fn_error $? "invalid value of canonical build" "$LINENO" 5;; +esac +build=$ac_cv_build +ac_save_IFS=$IFS; IFS='-' +set x $ac_cv_build +shift +build_cpu=$1 +build_vendor=$2 +shift; shift +# Remember, the first character of IFS is used to create $*, +# except with old shells: +build_os=$* +IFS=$ac_save_IFS +case $build_os in *\ *) build_os=`echo "$build_os" | sed 's/ /-/g'`;; esac + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking host system type" >&5 +$as_echo_n "checking host system type... " >&6; } +if ${ac_cv_host+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test "x$host_alias" = x; then + ac_cv_host=$ac_cv_build +else + ac_cv_host=`$SHELL "$ac_aux_dir/config.sub" $host_alias` || + as_fn_error $? "$SHELL $ac_aux_dir/config.sub $host_alias failed" "$LINENO" 5 +fi + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_host" >&5 +$as_echo "$ac_cv_host" >&6; } +case $ac_cv_host in +*-*-*) ;; +*) as_fn_error $? "invalid value of canonical host" "$LINENO" 5;; +esac +host=$ac_cv_host +ac_save_IFS=$IFS; IFS='-' +set x $ac_cv_host +shift +host_cpu=$1 +host_vendor=$2 +shift; shift +# Remember, the first character of IFS is used to create $*, +# except with old shells: +host_os=$* +IFS=$ac_save_IFS +case $host_os in *\ *) host_os=`echo "$host_os" | sed 's/ /-/g'`;; esac + + + + +# AC_REQUIRE([AC_F77_LIBRARY_LDFLAGS]) + +ax_blas_ok=no + + +# Check whether --with-blas was given. +if test "${with_blas+set}" = set; then : + withval=$with_blas; +fi + +case $with_blas in + yes | "") ;; + no) ax_blas_ok=disable ;; + -* | */* | *.a | *.so | *.so.* | *.o) BLAS_LIBS="$with_blas" ;; + *) BLAS_LIBS="-l$with_blas" ;; +esac + +OPENMP_LDFLAGS="$OPENMP_CXXFLAGS" + +# Get fortran linker names of BLAS functions to check for. +# AC_F77_FUNC(sgemm) +# AC_F77_FUNC(dgemm) +sgemm=sgemm_ +dgemm=dgemm_ + +ax_blas_save_LIBS="$LIBS" +LIBS="$LIBS $FLIBS" + +# First, check BLAS_LIBS environment variable +if test $ax_blas_ok = no; then +if test "x$BLAS_LIBS" != x; then + save_LIBS="$LIBS"; LIBS="$BLAS_LIBS $LIBS" + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in $BLAS_LIBS" >&5 +$as_echo_n "checking for $sgemm in $BLAS_LIBS... " >&6; } + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ax_blas_ok=yes +else + BLAS_LIBS="" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ax_blas_ok" >&5 +$as_echo "$ax_blas_ok" >&6; } + LIBS="$save_LIBS" +fi +fi + +# BLAS linked to by default? (happens on some supercomputers) +if test $ax_blas_ok = no; then + save_LIBS="$LIBS"; LIBS="$LIBS" + { $as_echo "$as_me:${as_lineno-$LINENO}: checking if $sgemm is being linked in already" >&5 +$as_echo_n "checking if $sgemm is being linked in already... " >&6; } + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ax_blas_ok=yes +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ax_blas_ok" >&5 +$as_echo "$ax_blas_ok" >&6; } + LIBS="$save_LIBS" +fi + +# BLAS in Intel MKL library? +if test $ax_blas_ok = no; then + case $host_os in + darwin*) + as_ac_Lib=`$as_echo "ac_cv_lib_mkl_intel_lp64_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lmkl_intel_lp64" >&5 +$as_echo_n "checking for $sgemm in -lmkl_intel_lp64... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lmkl_intel_lp64 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread"; OPENMP_LDFLAGS="" +fi + + ;; + *) + if test $host_cpu = x86_64; then + as_ac_Lib=`$as_echo "ac_cv_lib_mkl_intel_lp64_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lmkl_intel_lp64" >&5 +$as_echo_n "checking for $sgemm in -lmkl_intel_lp64... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lmkl_intel_lp64 -lmkl_intel_lp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel_lp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl" +fi + + elif test $host_cpu = i686; then + as_ac_Lib=`$as_echo "ac_cv_lib_mkl_intel_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lmkl_intel" >&5 +$as_echo_n "checking for $sgemm in -lmkl_intel... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lmkl_intel -lmkl_intel -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-lmkl_intel -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -lm -ldl" +fi + + fi + ;; + esac +fi +# Old versions of MKL +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_mkl_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lmkl" >&5 +$as_echo_n "checking for $sgemm in -lmkl... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lmkl -lguide -lpthread $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-lmkl -lguide -lpthread" +fi + +fi + +# BLAS in OpenBLAS library? (http://xianyi.github.com/OpenBLAS/) +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_openblas_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lopenblas" >&5 +$as_echo_n "checking for $sgemm in -lopenblas... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lopenblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes + BLAS_LIBS="-lopenblas" +fi + +fi + +# BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) +if test $ax_blas_ok = no; then + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for ATL_xerbla in -latlas" >&5 +$as_echo_n "checking for ATL_xerbla in -latlas... " >&6; } +if ${ac_cv_lib_atlas_ATL_xerbla+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-latlas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char ATL_xerbla (); +int +main () +{ +return ATL_xerbla (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_lib_atlas_ATL_xerbla=yes +else + ac_cv_lib_atlas_ATL_xerbla=no +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_atlas_ATL_xerbla" >&5 +$as_echo "$ac_cv_lib_atlas_ATL_xerbla" >&6; } +if test "x$ac_cv_lib_atlas_ATL_xerbla" = xyes; then : + as_ac_Lib=`$as_echo "ac_cv_lib_f77blas_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lf77blas" >&5 +$as_echo_n "checking for $sgemm in -lf77blas... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lf77blas -latlas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for cblas_dgemm in -lcblas" >&5 +$as_echo_n "checking for cblas_dgemm in -lcblas... " >&6; } +if ${ac_cv_lib_cblas_cblas_dgemm+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lcblas -lf77blas -latlas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char cblas_dgemm (); +int +main () +{ +return cblas_dgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_lib_cblas_cblas_dgemm=yes +else + ac_cv_lib_cblas_cblas_dgemm=no +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_cblas_cblas_dgemm" >&5 +$as_echo "$ac_cv_lib_cblas_cblas_dgemm" >&6; } +if test "x$ac_cv_lib_cblas_cblas_dgemm" = xyes; then : + ax_blas_ok=yes + BLAS_LIBS="-lcblas -lf77blas -latlas" +fi + +fi + +fi + +fi + +# BLAS in PhiPACK libraries? (requires generic BLAS lib, too) +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_blas_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lblas" >&5 +$as_echo_n "checking for $sgemm in -lblas... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + as_ac_Lib=`$as_echo "ac_cv_lib_dgemm_$dgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $dgemm in -ldgemm" >&5 +$as_echo_n "checking for $dgemm in -ldgemm... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-ldgemm -lblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $dgemm (); +int +main () +{ +return $dgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + as_ac_Lib=`$as_echo "ac_cv_lib_sgemm_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lsgemm" >&5 +$as_echo_n "checking for $sgemm in -lsgemm... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lsgemm -lblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes; BLAS_LIBS="-lsgemm -ldgemm -lblas" +fi + +fi + +fi + +fi + +# BLAS in Apple vecLib library? +if test $ax_blas_ok = no; then + save_LIBS="$LIBS"; LIBS="-framework vecLib $LIBS" + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -framework vecLib" >&5 +$as_echo_n "checking for $sgemm in -framework vecLib... " >&6; } + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ax_blas_ok=yes;BLAS_LIBS="-framework vecLib" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ax_blas_ok" >&5 +$as_echo "$ax_blas_ok" >&6; } + LIBS="$save_LIBS" +fi + +# BLAS in Alpha CXML library? +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_cxml_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lcxml" >&5 +$as_echo_n "checking for $sgemm in -lcxml... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lcxml $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-lcxml" +fi + +fi + +# BLAS in Alpha DXML library? (now called CXML, see above) +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_dxml_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -ldxml" >&5 +$as_echo_n "checking for $sgemm in -ldxml... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-ldxml $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes;BLAS_LIBS="-ldxml" +fi + +fi + +# BLAS in Sun Performance library? +if test $ax_blas_ok = no; then + if test "x$GCC" != xyes; then # only works with Sun CC + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for acosp in -lsunmath" >&5 +$as_echo_n "checking for acosp in -lsunmath... " >&6; } +if ${ac_cv_lib_sunmath_acosp+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lsunmath $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char acosp (); +int +main () +{ +return acosp (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ac_cv_lib_sunmath_acosp=yes +else + ac_cv_lib_sunmath_acosp=no +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_sunmath_acosp" >&5 +$as_echo "$ac_cv_lib_sunmath_acosp" >&6; } +if test "x$ac_cv_lib_sunmath_acosp" = xyes; then : + as_ac_Lib=`$as_echo "ac_cv_lib_sunperf_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lsunperf" >&5 +$as_echo_n "checking for $sgemm in -lsunperf... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lsunperf -lsunmath $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + BLAS_LIBS="-xlic_lib=sunperf -lsunmath" + ax_blas_ok=yes +fi + +fi + + fi +fi + +# BLAS in SCSL library? (SGI/Cray Scientific Library) +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_scs_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lscs" >&5 +$as_echo_n "checking for $sgemm in -lscs... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lscs $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes; BLAS_LIBS="-lscs" +fi + +fi + +# BLAS in SGIMATH library? +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_complib.sgimath_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lcomplib.sgimath" >&5 +$as_echo_n "checking for $sgemm in -lcomplib.sgimath... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lcomplib.sgimath $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes; BLAS_LIBS="-lcomplib.sgimath" +fi + +fi + +# BLAS in IBM ESSL library? (requires generic BLAS lib, too) +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_blas_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lblas" >&5 +$as_echo_n "checking for $sgemm in -lblas... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + as_ac_Lib=`$as_echo "ac_cv_lib_essl_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lessl" >&5 +$as_echo_n "checking for $sgemm in -lessl... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lessl -lblas $FLIBS $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes; BLAS_LIBS="-lessl -lblas" +fi + +fi + +fi + +# Generic BLAS library? +if test $ax_blas_ok = no; then + as_ac_Lib=`$as_echo "ac_cv_lib_blas_$sgemm" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $sgemm in -lblas" >&5 +$as_echo_n "checking for $sgemm in -lblas... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-lblas $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $sgemm (); +int +main () +{ +return $sgemm (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_blas_ok=yes; BLAS_LIBS="-lblas" +fi + +fi + + + + +LIBS="$ax_blas_save_LIBS" + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$ax_blas_ok" = xyes; then + +$as_echo "#define HAVE_BLAS 1" >>confdefs.h + + : +else + ax_blas_ok=no + +fi + +if test "x$ax_blas_ok" == "xno"; then + as_fn_error $? "An implementation of BLAS is required but none was found." "$LINENO" 5 +fi + + + +ax_lapack_ok=no + + +# Check whether --with-lapack was given. +if test "${with_lapack+set}" = set; then : + withval=$with_lapack; +fi + +case $with_lapack in + yes | "") ;; + no) ax_lapack_ok=disable ;; + -* | */* | *.a | *.so | *.so.* | *.o) LAPACK_LIBS="$with_lapack" ;; + *) LAPACK_LIBS="-l$with_lapack" ;; +esac + +# Get fortran linker name of LAPACK function to check for. +# AC_F77_FUNC(cheev) +cheev=cheev_ + +# We cannot use LAPACK if BLAS is not found +if test "x$ax_blas_ok" != xyes; then + ax_lapack_ok=noblas + LAPACK_LIBS="" +fi + +# First, check LAPACK_LIBS environment variable +if test "x$LAPACK_LIBS" != x; then + save_LIBS="$LIBS"; LIBS="$LAPACK_LIBS $BLAS_LIBS $LIBS $FLIBS" + { $as_echo "$as_me:${as_lineno-$LINENO}: checking for $cheev in $LAPACK_LIBS" >&5 +$as_echo_n "checking for $cheev in $LAPACK_LIBS... " >&6; } + cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $cheev (); +int +main () +{ +return $cheev (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + ax_lapack_ok=yes +else + LAPACK_LIBS="" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ax_lapack_ok" >&5 +$as_echo "$ax_lapack_ok" >&6; } + LIBS="$save_LIBS" + if test $ax_lapack_ok = no; then + LAPACK_LIBS="" + fi +fi + +# LAPACK linked to by default? (is sometimes included in BLAS lib) +if test $ax_lapack_ok = no; then + save_LIBS="$LIBS"; LIBS="$LIBS $BLAS_LIBS $FLIBS" + as_ac_var=`$as_echo "ac_cv_func_$cheev" | $as_tr_sh` +ac_fn_cxx_check_func "$LINENO" "$cheev" "$as_ac_var" +if eval test \"x\$"$as_ac_var"\" = x"yes"; then : + ax_lapack_ok=yes +fi + + LIBS="$save_LIBS" +fi + +# Generic LAPACK library? +for lapack in lapack lapack_rs6k; do + if test $ax_lapack_ok = no; then + save_LIBS="$LIBS"; LIBS="$BLAS_LIBS $LIBS" + as_ac_Lib=`$as_echo "ac_cv_lib_$lapack''_$cheev" | $as_tr_sh` +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for $cheev in -l$lapack" >&5 +$as_echo_n "checking for $cheev in -l$lapack... " >&6; } +if eval \${$as_ac_Lib+:} false; then : + $as_echo_n "(cached) " >&6 +else + ac_check_lib_save_LIBS=$LIBS +LIBS="-l$lapack $FLIBS $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +#ifdef __cplusplus +extern "C" +#endif +char $cheev (); +int +main () +{ +return $cheev (); + ; + return 0; +} +_ACEOF +if ac_fn_cxx_try_link "$LINENO"; then : + eval "$as_ac_Lib=yes" +else + eval "$as_ac_Lib=no" +fi +rm -f core conftest.err conftest.$ac_objext \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +fi +eval ac_res=\$$as_ac_Lib + { $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_res" >&5 +$as_echo "$ac_res" >&6; } +if eval test \"x\$"$as_ac_Lib"\" = x"yes"; then : + ax_lapack_ok=yes; LAPACK_LIBS="-l$lapack" +fi + + LIBS="$save_LIBS" + fi +done + + + +# Finally, execute ACTION-IF-FOUND/ACTION-IF-NOT-FOUND: +if test x"$ax_lapack_ok" = xyes; then + +$as_echo "#define HAVE_LAPACK 1" >>confdefs.h + + : +else + ax_lapack_ok=no + +fi + +if test "x$ax_lapack_ok" == "xno"; then + as_fn_error $? "An implementation of LAPACK is required but none was found." "$LINENO" 5 +fi + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking target system type" >&5 +$as_echo_n "checking target system type... " >&6; } +if ${ac_cv_target+:} false; then : + $as_echo_n "(cached) " >&6 +else + if test "x$target_alias" = x; then + ac_cv_target=$ac_cv_host +else + ac_cv_target=`$SHELL "$ac_aux_dir/config.sub" $target_alias` || + as_fn_error $? "$SHELL $ac_aux_dir/config.sub $target_alias failed" "$LINENO" 5 +fi + +fi +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $ac_cv_target" >&5 +$as_echo "$ac_cv_target" >&6; } +case $ac_cv_target in +*-*-*) ;; +*) as_fn_error $? "invalid value of canonical target" "$LINENO" 5;; +esac +target=$ac_cv_target +ac_save_IFS=$IFS; IFS='-' +set x $ac_cv_target +shift +target_cpu=$1 +target_vendor=$2 +shift; shift +# Remember, the first character of IFS is used to create $*, +# except with old shells: +target_os=$* +IFS=$ac_save_IFS +case $target_os in *\ *) target_os=`echo "$target_os" | sed 's/ /-/g'`;; esac + + +# The aliases save the names the user supplied, while $host etc. +# will get canonicalized. +test -n "$target_alias" && + test "$program_prefix$program_suffix$program_transform_name" = \ + NONENONEs,x,x, && + program_prefix=${target_alias}- + + +{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for cpu arch" >&5 +$as_echo_n "checking for cpu arch... " >&6; } + + + + case $target in + amd64-* | x86_64-*) + ARCH_CPUFLAGS="-mpopcnt -msse4" + ARCH_CXXFLAGS="-m64" + ;; + aarch64*-*) + ARCH_CPUFLAGS="-march=armv8.2-a" + ;; + *) ;; + esac + +{ $as_echo "$as_me:${as_lineno-$LINENO}: result: $target CPUFLAGS+=\"$ARCH_CPUFLAGS\" CXXFLAGS+=\"$ARCH_CXXFLAGS\"" >&5 +$as_echo "$target CPUFLAGS+=\"$ARCH_CPUFLAGS\" CXXFLAGS+=\"$ARCH_CXXFLAGS\"" >&6; } + + + + + + +ac_config_files="$ac_config_files makefile.inc" + +cat >confcache <<\_ACEOF +# This file is a shell script that caches the results of configure +# tests run on this system so they can be shared between configure +# scripts and configure runs, see configure's option --config-cache. +# It is not useful on other systems. If it contains results you don't +# want to keep, you may remove or edit it. +# +# config.status only pays attention to the cache file if you give it +# the --recheck option to rerun configure. +# +# `ac_cv_env_foo' variables (set or unset) will be overridden when +# loading this file, other *unset* `ac_cv_foo' will be assigned the +# following values. + +_ACEOF + +# The following way of writing the cache mishandles newlines in values, +# but we know of no workaround that is simple, portable, and efficient. +# So, we kill variables containing newlines. +# Ultrix sh set writes to stderr and can't be redirected directly, +# and sets the high bit in the cache file unless we assign to the vars. +( + for ac_var in `(set) 2>&1 | sed -n 's/^\([a-zA-Z_][a-zA-Z0-9_]*\)=.*/\1/p'`; do + eval ac_val=\$$ac_var + case $ac_val in #( + *${as_nl}*) + case $ac_var in #( + *_cv_*) { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: cache variable $ac_var contains a newline" >&5 +$as_echo "$as_me: WARNING: cache variable $ac_var contains a newline" >&2;} ;; + esac + case $ac_var in #( + _ | IFS | as_nl) ;; #( + BASH_ARGV | BASH_SOURCE) eval $ac_var= ;; #( + *) { eval $ac_var=; unset $ac_var;} ;; + esac ;; + esac + done + + (set) 2>&1 | + case $as_nl`(ac_space=' '; set) 2>&1` in #( + *${as_nl}ac_space=\ *) + # `set' does not quote correctly, so add quotes: double-quote + # substitution turns \\\\ into \\, and sed turns \\ into \. + sed -n \ + "s/'/'\\\\''/g; + s/^\\([_$as_cr_alnum]*_cv_[_$as_cr_alnum]*\\)=\\(.*\\)/\\1='\\2'/p" + ;; #( + *) + # `set' quotes correctly as required by POSIX, so do not add quotes. + sed -n "/^[_$as_cr_alnum]*_cv_[_$as_cr_alnum]*=/p" + ;; + esac | + sort +) | + sed ' + /^ac_cv_env_/b end + t clear + :clear + s/^\([^=]*\)=\(.*[{}].*\)$/test "${\1+set}" = set || &/ + t end + s/^\([^=]*\)=\(.*\)$/\1=${\1=\2}/ + :end' >>confcache +if diff "$cache_file" confcache >/dev/null 2>&1; then :; else + if test -w "$cache_file"; then + if test "x$cache_file" != "x/dev/null"; then + { $as_echo "$as_me:${as_lineno-$LINENO}: updating cache $cache_file" >&5 +$as_echo "$as_me: updating cache $cache_file" >&6;} + if test ! -f "$cache_file" || test -h "$cache_file"; then + cat confcache >"$cache_file" + else + case $cache_file in #( + */* | ?:*) + mv -f confcache "$cache_file"$$ && + mv -f "$cache_file"$$ "$cache_file" ;; #( + *) + mv -f confcache "$cache_file" ;; + esac + fi + fi + else + { $as_echo "$as_me:${as_lineno-$LINENO}: not updating unwritable cache $cache_file" >&5 +$as_echo "$as_me: not updating unwritable cache $cache_file" >&6;} + fi +fi +rm -f confcache + +test "x$prefix" = xNONE && prefix=$ac_default_prefix +# Let make expand exec_prefix. +test "x$exec_prefix" = xNONE && exec_prefix='${prefix}' + +# Transform confdefs.h into DEFS. +# Protect against shell expansion while executing Makefile rules. +# Protect against Makefile macro expansion. +# +# If the first sed substitution is executed (which looks for macros that +# take arguments), then branch to the quote section. Otherwise, +# look for a macro that doesn't take arguments. +ac_script=' +:mline +/\\$/{ + N + s,\\\n,, + b mline +} +t clear +:clear +s/^[ ]*#[ ]*define[ ][ ]*\([^ (][^ (]*([^)]*)\)[ ]*\(.*\)/-D\1=\2/g +t quote +s/^[ ]*#[ ]*define[ ][ ]*\([^ ][^ ]*\)[ ]*\(.*\)/-D\1=\2/g +t quote +b any +:quote +s/[ `~#$^&*(){}\\|;'\''"<>?]/\\&/g +s/\[/\\&/g +s/\]/\\&/g +s/\$/$$/g +H +:any +${ + g + s/^\n// + s/\n/ /g + p +} +' +DEFS=`sed -n "$ac_script" confdefs.h` + + +ac_libobjs= +ac_ltlibobjs= +U= +for ac_i in : $LIBOBJS; do test "x$ac_i" = x: && continue + # 1. Remove the extension, and $U if already installed. + ac_script='s/\$U\././;s/\.o$//;s/\.obj$//' + ac_i=`$as_echo "$ac_i" | sed "$ac_script"` + # 2. Prepend LIBOBJDIR. When used with automake>=1.10 LIBOBJDIR + # will be set to the directory where LIBOBJS objects are built. + as_fn_append ac_libobjs " \${LIBOBJDIR}$ac_i\$U.$ac_objext" + as_fn_append ac_ltlibobjs " \${LIBOBJDIR}$ac_i"'$U.lo' +done +LIBOBJS=$ac_libobjs + +LTLIBOBJS=$ac_ltlibobjs + + + +: "${CONFIG_STATUS=./config.status}" +ac_write_fail=0 +ac_clean_files_save=$ac_clean_files +ac_clean_files="$ac_clean_files $CONFIG_STATUS" +{ $as_echo "$as_me:${as_lineno-$LINENO}: creating $CONFIG_STATUS" >&5 +$as_echo "$as_me: creating $CONFIG_STATUS" >&6;} +as_write_fail=0 +cat >$CONFIG_STATUS <<_ASEOF || as_write_fail=1 +#! $SHELL +# Generated by $as_me. +# Run this file to recreate the current configuration. +# Compiler output produced by configure, useful for debugging +# configure, is in config.log if it exists. + +debug=false +ac_cs_recheck=false +ac_cs_silent=false + +SHELL=\${CONFIG_SHELL-$SHELL} +export SHELL +_ASEOF +cat >>$CONFIG_STATUS <<\_ASEOF || as_write_fail=1 +## -------------------- ## +## M4sh Initialization. ## +## -------------------- ## + +# Be more Bourne compatible +DUALCASE=1; export DUALCASE # for MKS sh +if test -n "${ZSH_VERSION+set}" && (emulate sh) >/dev/null 2>&1; then : + emulate sh + NULLCMD=: + # Pre-4.2 versions of Zsh do word splitting on ${1+"$@"}, which + # is contrary to our usage. Disable this feature. + alias -g '${1+"$@"}'='"$@"' + setopt NO_GLOB_SUBST +else + case `(set -o) 2>/dev/null` in #( + *posix*) : + set -o posix ;; #( + *) : + ;; +esac +fi + + +as_nl=' +' +export as_nl +# Printing a long string crashes Solaris 7 /usr/bin/printf. +as_echo='\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\' +as_echo=$as_echo$as_echo$as_echo$as_echo$as_echo +as_echo=$as_echo$as_echo$as_echo$as_echo$as_echo$as_echo +# Prefer a ksh shell builtin over an external printf program on Solaris, +# but without wasting forks for bash or zsh. +if test -z "$BASH_VERSION$ZSH_VERSION" \ + && (test "X`print -r -- $as_echo`" = "X$as_echo") 2>/dev/null; then + as_echo='print -r --' + as_echo_n='print -rn --' +elif (test "X`printf %s $as_echo`" = "X$as_echo") 2>/dev/null; then + as_echo='printf %s\n' + as_echo_n='printf %s' +else + if test "X`(/usr/ucb/echo -n -n $as_echo) 2>/dev/null`" = "X-n $as_echo"; then + as_echo_body='eval /usr/ucb/echo -n "$1$as_nl"' + as_echo_n='/usr/ucb/echo -n' + else + as_echo_body='eval expr "X$1" : "X\\(.*\\)"' + as_echo_n_body='eval + arg=$1; + case $arg in #( + *"$as_nl"*) + expr "X$arg" : "X\\(.*\\)$as_nl"; + arg=`expr "X$arg" : ".*$as_nl\\(.*\\)"`;; + esac; + expr "X$arg" : "X\\(.*\\)" | tr -d "$as_nl" + ' + export as_echo_n_body + as_echo_n='sh -c $as_echo_n_body as_echo' + fi + export as_echo_body + as_echo='sh -c $as_echo_body as_echo' +fi + +# The user is always right. +if test "${PATH_SEPARATOR+set}" != set; then + PATH_SEPARATOR=: + (PATH='/bin;/bin'; FPATH=$PATH; sh -c :) >/dev/null 2>&1 && { + (PATH='/bin:/bin'; FPATH=$PATH; sh -c :) >/dev/null 2>&1 || + PATH_SEPARATOR=';' + } +fi + + +# IFS +# We need space, tab and new line, in precisely that order. Quoting is +# there to prevent editors from complaining about space-tab. +# (If _AS_PATH_WALK were called with IFS unset, it would disable word +# splitting by setting IFS to empty value.) +IFS=" "" $as_nl" + +# Find who we are. Look in the path if we contain no directory separator. +as_myself= +case $0 in #(( + *[\\/]* ) as_myself=$0 ;; + *) as_save_IFS=$IFS; IFS=$PATH_SEPARATOR +for as_dir in $PATH +do + IFS=$as_save_IFS + test -z "$as_dir" && as_dir=. + test -r "$as_dir/$0" && as_myself=$as_dir/$0 && break + done +IFS=$as_save_IFS + + ;; +esac +# We did not find ourselves, most probably we were run as `sh COMMAND' +# in which case we are not to be found in the path. +if test "x$as_myself" = x; then + as_myself=$0 +fi +if test ! -f "$as_myself"; then + $as_echo "$as_myself: error: cannot find myself; rerun with an absolute file name" >&2 + exit 1 +fi + +# Unset variables that we do not need and which cause bugs (e.g. in +# pre-3.0 UWIN ksh). But do not cause bugs in bash 2.01; the "|| exit 1" +# suppresses any "Segmentation fault" message there. '((' could +# trigger a bug in pdksh 5.2.14. +for as_var in BASH_ENV ENV MAIL MAILPATH +do eval test x\${$as_var+set} = xset \ + && ( (unset $as_var) || exit 1) >/dev/null 2>&1 && unset $as_var || : +done +PS1='$ ' +PS2='> ' +PS4='+ ' + +# NLS nuisances. +LC_ALL=C +export LC_ALL +LANGUAGE=C +export LANGUAGE + +# CDPATH. +(unset CDPATH) >/dev/null 2>&1 && unset CDPATH + + +# as_fn_error STATUS ERROR [LINENO LOG_FD] +# ---------------------------------------- +# Output "`basename $0`: error: ERROR" to stderr. If LINENO and LOG_FD are +# provided, also output the error to LOG_FD, referencing LINENO. Then exit the +# script with STATUS, using 1 if that was 0. +as_fn_error () +{ + as_status=$1; test $as_status -eq 0 && as_status=1 + if test "$4"; then + as_lineno=${as_lineno-"$3"} as_lineno_stack=as_lineno_stack=$as_lineno_stack + $as_echo "$as_me:${as_lineno-$LINENO}: error: $2" >&$4 + fi + $as_echo "$as_me: error: $2" >&2 + as_fn_exit $as_status +} # as_fn_error + + +# as_fn_set_status STATUS +# ----------------------- +# Set $? to STATUS, without forking. +as_fn_set_status () +{ + return $1 +} # as_fn_set_status + +# as_fn_exit STATUS +# ----------------- +# Exit the shell with STATUS, even in a "trap 0" or "set -e" context. +as_fn_exit () +{ + set +e + as_fn_set_status $1 + exit $1 +} # as_fn_exit + +# as_fn_unset VAR +# --------------- +# Portably unset VAR. +as_fn_unset () +{ + { eval $1=; unset $1;} +} +as_unset=as_fn_unset +# as_fn_append VAR VALUE +# ---------------------- +# Append the text in VALUE to the end of the definition contained in VAR. Take +# advantage of any shell optimizations that allow amortized linear growth over +# repeated appends, instead of the typical quadratic growth present in naive +# implementations. +if (eval "as_var=1; as_var+=2; test x\$as_var = x12") 2>/dev/null; then : + eval 'as_fn_append () + { + eval $1+=\$2 + }' +else + as_fn_append () + { + eval $1=\$$1\$2 + } +fi # as_fn_append + +# as_fn_arith ARG... +# ------------------ +# Perform arithmetic evaluation on the ARGs, and store the result in the +# global $as_val. Take advantage of shells that can avoid forks. The arguments +# must be portable across $(()) and expr. +if (eval "test \$(( 1 + 1 )) = 2") 2>/dev/null; then : + eval 'as_fn_arith () + { + as_val=$(( $* )) + }' +else + as_fn_arith () + { + as_val=`expr "$@" || test $? -eq 1` + } +fi # as_fn_arith + + +if expr a : '\(a\)' >/dev/null 2>&1 && + test "X`expr 00001 : '.*\(...\)'`" = X001; then + as_expr=expr +else + as_expr=false +fi + +if (basename -- /) >/dev/null 2>&1 && test "X`basename -- / 2>&1`" = "X/"; then + as_basename=basename +else + as_basename=false +fi + +if (as_dir=`dirname -- /` && test "X$as_dir" = X/) >/dev/null 2>&1; then + as_dirname=dirname +else + as_dirname=false +fi + +as_me=`$as_basename -- "$0" || +$as_expr X/"$0" : '.*/\([^/][^/]*\)/*$' \| \ + X"$0" : 'X\(//\)$' \| \ + X"$0" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X/"$0" | + sed '/^.*\/\([^/][^/]*\)\/*$/{ + s//\1/ + q + } + /^X\/\(\/\/\)$/{ + s//\1/ + q + } + /^X\/\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + +# Avoid depending upon Character Ranges. +as_cr_letters='abcdefghijklmnopqrstuvwxyz' +as_cr_LETTERS='ABCDEFGHIJKLMNOPQRSTUVWXYZ' +as_cr_Letters=$as_cr_letters$as_cr_LETTERS +as_cr_digits='0123456789' +as_cr_alnum=$as_cr_Letters$as_cr_digits + +ECHO_C= ECHO_N= ECHO_T= +case `echo -n x` in #((((( +-n*) + case `echo 'xy\c'` in + *c*) ECHO_T=' ';; # ECHO_T is single tab character. + xy) ECHO_C='\c';; + *) echo `echo ksh88 bug on AIX 6.1` > /dev/null + ECHO_T=' ';; + esac;; +*) + ECHO_N='-n';; +esac + +rm -f conf$$ conf$$.exe conf$$.file +if test -d conf$$.dir; then + rm -f conf$$.dir/conf$$.file +else + rm -f conf$$.dir + mkdir conf$$.dir 2>/dev/null +fi +if (echo >conf$$.file) 2>/dev/null; then + if ln -s conf$$.file conf$$ 2>/dev/null; then + as_ln_s='ln -s' + # ... but there are two gotchas: + # 1) On MSYS, both `ln -s file dir' and `ln file dir' fail. + # 2) DJGPP < 2.04 has no symlinks; `ln -s' creates a wrapper executable. + # In both cases, we have to default to `cp -pR'. + ln -s conf$$.file conf$$.dir 2>/dev/null && test ! -f conf$$.exe || + as_ln_s='cp -pR' + elif ln conf$$.file conf$$ 2>/dev/null; then + as_ln_s=ln + else + as_ln_s='cp -pR' + fi +else + as_ln_s='cp -pR' +fi +rm -f conf$$ conf$$.exe conf$$.dir/conf$$.file conf$$.file +rmdir conf$$.dir 2>/dev/null + + +# as_fn_mkdir_p +# ------------- +# Create "$as_dir" as a directory, including parents if necessary. +as_fn_mkdir_p () +{ + + case $as_dir in #( + -*) as_dir=./$as_dir;; + esac + test -d "$as_dir" || eval $as_mkdir_p || { + as_dirs= + while :; do + case $as_dir in #( + *\'*) as_qdir=`$as_echo "$as_dir" | sed "s/'/'\\\\\\\\''/g"`;; #'( + *) as_qdir=$as_dir;; + esac + as_dirs="'$as_qdir' $as_dirs" + as_dir=`$as_dirname -- "$as_dir" || +$as_expr X"$as_dir" : 'X\(.*[^/]\)//*[^/][^/]*/*$' \| \ + X"$as_dir" : 'X\(//\)[^/]' \| \ + X"$as_dir" : 'X\(//\)$' \| \ + X"$as_dir" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X"$as_dir" | + sed '/^X\(.*[^/]\)\/\/*[^/][^/]*\/*$/{ + s//\1/ + q + } + /^X\(\/\/\)[^/].*/{ + s//\1/ + q + } + /^X\(\/\/\)$/{ + s//\1/ + q + } + /^X\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + test -d "$as_dir" && break + done + test -z "$as_dirs" || eval "mkdir $as_dirs" + } || test -d "$as_dir" || as_fn_error $? "cannot create directory $as_dir" + + +} # as_fn_mkdir_p +if mkdir -p . 2>/dev/null; then + as_mkdir_p='mkdir -p "$as_dir"' +else + test -d ./-p && rmdir ./-p + as_mkdir_p=false +fi + + +# as_fn_executable_p FILE +# ----------------------- +# Test if FILE is an executable regular file. +as_fn_executable_p () +{ + test -f "$1" && test -x "$1" +} # as_fn_executable_p +as_test_x='test -x' +as_executable_p=as_fn_executable_p + +# Sed expression to map a string onto a valid CPP name. +as_tr_cpp="eval sed 'y%*$as_cr_letters%P$as_cr_LETTERS%;s%[^_$as_cr_alnum]%_%g'" + +# Sed expression to map a string onto a valid variable name. +as_tr_sh="eval sed 'y%*+%pp%;s%[^_$as_cr_alnum]%_%g'" + + +exec 6>&1 +## ----------------------------------- ## +## Main body of $CONFIG_STATUS script. ## +## ----------------------------------- ## +_ASEOF +test $as_write_fail = 0 && chmod +x $CONFIG_STATUS || ac_write_fail=1 + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +# Save the log message, to keep $0 and so on meaningful, and to +# report actual input values of CONFIG_FILES etc. instead of their +# values after options handling. +ac_log=" +This file was extended by faiss $as_me 1.0, which was +generated by GNU Autoconf 2.69. Invocation command line was + + CONFIG_FILES = $CONFIG_FILES + CONFIG_HEADERS = $CONFIG_HEADERS + CONFIG_LINKS = $CONFIG_LINKS + CONFIG_COMMANDS = $CONFIG_COMMANDS + $ $0 $@ + +on `(hostname || uname -n) 2>/dev/null | sed 1q` +" + +_ACEOF + +case $ac_config_files in *" +"*) set x $ac_config_files; shift; ac_config_files=$*;; +esac + + + +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +# Files that config.status was made for. +config_files="$ac_config_files" + +_ACEOF + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +ac_cs_usage="\ +\`$as_me' instantiates files and other configuration actions +from templates according to the current configuration. Unless the files +and actions are specified as TAGs, all are instantiated by default. + +Usage: $0 [OPTION]... [TAG]... + + -h, --help print this help, then exit + -V, --version print version number and configuration settings, then exit + --config print configuration, then exit + -q, --quiet, --silent + do not print progress messages + -d, --debug don't remove temporary files + --recheck update $as_me by reconfiguring in the same conditions + --file=FILE[:TEMPLATE] + instantiate the configuration file FILE + +Configuration files: +$config_files + +Report bugs to the package provider." + +_ACEOF +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +ac_cs_config="`$as_echo "$ac_configure_args" | sed 's/^ //; s/[\\""\`\$]/\\\\&/g'`" +ac_cs_version="\\ +faiss config.status 1.0 +configured by $0, generated by GNU Autoconf 2.69, + with options \\"\$ac_cs_config\\" + +Copyright (C) 2012 Free Software Foundation, Inc. +This config.status script is free software; the Free Software Foundation +gives unlimited permission to copy, distribute and modify it." + +ac_pwd='$ac_pwd' +srcdir='$srcdir' +MKDIR_P='$MKDIR_P' +test -n "\$AWK" || AWK=awk +_ACEOF + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +# The default lists apply if the user does not specify any file. +ac_need_defaults=: +while test $# != 0 +do + case $1 in + --*=?*) + ac_option=`expr "X$1" : 'X\([^=]*\)='` + ac_optarg=`expr "X$1" : 'X[^=]*=\(.*\)'` + ac_shift=: + ;; + --*=) + ac_option=`expr "X$1" : 'X\([^=]*\)='` + ac_optarg= + ac_shift=: + ;; + *) + ac_option=$1 + ac_optarg=$2 + ac_shift=shift + ;; + esac + + case $ac_option in + # Handling of the options. + -recheck | --recheck | --rechec | --reche | --rech | --rec | --re | --r) + ac_cs_recheck=: ;; + --version | --versio | --versi | --vers | --ver | --ve | --v | -V ) + $as_echo "$ac_cs_version"; exit ;; + --config | --confi | --conf | --con | --co | --c ) + $as_echo "$ac_cs_config"; exit ;; + --debug | --debu | --deb | --de | --d | -d ) + debug=: ;; + --file | --fil | --fi | --f ) + $ac_shift + case $ac_optarg in + *\'*) ac_optarg=`$as_echo "$ac_optarg" | sed "s/'/'\\\\\\\\''/g"` ;; + '') as_fn_error $? "missing file argument" ;; + esac + as_fn_append CONFIG_FILES " '$ac_optarg'" + ac_need_defaults=false;; + --he | --h | --help | --hel | -h ) + $as_echo "$ac_cs_usage"; exit ;; + -q | -quiet | --quiet | --quie | --qui | --qu | --q \ + | -silent | --silent | --silen | --sile | --sil | --si | --s) + ac_cs_silent=: ;; + + # This is an error. + -*) as_fn_error $? "unrecognized option: \`$1' +Try \`$0 --help' for more information." ;; + + *) as_fn_append ac_config_targets " $1" + ac_need_defaults=false ;; + + esac + shift +done + +ac_configure_extra_args= + +if $ac_cs_silent; then + exec 6>/dev/null + ac_configure_extra_args="$ac_configure_extra_args --silent" +fi + +_ACEOF +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +if \$ac_cs_recheck; then + set X $SHELL '$0' $ac_configure_args \$ac_configure_extra_args --no-create --no-recursion + shift + \$as_echo "running CONFIG_SHELL=$SHELL \$*" >&6 + CONFIG_SHELL='$SHELL' + export CONFIG_SHELL + exec "\$@" +fi + +_ACEOF +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +exec 5>>config.log +{ + echo + sed 'h;s/./-/g;s/^.../## /;s/...$/ ##/;p;x;p;x' <<_ASBOX +## Running $as_me. ## +_ASBOX + $as_echo "$ac_log" +} >&5 + +_ACEOF +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +_ACEOF + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 + +# Handling of arguments. +for ac_config_target in $ac_config_targets +do + case $ac_config_target in + "makefile.inc") CONFIG_FILES="$CONFIG_FILES makefile.inc" ;; + + *) as_fn_error $? "invalid argument: \`$ac_config_target'" "$LINENO" 5;; + esac +done + + +# If the user did not use the arguments to specify the items to instantiate, +# then the envvar interface is used. Set only those that are not. +# We use the long form for the default assignment because of an extremely +# bizarre bug on SunOS 4.1.3. +if $ac_need_defaults; then + test "${CONFIG_FILES+set}" = set || CONFIG_FILES=$config_files +fi + +# Have a temporary directory for convenience. Make it in the build tree +# simply because there is no reason against having it here, and in addition, +# creating and moving files from /tmp can sometimes cause problems. +# Hook for its removal unless debugging. +# Note that there is a small window in which the directory will not be cleaned: +# after its creation but before its name has been assigned to `$tmp'. +$debug || +{ + tmp= ac_tmp= + trap 'exit_status=$? + : "${ac_tmp:=$tmp}" + { test ! -d "$ac_tmp" || rm -fr "$ac_tmp"; } && exit $exit_status +' 0 + trap 'as_fn_exit 1' 1 2 13 15 +} +# Create a (secure) tmp directory for tmp files. + +{ + tmp=`(umask 077 && mktemp -d "./confXXXXXX") 2>/dev/null` && + test -d "$tmp" +} || +{ + tmp=./conf$$-$RANDOM + (umask 077 && mkdir "$tmp") +} || as_fn_error $? "cannot create a temporary directory in ." "$LINENO" 5 +ac_tmp=$tmp + +# Set up the scripts for CONFIG_FILES section. +# No need to generate them if there are no CONFIG_FILES. +# This happens for instance with `./config.status config.h'. +if test -n "$CONFIG_FILES"; then + + +ac_cr=`echo X | tr X '\015'` +# On cygwin, bash can eat \r inside `` if the user requested igncr. +# But we know of no other shell where ac_cr would be empty at this +# point, so we can use a bashism as a fallback. +if test "x$ac_cr" = x; then + eval ac_cr=\$\'\\r\' +fi +ac_cs_awk_cr=`$AWK 'BEGIN { print "a\rb" }' /dev/null` +if test "$ac_cs_awk_cr" = "a${ac_cr}b"; then + ac_cs_awk_cr='\\r' +else + ac_cs_awk_cr=$ac_cr +fi + +echo 'BEGIN {' >"$ac_tmp/subs1.awk" && +_ACEOF + + +{ + echo "cat >conf$$subs.awk <<_ACEOF" && + echo "$ac_subst_vars" | sed 's/.*/&!$&$ac_delim/' && + echo "_ACEOF" +} >conf$$subs.sh || + as_fn_error $? "could not make $CONFIG_STATUS" "$LINENO" 5 +ac_delim_num=`echo "$ac_subst_vars" | grep -c '^'` +ac_delim='%!_!# ' +for ac_last_try in false false false false false :; do + . ./conf$$subs.sh || + as_fn_error $? "could not make $CONFIG_STATUS" "$LINENO" 5 + + ac_delim_n=`sed -n "s/.*$ac_delim\$/X/p" conf$$subs.awk | grep -c X` + if test $ac_delim_n = $ac_delim_num; then + break + elif $ac_last_try; then + as_fn_error $? "could not make $CONFIG_STATUS" "$LINENO" 5 + else + ac_delim="$ac_delim!$ac_delim _$ac_delim!! " + fi +done +rm -f conf$$subs.sh + +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +cat >>"\$ac_tmp/subs1.awk" <<\\_ACAWK && +_ACEOF +sed -n ' +h +s/^/S["/; s/!.*/"]=/ +p +g +s/^[^!]*!// +:repl +t repl +s/'"$ac_delim"'$// +t delim +:nl +h +s/\(.\{148\}\)..*/\1/ +t more1 +s/["\\]/\\&/g; s/^/"/; s/$/\\n"\\/ +p +n +b repl +:more1 +s/["\\]/\\&/g; s/^/"/; s/$/"\\/ +p +g +s/.\{148\}// +t nl +:delim +h +s/\(.\{148\}\)..*/\1/ +t more2 +s/["\\]/\\&/g; s/^/"/; s/$/"/ +p +b +:more2 +s/["\\]/\\&/g; s/^/"/; s/$/"\\/ +p +g +s/.\{148\}// +t delim +' >$CONFIG_STATUS || ac_write_fail=1 +rm -f conf$$subs.awk +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +_ACAWK +cat >>"\$ac_tmp/subs1.awk" <<_ACAWK && + for (key in S) S_is_set[key] = 1 + FS = "" + +} +{ + line = $ 0 + nfields = split(line, field, "@") + substed = 0 + len = length(field[1]) + for (i = 2; i < nfields; i++) { + key = field[i] + keylen = length(key) + if (S_is_set[key]) { + value = S[key] + line = substr(line, 1, len) "" value "" substr(line, len + keylen + 3) + len += length(value) + length(field[++i]) + substed = 1 + } else + len += 1 + keylen + } + + print line +} + +_ACAWK +_ACEOF +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +if sed "s/$ac_cr//" < /dev/null > /dev/null 2>&1; then + sed "s/$ac_cr\$//; s/$ac_cr/$ac_cs_awk_cr/g" +else + cat +fi < "$ac_tmp/subs1.awk" > "$ac_tmp/subs.awk" \ + || as_fn_error $? "could not setup config files machinery" "$LINENO" 5 +_ACEOF + +# VPATH may cause trouble with some makes, so we remove sole $(srcdir), +# ${srcdir} and @srcdir@ entries from VPATH if srcdir is ".", strip leading and +# trailing colons and then remove the whole line if VPATH becomes empty +# (actually we leave an empty line to preserve line numbers). +if test "x$srcdir" = x.; then + ac_vpsub='/^[ ]*VPATH[ ]*=[ ]*/{ +h +s/// +s/^/:/ +s/[ ]*$/:/ +s/:\$(srcdir):/:/g +s/:\${srcdir}:/:/g +s/:@srcdir@:/:/g +s/^:*// +s/:*$// +x +s/\(=[ ]*\).*/\1/ +G +s/\n// +s/^[^=]*=[ ]*$// +}' +fi + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +fi # test -n "$CONFIG_FILES" + + +eval set X " :F $CONFIG_FILES " +shift +for ac_tag +do + case $ac_tag in + :[FHLC]) ac_mode=$ac_tag; continue;; + esac + case $ac_mode$ac_tag in + :[FHL]*:*);; + :L* | :C*:*) as_fn_error $? "invalid tag \`$ac_tag'" "$LINENO" 5;; + :[FH]-) ac_tag=-:-;; + :[FH]*) ac_tag=$ac_tag:$ac_tag.in;; + esac + ac_save_IFS=$IFS + IFS=: + set x $ac_tag + IFS=$ac_save_IFS + shift + ac_file=$1 + shift + + case $ac_mode in + :L) ac_source=$1;; + :[FH]) + ac_file_inputs= + for ac_f + do + case $ac_f in + -) ac_f="$ac_tmp/stdin";; + *) # Look for the file first in the build tree, then in the source tree + # (if the path is not absolute). The absolute path cannot be DOS-style, + # because $ac_f cannot contain `:'. + test -f "$ac_f" || + case $ac_f in + [\\/$]*) false;; + *) test -f "$srcdir/$ac_f" && ac_f="$srcdir/$ac_f";; + esac || + as_fn_error 1 "cannot find input file: \`$ac_f'" "$LINENO" 5;; + esac + case $ac_f in *\'*) ac_f=`$as_echo "$ac_f" | sed "s/'/'\\\\\\\\''/g"`;; esac + as_fn_append ac_file_inputs " '$ac_f'" + done + + # Let's still pretend it is `configure' which instantiates (i.e., don't + # use $as_me), people would be surprised to read: + # /* config.h. Generated by config.status. */ + configure_input='Generated from '` + $as_echo "$*" | sed 's|^[^:]*/||;s|:[^:]*/|, |g' + `' by configure.' + if test x"$ac_file" != x-; then + configure_input="$ac_file. $configure_input" + { $as_echo "$as_me:${as_lineno-$LINENO}: creating $ac_file" >&5 +$as_echo "$as_me: creating $ac_file" >&6;} + fi + # Neutralize special characters interpreted by sed in replacement strings. + case $configure_input in #( + *\&* | *\|* | *\\* ) + ac_sed_conf_input=`$as_echo "$configure_input" | + sed 's/[\\\\&|]/\\\\&/g'`;; #( + *) ac_sed_conf_input=$configure_input;; + esac + + case $ac_tag in + *:-:* | *:-) cat >"$ac_tmp/stdin" \ + || as_fn_error $? "could not create $ac_file" "$LINENO" 5 ;; + esac + ;; + esac + + ac_dir=`$as_dirname -- "$ac_file" || +$as_expr X"$ac_file" : 'X\(.*[^/]\)//*[^/][^/]*/*$' \| \ + X"$ac_file" : 'X\(//\)[^/]' \| \ + X"$ac_file" : 'X\(//\)$' \| \ + X"$ac_file" : 'X\(/\)' \| . 2>/dev/null || +$as_echo X"$ac_file" | + sed '/^X\(.*[^/]\)\/\/*[^/][^/]*\/*$/{ + s//\1/ + q + } + /^X\(\/\/\)[^/].*/{ + s//\1/ + q + } + /^X\(\/\/\)$/{ + s//\1/ + q + } + /^X\(\/\).*/{ + s//\1/ + q + } + s/.*/./; q'` + as_dir="$ac_dir"; as_fn_mkdir_p + ac_builddir=. + +case "$ac_dir" in +.) ac_dir_suffix= ac_top_builddir_sub=. ac_top_build_prefix= ;; +*) + ac_dir_suffix=/`$as_echo "$ac_dir" | sed 's|^\.[\\/]||'` + # A ".." for each directory in $ac_dir_suffix. + ac_top_builddir_sub=`$as_echo "$ac_dir_suffix" | sed 's|/[^\\/]*|/..|g;s|/||'` + case $ac_top_builddir_sub in + "") ac_top_builddir_sub=. ac_top_build_prefix= ;; + *) ac_top_build_prefix=$ac_top_builddir_sub/ ;; + esac ;; +esac +ac_abs_top_builddir=$ac_pwd +ac_abs_builddir=$ac_pwd$ac_dir_suffix +# for backward compatibility: +ac_top_builddir=$ac_top_build_prefix + +case $srcdir in + .) # We are building in place. + ac_srcdir=. + ac_top_srcdir=$ac_top_builddir_sub + ac_abs_top_srcdir=$ac_pwd ;; + [\\/]* | ?:[\\/]* ) # Absolute name. + ac_srcdir=$srcdir$ac_dir_suffix; + ac_top_srcdir=$srcdir + ac_abs_top_srcdir=$srcdir ;; + *) # Relative name. + ac_srcdir=$ac_top_build_prefix$srcdir$ac_dir_suffix + ac_top_srcdir=$ac_top_build_prefix$srcdir + ac_abs_top_srcdir=$ac_pwd/$srcdir ;; +esac +ac_abs_srcdir=$ac_abs_top_srcdir$ac_dir_suffix + + + case $ac_mode in + :F) + # + # CONFIG_FILE + # + + ac_MKDIR_P=$MKDIR_P + case $MKDIR_P in + [\\/$]* | ?:[\\/]* ) ;; + */*) ac_MKDIR_P=$ac_top_build_prefix$MKDIR_P ;; + esac +_ACEOF + +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +# If the template does not know about datarootdir, expand it. +# FIXME: This hack should be removed a few years after 2.60. +ac_datarootdir_hack=; ac_datarootdir_seen= +ac_sed_dataroot=' +/datarootdir/ { + p + q +} +/@datadir@/p +/@docdir@/p +/@infodir@/p +/@localedir@/p +/@mandir@/p' +case `eval "sed -n \"\$ac_sed_dataroot\" $ac_file_inputs"` in +*datarootdir*) ac_datarootdir_seen=yes;; +*@datadir@*|*@docdir@*|*@infodir@*|*@localedir@*|*@mandir@*) + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $ac_file_inputs seems to ignore the --datarootdir setting" >&5 +$as_echo "$as_me: WARNING: $ac_file_inputs seems to ignore the --datarootdir setting" >&2;} +_ACEOF +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 + ac_datarootdir_hack=' + s&@datadir@&$datadir&g + s&@docdir@&$docdir&g + s&@infodir@&$infodir&g + s&@localedir@&$localedir&g + s&@mandir@&$mandir&g + s&\\\${datarootdir}&$datarootdir&g' ;; +esac +_ACEOF + +# Neutralize VPATH when `$srcdir' = `.'. +# Shell code in configure.ac might set extrasub. +# FIXME: do we really want to maintain this feature? +cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 +ac_sed_extra="$ac_vpsub +$extrasub +_ACEOF +cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 +:t +/@[a-zA-Z_][a-zA-Z_0-9]*@/!b +s|@configure_input@|$ac_sed_conf_input|;t t +s&@top_builddir@&$ac_top_builddir_sub&;t t +s&@top_build_prefix@&$ac_top_build_prefix&;t t +s&@srcdir@&$ac_srcdir&;t t +s&@abs_srcdir@&$ac_abs_srcdir&;t t +s&@top_srcdir@&$ac_top_srcdir&;t t +s&@abs_top_srcdir@&$ac_abs_top_srcdir&;t t +s&@builddir@&$ac_builddir&;t t +s&@abs_builddir@&$ac_abs_builddir&;t t +s&@abs_top_builddir@&$ac_abs_top_builddir&;t t +s&@MKDIR_P@&$ac_MKDIR_P&;t t +$ac_datarootdir_hack +" +eval sed \"\$ac_sed_extra\" "$ac_file_inputs" | $AWK -f "$ac_tmp/subs.awk" \ + >$ac_tmp/out || as_fn_error $? "could not create $ac_file" "$LINENO" 5 + +test -z "$ac_datarootdir_hack$ac_datarootdir_seen" && + { ac_out=`sed -n '/\${datarootdir}/p' "$ac_tmp/out"`; test -n "$ac_out"; } && + { ac_out=`sed -n '/^[ ]*datarootdir[ ]*:*=/p' \ + "$ac_tmp/out"`; test -z "$ac_out"; } && + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: $ac_file contains a reference to the variable \`datarootdir' +which seems to be undefined. Please make sure it is defined" >&5 +$as_echo "$as_me: WARNING: $ac_file contains a reference to the variable \`datarootdir' +which seems to be undefined. Please make sure it is defined" >&2;} + + rm -f "$ac_tmp/stdin" + case $ac_file in + -) cat "$ac_tmp/out" && rm -f "$ac_tmp/out";; + *) rm -f "$ac_file" && mv "$ac_tmp/out" "$ac_file";; + esac \ + || as_fn_error $? "could not create $ac_file" "$LINENO" 5 + ;; + + + + esac + +done # for ac_tag + + +as_fn_exit 0 +_ACEOF +ac_clean_files=$ac_clean_files_save + +test $ac_write_fail = 0 || + as_fn_error $? "write failure creating $CONFIG_STATUS" "$LINENO" 5 + + +# configure is writing to config.log, and then calls config.status. +# config.status does its own redirection, appending to config.log. +# Unfortunately, on DOS this fails, as config.log is still kept open +# by configure, so config.status won't be able to write to it; its +# output is simply discarded. So we exec the FD to /dev/null, +# effectively closing config.log, so it can be properly (re)opened and +# appended to by config.status. When coming back to configure, we +# need to make the FD available again. +if test "$no_create" != yes; then + ac_cs_success=: + ac_config_status_args= + test "$silent" = yes && + ac_config_status_args="$ac_config_status_args --quiet" + exec 5>/dev/null + $SHELL $CONFIG_STATUS $ac_config_status_args || ac_cs_success=false + exec 5>>config.log + # Use ||, not &&, to avoid exiting from the if with $? = 1, which + # would make configure fail if this is the last instruction. + $ac_cs_success || as_fn_exit 1 +fi +if test -n "$ac_unrecognized_opts" && test "$enable_option_checking" != no; then + { $as_echo "$as_me:${as_lineno-$LINENO}: WARNING: unrecognized options: $ac_unrecognized_opts" >&5 +$as_echo "$as_me: WARNING: unrecognized options: $ac_unrecognized_opts" >&2;} +fi + diff --git a/core/src/index/thirdparty/faiss/configure.ac b/core/src/index/thirdparty/faiss/configure.ac new file mode 100644 index 0000000000..31b587b86d --- /dev/null +++ b/core/src/index/thirdparty/faiss/configure.ac @@ -0,0 +1,70 @@ +# -*- Autoconf -*- +# Process this file with autoconf to produce a configure script. + +AC_PREREQ([2.69]) +AC_INIT([faiss], [1.0]) +AC_COPYRIGHT([Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree.]) +AC_CONFIG_SRCDIR([Index.h]) +AC_CONFIG_AUX_DIR([build-aux]) +AC_CONFIG_MACRO_DIR([acinclude]) + +: ${CXXFLAGS="-g -O3 -Wall -Wextra"} + +# Checks for programs. +AC_LANG(C++) +AC_PROG_CXX +AX_CXX_COMPILE_STDCXX([11], [noext], [mandatory]) +AC_PROG_CPP +AC_PROG_MAKE_SET +AC_PROG_MKDIR_P + +FA_PYTHON + +if test x$PYTHON != x; then + FA_NUMPY +fi + +FA_PROG_SWIG + +FA_CHECK_CUDA + + +# Checks for header files. +AC_CHECK_HEADERS([float.h limits.h stddef.h stdint.h stdlib.h string.h sys/time.h unistd.h]) + +# Checks for typedefs, structures, and compiler characteristics. +AC_CHECK_HEADER_STDBOOL +AC_C_INLINE +AC_TYPE_INT32_T +AC_TYPE_INT64_T +AC_C_RESTRICT +AC_TYPE_SIZE_T +AC_TYPE_UINT16_T +AC_TYPE_UINT32_T +AC_TYPE_UINT64_T +AC_TYPE_UINT8_T + +# Checks for library functions. +AC_FUNC_MALLOC +AC_FUNC_MMAP +AC_CHECK_FUNCS([clock_gettime floor gettimeofday memmove memset munmap pow sqrt strerror strstr]) + +AC_OPENMP + +AX_BLAS +if test "x$ax_blas_ok" == "xno"; then + AC_MSG_ERROR([An implementation of BLAS is required but none was found.]) +fi + +AX_LAPACK +if test "x$ax_lapack_ok" == "xno"; then + AC_MSG_ERROR([An implementation of LAPACK is required but none was found.]) +fi + +AX_CPU_ARCH + +AC_CONFIG_FILES([makefile.inc]) +AC_OUTPUT diff --git a/core/src/index/thirdparty/faiss/demos/Makefile b/core/src/index/thirdparty/faiss/demos/Makefile new file mode 100644 index 0000000000..9d871697a9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/Makefile @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include ../makefile.inc + +DEMOS_SRC=$(wildcard demo_*.cpp) +DEMOS=$(DEMOS_SRC:.cpp=) + + +all: $(DEMOS) + +clean: + rm -f $(DEMOS) + +%: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -o $@ $^ $(LDFLAGS) $(LIBS) -lfaiss + + +.PHONY: all clean diff --git a/core/src/index/thirdparty/faiss/demos/README.md b/core/src/index/thirdparty/faiss/demos/README.md new file mode 100644 index 0000000000..71a23f272e --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/README.md @@ -0,0 +1,28 @@ + + +Demos for a few Faiss functionalities +===================================== + + +demo_auto_tune.py +----------------- + +Demonstrates the auto-tuning functionality of Faiss + + +demo_ondisk_ivf.py +------------------ + +Shows how to construct a Faiss index that stores the inverted file +data on disk, eg. when it does not fit in RAM. The script works on a +small dataset (sift1M) for demonstration and proceeds in stages: + +0: train on the dataset + +1-4: build 4 indexes, each containing 1/4 of the dataset. This can be +done in parallel on several machines + +5: merge the 4 indexes into one that is written directly to disk +(needs not to fit in RAM) + +6: load and test the index diff --git a/core/src/index/thirdparty/faiss/demos/demo_auto_tune.py b/core/src/index/thirdparty/faiss/demos/demo_auto_tune.py new file mode 100644 index 0000000000..eb7c709a1b --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_auto_tune.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +from __future__ import print_function +import os +import time +import numpy as np + +try: + import matplotlib + matplotlib.use('Agg') + from matplotlib import pyplot + graphical_output = True +except ImportError: + graphical_output = False + +import faiss + +################################################################# +# Small I/O functions +################################################################# + +def ivecs_read(fname): + a = np.fromfile(fname, dtype="int32") + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +def plot_OperatingPoints(ops, nq, **kwargs): + ops = ops.optimal_pts + n = ops.size() * 2 - 1 + pyplot.plot([ops.at( i // 2).perf for i in range(n)], + [ops.at((i + 1) // 2).t / nq * 1000 for i in range(n)], + **kwargs) + + +################################################################# +# prepare common data for all indexes +################################################################# + + + +t0 = time.time() + +print("load data") + +xt = fvecs_read("sift1M/sift_learn.fvecs") +xb = fvecs_read("sift1M/sift_base.fvecs") +xq = fvecs_read("sift1M/sift_query.fvecs") + +d = xt.shape[1] + +print("load GT") + +gt = ivecs_read("sift1M/sift_groundtruth.ivecs") +gt = gt.astype('int64') +k = gt.shape[1] + +print("prepare criterion") + +# criterion = 1-recall at 1 +crit = faiss.OneRecallAtRCriterion(xq.shape[0], 1) +crit.set_groundtruth(None, gt) +crit.nnn = k + +# indexes that are useful when there is no limitation on memory usage +unlimited_mem_keys = [ + "IMI2x10,Flat", "IMI2x11,Flat", + "IVF4096,Flat", "IVF16384,Flat", + "PCA64,IMI2x10,Flat"] + +# memory limited to 16 bytes / vector +keys_mem_16 = [ + 'IMI2x10,PQ16', 'IVF4096,PQ16', + 'IMI2x10,PQ8+8', 'OPQ16_64,IMI2x10,PQ16' + ] + +# limited to 32 bytes / vector +keys_mem_32 = [ + 'IMI2x10,PQ32', 'IVF4096,PQ32', 'IVF16384,PQ32', + 'IMI2x10,PQ16+16', + 'OPQ32,IVF4096,PQ32', 'IVF4096,PQ16+16', 'OPQ16,IMI2x10,PQ16+16' + ] + +# indexes that can run on the GPU +keys_gpu = [ + "PCA64,IVF4096,Flat", + "PCA64,Flat", "Flat", "IVF4096,Flat", "IVF16384,Flat", + "IVF4096,PQ32"] + + +keys_to_test = unlimited_mem_keys +use_gpu = False + + +if use_gpu: + # if this fails, it means that the GPU version was not comp + assert faiss.StandardGpuResources, \ + "FAISS was not compiled with GPU support, or loading _swigfaiss_gpu.so failed" + res = faiss.StandardGpuResources() + dev_no = 0 + +# remember results from other index types +op_per_key = [] + + +# keep track of optimal operating points seen so far +op = faiss.OperatingPoints() + + +for index_key in keys_to_test: + + print("============ key", index_key) + + # make the index described by the key + index = faiss.index_factory(d, index_key) + + + if use_gpu: + # transfer to GPU (may be partial) + index = faiss.index_cpu_to_gpu(res, dev_no, index) + params = faiss.GpuParameterSpace() + else: + params = faiss.ParameterSpace() + + params.initialize(index) + + print("[%.3f s] train & add" % (time.time() - t0)) + + index.train(xt) + index.add(xb) + + print("[%.3f s] explore op points" % (time.time() - t0)) + + # find operating points for this index + opi = params.explore(index, xq, crit) + + print("[%.3f s] result operating points:" % (time.time() - t0)) + opi.display() + + # update best operating points so far + op.merge_with(opi, index_key + " ") + + op_per_key.append((index_key, opi)) + + if graphical_output: + # graphical output (to tmp/ subdirectory) + + fig = pyplot.figure(figsize=(12, 9)) + pyplot.xlabel("1-recall at 1") + pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads()) + pyplot.gca().set_yscale('log') + pyplot.grid() + for i2, opi2 in op_per_key: + plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o') + # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r') + pyplot.legend(loc=2) + fig.savefig('tmp/demo_auto_tune.png') + + +print("[%.3f s] final result:" % (time.time() - t0)) + +op.display() diff --git a/core/src/index/thirdparty/faiss/demos/demo_imi_flat.cpp b/core/src/index/thirdparty/faiss/demos/demo_imi_flat.cpp new file mode 100644 index 0000000000..b037817321 --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_imi_flat.cpp @@ -0,0 +1,151 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +#include +#include +#include + +#include + + +#include +#include +#include +#include + +double elapsed () +{ + struct timeval tv; + gettimeofday (&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + + +int main () +{ + double t0 = elapsed(); + + // dimension of the vectors to index + int d = 128; + + // size of the database we plan to index + size_t nb = 1000 * 1000; + + // make a set of nt training vectors in the unit cube + // (could be the database) + size_t nt = 100 * 1000; + + //--------------------------------------------------------------- + // Define the core quantizer + // We choose a multiple inverted index for faster training with less data + // and because it usually offers best accuracy/speed trade-offs + // + // We here assume that its lifespan of this coarse quantizer will cover the + // lifespan of the inverted-file quantizer IndexIVFFlat below + // With dynamic allocation, one may give the responsability to free the + // quantizer to the inverted-file index (with attribute do_delete_quantizer) + // + // Note: a regular clustering algorithm would be defined as: + // faiss::IndexFlatL2 coarse_quantizer (d); + // + // Use nhash=2 subquantizers used to define the product coarse quantizer + // Number of bits: we will have 2^nbits_coarse centroids per subquantizer + // meaning (2^12)^nhash distinct inverted lists + size_t nhash = 2; + size_t nbits_subq = int (log2 (nb+1) / 2); // good choice in general + size_t ncentroids = 1 << (nhash * nbits_subq); // total # of centroids + + faiss::MultiIndexQuantizer coarse_quantizer (d, nhash, nbits_subq); + + printf ("IMI (%ld,%ld): %ld virtual centroids (target: %ld base vectors)", + nhash, nbits_subq, ncentroids, nb); + + // the coarse quantizer should not be dealloced before the index + // 4 = nb of bytes per code (d must be a multiple of this) + // 8 = nb of bits per sub-code (almost always 8) + faiss::MetricType metric = faiss::METRIC_L2; // can be METRIC_INNER_PRODUCT + faiss::IndexIVFFlat index (&coarse_quantizer, d, ncentroids, metric); + index.quantizer_trains_alone = true; + + // define the number of probes. 2048 is for high-dim, overkilled in practice + // Use 4-1024 depending on the trade-off speed accuracy that you want + index.nprobe = 2048; + + + { // training + printf ("[%.3f s] Generating %ld vectors in %dD for training\n", + elapsed() - t0, nt, d); + + std::vector trainvecs (nt * d); + for (size_t i = 0; i < nt * d; i++) { + trainvecs[i] = drand48(); + } + + printf ("[%.3f s] Training the index\n", elapsed() - t0); + index.verbose = true; + index.train (nt, trainvecs.data()); + } + + size_t nq; + std::vector queries; + + { // populating the database + printf ("[%.3f s] Building a dataset of %ld vectors to index\n", + elapsed() - t0, nb); + + std::vector database (nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + + printf ("[%.3f s] Adding the vectors to the index\n", elapsed() - t0); + + index.add (nb, database.data()); + + // remember a few elements from the database as queries + int i0 = 1234; + int i1 = 1244; + + nq = i1 - i0; + queries.resize (nq * d); + for (int i = i0; i < i1; i++) { + for (int j = 0; j < d; j++) { + queries [(i - i0) * d + j] = database [i * d + j]; + } + } + } + + { // searching the database + int k = 5; + printf ("[%.3f s] Searching the %d nearest neighbors " + "of %ld vectors in the index\n", + elapsed() - t0, k, nq); + + std::vector nns (k * nq); + std::vector dis (k * nq); + + index.search (nq, queries.data(), k, dis.data(), nns.data()); + + printf ("[%.3f s] Query results (vector ids, then distances):\n", + elapsed() - t0); + + for (int i = 0; i < nq; i++) { + printf ("query %2d: ", i); + for (int j = 0; j < k; j++) { + printf ("%7ld ", nns[j + i * k]); + } + printf ("\n dis: "); + for (int j = 0; j < k; j++) { + printf ("%7g ", dis[j + i * k]); + } + printf ("\n"); + } + } + return 0; +} diff --git a/core/src/index/thirdparty/faiss/demos/demo_imi_pq.cpp b/core/src/index/thirdparty/faiss/demos/demo_imi_pq.cpp new file mode 100644 index 0000000000..ea6f998c6e --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_imi_pq.cpp @@ -0,0 +1,199 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +#include +#include +#include + +#include + + +#include +#include +#include +#include + +double elapsed () +{ + struct timeval tv; + gettimeofday (&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + + +int main () +{ + double t0 = elapsed(); + + // dimension of the vectors to index + int d = 64; + + // size of the database we plan to index + size_t nb = 1000 * 1000; + size_t add_bs = 10000; // # size of the blocks to add + + // make a set of nt training vectors in the unit cube + // (could be the database) + size_t nt = 100 * 1000; + + //--------------------------------------------------------------- + // Define the core quantizer + // We choose a multiple inverted index for faster training with less data + // and because it usually offers best accuracy/speed trade-offs + // + // We here assume that its lifespan of this coarse quantizer will cover the + // lifespan of the inverted-file quantizer IndexIVFFlat below + // With dynamic allocation, one may give the responsability to free the + // quantizer to the inverted-file index (with attribute do_delete_quantizer) + // + // Note: a regular clustering algorithm would be defined as: + // faiss::IndexFlatL2 coarse_quantizer (d); + // + // Use nhash=2 subquantizers used to define the product coarse quantizer + // Number of bits: we will have 2^nbits_coarse centroids per subquantizer + // meaning (2^12)^nhash distinct inverted lists + // + // The parameter bytes_per_code is determined by the memory + // constraint, the dataset will use nb * (bytes_per_code + 8) + // bytes. + // + // The parameter nbits_subq is determined by the size of the dataset to index. + // + size_t nhash = 2; + size_t nbits_subq = 9; + size_t ncentroids = 1 << (nhash * nbits_subq); // total # of centroids + int bytes_per_code = 16; + + faiss::MultiIndexQuantizer coarse_quantizer (d, nhash, nbits_subq); + + printf ("IMI (%ld,%ld): %ld virtual centroids (target: %ld base vectors)", + nhash, nbits_subq, ncentroids, nb); + + // the coarse quantizer should not be dealloced before the index + // 4 = nb of bytes per code (d must be a multiple of this) + // 8 = nb of bits per sub-code (almost always 8) + faiss::MetricType metric = faiss::METRIC_L2; // can be METRIC_INNER_PRODUCT + faiss::IndexIVFPQ index (&coarse_quantizer, d, ncentroids, bytes_per_code, 8); + index.quantizer_trains_alone = true; + + // define the number of probes. 2048 is for high-dim, overkill in practice + // Use 4-1024 depending on the trade-off speed accuracy that you want + index.nprobe = 2048; + + + { // training. + + // The distribution of the training vectors should be the same + // as the database vectors. It could be a sub-sample of the + // database vectors, if sampling is not biased. Here we just + // randomly generate the vectors. + + printf ("[%.3f s] Generating %ld vectors in %dD for training\n", + elapsed() - t0, nt, d); + + std::vector trainvecs (nt * d); + for (size_t i = 0; i < nt; i++) { + for (size_t j = 0; j < d; j++) { + trainvecs[i * d + j] = drand48(); + } + } + + printf ("[%.3f s] Training the index\n", elapsed() - t0); + index.verbose = true; + index.train (nt, trainvecs.data()); + } + + // the index can be re-loaded later with + // faiss::Index * idx = faiss::read_index("/tmp/trained_index.faissindex"); + faiss::write_index(&index, "/tmp/trained_index.faissindex"); + + size_t nq; + std::vector queries; + + { // populating the database + printf ("[%.3f s] Building a dataset of %ld vectors to index\n", + elapsed() - t0, nb); + + std::vector database (nb * d); + std::vector ids (nb); + for (size_t i = 0; i < nb; i++) { + for (size_t j = 0; j < d; j++) { + database[i * d + j] = drand48(); + } + ids[i] = 8760000000L + i; + } + + printf ("[%.3f s] Adding the vectors to the index\n", elapsed() - t0); + + for (size_t begin = 0; begin < nb; begin += add_bs) { + size_t end = std::min (begin + add_bs, nb); + index.add_with_ids (end - begin, + database.data() + d * begin, + ids.data() + begin); + } + + // remember a few elements from the database as queries + int i0 = 1234; + int i1 = 1244; + + nq = i1 - i0; + queries.resize (nq * d); + for (int i = i0; i < i1; i++) { + for (int j = 0; j < d; j++) { + queries [(i - i0) * d + j] = database [i * d + j]; + } + } + } + + // A few notes on the internal format of the index: + // + // - the positing lists for PQ codes are index.codes, which is a + // std::vector < std::vector > + // if n is the length of posting list #i, codes[i] has length bytes_per_code * n + // + // - the corresponding ids are stored in index.ids + // + // - given a vector float *x, finding which k centroids are + // closest to it (ie to find the nearest neighbors) can be done with + // + // long *centroid_ids = new long[k]; + // float *distances = new float[k]; + // index.quantizer->search (1, x, k, dis, centroids_ids); + // + + faiss::write_index(&index, "/tmp/populated_index.faissindex"); + + { // searching the database + int k = 5; + printf ("[%.3f s] Searching the %d nearest neighbors " + "of %ld vectors in the index\n", + elapsed() - t0, k, nq); + + std::vector nns (k * nq); + std::vector dis (k * nq); + + index.search (nq, queries.data(), k, dis.data(), nns.data()); + + printf ("[%.3f s] Query results (vector ids, then distances):\n", + elapsed() - t0); + + for (int i = 0; i < nq; i++) { + printf ("query %2d: ", i); + for (int j = 0; j < k; j++) { + printf ("%7ld ", nns[j + i * k]); + } + printf ("\n dis: "); + for (int j = 0; j < k; j++) { + printf ("%7g ", dis[j + i * k]); + } + printf ("\n"); + } + } + return 0; +} diff --git a/core/src/index/thirdparty/faiss/demos/demo_ivfpq_indexing.cpp b/core/src/index/thirdparty/faiss/demos/demo_ivfpq_indexing.cpp new file mode 100644 index 0000000000..743395ec2f --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_ivfpq_indexing.cpp @@ -0,0 +1,146 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +#include +#include +#include + +#include + + +#include +#include +#include + +double elapsed () +{ + struct timeval tv; + gettimeofday (&tv, NULL); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + + +int main () +{ + + double t0 = elapsed(); + + // dimension of the vectors to index + int d = 128; + + // size of the database we plan to index + size_t nb = 200 * 1000; + + // make a set of nt training vectors in the unit cube + // (could be the database) + size_t nt = 100 * 1000; + + // make the index object and train it + faiss::IndexFlatL2 coarse_quantizer (d); + + // a reasonable number of centroids to index nb vectors + int ncentroids = int (4 * sqrt (nb)); + + // the coarse quantizer should not be dealloced before the index + // 4 = nb of bytes per code (d must be a multiple of this) + // 8 = nb of bits per sub-code (almost always 8) + faiss::IndexIVFPQ index (&coarse_quantizer, d, + ncentroids, 4, 8); + + + { // training + printf ("[%.3f s] Generating %ld vectors in %dD for training\n", + elapsed() - t0, nt, d); + + std::vector trainvecs (nt * d); + for (size_t i = 0; i < nt * d; i++) { + trainvecs[i] = drand48(); + } + + printf ("[%.3f s] Training the index\n", + elapsed() - t0); + index.verbose = true; + + index.train (nt, trainvecs.data()); + } + + { // I/O demo + const char *outfilename = "/tmp/index_trained.faissindex"; + printf ("[%.3f s] storing the pre-trained index to %s\n", + elapsed() - t0, outfilename); + + write_index (&index, outfilename); + } + + size_t nq; + std::vector queries; + + { // populating the database + printf ("[%.3f s] Building a dataset of %ld vectors to index\n", + elapsed() - t0, nb); + + std::vector database (nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + + printf ("[%.3f s] Adding the vectors to the index\n", + elapsed() - t0); + + index.add (nb, database.data()); + + printf ("[%.3f s] imbalance factor: %g\n", + elapsed() - t0, index.invlists->imbalance_factor ()); + + // remember a few elements from the database as queries + int i0 = 1234; + int i1 = 1243; + + nq = i1 - i0; + queries.resize (nq * d); + for (int i = i0; i < i1; i++) { + for (int j = 0; j < d; j++) { + queries [(i - i0) * d + j] = database [i * d + j]; + } + } + + } + + { // searching the database + int k = 5; + printf ("[%.3f s] Searching the %d nearest neighbors " + "of %ld vectors in the index\n", + elapsed() - t0, k, nq); + + std::vector nns (k * nq); + std::vector dis (k * nq); + + index.search (nq, queries.data(), k, dis.data(), nns.data()); + + printf ("[%.3f s] Query results (vector ids, then distances):\n", + elapsed() - t0); + + for (int i = 0; i < nq; i++) { + printf ("query %2d: ", i); + for (int j = 0; j < k; j++) { + printf ("%7ld ", nns[j + i * k]); + } + printf ("\n dis: "); + for (int j = 0; j < k; j++) { + printf ("%7g ", dis[j + i * k]); + } + printf ("\n"); + } + + printf ("note that the nearest neighbor is not at " + "distance 0 due to quantization errors\n"); + } + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/demos/demo_ondisk_ivf.py b/core/src/index/thirdparty/faiss/demos/demo_ondisk_ivf.py new file mode 100644 index 0000000000..c89acc8402 --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_ondisk_ivf.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python2 + +import sys +import numpy as np +import faiss + + +################################################################# +# Small I/O functions +################################################################# + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +################################################################# +# Main program +################################################################# + +stage = int(sys.argv[1]) + +tmpdir = '/tmp/' + +if stage == 0: + # train the index + xt = fvecs_read("sift1M/sift_learn.fvecs") + index = faiss.index_factory(xt.shape[1], "IVF4096,Flat") + print("training index") + index.train(xt) + print("write " + tmpdir + "trained.index") + faiss.write_index(index, tmpdir + "trained.index") + + +if 1 <= stage <= 4: + # add 1/4 of the database to 4 independent indexes + bno = stage - 1 + xb = fvecs_read("sift1M/sift_base.fvecs") + i0, i1 = int(bno * xb.shape[0] / 4), int((bno + 1) * xb.shape[0] / 4) + index = faiss.read_index(tmpdir + "trained.index") + print("adding vectors %d:%d" % (i0, i1)) + index.add_with_ids(xb[i0:i1], np.arange(i0, i1)) + print("write " + tmpdir + "block_%d.index" % bno) + faiss.write_index(index, tmpdir + "block_%d.index" % bno) + + +if stage == 5: + # merge the images into an on-disk index + # first load the inverted lists + ivfs = [] + for bno in range(4): + # the IO_FLAG_MMAP is to avoid actually loading the data thus + # the total size of the inverted lists can exceed the + # available RAM + print("read " + tmpdir + "block_%d.index" % bno) + index = faiss.read_index(tmpdir + "block_%d.index" % bno, + faiss.IO_FLAG_MMAP) + ivfs.append(index.invlists) + + # avoid that the invlists get deallocated with the index + index.own_invlists = False + + # construct the output index + index = faiss.read_index(tmpdir + "trained.index") + + # prepare the output inverted lists. They will be written + # to merged_index.ivfdata + invlists = faiss.OnDiskInvertedLists( + index.nlist, index.code_size, + tmpdir + "merged_index.ivfdata") + + # merge all the inverted lists + ivf_vector = faiss.InvertedListsPtrVector() + for ivf in ivfs: + ivf_vector.push_back(ivf) + + print("merge %d inverted lists " % ivf_vector.size()) + ntotal = invlists.merge_from(ivf_vector.data(), ivf_vector.size()) + + # now replace the inverted lists in the output index + index.ntotal = ntotal + index.replace_invlists(invlists) + + print("write " + tmpdir + "populated.index") + faiss.write_index(index, tmpdir + "populated.index") + + +if stage == 6: + # perform a search from disk + print("read " + tmpdir + "populated.index") + index = faiss.read_index(tmpdir + "populated.index") + index.nprobe = 16 + + # load query vectors and ground-truth + xq = fvecs_read("sift1M/sift_query.fvecs") + gt = ivecs_read("sift1M/sift_groundtruth.ivecs") + + D, I = index.search(xq, 5) + + recall_at_1 = (I[:, :1] == gt[:, :1]).sum() / float(xq.shape[0]) + print("recall@1: %.3f" % recall_at_1) diff --git a/core/src/index/thirdparty/faiss/demos/demo_sift1M.cpp b/core/src/index/thirdparty/faiss/demos/demo_sift1M.cpp new file mode 100644 index 0000000000..dd91c59080 --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_sift1M.cpp @@ -0,0 +1,252 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include + +/** + * To run this demo, please download the ANN_SIFT1M dataset from + * + * http://corpus-texmex.irisa.fr/ + * + * and unzip it to the sudirectory sift1M. + **/ + +/***************************************************** + * I/O functions for fvecs and ivecs + *****************************************************/ + + +float * fvecs_read (const char *fname, + size_t *d_out, size_t *n_out) +{ + FILE *f = fopen(fname, "r"); + if(!f) { + fprintf(stderr, "could not open %s\n", fname); + perror(""); + abort(); + } + int d; + fread(&d, 1, sizeof(int), f); + assert((d > 0 && d < 1000000) || !"unreasonable dimension"); + fseek(f, 0, SEEK_SET); + struct stat st; + fstat(fileno(f), &st); + size_t sz = st.st_size; + assert(sz % ((d + 1) * 4) == 0 || !"weird file size"); + size_t n = sz / ((d + 1) * 4); + + *d_out = d; *n_out = n; + float *x = new float[n * (d + 1)]; + size_t nr = fread(x, sizeof(float), n * (d + 1), f); + assert(nr == n * (d + 1) || !"could not read whole file"); + + // shift array to remove row headers + for(size_t i = 0; i < n; i++) + memmove(x + i * d, x + 1 + i * (d + 1), d * sizeof(*x)); + + fclose(f); + return x; +} + +// not very clean, but works as long as sizeof(int) == sizeof(float) +int *ivecs_read(const char *fname, size_t *d_out, size_t *n_out) +{ + return (int*)fvecs_read(fname, d_out, n_out); +} + +double elapsed () +{ + struct timeval tv; + gettimeofday (&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + + + +int main() +{ + double t0 = elapsed(); + + // this is typically the fastest one. + const char *index_key = "IVF4096,Flat"; + + // these ones have better memory usage + // const char *index_key = "Flat"; + // const char *index_key = "PQ32"; + // const char *index_key = "PCA80,Flat"; + // const char *index_key = "IVF4096,PQ8+16"; + // const char *index_key = "IVF4096,PQ32"; + // const char *index_key = "IMI2x8,PQ32"; + // const char *index_key = "IMI2x8,PQ8+16"; + // const char *index_key = "OPQ16_64,IMI2x8,PQ8+16"; + + faiss::Index * index; + + size_t d; + + { + printf ("[%.3f s] Loading train set\n", elapsed() - t0); + + size_t nt; + float *xt = fvecs_read("sift1M/sift_learn.fvecs", &d, &nt); + + printf ("[%.3f s] Preparing index \"%s\" d=%ld\n", + elapsed() - t0, index_key, d); + index = faiss::index_factory(d, index_key); + + printf ("[%.3f s] Training on %ld vectors\n", elapsed() - t0, nt); + + index->train(nt, xt); + delete [] xt; + } + + + { + printf ("[%.3f s] Loading database\n", elapsed() - t0); + + size_t nb, d2; + float *xb = fvecs_read("sift1M/sift_base.fvecs", &d2, &nb); + assert(d == d2 || !"dataset does not have same dimension as train set"); + + printf ("[%.3f s] Indexing database, size %ld*%ld\n", + elapsed() - t0, nb, d); + + index->add(nb, xb); + + delete [] xb; + } + + size_t nq; + float *xq; + + { + printf ("[%.3f s] Loading queries\n", elapsed() - t0); + + size_t d2; + xq = fvecs_read("sift1M/sift_query.fvecs", &d2, &nq); + assert(d == d2 || !"query does not have same dimension as train set"); + + } + + size_t k; // nb of results per query in the GT + faiss::Index::idx_t *gt; // nq * k matrix of ground-truth nearest-neighbors + + { + printf ("[%.3f s] Loading ground truth for %ld queries\n", + elapsed() - t0, nq); + + // load ground-truth and convert int to long + size_t nq2; + int *gt_int = ivecs_read("sift1M/sift_groundtruth.ivecs", &k, &nq2); + assert(nq2 == nq || !"incorrect nb of ground truth entries"); + + gt = new faiss::Index::idx_t[k * nq]; + for(int i = 0; i < k * nq; i++) { + gt[i] = gt_int[i]; + } + delete [] gt_int; + } + + // Result of the auto-tuning + std::string selected_params; + + { // run auto-tuning + + printf ("[%.3f s] Preparing auto-tune criterion 1-recall at 1 " + "criterion, with k=%ld nq=%ld\n", elapsed() - t0, k, nq); + + faiss::OneRecallAtRCriterion crit(nq, 1); + crit.set_groundtruth (k, nullptr, gt); + crit.nnn = k; // by default, the criterion will request only 1 NN + + printf ("[%.3f s] Preparing auto-tune parameters\n", elapsed() - t0); + + faiss::ParameterSpace params; + params.initialize(index); + + printf ("[%.3f s] Auto-tuning over %ld parameters (%ld combinations)\n", + elapsed() - t0, params.parameter_ranges.size(), + params.n_combinations()); + + faiss::OperatingPoints ops; + params.explore (index, nq, xq, crit, &ops); + + printf ("[%.3f s] Found the following operating points: \n", + elapsed() - t0); + + ops.display (); + + // keep the first parameter that obtains > 0.5 1-recall@1 + for (int i = 0; i < ops.optimal_pts.size(); i++) { + if (ops.optimal_pts[i].perf > 0.5) { + selected_params = ops.optimal_pts[i].key; + break; + } + } + assert (selected_params.size() >= 0 || + !"could not find good enough op point"); + } + + + { // Use the found configuration to perform a search + + faiss::ParameterSpace params; + + printf ("[%.3f s] Setting parameter configuration \"%s\" on index\n", + elapsed() - t0, selected_params.c_str()); + + params.set_index_parameters (index, selected_params.c_str()); + + printf ("[%.3f s] Perform a search on %ld queries\n", + elapsed() - t0, nq); + + // output buffers + faiss::Index::idx_t *I = new faiss::Index::idx_t[nq * k]; + float *D = new float[nq * k]; + + index->search(nq, xq, k, D, I); + + printf ("[%.3f s] Compute recalls\n", elapsed() - t0); + + // evaluate result by hand. + int n_1 = 0, n_10 = 0, n_100 = 0; + for(int i = 0; i < nq; i++) { + int gt_nn = gt[i * k]; + for(int j = 0; j < k; j++) { + if (I[i * k + j] == gt_nn) { + if(j < 1) n_1++; + if(j < 10) n_10++; + if(j < 100) n_100++; + } + } + } + printf("R@1 = %.4f\n", n_1 / float(nq)); + printf("R@10 = %.4f\n", n_10 / float(nq)); + printf("R@100 = %.4f\n", n_100 / float(nq)); + + } + + delete [] xq; + delete [] gt; + delete index; + return 0; +} diff --git a/core/src/index/thirdparty/faiss/demos/demo_weighted_kmeans.cpp b/core/src/index/thirdparty/faiss/demos/demo_weighted_kmeans.cpp new file mode 100644 index 0000000000..eee188e4b3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/demos/demo_weighted_kmeans.cpp @@ -0,0 +1,185 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include + + +namespace { + + +enum WeightedKMeansType { + WKMT_FlatL2, + WKMT_FlatIP, + WKMT_FlatIP_spherical, + WKMT_HNSW, +}; + + +float weighted_kmeans_clustering (size_t d, size_t n, size_t k, + const float *input, + const float *weights, + float *centroids, + WeightedKMeansType index_num) +{ + using namespace faiss; + Clustering clus (d, k); + clus.verbose = true; + + std::unique_ptr index; + + switch (index_num) { + case WKMT_FlatL2: + index.reset(new IndexFlatL2 (d)); + break; + case WKMT_FlatIP: + index.reset(new IndexFlatIP (d)); + break; + case WKMT_FlatIP_spherical: + index.reset(new IndexFlatIP (d)); + clus.spherical = true; + break; + case WKMT_HNSW: + IndexHNSWFlat *ihnsw = new IndexHNSWFlat (d, 32); + ihnsw->hnsw.efSearch = 128; + index.reset(ihnsw); + break; + } + + clus.train(n, input, *index.get(), weights); + // on output the index contains the centroids. + memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k); + return clus.iteration_stats.back().obj; +} + + +int d = 32; +float sigma = 0.1; + +#define BIGTEST + +#ifdef BIGTEST +// the production setup = setting of https://fb.quip.com/CWgnAAYbwtgs +int nc = 200000; +int n_big = 4; +int n_small = 2; +#else +int nc = 5; +int n_big = 100; +int n_small = 10; +#endif + +int n; // number of training points + +void generate_trainset (std::vector & ccent, + std::vector & x, + std::vector & weights) +{ + // same sampling as test_build_blocks.py test_weighted + + ccent.resize (d * 2 * nc); + faiss::float_randn (ccent.data(), d * 2 * nc, 123); + faiss::fvec_renorm_L2 (d, 2 * nc, ccent.data()); + n = nc * n_big + nc * n_small; + x.resize(d * n); + weights.resize(n); + faiss::float_randn (x.data(), x.size(), 1234); + + float *xi = x.data(); + float *w = weights.data(); + for (int ci = 0; ci < nc * 2; ci++) { // loop over centroids + int np = ci < nc ? n_big : n_small; // nb of points around this centroid + for (int i = 0; i < np; i++) { + for (int j = 0; j < d; j++) { + xi[j] = xi[j] * sigma + ccent[ci * d + j]; + } + *w++ = ci < nc ? 0.1 : 10; + xi += d; + } + } +} + +} + + +int main(int argc, char **argv) { + std::vector ccent; + std::vector x; + std::vector weights; + + printf("generate training set\n"); + generate_trainset(ccent, x, weights); + + std::vector centroids; + centroids.resize(nc * d); + + int the_index_num = -1; + int the_with_weights = -1; + + if (argc == 3) { + the_index_num = atoi(argv[1]); + the_with_weights = atoi(argv[2]); + } + + + for (int index_num = WKMT_FlatL2; + index_num <= WKMT_HNSW; + index_num++) { + + if (the_index_num >= 0 && index_num != the_index_num) { + continue; + } + + for (int with_weights = 0; with_weights <= 1; with_weights++) { + if (the_with_weights >= 0 && with_weights != the_with_weights) { + continue; + } + + printf("=================== index_num=%d Run %s weights\n", + index_num, with_weights ? "with" : "without"); + + weighted_kmeans_clustering ( + d, n, nc, x.data(), + with_weights ? weights.data() : nullptr, + centroids.data(), (WeightedKMeansType)index_num + ); + + { // compute distance of points to centroids + faiss::IndexFlatL2 cent_index(d); + cent_index.add(nc, centroids.data()); + std::vector dis (n); + std::vector idx (n); + + cent_index.search (nc * 2, ccent.data(), 1, + dis.data(), idx.data()); + + float dis1 = 0, dis2 = 0; + for (int i = 0; i < nc ; i++) { + dis1 += dis[i]; + } + printf("average distance of points from big clusters: %g\n", + dis1 / nc); + + for (int i = 0; i < nc ; i++) { + dis2 += dis[i + nc]; + } + + printf("average distance of points from small clusters: %g\n", + dis2 / nc); + + } + + } + } + return 0; +} diff --git a/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Linux b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Linux new file mode 100644 index 0000000000..12da227039 --- /dev/null +++ b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Linux @@ -0,0 +1,140 @@ +# -*- makefile -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# tested on CentOS 7, Ubuntu 16 and Ubuntu 14, see below to adjust flags to distribution. + + +CXX = g++ -std=c++11 +CXXFLAGS = -fPIC -m64 -Wall -g -O3 -fopenmp -Wno-sign-compare +CPUFLAGS = -mavx -msse4 -mpopcnt +LDFLAGS = -fPIC -fopenmp + +# common linux flags +SHAREDEXT = so +SHAREDFLAGS = -shared +MKDIR_P = mkdir -p + +prefix ?= /usr/local +exec_prefix ?= ${prefix} +libdir = ${exec_prefix}/lib +includedir = ${prefix}/include + +########################################################################## +# Uncomment one of the 4 BLAS/Lapack implementation options +# below. They are sorted # from fastest to slowest (in our +# experiments). +########################################################################## + +# +# 1. Intel MKL +# +# This is the fastest BLAS implementation we tested. Unfortunately it +# is not open-source and determining the correct linking flags is a +# nightmare. See +# +# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor +# +# The latest tested version is MKL 2017.0.098 (2017 Initial Release) and can +# be downloaded here: +# +# https://registrationcenter.intel.com/en/forms/?productid=2558&licensetype=2 +# +# The following settings are working if MKL is installed on its default folder: +# +# MKLROOT = /opt/intel/compilers_and_libraries/linux/mkl/ +# +# LDFLAGS += -Wl,--no-as-needed -L$(MKLROOT)/lib/intel64 +# LIBS += -lmkl_intel_ilp64 -lmkl_core -lmkl_gnu_thread -ldl -lpthread +# +# CPPFLAGS += -DFINTEGER=long +# +# You may have to set the LD_LIBRARY_PATH=$MKLROOT/lib/intel64 at runtime. +# +# If at runtime you get the error: +# Intel MKL FATAL ERROR: Cannot load libmkl_avx2.so or libmkl_def.so +# you may set +# LD_PRELOAD=$MKLROOT/lib/intel64/libmkl_core.so:$MKLROOT/lib/intel64/libmkl_sequential.so +# at runtime as well. + +# +# 2. Openblas +# +# The library contains both BLAS and Lapack. About 30% slower than MKL. Please see +# https://github.com/facebookresearch/faiss/wiki/Troubleshooting#slow-brute-force-search-with-openblas +# to fix performance problemes with OpenBLAS + +# for Ubuntu 16: +# sudo apt-get install libopenblas-dev python-numpy python-dev + +# for Ubuntu 14: +# sudo apt-get install libopenblas-dev liblapack3 python-numpy python-dev + +CPPFLAGS += -DFINTEGER=int +LIBS += -lopenblas -llapack + +# 3. Atlas +# +# Automatically tuned linear algebra package. As the name indicates, +# it is tuned automatically for a give architecture, and in Linux +# distributions, it the architecture is typically indicated by the +# directory name, eg. atlas-sse3 = optimized for SSE3 architecture. +# +# BLASCFLAGS=-DFINTEGER=int +# BLASLDFLAGS=/usr/lib64/atlas-sse3/libptf77blas.so.3 /usr/lib64/atlas-sse3/liblapack.so +# +# 4. reference implementation +# +# This is just a compiled version of the reference BLAS +# implementation, that is not optimized at all. +# +# CPPFLAGS += -DFINTEGER=int +# LIBS += /usr/lib64/libblas.so.3 /usr/lib64/liblapack.so.3.2 +# + + +########################################################################## +# SWIG and Python flags +########################################################################## + +# SWIG executable. This should be at least version 3.x +SWIG = swig + +# The Python include directories for a given python executable can +# typically be found with +# +# python -c "import distutils.sysconfig; print distutils.sysconfig.get_python_inc()" +# python -c "import numpy ; print numpy.get_include()" +# +# or, for Python 3, with +# +# python3 -c "import distutils.sysconfig; print(distutils.sysconfig.get_python_inc())" +# python3 -c "import numpy ; print(numpy.get_include())" +# + +PYTHONCFLAGS = -I/usr/include/python2.7/ -I/usr/lib64/python2.7/site-packages/numpy/core/include/ +PYTHONLIB = -lpython +PYTHON = /usr/bin/python + +########################################################################### +# Cuda GPU flags +########################################################################### + + + +# root of the cuda 8 installation +CUDAROOT = /usr/local/cuda-8.0 +NVCC = $(CUDAROOT)/bin/nvcc +NVCCLDFLAGS = -L$(CUDAROOT)/lib64 +NVCCLIBS = -lcudart -lcublas -lcuda +CUDACFLAGS = -I$(CUDAROOT)/include +NVCCFLAGS = -I $(CUDAROOT)/targets/x86_64-linux/include/ \ +-Xcompiler -fPIC \ +-Xcudafe --diag_suppress=unrecognized_attribute \ +-gencode arch=compute_35,code="compute_35" \ +-gencode arch=compute_52,code="compute_52" \ +-gencode arch=compute_60,code="compute_60" \ +-lineinfo \ +-ccbin $(CXX) -DFAISS_USE_FLOAT16 diff --git a/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.brew b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.brew new file mode 100644 index 0000000000..8fa6fe7616 --- /dev/null +++ b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.brew @@ -0,0 +1,99 @@ +# -*- makefile -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# Tested on macOS Sierra (10.12.2) with llvm installed using Homebrew (https://brew.sh) +# brew install llvm +CXX = /usr/local/opt/llvm/bin/clang++ -std=c++11 +CXXFLAGS = -fPIC -m64 -Wall -g -O3 -fopenmp -Wno-sign-compare -I/usr/local/opt/llvm/include +CPUFLAGS = -msse4 -mpopcnt +LLVM_VERSION_PATH=$(shell ls -rt /usr/local/Cellar/llvm/ | tail -n1) +LDFLAGS = -fPIC -fopenmp -L/usr/local/opt/llvm/lib -L/usr/local/Cellar/llvm/${LLVM_VERSION_PATH}/lib + +# common mac flags +SHAREDEXT = dylib +SHAREDFLAGS = -dynamiclib +MKDIR_P = mkdir -p + +prefix ?= /usr/local +exec_prefix ?= ${prefix} +libdir = ${exec_prefix}/lib +includedir = ${prefix}/include + +########################################################################## +# Uncomment one of the 4 BLAS/Lapack implementation options +# below. They are sorted # from fastest to slowest (in our +# experiments). +########################################################################## + +# +# 1. Intel MKL +# +# This is the fastest BLAS implementation we tested. Unfortunately it +# is not open-source and determining the correct linking flags is a +# nightmare. See +# +# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor +# +# The latest tested version is MKL 2017.0.098 (2017 Initial Release) and can +# be downloaded here: +# +# https://registrationcenter.intel.com/en/forms/?productid=2558&licensetype=2 +# +# The following settings are working if MKL is installed on its default folder: +# +# MKLROOT = /opt/intel/compilers_and_libraries/linux/mkl/ +# +# LDFLAGS += -Wl,--no-as-needed -L$(MKLROOT)/lib/intel64 +# LIBS += -lmkl_intel_ilp64 -lmkl_core -lmkl_gnu_thread -ldl -lpthread +# +# CPPFLAGS += -DFINTEGER=long +# +# You may have to set the LD_LIBRARY_PATH=$MKLROOT/lib/intel64 at runtime. + +# +# 2. Openblas +# +# The library contains both BLAS and Lapack. Install with brew install OpenBLAS +# +# CPPFLAGS += -DFINTEGER=int +# LIBS += /usr/local/opt/openblas/lib/libblas.dylib +# + +# +# 3. Apple's framework accelerate +# +# This has the advantage that it does not require to install anything, +# as it is provided by default on the mac. It is not very fast, though. +# + +CPPFLAGS += -DFINTEGER=int +LIBS += -framework Accelerate + + + +########################################################################## +# SWIG and Python flags +########################################################################## + +# SWIG executable. This should be at least version 3.x +# brew install swig + +SWIG = /usr/local/bin/swig + +# The Python include directories for the current python executable + +PYTHON_INC=$(shell python -c "import distutils.sysconfig; print(distutils.sysconfig.get_python_inc())") +NUMPY_INC=$(shell python -c "import numpy ; print(numpy.get_include())") +PYTHONCFLAGS=-I${PYTHON_INC} -I${NUMPY_INC} +PYTHONLIB=-lpython + +########################################################################## +# Faiss GPU +########################################################################## + +# As we don't have access to a Mac with nvidia GPUs installed, we +# could not validate the GPU compile of Faiss. diff --git a/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.port b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.port new file mode 100644 index 0000000000..6b2c292220 --- /dev/null +++ b/core/src/index/thirdparty/faiss/example_makefiles/makefile.inc.Mac.port @@ -0,0 +1,106 @@ +# -*- makefile -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# tested on Mac OS X 10.12.2 Sierra with additional software installed via macports + + +# The system default clang does not support openmp +# You can install an openmp compatible g++ with macports: +# port install g++-mp-6 +CXX = /opt/local/bin/g++-mp-6 -std=c++11 +CXXFLAGS = -fPIC -m64 -Wall -g -O3 -fopenmp -Wno-sign-compare +CPUFLAGS = -msse4 -mpopcnt +LDFLAGS = -g -fPIC -fopenmp + +# common linux flags +SHAREDEXT = dylib +SHAREDFLAGS = -dynamiclib +MKDIR_P = mkdir -p + +prefix ?= /usr/local +exec_prefix ?= ${prefix} +libdir = ${exec_prefix}/lib +includedir = ${prefix}/include + +########################################################################## +# Uncomment one of the 4 BLAS/Lapack implementation options +# below. They are sorted # from fastest to slowest (in our +# experiments). +########################################################################## + +# +# 1. Intel MKL +# +# This is the fastest BLAS implementation we tested. Unfortunately it +# is not open-source and determining the correct linking flags is a +# nightmare. See +# +# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor +# +# The latest tested version is MKL 2017.0.098 (2017 Initial Release) and can +# be downloaded here: +# +# https://registrationcenter.intel.com/en/forms/?productid=2558&licensetype=2 +# +# The following settings are working if MKL is installed on its default folder: +# +# MKLROOT = /opt/intel/compilers_and_libraries/linux/mkl/ +# +# LDFLAGS += -Wl,--no-as-needed -L$(MKLROOT)/lib/intel64 +# LIBS += -lmkl_intel_ilp64 -lmkl_core -lmkl_gnu_thread -ldl -lpthread +# +# CPPFLAGS += -DFINTEGER=long +# +# You may have to set the LD_LIBRARY_PATH=$MKLROOT/lib/intel64 at runtime. + +# +# 2. Openblas +# +# The library contains both BLAS and Lapack. Install with port install OpenBLAS +# +# CPPFLAGS += -DFINTEGER=int +# LIBS += /opt/local/lib/libopenblas.dylib +# + +# +# 3. Apple's framework accelerate +# +# This has the advantage that it does not require to install anything, +# as it is provided by default on the mac. It is not very fast, though. +# + +CPPFLAGS += -DFINTEGER=int +LIBS += -framework Accelerate + + + +########################################################################## +# SWIG and Python flags +########################################################################## + +# SWIG executable. This should be at least version 3.x +# port install swig swig-python + +SWIG = /opt/local/bin/swig + +# The Python include directories for the current python executable can +# typically be found with +# +# python -c "import distutils.sysconfig; print distutils.sysconfig.get_python_inc()" +# python -c "import numpy ; print numpy.get_include()" +# +# the paths below are for the system python (not the macports one) + +PYTHONCFLAGS=-I/System/Library/Frameworks/Python.framework/Versions/2.7/include/python2.7 \ +-I/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/numpy/core/include +PYTHONLIB=-lpython + +########################################################################## +# Faiss GPU +########################################################################## + +# As we don't have access to a Mac with nvidia GPUs installed, we +# could not validate the GPU compile of Faiss. diff --git a/core/src/index/thirdparty/faiss/faiss b/core/src/index/thirdparty/faiss/faiss new file mode 120000 index 0000000000..945c9b46d6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/faiss @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.cpp b/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.cpp new file mode 100644 index 0000000000..c734fdabb5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.cpp @@ -0,0 +1,95 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + + +using namespace ::faiss; + +/********************************************************** + * Parameters to auto-tune on GpuIndex'es + **********************************************************/ + +#define DC(classname) auto ix = dynamic_cast(index) + + +void GpuParameterSpace::initialize (const Index * index) +{ + if (DC (IndexPreTransform)) { + index = ix->index; + } + if (DC (IndexReplicas)) { + if (ix->count() == 0) return; + index = ix->at(0); + } + if (DC (IndexShards)) { + if (ix->count() == 0) return; + index = ix->at(0); + } + if (DC (GpuIndexIVF)) { + ParameterRange & pr = add_range("nprobe"); + for (int i = 0; i < 12; i++) { + size_t nprobe = 1 << i; + if (nprobe >= ix->getNumLists() || + nprobe > getMaxKSelection()) break; + pr.values.push_back (nprobe); + } + } + // not sure we should call the parent initializer +} + + + +#undef DC +// non-const version +#define DC(classname) auto *ix = dynamic_cast(index) + + + +void GpuParameterSpace::set_index_parameter ( + Index * index, const std::string & name, double val) const +{ + if (DC (IndexReplicas)) { + for (int i = 0; i < ix->count(); i++) + set_index_parameter (ix->at(i), name, val); + return; + } + if (name == "nprobe") { + if (DC (GpuIndexIVF)) { + ix->setNumProbes (int (val)); + return; + } + } + if (name == "use_precomputed_table") { + if (DC (GpuIndexIVFPQ)) { + ix->setPrecomputedCodes(bool (val)); + return; + } + } + + // maybe normal index parameters apply? + ParameterSpace::set_index_parameter (index, name, val); +} + + + + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.h b/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.h new file mode 100644 index 0000000000..1bcc9205d8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuAutoTune.h @@ -0,0 +1,27 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + + +/// parameter space and setters for GPU indexes +struct GpuParameterSpace: faiss::ParameterSpace { + /// initialize with reasonable parameters for the index + void initialize (const faiss::Index * index) override; + + /// set a combination of parameters on an index + void set_index_parameter ( + faiss::Index * index, const std::string & name, + double val) const override; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp b/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp new file mode 100644 index 0000000000..c97bcd7542 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp @@ -0,0 +1,564 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + + +/********************************************************** + * Cloning to CPU + **********************************************************/ + +void ToCPUCloner::merge_index(Index *dst, Index *src, bool successive_ids) +{ + if (auto ifl = dynamic_cast(dst)) { + auto ifl2 = dynamic_cast(src); + FAISS_ASSERT(ifl2); + FAISS_ASSERT(successive_ids); + ifl->add(ifl2->ntotal, ifl2->xb.data()); + } else if(auto ifl = dynamic_cast(dst)) { + auto ifl2 = dynamic_cast(src); + FAISS_ASSERT(ifl2); + ifl->merge_from(*ifl2, successive_ids ? ifl->ntotal : 0); + } else if(auto ifl = dynamic_cast(dst)) { + auto ifl2 = dynamic_cast(src); + FAISS_ASSERT(ifl2); + ifl->merge_from(*ifl2, successive_ids ? ifl->ntotal : 0); + } else if(auto ifl = dynamic_cast(dst)) { + auto ifl2 = dynamic_cast(src); + FAISS_ASSERT(ifl2); + ifl->merge_from(*ifl2, successive_ids ? ifl->ntotal : 0); + } else { + FAISS_ASSERT(!"merging not implemented for this type of class"); + } +} + + +Index *ToCPUCloner::clone_Index(const Index *index) +{ + if(auto ifl = dynamic_cast(index)) { + IndexFlat *res = new IndexFlat(); + ifl->copyTo(res); + return res; + } else if(auto ifl = dynamic_cast(index)) { + IndexIVFFlat *res = new IndexIVFFlat(); + ifl->copyTo(res); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + IndexIVFScalarQuantizer *res = new IndexIVFScalarQuantizer(); + ifl->copyTo(res); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + IndexIVFSQHybrid *res = new IndexIVFSQHybrid(); + ifl->copyTo(res); + return res; + } else if(auto ipq = dynamic_cast(index)) { + IndexIVFPQ *res = new IndexIVFPQ(); + ipq->copyTo(res); + return res; + + // for IndexShards and IndexReplicas we assume that the + // objective is to make a single component out of them + // (inverse op of ToGpuClonerMultiple) + + } else if(auto ish = dynamic_cast(index)) { + int nshard = ish->count(); + FAISS_ASSERT(nshard > 0); + Index *res = clone_Index(ish->at(0)); + for(int i = 1; i < ish->count(); i++) { + Index *res_i = clone_Index(ish->at(i)); + merge_index(res, res_i, ish->successive_ids); + delete res_i; + } + return res; + } else if(auto ipr = dynamic_cast(index)) { + // just clone one of the replicas + FAISS_ASSERT(ipr->count() > 0); + return clone_Index(ipr->at(0)); + } else { + return Cloner::clone_Index(index); + } +} + +Index *ToCPUCloner::clone_Index_Without_Codes(const Index *index) +{ + if(auto ifl = dynamic_cast(index)) { + IndexIVFFlat *res = new IndexIVFFlat(); + ifl->copyToWithoutCodes(res); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + IndexIVFScalarQuantizer *res = new IndexIVFScalarQuantizer(); + ifl->copyToWithoutCodes(res); + return res; + } else { + return Cloner::clone_Index(index); + } +} + +faiss::Index * index_gpu_to_cpu(const faiss::Index *gpu_index) +{ + ToCPUCloner cl; + return cl.clone_Index(gpu_index); +} + +faiss::Index * index_gpu_to_cpu_without_codes(const faiss::Index *gpu_index) +{ + ToCPUCloner cl; + return cl.clone_Index_Without_Codes(gpu_index); +} + + + +/********************************************************** + * Cloning to 1 GPU + **********************************************************/ + +ToGpuCloner::ToGpuCloner(GpuResources *resources, int device, + const GpuClonerOptions &options): + GpuClonerOptions(options), resources(resources), device(device) +{} + +Index *ToGpuCloner::clone_Index (IndexComposition* index_composition) { + Index* index = index_composition->index; + + if(auto ifl = dynamic_cast(index)) { + gpu::GpuIndexFlat *&quantizer = index_composition->quantizer; + long mode = index_composition->mode; + + GpuIndexIVFSQHybridConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFSQHybrid *res = + new GpuIndexIVFSQHybrid(resources, + ifl->d, + ifl->nlist, + ifl->sq.qtype, + ifl->metric_type, + ifl->by_residual, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFrom(ifl, quantizer, mode); + return res; + } else { + return clone_Index(index); + } +} + +Index *ToGpuCloner::clone_Index(const Index *index) +{ + auto ivf_sqh = dynamic_cast(index); + if(ivf_sqh) { + auto ifl = ivf_sqh; + GpuIndexIVFSQHybridConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFSQHybrid *res = + new GpuIndexIVFSQHybrid(resources, + ifl->d, + ifl->nlist, + ifl->sq.qtype, + ifl->metric_type, + ifl->by_residual, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFrom(ifl); + return res; + } else if(auto ifl = dynamic_cast(index)) { + GpuIndexFlatConfig config; + config.device = device; + config.useFloat16 = useFloat16; + config.storeTransposed = storeTransposed; + config.storeInCpu = storeInCpu; + + return new GpuIndexFlat(resources, ifl, config); + } else if(auto ifl = dynamic_cast(index)) { + GpuIndexIVFFlatConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFFlat *res = + new GpuIndexIVFFlat(resources, + ifl->d, + ifl->nlist, + ifl->metric_type, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFrom(ifl); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + GpuIndexIVFScalarQuantizerConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFScalarQuantizer *res = + new GpuIndexIVFScalarQuantizer(resources, + ifl->d, + ifl->nlist, + ifl->sq.qtype, + ifl->metric_type, + ifl->by_residual, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFrom(ifl); + return res; + } else if(auto ipq = dynamic_cast(index)) { + if(verbose) + printf(" IndexIVFPQ size %ld -> GpuIndexIVFPQ " + "indicesOptions=%d " + "usePrecomputed=%d useFloat16=%d reserveVecs=%ld\n", + ipq->ntotal, indicesOptions, usePrecomputed, + useFloat16, reserveVecs); + GpuIndexIVFPQConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + config.useFloat16LookupTables = useFloat16; + config.usePrecomputedTables = usePrecomputed; + + GpuIndexIVFPQ *res = new GpuIndexIVFPQ(resources, ipq, config); + + if(reserveVecs > 0 && ipq->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + return res; + } else { + return Cloner::clone_Index(index); + + } + +} + + +Index *ToGpuCloner::clone_Index_Without_Codes(const Index *index, const uint8_t *arranged_data) +{ + auto ivf_sqh = dynamic_cast(index); + if(ivf_sqh) { + // should not happen + } else if(auto ifl = dynamic_cast(index)) { + GpuIndexIVFFlatConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFFlat *res = + new GpuIndexIVFFlat(resources, + ifl->d, + ifl->nlist, + ifl->metric_type, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFromWithoutCodes(ifl, arranged_data); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + GpuIndexIVFScalarQuantizerConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFScalarQuantizer *res = + new GpuIndexIVFScalarQuantizer(resources, + ifl->d, + ifl->nlist, + ifl->sq.qtype, + ifl->metric_type, + ifl->by_residual, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFromWithoutCodes(ifl, arranged_data); + return res; + } + + return Cloner::clone_Index(index); +} + + +faiss::Index * index_cpu_to_gpu( + GpuResources* resources, int device, + const faiss::Index *index, + const GpuClonerOptions *options) +{ + GpuClonerOptions defaults; + ToGpuCloner cl(resources, device, options ? *options : defaults); + return cl.clone_Index(index); +} + +faiss::Index * index_cpu_to_gpu_without_codes( + GpuResources* resources, int device, + const faiss::Index *index, + const uint8_t *arranged_data, + const GpuClonerOptions *options) +{ + GpuClonerOptions defaults; + ToGpuCloner cl(resources, device, options ? *options : defaults); + return cl.clone_Index_Without_Codes(index, arranged_data); +} + +faiss::Index * index_cpu_to_gpu( + GpuResources* resources, int device, + IndexComposition* index_composition, + const GpuClonerOptions *options) { + GpuClonerOptions defaults; + ToGpuCloner cl(resources, device, options ? *options : defaults); + return cl.clone_Index(index_composition); +} + +/********************************************************** + * Cloning to multiple GPUs + **********************************************************/ + +ToGpuClonerMultiple::ToGpuClonerMultiple( + std::vector & resources, + std::vector& devices, + const GpuMultipleClonerOptions &options): + GpuMultipleClonerOptions(options) +{ + FAISS_ASSERT(resources.size() == devices.size()); + for(int i = 0; i < resources.size(); i++) { + sub_cloners.push_back(ToGpuCloner(resources[i], devices[i], options)); + } +} + + +ToGpuClonerMultiple::ToGpuClonerMultiple( + const std::vector & sub_cloners, + const GpuMultipleClonerOptions &options): + GpuMultipleClonerOptions(options), + sub_cloners(sub_cloners) +{} + + +void ToGpuClonerMultiple::copy_ivf_shard ( + const IndexIVF *index_ivf, IndexIVF *idx2, + long n, long i) +{ + if (shard_type == 2) { + long i0 = i * index_ivf->ntotal / n; + long i1 = (i + 1) * index_ivf->ntotal / n; + + if(verbose) + printf("IndexShards shard %ld indices %ld:%ld\n", + i, i0, i1); + index_ivf->copy_subset_to(*idx2, 2, i0, i1); + FAISS_ASSERT(idx2->ntotal == i1 - i0); + } else if (shard_type == 1) { + if(verbose) + printf("IndexShards shard %ld select modulo %ld = %ld\n", + i, n, i); + index_ivf->copy_subset_to(*idx2, 1, n, i); + } else { + FAISS_THROW_FMT ("shard_type %d not implemented", shard_type); + } + +} + +Index * ToGpuClonerMultiple::clone_Index_to_shards (const Index *index) +{ + long n = sub_cloners.size(); + + auto index_ivfpq = + dynamic_cast(index); + auto index_ivfflat = + dynamic_cast(index); + auto index_ivfsq = + dynamic_cast(index); + auto index_flat = + dynamic_cast(index); + FAISS_THROW_IF_NOT_MSG ( + index_ivfpq || index_ivfflat || index_flat || index_ivfsq, + "IndexShards implemented only for " + "IndexIVFFlat, IndexIVFScalarQuantizer, " + "IndexFlat and IndexIVFPQ"); + + std::vector shards(n); + + for(long i = 0; i < n; i++) { + // make a shallow copy + if(reserveVecs) + sub_cloners[i].reserveVecs = + (reserveVecs + n - 1) / n; + + if (index_ivfpq) { + faiss::IndexIVFPQ idx2( + index_ivfpq->quantizer, index_ivfpq->d, + index_ivfpq->nlist, index_ivfpq->code_size, + index_ivfpq->pq.nbits); + idx2.metric_type = index_ivfpq->metric_type; + idx2.pq = index_ivfpq->pq; + idx2.nprobe = index_ivfpq->nprobe; + idx2.use_precomputed_table = 0; + idx2.is_trained = index->is_trained; + copy_ivf_shard (index_ivfpq, &idx2, n, i); + shards[i] = sub_cloners[i].clone_Index(&idx2); + } else if (index_ivfflat) { + faiss::IndexIVFFlat idx2( + index_ivfflat->quantizer, index->d, + index_ivfflat->nlist, index_ivfflat->metric_type); + idx2.nprobe = index_ivfflat->nprobe; + idx2.is_trained = index->is_trained; + copy_ivf_shard (index_ivfflat, &idx2, n, i); + shards[i] = sub_cloners[i].clone_Index(&idx2); + } else if (index_ivfsq) { + faiss::IndexIVFScalarQuantizer idx2( + index_ivfsq->quantizer, index->d, index_ivfsq->nlist, + index_ivfsq->sq.qtype, + index_ivfsq->metric_type, + index_ivfsq->by_residual); + + idx2.nprobe = index_ivfsq->nprobe; + idx2.is_trained = index->is_trained; + idx2.sq = index_ivfsq->sq; + copy_ivf_shard (index_ivfsq, &idx2, n, i); + shards[i] = sub_cloners[i].clone_Index(&idx2); + } else if (index_flat) { + faiss::IndexFlat idx2 ( + index->d, index->metric_type); + shards[i] = sub_cloners[i].clone_Index(&idx2); + if (index->ntotal > 0) { + long i0 = index->ntotal * i / n; + long i1 = index->ntotal * (i + 1) / n; + shards[i]->add (i1 - i0, + index_flat->xb.data() + i0 * index->d); + } + } + } + + bool successive_ids = index_flat != nullptr; + faiss::IndexShards *res = + new faiss::IndexShards(index->d, true, + successive_ids); + + for (int i = 0; i < n; i++) { + res->add_shard(shards[i]); + } + res->own_fields = true; + FAISS_ASSERT(index->ntotal == res->ntotal); + return res; +} + +Index *ToGpuClonerMultiple::clone_Index(const Index *index) +{ + long n = sub_cloners.size(); + if (n == 1) + return sub_cloners[0].clone_Index(index); + + if(dynamic_cast(index) || + dynamic_cast(index) || + dynamic_cast(index) || + dynamic_cast(index)) { + if(!shard) { + IndexReplicas * res = new IndexReplicas(); + for(auto & sub_cloner: sub_cloners) { + res->addIndex(sub_cloner.clone_Index(index)); + } + res->own_fields = true; + return res; + } else { + return clone_Index_to_shards (index); + } + } else if(auto miq = dynamic_cast(index)) { + if (verbose) { + printf("cloning MultiIndexQuantizer: " + "will be valid only for search k=1\n"); + } + const ProductQuantizer & pq = miq->pq; + IndexSplitVectors *splitv = new IndexSplitVectors(pq.d, true); + splitv->own_fields = true; + + for (int m = 0; m < pq.M; m++) { + // which GPU(s) will be assigned to this sub-quantizer + + long i0 = m * n / pq.M; + long i1 = pq.M <= n ? (m + 1) * n / pq.M : i0 + 1; + std::vector sub_cloners_2; + sub_cloners_2.insert( + sub_cloners_2.begin(), sub_cloners.begin() + i0, + sub_cloners.begin() + i1); + ToGpuClonerMultiple cm(sub_cloners_2, *this); + IndexFlatL2 idxc (pq.dsub); + idxc.add (pq.ksub, pq.centroids.data() + m * pq.d * pq.ksub); + Index *idx2 = cm.clone_Index(&idxc); + splitv->add_sub_index(idx2); + } + return splitv; + } else { + return Cloner::clone_Index(index); + } +} + + + +faiss::Index * index_cpu_to_gpu_multiple( + std::vector & resources, + std::vector &devices, + const faiss::Index *index, + const GpuMultipleClonerOptions *options) +{ + GpuMultipleClonerOptions defaults; + ToGpuClonerMultiple cl(resources, devices, options ? *options : defaults); + return cl.clone_Index(index); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuCloner.h b/core/src/index/thirdparty/faiss/gpu/GpuCloner.h new file mode 100644 index 0000000000..c01029279e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuCloner.h @@ -0,0 +1,101 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + + +/// Cloner specialized for GPU -> CPU +struct ToCPUCloner: faiss::Cloner { + void merge_index(Index *dst, Index *src, bool successive_ids); + + Index *clone_Index(const Index *index) override; + + Index *clone_Index_Without_Codes(const Index *index); +}; + + +/// Cloner specialized for CPU -> 1 GPU +struct ToGpuCloner: faiss::Cloner, GpuClonerOptions { + GpuResources *resources; + int device; + + ToGpuCloner(GpuResources *resources, int device, + const GpuClonerOptions &options); + + Index *clone_Index(const Index *index) override; + + Index *clone_Index (IndexComposition* index_composition) override; + + Index *clone_Index_Without_Codes(const Index *index, const uint8_t *arranged_data); +}; + +/// Cloner specialized for CPU -> multiple GPUs +struct ToGpuClonerMultiple: faiss::Cloner, GpuMultipleClonerOptions { + std::vector sub_cloners; + + ToGpuClonerMultiple(std::vector & resources, + std::vector& devices, + const GpuMultipleClonerOptions &options); + + ToGpuClonerMultiple(const std::vector & sub_cloners, + const GpuMultipleClonerOptions &options); + + void copy_ivf_shard (const IndexIVF *index_ivf, IndexIVF *idx2, + long n, long i); + + Index * clone_Index_to_shards (const Index *index); + + /// main function + Index *clone_Index(const Index *index) override; +}; + + + + +/// converts any GPU index inside gpu_index to a CPU index +faiss::Index * index_gpu_to_cpu(const faiss::Index *gpu_index); + +faiss::Index * index_gpu_to_cpu_without_codes(const faiss::Index *gpu_index); + +/// converts any CPU index that can be converted to GPU +faiss::Index * index_cpu_to_gpu( + GpuResources* resources, int device, + const faiss::Index *index, + const GpuClonerOptions *options = nullptr); + +faiss::Index * index_cpu_to_gpu_without_codes( + GpuResources* resources, int device, + const faiss::Index *index, + const uint8_t *arranged_data, + const GpuClonerOptions *options = nullptr); + +faiss::Index * index_cpu_to_gpu( + GpuResources* resources, int device, + IndexComposition* index_composition, + const GpuClonerOptions *options = nullptr); + +faiss::Index * index_cpu_to_gpu_multiple( + std::vector & resources, + std::vector &devices, + const faiss::Index *index, + const GpuMultipleClonerOptions *options = nullptr); + + + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.cpp b/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.cpp new file mode 100644 index 0000000000..4e0b40bd84 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.cpp @@ -0,0 +1,30 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +GpuClonerOptions::GpuClonerOptions() + : indicesOptions(INDICES_64_BIT), + useFloat16CoarseQuantizer(false), + useFloat16(false), + usePrecomputed(false), + reserveVecs(0), + storeTransposed(false), + storeInCpu(false), + allInGpu(false), + verbose(false) { +} + +GpuMultipleClonerOptions::GpuMultipleClonerOptions() + : shard(false), + shard_type(1) +{ +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.h b/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.h new file mode 100644 index 0000000000..b56a33d8d7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuClonerOptions.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { namespace gpu { + +/// set some options on how to copy to GPU +struct GpuClonerOptions { + GpuClonerOptions(); + + /// how should indices be stored on index types that support indices + /// (anything but GpuIndexFlat*)? + IndicesOptions indicesOptions; + + /// is the coarse quantizer in float16? + bool useFloat16CoarseQuantizer; + + /// for GpuIndexIVFFlat, is storage in float16? + /// for GpuIndexIVFPQ, are intermediate calculations in float16? + bool useFloat16; + + /// use precomputed tables? + bool usePrecomputed; + + /// reserve vectors in the invfiles? + long reserveVecs; + + /// For GpuIndexFlat, store data in transposed layout? + bool storeTransposed; + + bool storeInCpu; + + /// For IndexIVFScalarQuantizer + bool allInGpu; + + /// Set verbose options on the index + bool verbose; +}; + +struct GpuMultipleClonerOptions : public GpuClonerOptions { + GpuMultipleClonerOptions (); + + /// Whether to shard the index across GPUs, versus replication + /// across GPUs + bool shard; + + /// IndexIVF::copy_subset_to subset type + int shard_type; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuDistance.cu b/core/src/index/thirdparty/faiss/gpu/GpuDistance.cu new file mode 100644 index 0000000000..f5ce8aa24e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuDistance.cu @@ -0,0 +1,157 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +void bfKnnConvert(GpuResources* resources, const GpuDistanceParams& args) { + auto device = getCurrentDevice(); + auto stream = resources->getDefaultStreamCurrentDevice(); + auto& mem = resources->getMemoryManagerCurrentDevice(); + + auto tVectors = + toDevice(resources, + device, + const_cast(reinterpret_cast(args.vectors)), + stream, + {args.vectorsRowMajor ? args.numVectors : args.dims, + args.vectorsRowMajor ? args.dims : args.numVectors}); + auto tQueries = + toDevice(resources, + device, + const_cast(reinterpret_cast(args.queries)), + stream, + {args.queriesRowMajor ? args.numQueries : args.dims, + args.queriesRowMajor ? args.dims : args.numQueries}); + + DeviceTensor tVectorNorms; + if (args.vectorNorms) { + tVectorNorms = toDevice(resources, + device, + const_cast(args.vectorNorms), + stream, + {args.numVectors}); + } + + auto tOutDistances = + toDevice(resources, + device, + args.outDistances, + stream, + {args.numQueries, args.k}); + + // The brute-force API only supports an interface for integer indices + DeviceTensor + tOutIntIndices(mem, {args.numQueries, args.k}, stream); + + // Empty bitset + auto bitsetDevice = toDevice(resources, device, nullptr, stream, {0}); + + // Since we've guaranteed that all arguments are on device, call the + // implementation + bfKnnOnDevice(resources, + device, + stream, + tVectors, + args.vectorsRowMajor, + args.vectorNorms ? &tVectorNorms : nullptr, + tQueries, + args.queriesRowMajor, + bitsetDevice, + args.k, + args.metric, + args.metricArg, + tOutDistances, + tOutIntIndices, + args.ignoreOutDistances); + + // Convert and copy int indices out + auto tOutIndices = + toDevice(resources, + device, + args.outIndices, + stream, + {args.numQueries, args.k}); + + // Convert int to idx_t + convertTensor(stream, + tOutIntIndices, + tOutIndices); + + // Copy back if necessary + fromDevice(tOutDistances, args.outDistances, stream); + fromDevice(tOutIndices, args.outIndices, stream); +} + +void +bfKnn(GpuResources* resources, const GpuDistanceParams& args) { + // For now, both vectors and queries must be of the same data type + FAISS_THROW_IF_NOT_MSG( + args.vectorType == args.queryType, + "limitation: both vectorType and queryType must currently " + "be the same (F32 or F16"); + + if (args.vectorType == DistanceDataType::F32) { + bfKnnConvert(resources, args); + } else if (args.vectorType == DistanceDataType::F16) { + bfKnnConvert(resources, args); + } else { + FAISS_THROW_MSG("unknown vectorType"); + } +} + +// legacy version +void +bruteForceKnn(GpuResources* resources, + faiss::MetricType metric, + // A region of memory size numVectors x dims, with dims + // innermost + const float* vectors, + bool vectorsRowMajor, + int numVectors, + // A region of memory size numQueries x dims, with dims + // innermost + const float* queries, + bool queriesRowMajor, + int numQueries, + int dims, + int k, + // A region of memory size numQueries x k, with k + // innermost + float* outDistances, + // A region of memory size numQueries x k, with k + // innermost + faiss::Index::idx_t* outIndices) { + std::cerr << "bruteForceKnn is deprecated; call bfKnn instead" << std::endl; + + GpuDistanceParams args; + args.metric = metric; + args.k = k; + args.dims = dims; + args.vectors = vectors; + args.vectorsRowMajor = vectorsRowMajor; + args.numVectors = numVectors; + args.queries = queries; + args.queriesRowMajor = queriesRowMajor; + args.numQueries = numQueries; + args.outDistances = outDistances; + args.outIndices = outIndices; + + bfKnn(resources, args); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuDistance.h b/core/src/index/thirdparty/faiss/gpu/GpuDistance.h new file mode 100644 index 0000000000..05667e70f7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuDistance.h @@ -0,0 +1,145 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +// Scalar type of the vector data +enum class DistanceDataType { + F32 = 1, + F16, +}; + +/// Arguments to brute-force GPU k-nearest neighbor searching +struct GpuDistanceParams { + GpuDistanceParams() + : metric(faiss::MetricType::METRIC_L2), + metricArg(0), + k(0), + dims(0), + vectors(nullptr), + vectorType(DistanceDataType::F32), + vectorsRowMajor(true), + numVectors(0), + vectorNorms(nullptr), + queries(nullptr), + queryType(DistanceDataType::F32), + queriesRowMajor(true), + numQueries(0), + outDistances(nullptr), + ignoreOutDistances(false), + outIndices(nullptr) { + } + + // + // Search parameters + // + + // Search parameter: distance metric + faiss::MetricType metric; + + // Search parameter: distance metric argument (if applicable) + // For metric == METRIC_Lp, this is the p-value + float metricArg; + + // Search parameter: return k nearest neighbors + int k; + + // Vector dimensionality + int dims; + + // + // Vectors being queried + // + + // If vectorsRowMajor is true, this is + // numVectors x dims, with dims innermost; otherwise, + // dims x numVectors, with numVectors innermost + const void* vectors; + DistanceDataType vectorType; + bool vectorsRowMajor; + int numVectors; + + // Precomputed L2 norms for each vector in `vectors`, which can be optionally + // provided in advance to speed computation for METRIC_L2 + const float* vectorNorms; + + // + // The query vectors (i.e., find k-nearest neighbors in `vectors` for each of + // the `queries` + // + + // If queriesRowMajor is true, this is + // numQueries x dims, with dims innermost; otherwise, + // dims x numQueries, with numQueries innermost + const void* queries; + DistanceDataType queryType; + bool queriesRowMajor; + int numQueries; + + // + // Output results + // + + // A region of memory size numQueries x k, with k + // innermost (row major) + float* outDistances; + + // Do we only care abouty the indices reported, rather than the output + // distances? + bool ignoreOutDistances; + + // A region of memory size numQueries x k, with k + // innermost (row major) + faiss::Index::idx_t* outIndices; +}; + +/// A wrapper for gpu/impl/Distance.cuh to expose direct brute-force k-nearest +/// neighbor searches on an externally-provided region of memory (e.g., from a +/// pytorch tensor). +/// The data (vectors, queries, outDistances, outIndices) can be resident on the +/// GPU or the CPU, but all calculations are performed on the GPU. If the result +/// buffers are on the CPU, results will be copied back when done. +/// +/// All GPU computation is performed on the current CUDA device, and ordered +/// with respect to resources->getDefaultStreamCurrentDevice(). +/// +/// For each vector in `queries`, searches all of `vectors` to find its k +/// nearest neighbors with respect to the given metric +void bfKnn(GpuResources* resources, const GpuDistanceParams& args); + +/// Deprecated legacy implementation +void bruteForceKnn(GpuResources* resources, + faiss::MetricType metric, + // If vectorsRowMajor is true, this is + // numVectors x dims, with dims innermost; otherwise, + // dims x numVectors, with numVectors innermost + const float* vectors, + bool vectorsRowMajor, + int numVectors, + // If queriesRowMajor is true, this is + // numQueries x dims, with dims innermost; otherwise, + // dims x numQueries, with numQueries innermost + const float* queries, + bool queriesRowMajor, + int numQueries, + int dims, + int k, + // A region of memory size numQueries x k, with k + // innermost (row major) + float* outDistances, + // A region of memory size numQueries x k, with k + // innermost (row major) + faiss::Index::idx_t* outIndices); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuFaissAssert.h b/core/src/index/thirdparty/faiss/gpu/GpuFaissAssert.h new file mode 100644 index 0000000000..1931b916cc --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuFaissAssert.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#ifndef GPU_FAISS_ASSERT_INCLUDED +#define GPU_FAISS_ASSERT_INCLUDED + +#include +#include + +/// +/// Assertions +/// + +#ifdef __CUDA_ARCH__ +#define GPU_FAISS_ASSERT(X) assert(X) +#define GPU_FAISS_ASSERT_MSG(X, MSG) assert(X) +#define GPU_FAISS_ASSERT_FMT(X, FMT, ...) assert(X) +#else +#define GPU_FAISS_ASSERT(X) FAISS_ASSERT(X) +#define GPU_FAISS_ASSERT_MSG(X, MSG) FAISS_ASSERT_MSG(X, MSG) +#define GPU_FAISS_ASSERT_FMT(X, FMT, ...) FAISS_ASSERT_FMT(X, FMT, __VA_ARGS) +#endif // __CUDA_ARCH__ + +#endif diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndex.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndex.cu new file mode 100644 index 0000000000..173b3206f2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndex.cu @@ -0,0 +1,485 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Default CPU search size for which we use paged copies +constexpr size_t kMinPageSize = (size_t) 256 * 1024 * 1024; + +/// Size above which we page copies from the CPU to GPU (non-paged +/// memory usage) +constexpr size_t kNonPinnedPageSize = (size_t) 256 * 1024 * 1024; + +// Default size for which we page add or search +constexpr size_t kAddPageSize = (size_t) 256 * 1024 * 1024; + +// Or, maximum number of vectors to consider per page of add or search +constexpr size_t kAddVecSize = (size_t) 512 * 1024; + +// Use a smaller search size, as precomputed code usage on IVFPQ +// requires substantial amounts of memory +// FIXME: parameterize based on algorithm need +constexpr size_t kSearchVecSize = (size_t) 32 * 1024; + +GpuIndex::GpuIndex(GpuResources* resources, + int dims, + faiss::MetricType metric, + float metricArg, + GpuIndexConfig config) : + Index(dims, metric), + resources_(resources), + device_(config.device), + memorySpace_(config.memorySpace), + minPagedSize_(kMinPageSize) { + FAISS_THROW_IF_NOT_FMT(device_ < getNumDevices(), + "Invalid GPU device %d", device_); + + FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions"); + +#ifdef FAISS_UNIFIED_MEM + FAISS_THROW_IF_NOT_FMT( + memorySpace_ == MemorySpace::Device || + (memorySpace_ == MemorySpace::Unified && + getFullUnifiedMemSupport(device_)), + "Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)", + config.device); +#else + FAISS_THROW_IF_NOT_MSG(memorySpace_ == MemorySpace::Device, + "Must compile with CUDA 8+ for Unified Memory support"); +#endif + + metric_arg = metricArg; + + FAISS_ASSERT(resources_); + resources_->initializeForDevice(device_); +} + +void +GpuIndex::copyFrom(const faiss::Index* index) { + d = index->d; + metric_type = index->metric_type; + metric_arg = index->metric_arg; + ntotal = index->ntotal; + is_trained = index->is_trained; +} + +void +GpuIndex::copyTo(faiss::Index* index) const { + index->d = d; + index->metric_type = metric_type; + index->metric_arg = metric_arg; + index->ntotal = ntotal; + index->is_trained = is_trained; +} + +void +GpuIndex::setMinPagingSize(size_t size) { + minPagedSize_ = size; +} + +size_t +GpuIndex::getMinPagingSize() const { + return minPagedSize_; +} + +void +GpuIndex::add(Index::idx_t n, const float* x) { + // Pass to add_with_ids + add_with_ids(n, x, nullptr); +} + +void +GpuIndex::add_with_ids(Index::idx_t n, + const float* x, + const Index::idx_t* ids) { + FAISS_THROW_IF_NOT_MSG(this->is_trained, "Index not trained"); + + // For now, only support <= max int results + FAISS_THROW_IF_NOT_FMT(n <= (Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %d indices", + std::numeric_limits::max()); + + if (n == 0) { + // nothing to add + return; + } + + std::vector generatedIds; + + // Generate IDs if we need them + if (!ids && addImplRequiresIDs_()) { + generatedIds = std::vector(n); + + for (Index::idx_t i = 0; i < n; ++i) { + generatedIds[i] = this->ntotal + i; + } + } + + DeviceScope scope(device_); + addPaged_((int) n, x, ids ? ids : generatedIds.data()); +} + +void +GpuIndex::addPaged_(int n, + const float* x, + const Index::idx_t* ids) { + if (n > 0) { + size_t totalSize = (size_t) n * this->d * sizeof(float); + + if (totalSize > kAddPageSize || n > kAddVecSize) { + // How many vectors fit into kAddPageSize? + size_t maxNumVecsForPageSize = + kAddPageSize / ((size_t) this->d * sizeof(float)); + + // Always add at least 1 vector, if we have huge vectors + maxNumVecsForPageSize = std::max(maxNumVecsForPageSize, (size_t) 1); + + size_t tileSize = std::min((size_t) n, maxNumVecsForPageSize); + tileSize = std::min(tileSize, kSearchVecSize); + + for (size_t i = 0; i < (size_t) n; i += tileSize) { + size_t curNum = std::min(tileSize, n - i); + + addPage_(curNum, + x + i * (size_t) this->d, + ids ? ids + i : nullptr); + } + } else { + addPage_(n, x, ids); + } + } +} + +void +GpuIndex::addPage_(int n, + const float* x, + const Index::idx_t* ids) { + // At this point, `x` can be resident on CPU or GPU, and `ids` may be resident + // on CPU, GPU or may be null. + // + // Before continuing, we guarantee that all data will be resident on the GPU. + auto stream = resources_->getDefaultStreamCurrentDevice(); + + auto vecs = toDevice(resources_, + device_, + const_cast(x), + stream, + {n, this->d}); + + if (ids) { + auto indices = toDevice(resources_, + device_, + const_cast(ids), + stream, + {n}); + + addImpl_(n, vecs.data(), ids ? indices.data() : nullptr); + } else { + addImpl_(n, vecs.data(), nullptr); + } +} + +void +GpuIndex::search(Index::idx_t n, + const float* x, + Index::idx_t k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + FAISS_THROW_IF_NOT_MSG(this->is_trained, "Index not trained"); + + // For now, only support <= max int results + FAISS_THROW_IF_NOT_FMT(n <= (Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %d indices", + std::numeric_limits::max()); + + // Maximum k-selection supported is based on the CUDA SDK + FAISS_THROW_IF_NOT_FMT(k <= (Index::idx_t) getMaxKSelection(), + "GPU index only supports k <= %d (requested %d)", + getMaxKSelection(), + (int) k); // select limitation + + if (n == 0 || k == 0) { + // nothing to search + return; + } + + DeviceScope scope(device_); + auto stream = resources_->getDefaultStream(device_); + + // We guarantee that the searchImpl_ will be called with device-resident + // pointers. + + // The input vectors may be too large for the GPU, but we still + // assume that the output distances and labels are not. + // Go ahead and make space for output distances and labels on the + // GPU. + // If we reach a point where all inputs are too big, we can add + // another level of tiling. + auto outDistances = + toDevice(resources_, device_, distances, stream, + {(int) n, (int) k}); + + auto outLabels = + toDevice(resources_, device_, labels, stream, + {(int) n, (int) k}); + + bool usePaged = false; + + if (getDeviceForAddress(x) == -1) { + // It is possible that the user is querying for a vector set size + // `x` that won't fit on the GPU. + // In this case, we will have to handle paging of the data from CPU + // -> GPU. + // Currently, we don't handle the case where the output data won't + // fit on the GPU (e.g., n * k is too large for the GPU memory). + size_t dataSize = (size_t) n * this->d * sizeof(float); + + if (dataSize >= minPagedSize_) { + searchFromCpuPaged_(n, x, k, + outDistances.data(), + outLabels.data(), + bitset); + usePaged = true; + } + } + + if (!usePaged) { + searchNonPaged_(n, x, k, + outDistances.data(), + outLabels.data(), + bitset); + } + + // Copy back if necessary + fromDevice(outDistances, distances, stream); + fromDevice(outLabels, labels, stream); +} + +void +GpuIndex::searchNonPaged_(int n, + const float* x, + int k, + float* outDistancesData, + Index::idx_t* outIndicesData, + ConcurrentBitsetPtr bitset) const { + auto stream = resources_->getDefaultStream(device_); + + // Make sure arguments are on the device we desire; use temporary + // memory allocations to move it if necessary + auto vecs = toDevice(resources_, + device_, + const_cast(x), + stream, + {n, (int) this->d}); + + searchImpl_(n, vecs.data(), k, outDistancesData, outIndicesData, bitset); +} + +void +GpuIndex::searchFromCpuPaged_(int n, + const float* x, + int k, + float* outDistancesData, + Index::idx_t* outIndicesData, + ConcurrentBitsetPtr bitset) const { + Tensor outDistances(outDistancesData, {n, k}); + Tensor outIndices(outIndicesData, {n, k}); + + // Is pinned memory available? + auto pinnedAlloc = resources_->getPinnedMemory(); + int pageSizeInVecs = + (int) ((pinnedAlloc.second / 2) / (sizeof(float) * this->d)); + + if (!pinnedAlloc.first || pageSizeInVecs < 1) { + // Just page without overlapping copy with compute + int batchSize = utils::nextHighestPowerOf2( + (int) ((size_t) kNonPinnedPageSize / + (sizeof(float) * this->d))); + + for (int cur = 0; cur < n; cur += batchSize) { + int num = std::min(batchSize, n - cur); + + auto outDistancesSlice = outDistances.narrowOutermost(cur, num); + auto outIndicesSlice = outIndices.narrowOutermost(cur, num); + + searchNonPaged_(num, + x + (size_t) cur * this->d, + k, + outDistancesSlice.data(), + outIndicesSlice.data(), + bitset); + } + + return; + } + + // + // Pinned memory is available, so we can overlap copy with compute. + // We use two pinned memory buffers, and triple-buffer the + // procedure: + // + // 1 CPU copy -> pinned + // 2 pinned copy -> GPU + // 3 GPU compute + // + // 1 2 3 1 2 3 ... (pinned buf A) + // 1 2 3 1 2 ... (pinned buf B) + // 1 2 3 1 ... (pinned buf A) + // time -> + // + auto defaultStream = resources_->getDefaultStream(device_); + auto copyStream = resources_->getAsyncCopyStream(device_); + + FAISS_ASSERT((size_t) pageSizeInVecs * this->d <= + (size_t) std::numeric_limits::max()); + + float* bufPinnedA = (float*) pinnedAlloc.first; + float* bufPinnedB = bufPinnedA + (size_t) pageSizeInVecs * this->d; + float* bufPinned[2] = {bufPinnedA, bufPinnedB}; + + // Reserve space on the GPU for the destination of the pinned buffer + // copy + DeviceTensor bufGpuA( + resources_->getMemoryManagerCurrentDevice(), + {(int) pageSizeInVecs, (int) this->d}, + defaultStream); + DeviceTensor bufGpuB( + resources_->getMemoryManagerCurrentDevice(), + {(int) pageSizeInVecs, (int) this->d}, + defaultStream); + DeviceTensor* bufGpus[2] = {&bufGpuA, &bufGpuB}; + + // Copy completion events for the pinned buffers + std::unique_ptr eventPinnedCopyDone[2]; + + // Execute completion events for the GPU buffers + std::unique_ptr eventGpuExecuteDone[2]; + + // All offsets are in terms of number of vectors; they remain within + // int bounds (as this function only handles max in vectors) + + // Current start offset for buffer 1 + int cur1 = 0; + int cur1BufIndex = 0; + + // Current start offset for buffer 2 + int cur2 = -1; + int cur2BufIndex = 0; + + // Current start offset for buffer 3 + int cur3 = -1; + int cur3BufIndex = 0; + + while (cur3 < n) { + // Start async pinned -> GPU copy first (buf 2) + if (cur2 != -1 && cur2 < n) { + // Copy pinned to GPU + int numToCopy = std::min(pageSizeInVecs, n - cur2); + + // Make sure any previous execution has completed before continuing + auto& eventPrev = eventGpuExecuteDone[cur2BufIndex]; + if (eventPrev.get()) { + eventPrev->streamWaitOnEvent(copyStream); + } + + CUDA_VERIFY(cudaMemcpyAsync(bufGpus[cur2BufIndex]->data(), + bufPinned[cur2BufIndex], + (size_t) numToCopy * this->d * sizeof(float), + cudaMemcpyHostToDevice, + copyStream)); + + // Mark a completion event in this stream + eventPinnedCopyDone[cur2BufIndex] = + std::move(std::unique_ptr(new CudaEvent(copyStream))); + + // We pick up from here + cur3 = cur2; + cur2 += numToCopy; + cur2BufIndex = (cur2BufIndex == 0) ? 1 : 0; + } + + if (cur3 != -1 && cur3 < n) { + // Process on GPU + int numToProcess = std::min(pageSizeInVecs, n - cur3); + + // Make sure the previous copy has completed before continuing + auto& eventPrev = eventPinnedCopyDone[cur3BufIndex]; + FAISS_ASSERT(eventPrev.get()); + + eventPrev->streamWaitOnEvent(defaultStream); + + // Create tensor wrappers + // DeviceTensor input(bufGpus[cur3BufIndex]->data(), + // {numToProcess, this->d}); + auto outDistancesSlice = outDistances.narrowOutermost(cur3, numToProcess); + auto outIndicesSlice = outIndices.narrowOutermost(cur3, numToProcess); + + searchImpl_(numToProcess, + bufGpus[cur3BufIndex]->data(), + k, + outDistancesSlice.data(), + outIndicesSlice.data(), + bitset); + + // Create completion event + eventGpuExecuteDone[cur3BufIndex] = + std::move(std::unique_ptr(new CudaEvent(defaultStream))); + + // We pick up from here + cur3BufIndex = (cur3BufIndex == 0) ? 1 : 0; + cur3 += numToProcess; + } + + if (cur1 < n) { + // Copy CPU mem to CPU pinned + int numToCopy = std::min(pageSizeInVecs, n - cur1); + + // Make sure any previous copy has completed before continuing + auto& eventPrev = eventPinnedCopyDone[cur1BufIndex]; + if (eventPrev.get()) { + eventPrev->cpuWaitOnEvent(); + } + + memcpy(bufPinned[cur1BufIndex], + x + (size_t) cur1 * this->d, + (size_t) numToCopy * this->d * sizeof(float)); + + // We pick up from here + cur2 = cur1; + cur1 += numToCopy; + cur1BufIndex = (cur1BufIndex == 0) ? 1 : 0; + } + } +} + +void +GpuIndex::compute_residual(const float* x, + float* residual, + Index::idx_t key) const { + FAISS_THROW_MSG("compute_residual not implemented for this type of index"); +} + +void +GpuIndex::compute_residual_n(Index::idx_t n, + const float* xs, + float* residuals, + const Index::idx_t* keys) const { + FAISS_THROW_MSG("compute_residual_n not implemented for this type of index"); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndex.h b/core/src/index/thirdparty/faiss/gpu/GpuIndex.h new file mode 100644 index 0000000000..ae902f57a8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndex.h @@ -0,0 +1,160 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +struct GpuIndexConfig { + inline GpuIndexConfig() + : device(0), + memorySpace(MemorySpace::Device) { + } + + /// GPU device on which the index is resident + int device; + + /// What memory space to use for primary storage. + /// On Pascal and above (CC 6+) architectures, allows GPUs to use + /// more memory than is available on the GPU. + MemorySpace memorySpace; +}; + +class GpuIndex : public faiss::Index { + public: + GpuIndex(GpuResources* resources, + int dims, + faiss::MetricType metric, + float metricArg, + GpuIndexConfig config); + + inline int getDevice() const { + return device_; + } + + inline GpuResources* getResources() { + return resources_; + } + + /// Set the minimum data size for searches (in MiB) for which we use + /// CPU -> GPU paging + void setMinPagingSize(size_t size); + + /// Returns the current minimum data size for paged searches + size_t getMinPagingSize() const; + + /// `x` can be resident on the CPU or any GPU; copies are performed + /// as needed + /// Handles paged adds if the add set is too large; calls addInternal_ + void add(faiss::Index::idx_t, const float* x) override; + + /// `x` and `ids` can be resident on the CPU or any GPU; copies are + /// performed as needed + /// Handles paged adds if the add set is too large; calls addInternal_ + void add_with_ids(Index::idx_t n, + const float* x, + const Index::idx_t* ids) override; + + /// `x`, `distances` and `labels` can be resident on the CPU or any + /// GPU; copies are performed as needed + void search(Index::idx_t n, + const float* x, + Index::idx_t k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// Overridden to force GPU indices to provide their own GPU-friendly + /// implementation + void compute_residual(const float* x, + float* residual, + Index::idx_t key) const override; + + /// Overridden to force GPU indices to provide their own GPU-friendly + /// implementation + void compute_residual_n(Index::idx_t n, + const float* xs, + float* residuals, + const Index::idx_t* keys) const override; + + protected: + /// Copy what we need from the CPU equivalent + void copyFrom(const faiss::Index* index); + + /// Copy what we have to the CPU equivalent + void copyTo(faiss::Index* index) const; + + /// Does addImpl_ require IDs? If so, and no IDs are provided, we will + /// generate them sequentially based on the order in which the IDs are added + virtual bool addImplRequiresIDs_() const = 0; + + /// Overridden to actually perform the add + /// All data is guaranteed to be resident on our device + virtual void addImpl_(int n, + const float* x, + const Index::idx_t* ids) = 0; + + /// Overridden to actually perform the search + /// All data is guaranteed to be resident on our device + virtual void searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const = 0; + +private: + /// Handles paged adds if the add set is too large, passes to + /// addImpl_ to actually perform the add for the current page + void addPaged_(int n, + const float* x, + const Index::idx_t* ids); + + /// Calls addImpl_ for a single page of GPU-resident data + void addPage_(int n, + const float* x, + const Index::idx_t* ids); + + /// Calls searchImpl_ for a single page of GPU-resident data + void searchNonPaged_(int n, + const float* x, + int k, + float* outDistancesData, + Index::idx_t* outIndicesData, + ConcurrentBitsetPtr bitset = nullptr) const; + + /// Calls searchImpl_ for a single page of GPU-resident data, + /// handling paging of the data and copies from the CPU + void searchFromCpuPaged_(int n, + const float* x, + int k, + float* outDistancesData, + Index::idx_t* outIndicesData, + ConcurrentBitsetPtr bitset = nullptr) const; + + protected: + /// Manages streams, cuBLAS handles and scratch memory for devices + GpuResources* resources_; + + /// The GPU device we are resident on + const int device_; + + /// The memory space of our primary storage on the GPU + const MemorySpace memorySpace_; + + /// Size above which we page copies from the CPU to GPU + size_t minPagedSize_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.cu new file mode 100644 index 0000000000..cd412be944 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.cu @@ -0,0 +1,290 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Default CPU search size for which we use paged copies +constexpr size_t kMinPageSize = (size_t) 256 * 1024 * 1024; + +GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResources* resources, + const faiss::IndexBinaryFlat* index, + GpuIndexBinaryFlatConfig config) + : IndexBinary(index->d), + resources_(resources), + config_(std::move(config)), + data_(nullptr) { + FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0, + "vector dimension (number of bits) " + "must be divisible by 8 (passed %d)", + this->d); + + // Flat index doesn't need training + this->is_trained = true; + + copyFrom(index); +} + + +GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResources* resources, + int dims, + GpuIndexBinaryFlatConfig config) + : IndexBinary(dims), + resources_(resources), + config_(std::move(config)), + data_(nullptr) { + FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0, + "vector dimension (number of bits) " + "must be divisible by 8 (passed %d)", + this->d); + + // Flat index doesn't need training + this->is_trained = true; + + // Construct index + DeviceScope scope(config_.device); + data_ = new BinaryFlatIndex(resources, + this->d, + config_.memorySpace); +} + +GpuIndexBinaryFlat::~GpuIndexBinaryFlat() { + delete data_; +} + +void +GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) { + DeviceScope scope(config_.device); + + this->d = index->d; + + // GPU code has 32 bit indices + FAISS_THROW_IF_NOT_FMT(index->ntotal <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices; " + "attempting to copy CPU index with %zu parameters", + (size_t) std::numeric_limits::max(), + (size_t) index->ntotal); + this->ntotal = index->ntotal; + + delete data_; + data_ = new BinaryFlatIndex(resources_, + this->d, + config_.memorySpace); + + // The index could be empty + if (index->ntotal > 0) { + data_->add(index->xb.data(), + index->ntotal, + resources_->getDefaultStream(config_.device)); + } +} + +void +GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const { + DeviceScope scope(config_.device); + + index->d = this->d; + index->ntotal = this->ntotal; + + FAISS_ASSERT(data_); + FAISS_ASSERT(data_->getSize() == this->ntotal); + index->xb.resize(this->ntotal * (this->d / 8)); + + if (this->ntotal > 0) { + fromDevice(data_->getVectorsRef(), + index->xb.data(), + resources_->getDefaultStream(config_.device)); + } +} + +void +GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n, + const uint8_t* x) { + DeviceScope scope(config_.device); + + // To avoid multiple re-allocations, ensure we have enough storage + // available + data_->reserve(n, resources_->getDefaultStream(config_.device)); + + // Due to GPU indexing in int32, we can't store more than this + // number of vectors on a GPU + FAISS_THROW_IF_NOT_FMT(this->ntotal + n <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices", + (size_t) std::numeric_limits::max()); + + data_->add((const unsigned char*) x, + n, + resources_->getDefaultStream(config_.device)); + this->ntotal += n; +} + +void +GpuIndexBinaryFlat::reset() { + DeviceScope scope(config_.device); + + // Free the underlying memory + data_->reset(); + this->ntotal = 0; +} + +void +GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n, + const uint8_t* x, + faiss::IndexBinary::idx_t k, + int32_t* distances, + faiss::IndexBinary::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + if (n == 0) { + return; + } + + // For now, only support <= max int results + FAISS_THROW_IF_NOT_FMT(n <= (Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices", + (size_t) std::numeric_limits::max()); + FAISS_THROW_IF_NOT_FMT(k <= (Index::idx_t) getMaxKSelection(), + "GPU only supports k <= %d (requested %d)", + getMaxKSelection(), + (int) k); // select limitation + + DeviceScope scope(config_.device); + auto stream = resources_->getDefaultStream(config_.device); + + // The input vectors may be too large for the GPU, but we still + // assume that the output distances and labels are not. + // Go ahead and make space for output distances and labels on the + // GPU. + // If we reach a point where all inputs are too big, we can add + // another level of tiling. + auto outDistances = toDevice(resources_, + config_.device, + distances, + stream, + {(int) n, (int) k}); + + // FlatIndex only supports an interface returning int indices + DeviceTensor outIntIndices( + resources_->getMemoryManagerCurrentDevice(), + {(int) n, (int) k}, stream); + + bool usePaged = false; + + if (getDeviceForAddress(x) == -1) { + // It is possible that the user is querying for a vector set size + // `x` that won't fit on the GPU. + // In this case, we will have to handle paging of the data from CPU + // -> GPU. + // Currently, we don't handle the case where the output data won't + // fit on the GPU (e.g., n * k is too large for the GPU memory). + size_t dataSize = (size_t) n * (this->d / 8) * sizeof(uint8_t); + + if (dataSize >= kMinPageSize) { + searchFromCpuPaged_(n, x, k, + outDistances.data(), + outIntIndices.data()); + usePaged = true; + } + } + + if (!usePaged) { + searchNonPaged_(n, x, k, + outDistances.data(), + outIntIndices.data()); + } + + // Convert and copy int indices out + auto outIndices = toDevice(resources_, + config_.device, + labels, + stream, + {(int) n, (int) k}); + + // Convert int to long + convertTensor(stream, + outIntIndices, + outIndices); + + // Copy back if necessary + fromDevice(outDistances, distances, stream); + fromDevice(outIndices, labels, stream); +} + +void +GpuIndexBinaryFlat::searchNonPaged_(int n, + const uint8_t* x, + int k, + int32_t* outDistancesData, + int* outIndicesData) const { + Tensor outDistances(outDistancesData, {n, k}); + Tensor outIndices(outIndicesData, {n, k}); + + auto stream = resources_->getDefaultStream(config_.device); + + // Make sure arguments are on the device we desire; use temporary + // memory allocations to move it if necessary + auto vecs = toDevice(resources_, + config_.device, + const_cast(x), + stream, + {n, (int) (this->d / 8)}); + + data_->query(vecs, k, outDistances, outIndices); +} + +void +GpuIndexBinaryFlat::searchFromCpuPaged_(int n, + const uint8_t* x, + int k, + int32_t* outDistancesData, + int* outIndicesData) const { + Tensor outDistances(outDistancesData, {n, k}); + Tensor outIndices(outIndicesData, {n, k}); + + auto vectorSize = sizeof(uint8_t) * (this->d / 8); + + // Just page without overlapping copy with compute (as GpuIndexFlat does) + int batchSize = utils::nextHighestPowerOf2( + (int) ((size_t) kMinPageSize / vectorSize)); + + for (int cur = 0; cur < n; cur += batchSize) { + int num = std::min(batchSize, n - cur); + + auto outDistancesSlice = outDistances.narrowOutermost(cur, num); + auto outIndicesSlice = outIndices.narrowOutermost(cur, num); + + searchNonPaged_(num, + x + (size_t) cur * (this->d / 8), + k, + outDistancesSlice.data(), + outIndicesSlice.data()); + } +} + +void +GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key, + uint8_t* out) const { + DeviceScope scope(config_.device); + + FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds"); + auto stream = resources_->getDefaultStream(config_.device); + + auto& vecs = data_->getVectorsRef(); + auto vec = vecs[key]; + + fromDevice(vec.data(), out, vecs.getSize(1), stream); +} + +} } // namespace gpu diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.h new file mode 100644 index 0000000000..da559cb4f1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexBinaryFlat.h @@ -0,0 +1,90 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +class BinaryFlatIndex; +class GpuResources; + +struct GpuIndexBinaryFlatConfig : public GpuIndexConfig { +}; + +/// A GPU version of IndexBinaryFlat for brute-force comparison of bit vectors +/// via Hamming distance +class GpuIndexBinaryFlat : public IndexBinary { + public: + /// Construct from a pre-existing faiss::IndexBinaryFlat instance, copying + /// data over to the given GPU + GpuIndexBinaryFlat(GpuResources* resources, + const faiss::IndexBinaryFlat* index, + GpuIndexBinaryFlatConfig config = + GpuIndexBinaryFlatConfig()); + + /// Construct an empty instance that can be added to + GpuIndexBinaryFlat(GpuResources* resources, + int dims, + GpuIndexBinaryFlatConfig config = + GpuIndexBinaryFlatConfig()); + + ~GpuIndexBinaryFlat() override; + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexBinaryFlat* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexBinaryFlat* index) const; + + void add(faiss::IndexBinary::idx_t n, + const uint8_t* x) override; + + void reset() override; + + void search(faiss::IndexBinary::idx_t n, + const uint8_t* x, + faiss::IndexBinary::idx_t k, + int32_t* distances, + faiss::IndexBinary::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + void reconstruct(faiss::IndexBinary::idx_t key, + uint8_t* recons) const override; + + protected: + /// Called from search when the input data is on the CPU; + /// potentially allows for pinned memory usage + void searchFromCpuPaged_(int n, + const uint8_t* x, + int k, + int32_t* outDistancesData, + int* outIndicesData) const; + + void searchNonPaged_(int n, + const uint8_t* x, + int k, + int32_t* outDistancesData, + int* outIndicesData) const; + + protected: + /// Manages streans, cuBLAS handles and scratch memory for devices + GpuResources* resources_; + + /// Configuration options + GpuIndexBinaryFlatConfig config_; + + /// Holds our GPU data containing the list of vectors; is managed via raw + /// pointer so as to allow non-CUDA compilers to see this header + BinaryFlatIndex* data_; +}; + +} } // namespace gpu diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.cu new file mode 100644 index 0000000000..5f4893586e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.cu @@ -0,0 +1,386 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +GpuIndexFlat::GpuIndexFlat(GpuResources* resources, + const faiss::IndexFlat* index, + GpuIndexFlatConfig config) : + GpuIndex(resources, + index->d, + index->metric_type, + index->metric_arg, + config), + config_(std::move(config)), + data_(nullptr) { + // Flat index doesn't need training + this->is_trained = true; + + copyFrom(index); +} + +GpuIndexFlat::GpuIndexFlat(GpuResources* resources, + int dims, + faiss::MetricType metric, + GpuIndexFlatConfig config) : + GpuIndex(resources, dims, metric, 0, config), + config_(std::move(config)), + data_(nullptr) { + // Flat index doesn't need training + this->is_trained = true; + + // Construct index + DeviceScope scope(device_); + data_ = new FlatIndex(resources, + dims, + config_.useFloat16, + config_.storeTransposed, + memorySpace_); +} + +GpuIndexFlat::~GpuIndexFlat() { + delete data_; +} + +void +GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) { + DeviceScope scope(device_); + + GpuIndex::copyFrom(index); + + // GPU code has 32 bit indices + FAISS_THROW_IF_NOT_FMT(index->ntotal <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices; " + "attempting to copy CPU index with %zu parameters", + (size_t) std::numeric_limits::max(), + (size_t) index->ntotal); + + delete data_; + data_ = new FlatIndex(resources_, + this->d, + config_.useFloat16, + config_.storeTransposed, + memorySpace_); + + // The index could be empty + if (index->ntotal > 0) { + data_->add(index->xb.data(), + index->ntotal, + resources_->getDefaultStream(device_)); + } + + xb_.clear(); + + if (config_.storeInCpu) { + xb_ = index->xb; + } +} + +void +GpuIndexFlat::copyTo(faiss::IndexFlat* index) const { + DeviceScope scope(device_); + + GpuIndex::copyTo(index); + + FAISS_ASSERT(data_); + FAISS_ASSERT(data_->getSize() == this->ntotal); + index->xb.resize(this->ntotal * this->d); + + auto stream = resources_->getDefaultStream(device_); + + if (this->ntotal > 0) { + if (config_.useFloat16) { + auto vecFloat32 = data_->getVectorsFloat32Copy(stream); + fromDevice(vecFloat32, index->xb.data(), stream); + } else { + fromDevice(data_->getVectorsFloat32Ref(), index->xb.data(), stream); + } + } +} + +size_t +GpuIndexFlat::getNumVecs() const { + return this->ntotal; +} + +void +GpuIndexFlat::reset() { + DeviceScope scope(device_); + + // Free the underlying memory + data_->reset(); + this->ntotal = 0; +} + +void +GpuIndexFlat::train(Index::idx_t n, const float* x) { + // nothing to do +} + +void +GpuIndexFlat::add(Index::idx_t n, const float* x) { + FAISS_THROW_IF_NOT_MSG(this->is_trained, "Index not trained"); + + // For now, only support <= max int results + FAISS_THROW_IF_NOT_FMT(n <= (Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %d indices", + std::numeric_limits::max()); + + if (n == 0) { + // nothing to add + return; + } + + DeviceScope scope(device_); + + // To avoid multiple re-allocations, ensure we have enough storage + // available + data_->reserve(n, resources_->getDefaultStream(device_)); + + // If we're not operating in float16 mode, we don't need the input + // data to be resident on our device; we can add directly. + if (!config_.useFloat16) { + addImpl_(n, x, nullptr); + } else { + // Otherwise, perform the paging + GpuIndex::add(n, x); + } +} + +bool +GpuIndexFlat::addImplRequiresIDs_() const { + return false; +} + +void +GpuIndexFlat::addImpl_(int n, + const float* x, + const Index::idx_t* ids) { + FAISS_ASSERT(data_); + FAISS_ASSERT(n > 0); + + // We do not support add_with_ids + FAISS_THROW_IF_NOT_MSG(!ids, "add_with_ids not supported"); + + // Due to GPU indexing in int32, we can't store more than this + // number of vectors on a GPU + FAISS_THROW_IF_NOT_FMT(this->ntotal + n <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices", + (size_t) std::numeric_limits::max()); + + data_->add(x, n, resources_->getDefaultStream(device_)); + this->ntotal += n; +} + +void +GpuIndexFlat::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + auto stream = resources_->getDefaultStream(device_); + + // Input and output data are already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + Tensor outLabels(labels, {n, k}); + + // FlatIndex only supports int indices + DeviceTensor outIntLabels( + resources_->getMemoryManagerCurrentDevice(), {n, k}, stream); + + // Copy bitset to GPU + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + data_->query(queries, bitsetDevice, k, metric_type, metric_arg, outDistances, outIntLabels, true); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream, + {(int) bitset->size()}); + data_->query(queries, bitsetDevice, k, metric_type, metric_arg, outDistances, outIntLabels, true); + } + + // Convert int to idx_t + convertTensor(stream, + outIntLabels, + outLabels); +} + +void +GpuIndexFlat::reconstruct(faiss::Index::idx_t key, + float* out) const { + if (config_.storeInCpu && xb_.size() > 0) { + memcpy (out, &(this->xb_[key * this->d]), sizeof(*out) * this->d); + return; + } + + DeviceScope scope(device_); + + FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds"); + auto stream = resources_->getDefaultStream(device_); + + if (config_.useFloat16) { + // FIXME jhj: kernel for copy + auto vec = data_->getVectorsFloat32Copy(key, 1, stream); + fromDevice(vec.data(), out, this->d, stream); + } else { + auto vec = data_->getVectorsFloat32Ref()[key]; + fromDevice(vec.data(), out, this->d, stream); + } +} + +void +GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0, + faiss::Index::idx_t num, + float* out) const { + DeviceScope scope(device_); + + FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds"); + FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds"); + auto stream = resources_->getDefaultStream(device_); + + if (config_.useFloat16) { + // FIXME jhj: kernel for copy + auto vec = data_->getVectorsFloat32Copy(i0, num, stream); + fromDevice(vec.data(), out, num * this->d, stream); + } else { + auto vec = data_->getVectorsFloat32Ref()[i0]; + fromDevice(vec.data(), out, this->d * num, stream); + } +} + +void +GpuIndexFlat::compute_residual(const float* x, + float* residual, + faiss::Index::idx_t key) const { + compute_residual_n(1, x, residual, &key); +} + +void +GpuIndexFlat::compute_residual_n(faiss::Index::idx_t n, + const float* xs, + float* residuals, + const faiss::Index::idx_t* keys) const { + FAISS_THROW_IF_NOT_FMT(n <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports up to %zu indices", + (size_t) std::numeric_limits::max()); + + auto stream = resources_->getDefaultStream(device_); + + DeviceScope scope(device_); + + auto vecsDevice = + toDevice(resources_, device_, + const_cast(xs), stream, + {(int) n, (int) this->d}); + auto idsDevice = + toDevice(resources_, device_, + const_cast(keys), + stream, + {(int) n}); + auto residualDevice = + toDevice(resources_, device_, residuals, stream, + {(int) n, (int) this->d}); + + // Convert idx_t to int + auto keysInt = + convertTensor(resources_, stream, idsDevice); + + FAISS_ASSERT(data_); + data_->computeResidual(vecsDevice, + keysInt, + residualDevice); + + fromDevice(residualDevice, residuals, stream); +} + +// +// GpuIndexFlatL2 +// + +GpuIndexFlatL2::GpuIndexFlatL2(GpuResources* resources, + faiss::IndexFlatL2* index, + GpuIndexFlatConfig config) : + GpuIndexFlat(resources, index, config) { +} + +GpuIndexFlatL2::GpuIndexFlatL2(GpuResources* resources, + int dims, + GpuIndexFlatConfig config) : + GpuIndexFlat(resources, dims, faiss::METRIC_L2, config) { +} + +void +GpuIndexFlatL2::copyFrom(faiss::IndexFlat* index) { + FAISS_THROW_IF_NOT_MSG(index->metric_type == metric_type, + "Cannot copy a GpuIndexFlatL2 from an index of " + "different metric_type"); + + GpuIndexFlat::copyFrom(index); +} + +void +GpuIndexFlatL2::copyTo(faiss::IndexFlat* index) { + FAISS_THROW_IF_NOT_MSG(index->metric_type == metric_type, + "Cannot copy a GpuIndexFlatL2 to an index of " + "different metric_type"); + + GpuIndexFlat::copyTo(index); +} + +// +// GpuIndexFlatIP +// + +GpuIndexFlatIP::GpuIndexFlatIP(GpuResources* resources, + faiss::IndexFlatIP* index, + GpuIndexFlatConfig config) : + GpuIndexFlat(resources, index, config) { +} + +GpuIndexFlatIP::GpuIndexFlatIP(GpuResources* resources, + int dims, + GpuIndexFlatConfig config) : + GpuIndexFlat(resources, dims, faiss::METRIC_INNER_PRODUCT, config) { +} + +void +GpuIndexFlatIP::copyFrom(faiss::IndexFlat* index) { + FAISS_THROW_IF_NOT_MSG(index->metric_type == metric_type, + "Cannot copy a GpuIndexFlatIP from an index of " + "different metric_type"); + + GpuIndexFlat::copyFrom(index); +} + +void +GpuIndexFlatIP::copyTo(faiss::IndexFlat* index) { + // The passed in index must be IP + FAISS_THROW_IF_NOT_MSG(index->metric_type == metric_type, + "Cannot copy a GpuIndexFlatIP to an index of " + "different metric_type"); + + GpuIndexFlat::copyTo(index); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.h new file mode 100644 index 0000000000..90823d69a4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexFlat.h @@ -0,0 +1,188 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { + +struct IndexFlat; +struct IndexFlatL2; +struct IndexFlatIP; + +} + +namespace faiss { namespace gpu { + +struct FlatIndex; + +struct GpuIndexFlatConfig : public GpuIndexConfig { + inline GpuIndexFlatConfig() + : useFloat16(false), + storeTransposed(false), + storeInCpu(false) { + } + + /// Whether or not data is stored as float16 + bool useFloat16; + + /// Whether or not data is stored (transparently) in a transposed + /// layout, enabling use of the NN GEMM call, which is ~10% faster. + /// This will improve the speed of the flat index, but will + /// substantially slow down any add() calls made, as all data must + /// be transposed, and will increase storage requirements (we store + /// data in both transposed and non-transposed layouts). + bool storeTransposed; + + bool storeInCpu; +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexFlat; copies over centroid data from a given +/// faiss::IndexFlat +class GpuIndexFlat : public GpuIndex { + public: + /// Construct from a pre-existing faiss::IndexFlat instance, copying + /// data over to the given GPU + GpuIndexFlat(GpuResources* resources, + const faiss::IndexFlat* index, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + /// Construct an empty instance that can be added to + GpuIndexFlat(GpuResources* resources, + int dims, + faiss::MetricType metric, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + ~GpuIndexFlat() override; + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexFlat* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexFlat* index) const; + + /// Returns the number of vectors we contain + size_t getNumVecs() const; + + /// Clears all vectors from this index + void reset() override; + + /// This index is not trained, so this does nothing + void train(Index::idx_t n, const float* x) override; + + /// Overrides to avoid excessive copies + void add(faiss::Index::idx_t, const float* x) override; + + /// Reconstruction methods; prefer the batch reconstruct as it will + /// be more efficient + void reconstruct(faiss::Index::idx_t key, float* out) const override; + + /// Batch reconstruction method + void reconstruct_n(faiss::Index::idx_t i0, + faiss::Index::idx_t num, + float* out) const override; + + /// Compute residual + void compute_residual(const float* x, + float* residual, + faiss::Index::idx_t key) const override; + + /// Compute residual (batch mode) + void compute_residual_n(faiss::Index::idx_t n, + const float* xs, + float* residuals, + const faiss::Index::idx_t* keys) const override; + + /// For internal access + inline FlatIndex* getGpuData() { return data_; } + + protected: + /// Flat index does not require IDs as there is no storage available for them + bool addImplRequiresIDs_() const override; + + /// Called from GpuIndex for add + void addImpl_(int n, + const float* x, + const Index::idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_(int n, + const float* x, + int k, + float* distances, + faiss::Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + protected: + /// Our config object + const GpuIndexFlatConfig config_; + + /// Holds our GPU data containing the list of vectors; is managed via raw + /// pointer so as to allow non-CUDA compilers to see this header + FlatIndex* data_; + + std::vector xb_; +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexFlatL2; copies over centroid data from a given +/// faiss::IndexFlat +class GpuIndexFlatL2 : public GpuIndexFlat { + public: + /// Construct from a pre-existing faiss::IndexFlatL2 instance, copying + /// data over to the given GPU + GpuIndexFlatL2(GpuResources* resources, + faiss::IndexFlatL2* index, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + /// Construct an empty instance that can be added to + GpuIndexFlatL2(GpuResources* resources, + int dims, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(faiss::IndexFlat* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexFlat* index); +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexFlatIP; copies over centroid data from a given +/// faiss::IndexFlat +class GpuIndexFlatIP : public GpuIndexFlat { + public: + /// Construct from a pre-existing faiss::IndexFlatIP instance, copying + /// data over to the given GPU + GpuIndexFlatIP(GpuResources* resources, + faiss::IndexFlatIP* index, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + /// Construct an empty instance that can be added to + GpuIndexFlatIP(GpuResources* resources, + int dims, + GpuIndexFlatConfig config = GpuIndexFlatConfig()); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(faiss::IndexFlat* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexFlat* index); +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu new file mode 100644 index 0000000000..130e95f866 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu @@ -0,0 +1,324 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +GpuIndexIVF::GpuIndexIVF(GpuResources* resources, + int dims, + faiss::MetricType metric, + float metricArg, + int nlistIn, + GpuIndexIVFConfig config) : + GpuIndex(resources, dims, metric, metricArg, config), + ivfConfig_(std::move(config)), + nlist(nlistIn), + nprobe(1), + quantizer(nullptr) { + + init_(); + + // Only IP and L2 are supported for now + if (!(metric_type == faiss::METRIC_L2 || + metric_type == faiss::METRIC_INNER_PRODUCT)) { + FAISS_THROW_FMT("unsupported metric type %d", (int) metric_type); + } +} + +void +GpuIndexIVF::init_() { + FAISS_THROW_IF_NOT_MSG(nlist > 0, "nlist must be > 0"); + + // Spherical by default if the metric is inner_product + if (metric_type == faiss::METRIC_INNER_PRODUCT) { + cp.spherical = true; + } + + // here we set a low # iterations because this is typically used + // for large clusterings + cp.niter = 10; + cp.verbose = verbose; + + if (!quantizer) { + // Construct an empty quantizer + GpuIndexFlatConfig config = ivfConfig_.flatConfig; + // FIXME: inherit our same device + config.device = device_; + + if (metric_type == faiss::METRIC_L2) { + quantizer = new GpuIndexFlatL2(resources_, d, config); + } else if (metric_type == faiss::METRIC_INNER_PRODUCT) { + quantizer = new GpuIndexFlatIP(resources_, d, config); + } else { + // unknown metric type + FAISS_THROW_FMT("unsupported metric type %d", (int) metric_type); + } + } +} + +GpuIndexIVF::~GpuIndexIVF() { + if (remove_quantizer == 1) { + delete quantizer; + } +} + +GpuIndexFlat* +GpuIndexIVF::getQuantizer() { + return quantizer; +} + +void +GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) { + DeviceScope scope(device_); + + GpuIndex::copyFrom(index); + + FAISS_ASSERT(index->nlist > 0); + FAISS_THROW_IF_NOT_FMT(index->nlist <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports %zu inverted lists", + (size_t) std::numeric_limits::max()); + nlist = index->nlist; + + FAISS_THROW_IF_NOT_FMT(index->nprobe > 0 && + index->nprobe <= getMaxKSelection(), + "GPU index only supports nprobe <= %zu; passed %zu", + (size_t) getMaxKSelection(), + index->nprobe); + nprobe = index->nprobe; + + // The metric type may have changed as well, so we might have to + // change our quantizer + delete quantizer; + quantizer = nullptr; + + // Construct an empty quantizer + GpuIndexFlatConfig config = ivfConfig_.flatConfig; + // FIXME: inherit our same device + config.device = device_; + + if (index->metric_type == faiss::METRIC_L2) { + // FIXME: 2 different float16 options? + quantizer = new GpuIndexFlatL2(resources_, this->d, config); + } else if (index->metric_type == faiss::METRIC_INNER_PRODUCT) { + // FIXME: 2 different float16 options? + quantizer = new GpuIndexFlatIP(resources_, this->d, config); + } else { + // unknown metric type + FAISS_ASSERT(false); + } + + if (!index->is_trained) { + // copied in GpuIndex::copyFrom + FAISS_ASSERT(!is_trained && ntotal == 0); + return; + } + + // copied in GpuIndex::copyFrom + // ntotal can exceed max int, but the number of vectors per inverted + // list cannot exceed this. We check this in the subclasses. + FAISS_ASSERT(is_trained && (ntotal == index->ntotal)); + + // Since we're trained, the quantizer must have data + FAISS_ASSERT(index->quantizer->ntotal > 0); + + // Right now, we can only handle IndexFlat or derived classes + auto qFlat = dynamic_cast(index->quantizer); + FAISS_THROW_IF_NOT_MSG(qFlat, + "Only IndexFlat is supported for the coarse quantizer " + "for copying from an IndexIVF into a GpuIndexIVF"); + + quantizer->copyFrom(qFlat); +} + +void +GpuIndexIVF::copyFrom(faiss::IndexIVF* index, gpu::GpuIndexFlat *&qt, int64_t mode) { + DeviceScope scope(device_); + + this->d = index->d; + this->metric_type = index->metric_type; + + FAISS_ASSERT(index->nlist > 0); + FAISS_THROW_IF_NOT_FMT(index->nlist <= + (faiss::Index::idx_t) std::numeric_limits::max(), + "GPU index only supports %zu inverted lists", + (size_t) std::numeric_limits::max()); + nlist = index->nlist; + + FAISS_THROW_IF_NOT_FMT(index->nprobe > 0 && + index->nprobe <= getMaxKSelection(), + "GPU index only supports nprobe <= %zu; passed %zu", + (size_t) getMaxKSelection(), + index->nprobe); + nprobe = index->nprobe; + + // The metric type may have changed as well, so we might have to + // change our quantizer + delete quantizer; + quantizer = nullptr; + + // Construct an empty quantizer + GpuIndexFlatConfig config = ivfConfig_.flatConfig; + // FIXME: inherit our same device + config.device = device_; + config.storeInCpu = true; + + if(qt == nullptr) { + if (index->metric_type == faiss::METRIC_L2) { + // FIXME: 2 different float16 options? + quantizer = new GpuIndexFlatL2(resources_, this->d, config); + } else if (index->metric_type == faiss::METRIC_INNER_PRODUCT) { + // FIXME: 2 different float16 options? + quantizer = new GpuIndexFlatIP(resources_, this->d, config); + } else { + // unknown metric type + FAISS_ASSERT(false); + } + } + + if (!index->is_trained) { + this->is_trained = false; + this->ntotal = 0; + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // restore quantizer from backup ptr + index->restore_quantizer(); + + // ntotal can exceed max int, but the number of vectors per inverted + // list cannot exceed this. We check this in the subclasses. + this->ntotal = index->ntotal; + + // Since we're trained, the quantizer must have data + FAISS_ASSERT(index->quantizer->ntotal > 0); + + if(qt == nullptr) { + // Right now, we can only handle IndexFlat or derived classes + auto qFlat = dynamic_cast(index->quantizer); + FAISS_THROW_IF_NOT_MSG(qFlat, + "Only IndexFlat is supported for the coarse quantizer " + "for copying from an IndexIVF into a GpuIndexIVF"); + quantizer->copyFrom(qFlat); + qt = quantizer; + } else { + quantizer = qt; + } + remove_quantizer = 0; +} + +void +GpuIndexIVF::copyTo(faiss::IndexIVF* index) const { + DeviceScope scope(device_); + + // + // Index information + // + GpuIndex::copyTo(index); + + // + // IndexIVF information + // + index->nlist = nlist; + index->nprobe = nprobe; + + // Construct and copy the appropriate quantizer + faiss::IndexFlat* q = nullptr; + + if (this->metric_type == faiss::METRIC_L2) { + q = new faiss::IndexFlatL2(this->d); + + } else if (this->metric_type == faiss::METRIC_INNER_PRODUCT) { + q = new faiss::IndexFlatIP(this->d); + + } else { + // we should have one of the above metrics + FAISS_ASSERT(false); + } + + FAISS_ASSERT(quantizer); + quantizer->copyTo(q); + + if (index->own_fields) { + delete index->quantizer; + } + + index->quantizer = q; + index->quantizer_trains_alone = 0; + index->own_fields = true; + index->cp = this->cp; + index->make_direct_map(false); +} + +int +GpuIndexIVF::getNumLists() const { + return nlist; +} + +void +GpuIndexIVF::setNumProbes(int nprobe) { + FAISS_THROW_IF_NOT_FMT(nprobe > 0 && nprobe <= getMaxKSelection(), + "GPU index only supports nprobe <= %d; passed %d", + getMaxKSelection(), + nprobe); + this->nprobe = nprobe; +} + +int +GpuIndexIVF::getNumProbes() const { + return nprobe; +} + +bool +GpuIndexIVF::addImplRequiresIDs_() const { + // All IVF indices have storage for IDs + return true; +} + +void +GpuIndexIVF::trainQuantizer_(faiss::Index::idx_t n, const float* x) { + if (n == 0) { + // nothing to do + return; + } + + if (quantizer->is_trained && (quantizer->ntotal == nlist)) { + if (this->verbose) { + printf ("IVF quantizer does not need training.\n"); + } + + return; + } + + if (this->verbose) { + printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d); + } + + DeviceScope scope(device_); + + // leverage the CPU-side k-means code, which works for the GPU + // flat index as well + quantizer->reset(); + Clustering clus(this->d, nlist, this->cp); + clus.verbose = verbose; + clus.train(n, x, *quantizer); + quantizer->is_trained = true; + + FAISS_ASSERT(quantizer->ntotal == nlist); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.h new file mode 100644 index 0000000000..bc0dddc9a6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.h @@ -0,0 +1,94 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { struct IndexIVF; } + +namespace faiss { namespace gpu { + +class GpuIndexFlat; +class GpuResources; + +struct GpuIndexIVFConfig : public GpuIndexConfig { + inline GpuIndexIVFConfig() + : indicesOptions(INDICES_64_BIT) { + } + + /// Index storage options for the GPU + IndicesOptions indicesOptions; + + /// Configuration for the coarse quantizer object + GpuIndexFlatConfig flatConfig; +}; + +class GpuIndexIVF : public GpuIndex { + public: + GpuIndexIVF(GpuResources* resources, + int dims, + faiss::MetricType metric, + float metricArg, + int nlist, + GpuIndexIVFConfig config = GpuIndexIVFConfig()); + + ~GpuIndexIVF() override; + + private: + /// Shared initialization functions + void init_(); + + public: + /// Copy what we need from the CPU equivalent + void copyFrom(const faiss::IndexIVF* index); + + void copyFrom(faiss::IndexIVF* index, gpu::GpuIndexFlat *&qt, int64_t mode); + + /// Copy what we have to the CPU equivalent + void copyTo(faiss::IndexIVF* index) const; + + /// Returns the number of inverted lists we're managing + int getNumLists() const; + + /// Return the quantizer we're using + GpuIndexFlat* getQuantizer(); + + /// Sets the number of list probes per query + void setNumProbes(int nprobe); + + /// Returns our current number of list probes per query + int getNumProbes() const; + + protected: + bool addImplRequiresIDs_() const override; + void trainQuantizer_(faiss::Index::idx_t n, const float* x); + + public: + /// Exposing this like the CPU version for manipulation + ClusteringParameters cp; + + /// Exposing this like the CPU version for query + int nlist; + + /// Exposing this like the CPU version for manipulation + int nprobe; + + /// Exposeing this like the CPU version for query + GpuIndexFlat* quantizer; + + int remove_quantizer = 1; + + protected: + GpuIndexIVFConfig ivfConfig_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu new file mode 100644 index 0000000000..938ac989ae --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu @@ -0,0 +1,330 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + + namespace faiss { namespace gpu { + + GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, + const faiss::IndexIVFFlat* index, + GpuIndexIVFFlatConfig config) : + GpuIndexIVF(resources, + index->d, + index->metric_type, + index->metric_arg, + index->nlist, + config), + ivfFlatConfig_(config), + reserveMemoryVecs_(0), + index_(nullptr) { + + copyFrom(index); + } + + GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, + int dims, + int nlist, + faiss::MetricType metric, + GpuIndexIVFFlatConfig config) : + GpuIndexIVF(resources, dims, metric, 0, nlist, config), + ivfFlatConfig_(config), + reserveMemoryVecs_(0), + index_(nullptr) { + + // faiss::Index params + this->is_trained = false; + + // We haven't trained ourselves, so don't construct the IVFFlat + // index yet + } + + GpuIndexIVFFlat::~GpuIndexIVFFlat() { + delete index_; + } + + void + GpuIndexIVFFlat::reserveMemory(size_t numVecs) { + reserveMemoryVecs_ = numVecs; + if (index_) { + DeviceScope scope(device_); + index_->reserveMemory(numVecs); + } + } + + void + GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) { + DeviceScope scope(device_); + + GpuIndexIVF::copyFrom(index); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // The other index might not be trained + if (!index->is_trained) { + FAISS_ASSERT(!is_trained); + return; + } + + // Otherwise, we can populate ourselves from the other index + FAISS_ASSERT(is_trained); + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + InvertedLists *ivf = index->invlists; + + if (ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + /* double t0 = getmillisecs(); */ + /* std::cout << "Readonly Takes " << getmillisecs() - t0 << " ms" << std::endl; */ + } else { + for (size_t i = 0; i < ivf->nlist; ++i) { + auto numVecs = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(numVecs <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + numVecs); + + index_->addCodeVectorsFromCpu(i, + (const unsigned char*)(ivf->get_codes(i)), + ivf->get_ids(i), + numVecs); + } + } + } + + void + GpuIndexIVFFlat::copyFromWithoutCodes(const faiss::IndexIVFFlat* index, const uint8_t* arranged_data) { + DeviceScope scope(device_); + + GpuIndexIVF::copyFrom(index); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // The other index might not be trained + if (!index->is_trained) { + FAISS_ASSERT(!is_trained); + return; + } + + // Otherwise, we can populate ourselves from the other index + FAISS_ASSERT(is_trained); + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + InvertedLists *ivf = index->invlists; + + if (ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float *) arranged_data, + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + // should not happen + } + } + + void + GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->code_size = this->d * sizeof(float); + + InvertedLists *ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + auto listData = index_->getListVectors(i); + + ivf->add_entries(i, + listIndices.size(), + listIndices.data(), + (const uint8_t*) listData.data()); + } + } + } + + void + GpuIndexIVFFlat::copyToWithoutCodes(faiss::IndexIVFFlat* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->code_size = this->d * sizeof(float); + + InvertedLists *ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + + ivf->add_entries_without_codes(i, + listIndices.size(), + listIndices.data()); + } + } + } + + size_t + GpuIndexIVFFlat::reclaimMemory() { + if (index_) { + DeviceScope scope(device_); + + return index_->reclaimMemory(); + } + + return 0; + } + + void + GpuIndexIVFFlat::reset() { + if (index_) { + DeviceScope scope(device_); + + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } + } + + void + GpuIndexIVFFlat::train(Index::idx_t n, const float* x) { + DeviceScope scope(device_); + + if (this->is_trained) { + FAISS_ASSERT(quantizer->is_trained); + FAISS_ASSERT(quantizer->ntotal == nlist); + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + trainQuantizer_(n, x); + + // The quantizer is now trained; construct the IVF index + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + this->metric_type, + this->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + + if (reserveMemoryVecs_) { + index_->reserveMemory(reserveMemoryVecs_); + } + + this->is_trained = true; + } + + void + GpuIndexIVFFlat::addImpl_(int n, + const float* x, + const Index::idx_t* xids) { + // Device is already set in GpuIndex::add + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor data(const_cast(x), {n, (int) this->d}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor labels(const_cast(xids), {n}); + + // Not all vectors may be able to be added (some may contain NaNs etc) + index_->classifyAndAddVectors(data, labels); + + // but keep the ntotal based on the total number of vectors that we attempted + // to add + ntotal += n; + } + + void + GpuIndexIVFFlat::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + // Device is already set in GpuIndex::search + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor outLabels(const_cast(labels), {n, k}); + + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream, + {(int) bitset->size()}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } + } + + + } } // namespace + \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h new file mode 100644 index 0000000000..e0b79aaee1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h @@ -0,0 +1,91 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { struct IndexIVFFlat; } + +namespace faiss { namespace gpu { + +class IVFFlat; +class GpuIndexFlat; + +struct GpuIndexIVFFlatConfig : public GpuIndexIVFConfig { +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexIVFFlat +class GpuIndexIVFFlat : public GpuIndexIVF { + public: + /// Construct from a pre-existing faiss::IndexIVFFlat instance, copying + /// data over to the given GPU, if the input index is trained. + GpuIndexIVFFlat(GpuResources* resources, + const faiss::IndexIVFFlat* index, + GpuIndexIVFFlatConfig config = GpuIndexIVFFlatConfig()); + + /// Constructs a new instance with an empty flat quantizer; the user + /// provides the number of lists desired. + GpuIndexIVFFlat(GpuResources* resources, + int dims, + int nlist, + faiss::MetricType metric, + GpuIndexIVFFlatConfig config = GpuIndexIVFFlatConfig()); + + ~GpuIndexIVFFlat() override; + + /// Reserve GPU memory in our inverted lists for this number of vectors + void reserveMemory(size_t numVecs); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexIVFFlat* index); + + void copyFromWithoutCodes(const faiss::IndexIVFFlat* index, const uint8_t* arranged_data); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexIVFFlat* index) const; + + void copyToWithoutCodes(faiss::IndexIVFFlat* index) const; + + /// After adding vectors, one can call this to reclaim device memory + /// to exactly the amount needed. Returns space reclaimed in bytes + size_t reclaimMemory(); + + void reset() override; + + void train(Index::idx_t n, const float* x) override; + + protected: + /// Called from GpuIndex for add/add_with_ids + void addImpl_(int n, + const float* x, + const Index::idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + private: + GpuIndexIVFFlatConfig ivfFlatConfig_; + + /// Desired inverted list memory reservation + size_t reserveMemoryVecs_; + + /// Instance that we own; contains the inverted list + IVFFlat* index_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu new file mode 100644 index 0000000000..d6095a58e8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu @@ -0,0 +1,471 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace faiss { namespace gpu { + +GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResources* resources, + const faiss::IndexIVFPQ* index, + GpuIndexIVFPQConfig config) : + GpuIndexIVF(resources, + index->d, + index->metric_type, + index->metric_arg, + index->nlist, + config), + ivfpqConfig_(config), + subQuantizers_(0), + bitsPerCode_(0), + reserveMemoryVecs_(0), + index_(nullptr) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!ivfpqConfig_.useFloat16LookupTables); +#endif + + copyFrom(index); +} + +GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResources* resources, + int dims, + int nlist, + int subQuantizers, + int bitsPerCode, + faiss::MetricType metric, + GpuIndexIVFPQConfig config) : + GpuIndexIVF(resources, + dims, + metric, + 0, + nlist, + config), + ivfpqConfig_(config), + subQuantizers_(subQuantizers), + bitsPerCode_(bitsPerCode), + reserveMemoryVecs_(0), + index_(nullptr) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!config.useFloat16LookupTables); +#endif + + verifySettings_(); + + // We haven't trained ourselves, so don't construct the PQ index yet + this->is_trained = false; +} + +GpuIndexIVFPQ::~GpuIndexIVFPQ() { + delete index_; +} + +void +GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) { + DeviceScope scope(device_); + + GpuIndexIVF::copyFrom(index); + + // Clear out our old data + delete index_; + index_ = nullptr; + + subQuantizers_ = index->pq.M; + bitsPerCode_ = index->pq.nbits; + + // We only support this + FAISS_THROW_IF_NOT_MSG(index->pq.nbits == 8, + "GPU: only pq.nbits == 8 is supported"); + FAISS_THROW_IF_NOT_MSG(index->by_residual, + "GPU: only by_residual = true is supported"); + FAISS_THROW_IF_NOT_MSG(index->polysemous_ht == 0, + "GPU: polysemous codes not supported"); + + verifySettings_(); + + // The other index might not be trained + if (!index->is_trained) { + // copied in GpuIndex::copyFrom + FAISS_ASSERT(!is_trained); + return; + } + + // Copy our lists as well + // The product quantizer must have data in it + FAISS_ASSERT(index->pq.centroids.size() > 0); + index_ = new IVFPQ(resources_, + index->metric_type, + index->metric_arg, + quantizer->getGpuData(), + subQuantizers_, + bitsPerCode_, + (float*) index->pq.centroids.data(), + ivfpqConfig_.indicesOptions, + ivfpqConfig_.useFloat16LookupTables, + memorySpace_); + // Doesn't make sense to reserve memory here + index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables); + + // Copy database vectors, if any + const InvertedLists *ivf = index->invlists; + size_t nlist = ivf ? ivf->nlist : 0; + for (size_t i = 0; i < nlist; ++i) { + size_t list_size = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(list_size <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + list_size); + + index_->addCodeVectorsFromCpu( + i, ivf->get_codes(i), ivf->get_ids(i), list_size); + } +} + +void +GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + + // + // IndexIVFPQ information + // + index->by_residual = true; + index->use_precomputed_table = 0; + index->code_size = subQuantizers_; + index->pq = faiss::ProductQuantizer(this->d, subQuantizers_, bitsPerCode_); + + index->do_polysemous_training = false; + index->polysemous_training = nullptr; + + index->scan_table_threshold = 0; + index->max_codes = 0; + index->polysemous_ht = 0; + index->precomputed_table.clear(); + + InvertedLists *ivf = new ArrayInvertedLists( + nlist, index->code_size); + + index->replace_invlists(ivf, true); + + if (index_) { + // Copy the inverted lists + for (int i = 0; i < nlist; ++i) { + auto ids = getListIndices(i); + auto codes = getListCodes(i); + index->invlists->add_entries (i, ids.size(), ids.data(), codes.data()); + } + + // Copy PQ centroids + auto devPQCentroids = index_->getPQCentroids(); + index->pq.centroids.resize(devPQCentroids.numElements()); + + fromDevice(devPQCentroids, + index->pq.centroids.data(), + resources_->getDefaultStream(device_)); + + if (ivfpqConfig_.usePrecomputedTables) { + index->precompute_table(); + } + } +} + +void +GpuIndexIVFPQ::reserveMemory(size_t numVecs) { + reserveMemoryVecs_ = numVecs; + if (index_) { + DeviceScope scope(device_); + index_->reserveMemory(numVecs); + } +} + +void +GpuIndexIVFPQ::setPrecomputedCodes(bool enable) { + ivfpqConfig_.usePrecomputedTables = enable; + if (index_) { + DeviceScope scope(device_); + index_->setPrecomputedCodes(enable); + } + + verifySettings_(); +} + +bool +GpuIndexIVFPQ::getPrecomputedCodes() const { + return ivfpqConfig_.usePrecomputedTables; +} + +int +GpuIndexIVFPQ::getNumSubQuantizers() const { + return subQuantizers_; +} + +int +GpuIndexIVFPQ::getBitsPerCode() const { + return bitsPerCode_; +} + +int +GpuIndexIVFPQ::getCentroidsPerSubQuantizer() const { + return utils::pow2(bitsPerCode_); +} + +size_t +GpuIndexIVFPQ::reclaimMemory() { + if (index_) { + DeviceScope scope(device_); + return index_->reclaimMemory(); + } + + return 0; +} + +void +GpuIndexIVFPQ::reset() { + if (index_) { + DeviceScope scope(device_); + + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } +} + +void +GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) { + // Code largely copied from faiss::IndexIVFPQ + // FIXME: GPUize more of this + n = std::min(n, (Index::idx_t) (1 << bitsPerCode_) * 64); + + if (this->verbose) { + printf("computing residuals\n"); + } + + std::vector assign(n); + quantizer->assign (n, x, assign.data()); + + std::vector residuals(n * d); + + // FIXME jhj convert to _n version + for (idx_t i = 0; i < n; i++) { + quantizer->compute_residual(x + i * d, &residuals[i * d], assign[i]); + } + + if (this->verbose) { + printf("training %d x %d product quantizer on %ld vectors in %dD\n", + subQuantizers_, getCentroidsPerSubQuantizer(), n, this->d); + } + + // Just use the CPU product quantizer to determine sub-centroids + faiss::ProductQuantizer pq(this->d, subQuantizers_, bitsPerCode_); + pq.verbose = this->verbose; + pq.train(n, residuals.data()); + + index_ = new IVFPQ(resources_, + metric_type, + metric_arg, + quantizer->getGpuData(), + subQuantizers_, + bitsPerCode_, + pq.centroids.data(), + ivfpqConfig_.indicesOptions, + ivfpqConfig_.useFloat16LookupTables, + memorySpace_); + if (reserveMemoryVecs_) { + index_->reserveMemory(reserveMemoryVecs_); + } + + index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables); +} + +void +GpuIndexIVFPQ::train(Index::idx_t n, const float* x) { + DeviceScope scope(device_); + + if (this->is_trained) { + FAISS_ASSERT(quantizer->is_trained); + FAISS_ASSERT(quantizer->ntotal == nlist); + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + // FIXME: GPUize more of this + // First, make sure that the data is resident on the CPU, if it is not on the + // CPU, as we depend upon parts of the CPU code + auto hostData = toHost((float*) x, + resources_->getDefaultStream(device_), + {(int) n, (int) this->d}); + + trainQuantizer_(n, hostData.data()); + trainResidualQuantizer_(n, hostData.data()); + + FAISS_ASSERT(index_); + + this->is_trained = true; +} + +void +GpuIndexIVFPQ::addImpl_(int n, + const float* x, + const Index::idx_t* xids) { + // Device is already set in GpuIndex::add + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor data(const_cast(x), {n, (int) this->d}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor labels(const_cast(xids), {n}); + + // Not all vectors may be able to be added (some may contain NaNs etc) + index_->classifyAndAddVectors(data, labels); + + // but keep the ntotal based on the total number of vectors that we attempted + // to add + ntotal += n; +} + +void +GpuIndexIVFPQ::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + // Device is already set in GpuIndex::search + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor outLabels(const_cast(labels), {n, k}); + + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream, + {(int) bitset->size()}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } +} + +int +GpuIndexIVFPQ::getListLength(int listId) const { + FAISS_ASSERT(index_); + return index_->getListLength(listId); +} + +std::vector +GpuIndexIVFPQ::getListCodes(int listId) const { + FAISS_ASSERT(index_); + DeviceScope scope(device_); + + return index_->getListCodes(listId); +} + +std::vector +GpuIndexIVFPQ::getListIndices(int listId) const { + FAISS_ASSERT(index_); + DeviceScope scope(device_); + + return index_->getListIndices(listId); +} + +void +GpuIndexIVFPQ::verifySettings_() const { + // Our implementation has these restrictions: + + // Must have some number of lists + FAISS_THROW_IF_NOT_MSG(nlist > 0, "nlist must be >0"); + + // up to a single byte per code + FAISS_THROW_IF_NOT_FMT(bitsPerCode_ <= 8, + "Bits per code must be <= 8 (passed %d)", bitsPerCode_); + + // Sub-quantizers must evenly divide dimensions available + FAISS_THROW_IF_NOT_FMT(this->d % subQuantizers_ == 0, + "Number of sub-quantizers (%d) must be an " + "even divisor of the number of dimensions (%d)", + subQuantizers_, this->d); + + // The number of bytes per encoded vector must be one we support + FAISS_THROW_IF_NOT_FMT(IVFPQ::isSupportedPQCodeLength(subQuantizers_), + "Number of bytes per encoded vector / sub-quantizers (%d) " + "is not supported", + subQuantizers_); + + // We must have enough shared memory on the current device to store + // our lookup distances + int lookupTableSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (ivfpqConfig_.useFloat16LookupTables) { + lookupTableSize = sizeof(half); + } +#endif + + // 64 bytes per code is only supported with usage of float16, at 2^8 + // codes per subquantizer + size_t requiredSmemSize = + lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_); + size_t smemPerBlock = getMaxSharedMemPerBlock(device_); + + FAISS_THROW_IF_NOT_FMT(requiredSmemSize + <= getMaxSharedMemPerBlock(device_), + "Device %d has %zu bytes of shared memory, while " + "%d bits per code and %d sub-quantizers requires %zu " + "bytes. Consider useFloat16LookupTables and/or " + "reduce parameters", + device_, smemPerBlock, bitsPerCode_, subQuantizers_, + requiredSmemSize); + + // If precomputed codes are disabled, we have an extra limitation in + // terms of the number of dimensions per subquantizer + FAISS_THROW_IF_NOT_FMT(ivfpqConfig_.usePrecomputedTables || + IVFPQ::isSupportedNoPrecomputedSubDimSize( + this->d / subQuantizers_), + "Number of dimensions per sub-quantizer (%d) " + "is not currently supported without precomputed codes. " + "Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims " + "per sub-quantizer are currently supported with no " + "precomputed codes. " + "Precomputed codes supports any number of dimensions, but " + "will involve memory overheads.", + this->d / subQuantizers_); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.h new file mode 100644 index 0000000000..54b65980e8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.h @@ -0,0 +1,145 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { struct IndexIVFPQ; } + +namespace faiss { namespace gpu { + +class GpuIndexFlat; +class IVFPQ; + +struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig { + inline GpuIndexIVFPQConfig() + : useFloat16LookupTables(false), + usePrecomputedTables(false) { + } + + /// Whether or not float16 residual distance tables are used in the + /// list scanning kernels. When subQuantizers * 2^bitsPerCode > + /// 16384, this is required. + bool useFloat16LookupTables; + + /// Whether or not we enable the precomputed table option for + /// search, which can substantially increase the memory requirement. + bool usePrecomputedTables; +}; + +/// IVFPQ index for the GPU +class GpuIndexIVFPQ : public GpuIndexIVF { + public: + /// Construct from a pre-existing faiss::IndexIVFPQ instance, copying + /// data over to the given GPU, if the input index is trained. + GpuIndexIVFPQ(GpuResources* resources, + const faiss::IndexIVFPQ* index, + GpuIndexIVFPQConfig config = GpuIndexIVFPQConfig()); + + /// Construct an empty index + GpuIndexIVFPQ(GpuResources* resources, + int dims, + int nlist, + int subQuantizers, + int bitsPerCode, + faiss::MetricType metric, + GpuIndexIVFPQConfig config = GpuIndexIVFPQConfig()); + + ~GpuIndexIVFPQ() override; + + /// Reserve space on the GPU for the inverted lists for `num` + /// vectors, assumed equally distributed among + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexIVFPQ* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexIVFPQ* index) const; + + /// Reserve GPU memory in our inverted lists for this number of vectors + void reserveMemory(size_t numVecs); + + /// Enable or disable pre-computed codes + void setPrecomputedCodes(bool enable); + + /// Are pre-computed codes enabled? + bool getPrecomputedCodes() const; + + /// Return the number of sub-quantizers we are using + int getNumSubQuantizers() const; + + /// Return the number of bits per PQ code + int getBitsPerCode() const; + + /// Return the number of centroids per PQ code (2^bits per code) + int getCentroidsPerSubQuantizer() const; + + /// After adding vectors, one can call this to reclaim device memory + /// to exactly the amount needed. Returns space reclaimed in bytes + size_t reclaimMemory(); + + /// Clears out all inverted lists, but retains the coarse and + /// product centroid information + void reset() override; + + void train(Index::idx_t n, const float* x) override; + + /// For debugging purposes, return the list length of a particular + /// list + int getListLength(int listId) const; + + /// For debugging purposes, return the list codes of a particular + /// list + std::vector getListCodes(int listId) const; + + /// For debugging purposes, return the list indices of a particular + /// list + std::vector getListIndices(int listId) const; + + protected: + /// Called from GpuIndex for add/add_with_ids + void addImpl_(int n, + const float* x, + const Index::idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + private: + void verifySettings_() const; + + void trainResidualQuantizer_(Index::idx_t n, const float* x); + + private: + GpuIndexIVFPQConfig ivfpqConfig_; + + /// Number of sub-quantizers per encoded vector + int subQuantizers_; + + /// Bits per sub-quantizer code + int bitsPerCode_; + + /// Desired inverted list memory reservation + size_t reserveMemoryVecs_; + + /// The product quantizer instance that we own; contains the + /// inverted lists + IVFPQ* index_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.cu new file mode 100644 index 0000000000..27da743cb4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.cu @@ -0,0 +1,357 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +GpuIndexIVFSQHybrid::GpuIndexIVFSQHybrid( + GpuResources* resources, + faiss::IndexIVFSQHybrid* index, + GpuIndexIVFSQHybridConfig config) : + GpuIndexIVF(resources, + index->d, + index->metric_type, + index->metric_arg, + index->nlist, + config), + ivfSQConfig_(config), + sq(index->sq), + by_residual(index->by_residual), + reserveMemoryVecs_(0), + index_(nullptr) { + gpu::GpuIndexFlat *quantizer = nullptr; + copyFrom(index, quantizer, 0); + + FAISS_THROW_IF_NOT_MSG(isSQSupported(sq.qtype), + "Unsupported QuantizerType on GPU"); +} + +GpuIndexIVFSQHybrid::GpuIndexIVFSQHybrid( + GpuResources* resources, + int dims, + int nlist, + faiss::QuantizerType qtype, + faiss::MetricType metric, + bool encodeResidual, + GpuIndexIVFSQHybridConfig config) : + GpuIndexIVF(resources, dims, metric, 0, nlist, config), + ivfSQConfig_(config), + sq(dims, qtype), + by_residual(encodeResidual), + reserveMemoryVecs_(0), + index_(nullptr) { + + // faiss::Index params + this->is_trained = false; + + // We haven't trained ourselves, so don't construct the IVFFlat + // index yet + FAISS_THROW_IF_NOT_MSG(isSQSupported(sq.qtype), + "Unsupported QuantizerType on GPU"); +} + +GpuIndexIVFSQHybrid::~GpuIndexIVFSQHybrid() { + delete index_; +} + +void +GpuIndexIVFSQHybrid::reserveMemory(size_t numVecs) { + reserveMemoryVecs_ = numVecs; + if (index_) { + index_->reserveMemory(numVecs); + } +} + +void +GpuIndexIVFSQHybrid::copyFrom( + const faiss::IndexIVFSQHybrid* index) { + DeviceScope scope(device_); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // Copy what we need from the CPU index + GpuIndexIVF::copyFrom(index); + sq = index->sq; + by_residual = index->by_residual; + + // The other index might not be trained, in which case we don't need to copy + // over the lists + if (!index->is_trained) { + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + InvertedLists* ivf = index->invlists; + if(ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + for (size_t i = 0; i < ivf->nlist; ++i) { + auto numVecs = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(numVecs <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + numVecs); + + index_->addCodeVectorsFromCpu( + i, + (const unsigned char*) ivf->get_codes(i), + ivf->get_ids(i), + numVecs); + } + } +} + +void +GpuIndexIVFSQHybrid::copyFrom( + faiss::IndexIVFSQHybrid* index, + gpu::GpuIndexFlat *&qt, + long mode) { + DeviceScope scope(device_); + + // Clear out our old data + delete index_; + index_ = nullptr; + + GpuIndexIVF::copyFrom(index, qt, mode); + if(mode == 1) { + // Only copy quantizer + return ; + } + + sq = index->sq; + by_residual = index->by_residual; + + // The other index might not be trained, in which case we don't need to copy + // over the lists + if (!index->is_trained) { + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + InvertedLists* ivf = index->invlists; + if(ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + for (size_t i = 0; i < ivf->nlist; ++i) { + auto numVecs = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(numVecs <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + numVecs); + + index_->addCodeVectorsFromCpu( + i, + (const unsigned char*) ivf->get_codes(i), + ivf->get_ids(i), + numVecs); + } + } +} + +void +GpuIndexIVFSQHybrid::copyTo( + faiss::IndexIVFSQHybrid* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG( + ivfSQConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->sq = sq; + index->by_residual = by_residual; + index->code_size = sq.code_size; + + InvertedLists* ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + auto listData = index_->getListVectors(i); + + ivf->add_entries(i, + listIndices.size(), + listIndices.data(), + (const uint8_t*) listData.data()); + } + } +} + +size_t +GpuIndexIVFSQHybrid::reclaimMemory() { + if (index_) { + DeviceScope scope(device_); + + return index_->reclaimMemory(); + } + + return 0; +} + +void +GpuIndexIVFSQHybrid::reset() { + if (index_) { + DeviceScope scope(device_); + + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } +} + +void +GpuIndexIVFSQHybrid::trainResiduals_(Index::idx_t n, const float* x) { + // The input is already guaranteed to be on the CPU + sq.train_residual(n, x, quantizer, by_residual, verbose); +} + +void +GpuIndexIVFSQHybrid::train(Index::idx_t n, const float* x) { + DeviceScope scope(device_); + + if (this->is_trained) { + FAISS_ASSERT(quantizer->is_trained); + FAISS_ASSERT(quantizer->ntotal == nlist); + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + // FIXME: GPUize more of this + // First, make sure that the data is resident on the CPU, if it is not on the + // CPU, as we depend upon parts of the CPU code + auto hostData = toHost((float*) x, + resources_->getDefaultStream(device_), + {(int) n, (int) this->d}); + + trainQuantizer_(n, hostData.data()); + trainResiduals_(n, hostData.data()); + + // The quantizer is now trained; construct the IVF index + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + this->metric_type, + this->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + if (reserveMemoryVecs_) { + index_->reserveMemory(reserveMemoryVecs_); + } + + this->is_trained = true; +} + +void +GpuIndexIVFSQHybrid::addImpl_(int n, + const float* x, + const Index::idx_t* xids) { + // Device is already set in GpuIndex::add + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor data(const_cast(x), {n, (int) this->d}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor labels(const_cast(xids), {n}); + + // Not all vectors may be able to be added (some may contain NaNs etc) + index_->classifyAndAddVectors(data, labels); + + // but keep the ntotal based on the total number of vectors that we attempted + // to add + ntotal += n; +} + +void +GpuIndexIVFSQHybrid::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + // Device is already set in GpuIndex::search + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor outLabels(const_cast(labels), {n, k}); + + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream,{(int) bitset->size()}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.h new file mode 100644 index 0000000000..049372c10f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFSQHybrid.h @@ -0,0 +1,105 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +class IVFFlat; +class GpuIndexFlat; + +struct GpuIndexIVFSQHybridConfig : public GpuIndexIVFConfig { +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexIVFSQHybrid +class GpuIndexIVFSQHybrid : public GpuIndexIVF { + public: + /// Construct from a pre-existing faiss::IndexIVFSQHybrid instance, + /// copying data over to the given GPU, if the input index is trained. + GpuIndexIVFSQHybrid( + GpuResources* resources, + faiss::IndexIVFSQHybrid* index, + GpuIndexIVFSQHybridConfig config = + GpuIndexIVFSQHybridConfig()); + + /// Constructs a new instance with an empty flat quantizer; the user + /// provides the number of lists desired. + GpuIndexIVFSQHybrid( + GpuResources* resources, + int dims, + int nlist, + faiss::QuantizerType qtype, + faiss::MetricType metric = MetricType::METRIC_L2, + bool encodeResidual = true, + GpuIndexIVFSQHybridConfig config = + GpuIndexIVFSQHybridConfig()); + + ~GpuIndexIVFSQHybrid() override; + + /// Reserve GPU memory in our inverted lists for this number of vectors + void reserveMemory(size_t numVecs); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexIVFSQHybrid* index); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(faiss::IndexIVFSQHybrid* index, gpu::GpuIndexFlat *&quantizer, int64_t mode); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexIVFSQHybrid* index) const; + + /// After adding vectors, one can call this to reclaim device memory + /// to exactly the amount needed. Returns space reclaimed in bytes + size_t reclaimMemory(); + + void reset() override; + + void train(Index::idx_t n, const float* x) override; + + protected: + /// Called from GpuIndex for add/add_with_ids + void addImpl_(int n, + const float* x, + const Index::idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// Called from train to handle SQ residual training + void trainResiduals_(Index::idx_t n, const float* x); + + public: + /// Exposed like the CPU version + faiss::ScalarQuantizer sq; + + /// Exposed like the CPU version + bool by_residual; + + private: + GpuIndexIVFSQHybridConfig ivfSQConfig_; + + /// Desired inverted list memory reservation + size_t reserveMemoryVecs_; + + /// Instance that we own; contains the inverted list + IVFFlat* index_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu new file mode 100644 index 0000000000..9a3d908b9c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu @@ -0,0 +1,370 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +GpuIndexIVFScalarQuantizer::GpuIndexIVFScalarQuantizer( + GpuResources* resources, + const faiss::IndexIVFScalarQuantizer* index, + GpuIndexIVFScalarQuantizerConfig config) : + GpuIndexIVF(resources, + index->d, + index->metric_type, + index->metric_arg, + index->nlist, + config), + ivfSQConfig_(config), + sq(index->sq), + by_residual(index->by_residual), + reserveMemoryVecs_(0), + index_(nullptr) { + copyFrom(index); + + FAISS_THROW_IF_NOT_MSG(isSQSupported(sq.qtype), + "Unsupported QuantizerType on GPU"); +} + +GpuIndexIVFScalarQuantizer::GpuIndexIVFScalarQuantizer( + GpuResources* resources, + int dims, + int nlist, + faiss::QuantizerType qtype, + faiss::MetricType metric, + bool encodeResidual, + GpuIndexIVFScalarQuantizerConfig config) : + GpuIndexIVF(resources, dims, metric, 0, nlist, config), + ivfSQConfig_(config), + sq(dims, qtype), + by_residual(encodeResidual), + reserveMemoryVecs_(0), + index_(nullptr) { + + // faiss::Index params + this->is_trained = false; + + // We haven't trained ourselves, so don't construct the IVFFlat + // index yet + FAISS_THROW_IF_NOT_MSG(isSQSupported(sq.qtype), + "Unsupported QuantizerType on GPU"); +} + +GpuIndexIVFScalarQuantizer::~GpuIndexIVFScalarQuantizer() { + delete index_; +} + +void +GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) { + reserveMemoryVecs_ = numVecs; + if (index_) { + DeviceScope scope(device_); + index_->reserveMemory(numVecs); + } +} + +void +GpuIndexIVFScalarQuantizer::copyFrom( + const faiss::IndexIVFScalarQuantizer* index) { + DeviceScope scope(device_); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // Copy what we need from the CPU index + GpuIndexIVF::copyFrom(index); + sq = index->sq; + by_residual = index->by_residual; + + // The other index might not be trained, in which case we don't need to copy + // over the lists + if (!index->is_trained) { + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + InvertedLists* ivf = index->invlists; + + if(ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + for (size_t i = 0; i < ivf->nlist; ++i) { + auto numVecs = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(numVecs <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + numVecs); + + index_->addCodeVectorsFromCpu( + i, + (const unsigned char*) ivf->get_codes(i), + ivf->get_ids(i), + numVecs); + } + } +} + +void +GpuIndexIVFScalarQuantizer::copyFromWithoutCodes( + const faiss::IndexIVFScalarQuantizer* index, const uint8_t* arranged_data) { + DeviceScope scope(device_); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // Copy what we need from the CPU index + GpuIndexIVF::copyFrom(index); + sq = index->sq; + by_residual = index->by_residual; + + // The other index might not be trained, in which case we don't need to copy + // over the lists + if (!index->is_trained) { + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + InvertedLists* ivf = index->invlists; + + if(ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float *)arranged_data, + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + // should not happen + } +} + +void +GpuIndexIVFScalarQuantizer::copyTo( + faiss::IndexIVFScalarQuantizer* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG( + ivfSQConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->sq = sq; + index->code_size = sq.code_size; + index->by_residual = by_residual; + index->code_size = sq.code_size; + + InvertedLists* ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + auto listData = index_->getListVectors(i); + + ivf->add_entries(i, + listIndices.size(), + listIndices.data(), + (const uint8_t*) listData.data()); + } + } +} + +void +GpuIndexIVFScalarQuantizer::copyToWithoutCodes( + faiss::IndexIVFScalarQuantizer* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG( + ivfSQConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->sq = sq; + index->code_size = sq.code_size; + index->by_residual = by_residual; + index->code_size = sq.code_size; + + InvertedLists* ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + + ivf->add_entries_without_codes(i, + listIndices.size(), + listIndices.data()); + } + } +} + +size_t +GpuIndexIVFScalarQuantizer::reclaimMemory() { + if (index_) { + DeviceScope scope(device_); + + return index_->reclaimMemory(); + } + + return 0; +} + +void +GpuIndexIVFScalarQuantizer::reset() { + if (index_) { + DeviceScope scope(device_); + + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } +} + +void +GpuIndexIVFScalarQuantizer::trainResiduals_(Index::idx_t n, const float* x) { + // The input is already guaranteed to be on the CPU + sq.train_residual(n, x, quantizer, by_residual, verbose); +} + +void +GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) { + DeviceScope scope(device_); + + if (this->is_trained) { + FAISS_ASSERT(quantizer->is_trained); + FAISS_ASSERT(quantizer->ntotal == nlist); + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + // FIXME: GPUize more of this + // First, make sure that the data is resident on the CPU, if it is not on the + // CPU, as we depend upon parts of the CPU code + auto hostData = toHost((float*) x, + resources_->getDefaultStream(device_), + {(int) n, (int) this->d}); + + trainQuantizer_(n, hostData.data()); + trainResiduals_(n, hostData.data()); + + // The quantizer is now trained; construct the IVF index + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + this->metric_type, + this->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + if (reserveMemoryVecs_) { + index_->reserveMemory(reserveMemoryVecs_); + } + + this->is_trained = true; +} + +void +GpuIndexIVFScalarQuantizer::addImpl_(int n, + const float* x, + const Index::idx_t* xids) { + // Device is already set in GpuIndex::add + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor data(const_cast(x), {n, (int) this->d}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor labels(const_cast(xids), {n}); + + // Not all vectors may be able to be added (some may contain NaNs etc) + index_->classifyAndAddVectors(data, labels); + + // but keep the ntotal based on the total number of vectors that we attempted + // to add + ntotal += n; +} + +void +GpuIndexIVFScalarQuantizer::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + // Device is already set in GpuIndex::search + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor outLabels(const_cast(labels), {n, k}); + + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream, + {(int) bitset->size()}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h new file mode 100644 index 0000000000..427d9e6702 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h @@ -0,0 +1,105 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +class IVFFlat; +class GpuIndexFlat; + +struct GpuIndexIVFScalarQuantizerConfig : public GpuIndexIVFConfig { +}; + +/// Wrapper around the GPU implementation that looks like +/// faiss::IndexIVFScalarQuantizer +class GpuIndexIVFScalarQuantizer : public GpuIndexIVF { + public: + /// Construct from a pre-existing faiss::IndexIVFScalarQuantizer instance, + /// copying data over to the given GPU, if the input index is trained. + GpuIndexIVFScalarQuantizer( + GpuResources* resources, + const faiss::IndexIVFScalarQuantizer* index, + GpuIndexIVFScalarQuantizerConfig config = + GpuIndexIVFScalarQuantizerConfig()); + + /// Constructs a new instance with an empty flat quantizer; the user + /// provides the number of lists desired. + GpuIndexIVFScalarQuantizer( + GpuResources* resources, + int dims, + int nlist, + faiss::QuantizerType qtype, + faiss::MetricType metric = MetricType::METRIC_L2, + bool encodeResidual = true, + GpuIndexIVFScalarQuantizerConfig config = + GpuIndexIVFScalarQuantizerConfig()); + + ~GpuIndexIVFScalarQuantizer() override; + + /// Reserve GPU memory in our inverted lists for this number of vectors + void reserveMemory(size_t numVecs); + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexIVFScalarQuantizer* index); + + void copyFromWithoutCodes(const faiss::IndexIVFScalarQuantizer* index, const uint8_t* arranged_data); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexIVFScalarQuantizer* index) const; + + void copyToWithoutCodes(faiss::IndexIVFScalarQuantizer* index) const; + + /// After adding vectors, one can call this to reclaim device memory + /// to exactly the amount needed. Returns space reclaimed in bytes + size_t reclaimMemory(); + + void reset() override; + + void train(Index::idx_t n, const float* x) override; + + protected: + /// Called from GpuIndex for add/add_with_ids + void addImpl_(int n, + const float* x, + const Index::idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset = nullptr) const override; + + /// Called from train to handle SQ residual training + void trainResiduals_(Index::idx_t n, const float* x); + + public: + /// Exposed like the CPU version + faiss::ScalarQuantizer sq; + + /// Exposed like the CPU version + bool by_residual; + + private: + GpuIndexIVFScalarQuantizerConfig ivfSQConfig_; + + /// Desired inverted list memory reservation + size_t reserveMemoryVecs_; + + /// Instance that we own; contains the inverted list + IVFFlat* index_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndicesOptions.h b/core/src/index/thirdparty/faiss/gpu/GpuIndicesOptions.h new file mode 100644 index 0000000000..768f981f71 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndicesOptions.h @@ -0,0 +1,30 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +namespace faiss { namespace gpu { + +/// How user vector index data is stored on the GPU +enum IndicesOptions { + /// The user indices are only stored on the CPU; the GPU returns + /// (inverted list, offset) to the CPU which is then translated to + /// the real user index. + INDICES_CPU = 0, + /// The indices are not stored at all, on either the CPU or + /// GPU. Only (inverted list, offset) is returned to the user as the + /// index. + INDICES_IVF = 1, + /// Indices are stored as 32 bit integers on the GPU, but returned + /// as 64 bit integers + INDICES_32_BIT = 2, + /// Indices are stored as 64 bit integers on the GPU + INDICES_64_BIT = 3, +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuResources.cpp b/core/src/index/thirdparty/faiss/gpu/GpuResources.cpp new file mode 100644 index 0000000000..fe386c2cf8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuResources.cpp @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +namespace faiss { namespace gpu { + +GpuResources::~GpuResources() { +} + +cublasHandle_t +GpuResources::getBlasHandleCurrentDevice() { + return getBlasHandle(getCurrentDevice()); +} + +cudaStream_t +GpuResources::getDefaultStreamCurrentDevice() { + return getDefaultStream(getCurrentDevice()); +} + +std::vector +GpuResources::getAlternateStreamsCurrentDevice() { + return getAlternateStreams(getCurrentDevice()); +} + +DeviceMemory& +GpuResources::getMemoryManagerCurrentDevice() { + return getMemoryManager(getCurrentDevice()); +} + +cudaStream_t +GpuResources::getAsyncCopyStreamCurrentDevice() { + return getAsyncCopyStream(getCurrentDevice()); +} + +void +GpuResources::syncDefaultStream(int device) { + CUDA_VERIFY(cudaStreamSynchronize(getDefaultStream(device))); +} + +void +GpuResources::syncDefaultStreamCurrentDevice() { + syncDefaultStream(getCurrentDevice()); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/GpuResources.h b/core/src/index/thirdparty/faiss/gpu/GpuResources.h new file mode 100644 index 0000000000..bdea4f630a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/GpuResources.h @@ -0,0 +1,73 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Base class of GPU-side resource provider; hides provision of +/// cuBLAS handles, CUDA streams and a temporary memory manager +class GpuResources { + public: + virtual ~GpuResources(); + + /// Call to pre-allocate resources for a particular device. If this is + /// not called, then resources will be allocated at the first time + /// of demand + virtual void initializeForDevice(int device) = 0; + + /// Returns the cuBLAS handle that we use for the given device + virtual cublasHandle_t getBlasHandle(int device) = 0; + + /// Returns the stream that we order all computation on for the + /// given device + virtual cudaStream_t getDefaultStream(int device) = 0; + + /// Returns the set of alternative streams that we use for the given device + virtual std::vector getAlternateStreams(int device) = 0; + + /// Returns the temporary memory manager for the given device + virtual DeviceMemory& getMemoryManager(int device) = 0; + + /// Returns the available CPU pinned memory buffer + virtual std::pair getPinnedMemory() = 0; + + /// Returns the stream on which we perform async CPU <-> GPU copies + virtual cudaStream_t getAsyncCopyStream(int device) = 0; + + /// Calls getBlasHandle with the current device + cublasHandle_t getBlasHandleCurrentDevice(); + + /// Calls getDefaultStream with the current device + cudaStream_t getDefaultStreamCurrentDevice(); + + /// Synchronizes the CPU with respect to the default stream for the + /// given device + // equivalent to cudaDeviceSynchronize(getDefaultStream(device)) + void syncDefaultStream(int device); + + /// Calls syncDefaultStream for the current device + void syncDefaultStreamCurrentDevice(); + + /// Calls getAlternateStreams for the current device + std::vector getAlternateStreamsCurrentDevice(); + + /// Calls getMemoryManager for the current device + DeviceMemory& getMemoryManagerCurrentDevice(); + + /// Calls getAsyncCopyStream for the current device + cudaStream_t getAsyncCopyStreamCurrentDevice(); +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.cpp b/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.cpp new file mode 100644 index 0000000000..e564f8e367 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.cpp @@ -0,0 +1,303 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +namespace { + +// How many streams per device we allocate by default (for multi-streaming) +constexpr int kNumStreams = 2; + +// Use 256 MiB of pinned memory for async CPU <-> GPU copies by default +constexpr size_t kDefaultPinnedMemoryAllocation = (size_t) 256 * 1024 * 1024; + +// Default temporary memory allocation for <= 4 GiB memory GPUs +constexpr size_t k4GiBTempMem = (size_t) 512 * 1024 * 1024; + +// Default temporary memory allocation for <= 8 GiB memory GPUs +constexpr size_t k8GiBTempMem = (size_t) 1024 * 1024 * 1024; + +// Maximum temporary memory allocation for all GPUs +constexpr size_t kMaxTempMem = (size_t) 1536 * 1024 * 1024; + +} + +StandardGpuResources::StandardGpuResources() : + pinnedMemAlloc_(nullptr), + pinnedMemAllocSize_(0), + // let the adjustment function determine the memory size for us by passing + // in a huge value that will then be adjusted + tempMemSize_(getDefaultTempMemForGPU(-1, + std::numeric_limits::max())), + pinnedMemSize_(kDefaultPinnedMemoryAllocation), + cudaMallocWarning_(true) { +} + +StandardGpuResources::~StandardGpuResources() { + for (auto& entry : defaultStreams_) { + DeviceScope scope(entry.first); + + auto it = userDefaultStreams_.find(entry.first); + if (it == userDefaultStreams_.end()) { + // The user did not specify this stream, thus we are the ones + // who have created it + CUDA_VERIFY(cudaStreamDestroy(entry.second)); + } + } + + for (auto& entry : alternateStreams_) { + DeviceScope scope(entry.first); + + for (auto stream : entry.second) { + CUDA_VERIFY(cudaStreamDestroy(stream)); + } + } + + for (auto& entry : asyncCopyStreams_) { + DeviceScope scope(entry.first); + + CUDA_VERIFY(cudaStreamDestroy(entry.second)); + } + + for (auto& entry : blasHandles_) { + DeviceScope scope(entry.first); + + auto blasStatus = cublasDestroy(entry.second); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); + } + + if (pinnedMemAlloc_) { + freeMemorySpace(MemorySpace::HostPinned, pinnedMemAlloc_); + } +} + +size_t +StandardGpuResources::getDefaultTempMemForGPU(int device, + size_t requested) { + auto totalMem = device != -1 ? + getDeviceProperties(device).totalGlobalMem : + std::numeric_limits::max(); + + if (totalMem <= (size_t) 4 * 1024 * 1024 * 1024) { + // If the GPU has <= 4 GiB of memory, reserve 512 MiB + + if (requested > k4GiBTempMem) { + return k4GiBTempMem; + } + } else if (totalMem <= (size_t) 8 * 1024 * 1024 * 1024) { + // If the GPU has <= 8 GiB of memory, reserve 1 GiB + + if (requested > k8GiBTempMem) { + return k8GiBTempMem; + } + } else { + // Never use more than 1.5 GiB + if (requested > kMaxTempMem) { + return kMaxTempMem; + } + } + + // use whatever lower limit the user requested + return requested; +} + +void +StandardGpuResources::noTempMemory() { + setTempMemory(0); + setCudaMallocWarning(false); +} + +void +StandardGpuResources::setTempMemory(size_t size) { + if (tempMemSize_ != size) { + // adjust based on general limits + tempMemSize_ = getDefaultTempMemForGPU(-1, size); + + // We need to re-initialize memory resources for all current devices that + // have been initialized. + // This should be safe to do, even if we are currently running work, because + // the cudaFree call that this implies will force-synchronize all GPUs with + // the CPU + for (auto& p : memory_) { + int device = p.first; + // Free the existing memory first + p.second.reset(); + + // Allocate new + p.second = std::unique_ptr( + new StackDeviceMemory(p.first, + // adjust for this specific device + getDefaultTempMemForGPU(device, tempMemSize_))); + } + } +} + +void +StandardGpuResources::setPinnedMemory(size_t size) { + // Should not call this after devices have been initialized + FAISS_ASSERT(defaultStreams_.size() == 0); + FAISS_ASSERT(!pinnedMemAlloc_); + + pinnedMemSize_ = size; +} + +void +StandardGpuResources::setDefaultStream(int device, cudaStream_t stream) { + auto it = defaultStreams_.find(device); + if (it != defaultStreams_.end()) { + // Replace this stream with the user stream + CUDA_VERIFY(cudaStreamDestroy(it->second)); + it->second = stream; + } + + userDefaultStreams_[device] = stream; +} + +void +StandardGpuResources::setDefaultNullStreamAllDevices() { + for (int dev = 0; dev < getNumDevices(); ++dev) { + setDefaultStream(dev, nullptr); + } +} + +void +StandardGpuResources::setCudaMallocWarning(bool b) { + cudaMallocWarning_ = b; + + for (auto& v : memory_) { + v.second->setCudaMallocWarning(b); + } +} + +bool +StandardGpuResources::isInitialized(int device) const { + // Use default streams as a marker for whether or not a certain + // device has been initialized + return defaultStreams_.count(device) != 0; +} + +void +StandardGpuResources::initializeForDevice(int device) { + if (isInitialized(device)) { + return; + } + + // If this is the first device that we're initializing, create our + // pinned memory allocation + if (defaultStreams_.empty() && pinnedMemSize_ > 0) { + allocMemorySpace(MemorySpace::HostPinned, &pinnedMemAlloc_, pinnedMemSize_); + pinnedMemAllocSize_ = pinnedMemSize_; + } + + FAISS_ASSERT(device < getNumDevices()); + DeviceScope scope(device); + + // Make sure that device properties for all devices are cached + auto& prop = getDeviceProperties(device); + + // Also check to make sure we meet our minimum compute capability (3.0) + FAISS_ASSERT_FMT(prop.major >= 3, + "Device id %d with CC %d.%d not supported, " + "need 3.0+ compute capability", + device, prop.major, prop.minor); + + // Create streams + cudaStream_t defaultStream = 0; + auto it = userDefaultStreams_.find(device); + if (it != userDefaultStreams_.end()) { + // We already have a stream provided by the user + defaultStream = it->second; + } else { + CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream, + cudaStreamNonBlocking)); + } + + defaultStreams_[device] = defaultStream; + + cudaStream_t asyncCopyStream = 0; + CUDA_VERIFY(cudaStreamCreateWithFlags(&asyncCopyStream, + cudaStreamNonBlocking)); + + asyncCopyStreams_[device] = asyncCopyStream; + + std::vector deviceStreams; + for (int j = 0; j < kNumStreams; ++j) { + cudaStream_t stream = 0; + CUDA_VERIFY(cudaStreamCreateWithFlags(&stream, + cudaStreamNonBlocking)); + + deviceStreams.push_back(stream); + } + + alternateStreams_[device] = std::move(deviceStreams); + + // Create cuBLAS handle + cublasHandle_t blasHandle = 0; + auto blasStatus = cublasCreate(&blasHandle); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); + blasHandles_[device] = blasHandle; + + // Enable tensor core support if available +#if CUDA_VERSION >= 9000 + if (getTensorCoreSupport(device)) { + cublasSetMathMode(blasHandle, CUBLAS_TENSOR_OP_MATH); + } +#endif + + FAISS_ASSERT(memory_.count(device) == 0); + + auto mem = std::unique_ptr( + new StackDeviceMemory(device, + // adjust for this specific device + getDefaultTempMemForGPU(device, tempMemSize_))); + mem->setCudaMallocWarning(cudaMallocWarning_); + + memory_.emplace(device, std::move(mem)); +} + +cublasHandle_t +StandardGpuResources::getBlasHandle(int device) { + initializeForDevice(device); + return blasHandles_[device]; +} + +cudaStream_t +StandardGpuResources::getDefaultStream(int device) { + initializeForDevice(device); + return defaultStreams_[device]; +} + +std::vector +StandardGpuResources::getAlternateStreams(int device) { + initializeForDevice(device); + return alternateStreams_[device]; +} + +DeviceMemory& StandardGpuResources::getMemoryManager(int device) { + initializeForDevice(device); + return *memory_[device]; +} + +std::pair +StandardGpuResources::getPinnedMemory() { + return std::make_pair(pinnedMemAlloc_, pinnedMemAllocSize_); +} + +cudaStream_t +StandardGpuResources::getAsyncCopyStream(int device) { + initializeForDevice(device); + return asyncCopyStreams_[device]; +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.h b/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.h new file mode 100644 index 0000000000..9d4ffa4c44 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/StandardGpuResources.h @@ -0,0 +1,114 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Default implementation of GpuResources that allocates a cuBLAS +/// stream and 2 streams for use, as well as temporary memory +class StandardGpuResources : public GpuResources { + public: + StandardGpuResources(); + + ~StandardGpuResources() override; + + /// Disable allocation of temporary memory; all temporary memory + /// requests will call cudaMalloc / cudaFree at the point of use + void noTempMemory(); + + /// Specify that we wish to use a certain fixed size of memory on + /// all devices as temporary memory. This is the upper bound for the GPU + /// memory that we will reserve. We will never go above 1.5 GiB on any GPU; + /// smaller GPUs (with <= 4 GiB or <= 8 GiB) will use less memory than that. + /// To avoid any temporary memory allocation, pass 0. + void setTempMemory(size_t size); + + /// Set amount of pinned memory to allocate, for async GPU <-> CPU + /// transfers + void setPinnedMemory(size_t size); + + /// Called to change the stream for work ordering + void setDefaultStream(int device, cudaStream_t stream); + + /// Called to change the work ordering streams to the null stream + /// for all devices + void setDefaultNullStreamAllDevices(); + + /// Enable or disable the warning about not having enough temporary memory + /// when cudaMalloc gets called + void setCudaMallocWarning(bool b); + + public: + /// Internal system calls + + /// Initialize resources for this device + void initializeForDevice(int device) override; + + cublasHandle_t getBlasHandle(int device) override; + + cudaStream_t getDefaultStream(int device) override; + + std::vector getAlternateStreams(int device) override; + + DeviceMemory& getMemoryManager(int device) override; + + std::pair getPinnedMemory() override; + + cudaStream_t getAsyncCopyStream(int device) override; + + private: + /// Have GPU resources been initialized for this device yet? + bool isInitialized(int device) const; + + /// Adjust the default temporary memory allocation based on the total GPU + /// memory size + static size_t getDefaultTempMemForGPU(int device, size_t requested); + + private: + /// Our default stream that work is ordered on, one per each device + std::unordered_map defaultStreams_; + + /// This contains particular streams as set by the user for + /// ordering, if any + std::unordered_map userDefaultStreams_; + + /// Other streams we can use, per each device + std::unordered_map > alternateStreams_; + + /// Async copy stream to use for GPU <-> CPU pinned memory copies + std::unordered_map asyncCopyStreams_; + + /// cuBLAS handle for each device + std::unordered_map blasHandles_; + + /// Temporary memory provider, per each device + std::unordered_map > memory_; + + /// Pinned memory allocation for use with this GPU + void* pinnedMemAlloc_; + size_t pinnedMemAllocSize_; + + /// Another option is to use a specified amount of memory on all + /// devices + size_t tempMemSize_; + + /// Amount of pinned memory we should allocate + size_t pinnedMemSize_; + + /// Whether or not a warning upon cudaMalloc is generated + bool cudaMallocWarning_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cu b/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cu new file mode 100644 index 0000000000..9c91ae2182 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cu @@ -0,0 +1,316 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Number of warps that the kernel is instantiated with +constexpr int kWarps = 8; +constexpr int kLanes = kWarpSize; + +constexpr int kMaxDistance = std::numeric_limits::max(); + +// Performs a binary matrix multiplication, returning the lowest k results in +// `vecs` for each `query` in terms of Hamming distance (a fused kernel) +// Each warp calculates distance for a single query +template +__launch_bounds__(kWarps * kLanes) +__global__ void binaryDistanceAnySize(const Tensor vecs, + const Tensor query, + Tensor outK, + Tensor outV, + int k) { + // A matrix tile (query, k) + __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict + + // B matrix tile (vec, k) + __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict + + WarpSelect, + NumWarpQ, NumThreadQ, kWarps * kLanes> + heap(kMaxDistance, -1, k); + + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + // Each warp handles a single query + int warpQuery = blockIdx.x * kWarps + warpId; + bool queryInBounds = warpQuery < query.getSize(0); + + // Each warp loops through the entire chunk of vectors + for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) { + int threadDistance = 0; + + // Reduction dimension + for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) { + int laneK = blockK + laneId; + bool kInBounds = laneK < vecs.getSize(1); + + queryTile[warpId][laneId] = queryInBounds && kInBounds ? + query[warpQuery][laneK] : 0; + + // kWarps warps are responsible for loading 32 vecs +#pragma unroll + for (int i = 0; i < kLanes / kWarps; ++i) { + int warpVec = i * kWarps + warpId; + int vec = blockVec + warpVec; + bool vecInBounds = vec < vecs.getSize(0); + + vecTile[warpVec][laneId] = vecInBounds && kInBounds ? + vecs[vec][laneK] : 0; + } + + __syncthreads(); + + // Compare distances +#pragma unroll + for (int i = 0; i < kLanes; ++i) { + threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); + } + + __syncthreads(); + } + + // Lanes within a warp are different vec results against the same query + // Only submit distances which represent real (query, vec) pairs + bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0)); + threadDistance = valInBounds ? threadDistance : kMaxDistance; + int id = valInBounds ? blockVec + laneId : -1; + + heap.add(threadDistance, id); + } + + heap.reduce(); + + if (warpQuery < query.getSize(0)) { + heap.writeOut(outK[warpQuery].data(), + outV[warpQuery].data(), + k); + } +} + +// Version of the kernel that avoids a loop over the reduction dimension, and +// thus avoids reloading the query vectors +template +__global__ void +__launch_bounds__(kWarps * kLanes) +binaryDistanceLimitSize(const Tensor vecs, + const Tensor query, + Tensor outK, + Tensor outV, + int k) { + // A matrix tile (query, k) + __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict + + // B matrix tile (vec, k) + __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict + + WarpSelect, + NumWarpQ, NumThreadQ, kWarps * kLanes> + heap(kMaxDistance, -1, k); + + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + // Each warp handles a single query + int laneK = laneId; + int warpQuery = blockIdx.x * kWarps + warpId; + bool kInBounds = laneK < vecs.getSize(1); + bool queryInBounds = warpQuery < query.getSize(0); + + + queryTile[warpId][laneId] = queryInBounds && kInBounds ? + query[warpQuery][laneK] : 0; + + // Each warp loops through the entire chunk of vectors + for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) { + int threadDistance = 0; + + // kWarps warps are responsible for loading 32 vecs +#pragma unroll + for (int i = 0; i < kLanes / kWarps; ++i) { + int warpVec = i * kWarps + warpId; + int vec = blockVec + warpVec; + bool vecInBounds = vec < vecs.getSize(0); + + vecTile[warpVec][laneId] = vecInBounds && kInBounds ? + vecs[vec][laneK] : 0; + } + + __syncthreads(); + + // Compare distances +#pragma unroll + for (int i = 0; i < ReductionLimit; ++i) { + threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]); + } + + __syncthreads(); + + // Lanes within a warp are different vec results against the same query + // Only submit distances which represent real (query, vec) pairs + bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0)); + threadDistance = valInBounds ? threadDistance : kMaxDistance; + int id = valInBounds ? blockVec + laneId : -1; + + heap.add(threadDistance, id); + } + + heap.reduce(); + + if (warpQuery < query.getSize(0)) { + heap.writeOut(outK[warpQuery].data(), + outV[warpQuery].data(), + k); + } +} + +template +void runBinaryDistanceAnySize(Tensor& vecs, + Tensor& query, + Tensor& outK, + Tensor& outV, + int k, cudaStream_t stream) { + dim3 grid(utils::divUp(query.getSize(0), kWarps)); + dim3 block(kLanes, kWarps); + + if (k == 1) { + binaryDistanceAnySize<1, 1, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 32) { + binaryDistanceAnySize<32, 2, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 64) { + binaryDistanceAnySize<64, 3, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 128) { + binaryDistanceAnySize<128, 3, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 256) { + binaryDistanceAnySize<256, 4, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 512) { + binaryDistanceAnySize<512, 8, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 1024) { + binaryDistanceAnySize<1024, 8, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } +#if GPU_MAX_SELECTION_K >= 2048 + else if (k <= 2048) { + binaryDistanceAnySize<2048, 8, BinaryType> + <<>>( + vecs, query, outK, outV, k); + } +#endif +} + +template +void runBinaryDistanceLimitSize(Tensor& vecs, + Tensor& query, + Tensor& outK, + Tensor& outV, + int k, cudaStream_t stream) { + dim3 grid(utils::divUp(query.getSize(0), kWarps)); + dim3 block(kLanes, kWarps); + + if (k == 1) { + binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 32) { + binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 64) { + binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 128) { + binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 256) { + binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 512) { + binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } else if (k <= 1024) { + binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } +#if GPU_MAX_SELECTION_K >= 2048 + else if (k <= 2048) { + binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit> + <<>>( + vecs, query, outK, outV, k); + } +#endif +} + +void runBinaryDistance(Tensor& vecs, + Tensor& query, + Tensor& outK, + Tensor& outV, + int k, cudaStream_t stream) { + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + FAISS_ASSERT(vecs.getSize(1) == query.getSize(1)); + + FAISS_ASSERT(outK.getSize(1) == k); + FAISS_ASSERT(outV.getSize(1) == k); + + // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims + constexpr int kReductionLimit32 = 8; + + // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims + constexpr int kReductionLimit8 = 16; + + // All other cases (large or small) go through the general kernel + + if (vecs.getSize(1) % sizeof(unsigned int) == 0 && + (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) { + auto vecs32 = vecs.castResize(); + auto query32 = query.castResize(); + + // Optimize for vectors with dimensions a multiple of 32 that are less than + // 32 * kReductionLimit (256) dimensions in size + runBinaryDistanceLimitSize( + vecs32, query32, outK, outV, k, stream); + + } else if (vecs.getSize(1) <= kReductionLimit8) { + // Optimize for vectors with dimensions a multiple of 32 that are less than + // 32 * kReductionLimit (256) dimensions in size + runBinaryDistanceLimitSize( + vecs, query, outK, outV, k, stream); + } else { + // Arbitrary size kernel + runBinaryDistanceAnySize( + vecs, query, outK, outV, k, stream); + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cuh b/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cuh new file mode 100644 index 0000000000..149accc016 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BinaryDistance.cuh @@ -0,0 +1,21 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include + +namespace faiss { namespace gpu { + +// Performs brute-force k-NN comparison between `vecs` and `query`, where they +// are encoded as binary vectors +void runBinaryDistance(Tensor& vecs, + Tensor& query, + Tensor& outK, + Tensor& outV, + int k, cudaStream_t stream); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cu b/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cu new file mode 100644 index 0000000000..dd38fdd7dd --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cu @@ -0,0 +1,88 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +BinaryFlatIndex::BinaryFlatIndex(GpuResources* res, + int dim, + MemorySpace space) : + resources_(res), + dim_(dim), + space_(space), + num_(0), + rawData_(space) { + FAISS_ASSERT(dim % 8 == 0); +} + +/// Returns the number of vectors we contain +int BinaryFlatIndex::getSize() const { + return vectors_.getSize(0); +} + +int BinaryFlatIndex::getDim() const { + return vectors_.getSize(1) * 8; +} + +void +BinaryFlatIndex::reserve(size_t numVecs, cudaStream_t stream) { + rawData_.reserve(numVecs * (dim_ / 8) * sizeof(unsigned int), stream); +} + +Tensor& +BinaryFlatIndex::getVectorsRef() { + return vectors_; +} + +void +BinaryFlatIndex::query(Tensor& input, + int k, + Tensor& outDistances, + Tensor& outIndices) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + runBinaryDistance(vectors_, + input, + outDistances, + outIndices, + k, + stream); +} + +void +BinaryFlatIndex::add(const unsigned char* data, + int numVecs, + cudaStream_t stream) { + if (numVecs == 0) { + return; + } + + rawData_.append((char*) data, + (size_t) (dim_ / 8) * numVecs * sizeof(unsigned char), + stream, + true /* reserve exactly */); + + num_ += numVecs; + + DeviceTensor vectors( + (unsigned char*) rawData_.data(), {(int) num_, (dim_ / 8)}, space_); + vectors_ = std::move(vectors); +} + +void +BinaryFlatIndex::reset() { + rawData_.clear(); + vectors_ = std::move(DeviceTensor()); + num_ = 0; +} + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cuh b/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cuh new file mode 100644 index 0000000000..c99afc45a7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BinaryFlatIndex.cuh @@ -0,0 +1,69 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +/// Holder of GPU resources for a particular flat index +class BinaryFlatIndex { + public: + BinaryFlatIndex(GpuResources* res, + int dim, + MemorySpace space); + + /// Returns the number of vectors we contain + int getSize() const; + + int getDim() const; + + /// Reserve storage that can contain at least this many vectors + void reserve(size_t numVecs, cudaStream_t stream); + + /// Returns a reference to our vectors currently in use + Tensor& getVectorsRef(); + + void query(Tensor& vecs, + int k, + Tensor& outDistances, + Tensor& outIndices); + + /// Add vectors to ourselves; the pointer passed can be on the host + /// or the device + void add(const unsigned char* data, int numVecs, cudaStream_t stream); + + /// Free all storage + void reset(); + + private: + /// Collection of GPU resources that we use + GpuResources* resources_; + + /// Dimensionality of our vectors + const int dim_; + + /// Memory space for our allocations + MemorySpace space_; + + /// How many vectors we have + int num_; + + /// The underlying expandable storage + DeviceVector rawData_; + + /// Vectors currently in rawData_ + DeviceTensor vectors_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu new file mode 100644 index 0000000000..e9f7548e25 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu @@ -0,0 +1,360 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +__global__ void sumAlongColumns(Tensor input, + Tensor output) { + static_assert(kRowsPerBlock % kRowUnroll == 0, "must fit rows"); + + // blockIdx.x: which chunk of rows we are responsible for updating + // blockIdx.y: which chunk of columns we are responsible for + // updating + int rowStart = blockIdx.x * kRowsPerBlock; + int rowEnd = rowStart + kRowsPerBlock; + int colStart = blockIdx.y * blockDim.x * kColLoad; + + // FIXME: if we have exact multiples, don't need this + bool endRow = (blockIdx.x == gridDim.x - 1); + bool endCol = (blockIdx.y == gridDim.y - 1); + + if (endRow) { + if (output.getSize(0) % kRowsPerBlock == 0) { + endRow = false; + } + } + + if (endCol) { + for (int col = colStart + threadIdx.x; + col < input.getSize(0); col += blockDim.x) { + T val = input[col]; + + if (endRow) { + for (int row = rowStart; row < output.getSize(0); ++row) { + T out = output[row][col]; + out = Math::add(out, val); + output[row][col] = out; + } + } else { + T rows[kRowUnroll]; + + for (int row = rowStart; row < rowEnd; row += kRowUnroll) { +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { + rows[i] = output[row + i][col]; + } + +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { + rows[i] = Math::add(rows[i], val); + } + +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { + output[row + i][col] = rows[i]; + } + } + } + } + } else { + int col = colStart + threadIdx.x; + + T val[kColLoad]; + +#pragma unroll + for (int i = 0; i < kColLoad; ++i) { + val[i] = input[col + i * blockDim.x]; + } + + if (endRow) { + for (int row = rowStart; row < output.getSize(0); ++row) { +#pragma unroll + for (int i = 0; i < kColLoad; ++i) { + T out = output[row][col + i * blockDim.x]; + out = Math::add(out, val[i]); + output[row][col + i * blockDim.x] = out; + } + } + } else { + T rows[kRowUnroll * kColLoad]; + + for (int row = rowStart; row < rowEnd; row += kRowUnroll) { +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kColLoad; ++j) { + rows[i * kColLoad + j] = + output[row + i][col + j * blockDim.x]; + } + } + +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kColLoad; ++j) { + rows[i * kColLoad + j] = + Math::add(rows[i * kColLoad + j], val[j]); + } + } + +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kColLoad; ++j) { + output[row + i][col + j * blockDim.x] = + rows[i * kColLoad + j]; + } + } + } + } + } +} + +template +__global__ void assignAlongColumns(Tensor input, + Tensor output) { + static_assert(kRowsPerBlock % kRowUnroll == 0, "must fit rows"); + + // blockIdx.x: which chunk of rows we are responsible for updating + // blockIdx.y: which chunk of columns we are responsible for + // updating + int rowStart = blockIdx.x * kRowsPerBlock; + int rowEnd = rowStart + kRowsPerBlock; + int colStart = blockIdx.y * blockDim.x * kColLoad; + + // FIXME: if we have exact multiples, don't need this + bool endRow = (blockIdx.x == gridDim.x - 1); + bool endCol = (blockIdx.y == gridDim.y - 1); + + if (endRow) { + if (output.getSize(0) % kRowsPerBlock == 0) { + endRow = false; + } + } + + if (endCol) { + for (int col = colStart + threadIdx.x; + col < input.getSize(0); col += blockDim.x) { + T val = input[col]; + + if (endRow) { + for (int row = rowStart; row < output.getSize(0); ++row) { + output[row][col] = val; + } + } else { + for (int row = rowStart; row < rowEnd; row += kRowUnroll) { +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { + output[row + i][col] = val; + } + } + } + } + } else { + int col = colStart + threadIdx.x; + + T val[kColLoad]; + +#pragma unroll + for (int i = 0; i < kColLoad; ++i) { + val[i] = input[col + i * blockDim.x]; + } + + if (endRow) { + for (int row = rowStart; row < output.getSize(0); ++row) { +#pragma unroll + for (int i = 0; i < kColLoad; ++i) { + output[row][col + i * blockDim.x] = val[i]; + } + } + } else { + for (int row = rowStart; row < rowEnd; row += kRowUnroll) { +#pragma unroll + for (int i = 0; i < kRowUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kColLoad; ++j) { + output[row + i][col + j * blockDim.x] = val[j]; + } + } + } + } + } +} + +template +__global__ void sumAlongRows(Tensor input, + Tensor output) { + __shared__ T sval; + + int row = blockIdx.x; + + if (threadIdx.x == 0) { + sval = input[row]; + } + + __syncthreads(); + + T val = sval; + + // FIXME: speed up + for (int i = threadIdx.x; i < output.getSize(1); i += blockDim.x) { + T out = output[row][i]; + out = Math::add(out, val); + out = Math::lt(out, Math::zero()) ? Math::zero() : out; + + output[row][i] = out; + } +} + +template +void runSumAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + FAISS_ASSERT(input.getSize(0) == output.getSize(1)); + + int threadsPerBlock = 256; + constexpr int kRowUnroll = 4; + constexpr int kRowsPerBlock = kRowUnroll * 4; + constexpr int kColLoad = 4; + + auto block = dim3(threadsPerBlock); + + if (input.template canCastResize() && + output.template canCastResize()) { + auto inputV = input.template castResize(); + auto outputV = output.template castResize(); + + auto grid = + dim3(utils::divUp(outputV.getSize(0), kRowsPerBlock), + utils::divUp(outputV.getSize(1), threadsPerBlock * kColLoad)); + + sumAlongColumns + <<>>(inputV, outputV); + } else { + auto grid = + dim3(utils::divUp(output.getSize(0), kRowsPerBlock), + utils::divUp(output.getSize(1), threadsPerBlock * kColLoad)); + + sumAlongColumns + <<>>(input, output); + } + + CUDA_TEST_ERROR(); +} + +void runSumAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + runSumAlongColumns(input, output, stream); +} + +#ifdef FAISS_USE_FLOAT16 +void runSumAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + runSumAlongColumns(input, output, stream); +} +#endif + +template +void runAssignAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + FAISS_ASSERT(input.getSize(0) == output.getSize(1)); + + int threadsPerBlock = 256; + constexpr int kRowUnroll = 4; + constexpr int kRowsPerBlock = kRowUnroll * 4; + constexpr int kColLoad = 4; + + auto block = dim3(threadsPerBlock); + + if (input.template canCastResize() && + output.template canCastResize()) { + auto inputV = input.template castResize(); + auto outputV = output.template castResize(); + + auto grid = + dim3(utils::divUp(outputV.getSize(0), kRowsPerBlock), + utils::divUp(outputV.getSize(1), threadsPerBlock * kColLoad)); + + assignAlongColumns + <<>>(inputV, outputV); + } else { + auto grid = + dim3(utils::divUp(output.getSize(0), kRowsPerBlock), + utils::divUp(output.getSize(1), threadsPerBlock * kColLoad)); + + assignAlongColumns + <<>>(input, output); + } + + CUDA_TEST_ERROR(); +} + +void runAssignAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + runAssignAlongColumns(input, output, stream); +} + +#ifdef FAISS_USE_FLOAT16 +void runAssignAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream) { + runAssignAlongColumns(input, output, stream); +} +#endif + +template +void runSumAlongRows(Tensor& input, + Tensor& output, + bool zeroClamp, + cudaStream_t stream) { + FAISS_ASSERT(input.getSize(0) == output.getSize(0)); + + int threadsPerBlock = + std::min(output.getSize(1), getMaxThreadsCurrentDevice()); + auto grid = dim3(output.getSize(0)); + auto block = dim3(threadsPerBlock); + + if (zeroClamp) { + sumAlongRows<<>>(input, output); + } else { + sumAlongRows<<>>(input, output); + } + + CUDA_TEST_ERROR(); +} + +void runSumAlongRows(Tensor& input, + Tensor& output, + bool zeroClamp, + cudaStream_t stream) { + runSumAlongRows(input, output, zeroClamp, stream); +} + +#ifdef FAISS_USE_FLOAT16 +void runSumAlongRows(Tensor& input, + Tensor& output, + bool zeroClamp, + cudaStream_t stream) { + runSumAlongRows(input, output, zeroClamp, stream); +} +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh new file mode 100644 index 0000000000..6641aadd40 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +// output[x][i] += input[i] for all x +void runSumAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runSumAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream); +#endif + +// output[x][i] = input[i] for all x +void runAssignAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runAssignAlongColumns(Tensor& input, + Tensor& output, + cudaStream_t stream); +#endif + +// output[i][x] += input[i] for all x +// If zeroClamp, output[i][x] = max(output[i][x] + input[i], 0) for all x +void runSumAlongRows(Tensor& input, + Tensor& output, + bool zeroClamp, + cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runSumAlongRows(Tensor& input, + Tensor& output, + bool zeroClamp, + cudaStream_t stream); + +#endif +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu new file mode 100644 index 0000000000..0856396cc1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu @@ -0,0 +1,448 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +void runDistance(bool computeL2, + GpuResources* resources, + Tensor& centroids, + bool centroidsRowMajor, + Tensor* centroidNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances) { + // The # of centroids in `centroids` based on memory layout + auto numCentroids = centroids.getSize(centroidsRowMajor ? 0 : 1); + + // The # of queries in `queries` based on memory layout + auto numQueries = queries.getSize(queriesRowMajor ? 0 : 1); + + // The dimensions of the vectors to consider + auto dim = queries.getSize(queriesRowMajor ? 1 : 0); + FAISS_ASSERT((numQueries == 0 || numCentroids == 0) || + dim == centroids.getSize(centroidsRowMajor ? 1 : 0)); + + FAISS_ASSERT(outDistances.getSize(0) == numQueries); + FAISS_ASSERT(outIndices.getSize(0) == numQueries); + FAISS_ASSERT(outDistances.getSize(1) == k); + FAISS_ASSERT(outIndices.getSize(1) == k); + + auto& mem = resources->getMemoryManagerCurrentDevice(); + auto defaultStream = resources->getDefaultStreamCurrentDevice(); + + // If we're quering against a 0 sized set, just return empty results + if (centroids.numElements() == 0) { + thrust::fill(thrust::cuda::par.on(defaultStream), + outDistances.data(), outDistances.end(), + Limits::getMax()); + + thrust::fill(thrust::cuda::par.on(defaultStream), + outIndices.data(), outIndices.end(), + -1); + + return; + } + + // L2: If ||c||^2 is not pre-computed, calculate it + DeviceTensor cNorms; + if (computeL2 && !centroidNorms) { + cNorms = + std::move(DeviceTensor( + mem, {numCentroids}, defaultStream)); + runL2Norm(centroids, centroidsRowMajor, cNorms, true, defaultStream); + centroidNorms = &cNorms; + } + + // + // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed + // + int qNormSize[1] = {numQueries}; + DeviceTensor queryNorms(mem, qNormSize, defaultStream); + + // ||q||^2 + if (computeL2) { + runL2Norm(queries, queriesRowMajor, queryNorms, true, defaultStream); + } + + // By default, aim to use up to 512 MB of memory for the processing, with both + // number of queries and number of centroids being at least 512. + int tileRows = 0; + int tileCols = 0; + chooseTileSize(numQueries, + numCentroids, + dim, + sizeof(T), + mem.getSizeAvailable(), + tileRows, + tileCols); + + int numColTiles = utils::divUp(numCentroids, tileCols); + + // We can have any number of vectors to query against, even less than k, in + // which case we'll return -1 for the index + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); // select limitation + + // Temporary output memory space we'll use + DeviceTensor distanceBuf1( + mem, {tileRows, tileCols}, defaultStream); + DeviceTensor distanceBuf2( + mem, {tileRows, tileCols}, defaultStream); + DeviceTensor* distanceBufs[2] = + {&distanceBuf1, &distanceBuf2}; + + DeviceTensor outDistanceBuf1( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor outDistanceBuf2( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor* outDistanceBufs[2] = + {&outDistanceBuf1, &outDistanceBuf2}; + + DeviceTensor outIndexBuf1( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor outIndexBuf2( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor* outIndexBufs[2] = + {&outIndexBuf1, &outIndexBuf2}; + + auto streams = resources->getAlternateStreamsCurrentDevice(); + streamWait(streams, {defaultStream}); + + int curStream = 0; + bool interrupt = false; + + // Tile over the input queries + for (int i = 0; i < numQueries; i += tileRows) { + if (interrupt || InterruptCallback::is_interrupted()) { + interrupt = true; + break; + } + + int curQuerySize = std::min(tileRows, numQueries - i); + + auto outDistanceView = + outDistances.narrow(0, i, curQuerySize); + auto outIndexView = + outIndices.narrow(0, i, curQuerySize); + + auto queryView = + queries.narrow(queriesRowMajor ? 0 : 1, i, curQuerySize); + auto queryNormNiew = + queryNorms.narrow(0, i, curQuerySize); + + auto outDistanceBufRowView = + outDistanceBufs[curStream]->narrow(0, 0, curQuerySize); + auto outIndexBufRowView = + outIndexBufs[curStream]->narrow(0, 0, curQuerySize); + + // Tile over the centroids + for (int j = 0; j < numCentroids; j += tileCols) { + if (InterruptCallback::is_interrupted()) { + interrupt = true; + break; + } + + int curCentroidSize = std::min(tileCols, numCentroids - j); + int curColTile = j / tileCols; + + auto centroidsView = + sliceCentroids(centroids, centroidsRowMajor, j, curCentroidSize); + + auto distanceBufView = distanceBufs[curStream]-> + narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize); + + auto outDistanceBufColView = + outDistanceBufRowView.narrow(1, k * curColTile, k); + auto outIndexBufColView = + outIndexBufRowView.narrow(1, k * curColTile, k); + + // L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc + // IP: just compute qc + // (query id x dim) x (centroid id, dim)' = (query id, centroid id) + runMatrixMult(distanceBufView, + false, // not transposed + queryView, + !queriesRowMajor, // transposed MM if col major + centroidsView, + centroidsRowMajor, // transposed MM if row major + computeL2 ? -2.0f : 1.0f, + 0.0f, + resources->getBlasHandleCurrentDevice(), + streams[curStream]); + + if (computeL2) { + // For L2 distance, we use this fused kernel that performs both + // adding ||c||^2 to -2qc and k-selection, so we only need two + // passes (one write by the gemm, one read here) over the huge + // region of output memory + // + // If we aren't tiling along the number of centroids, we can perform the + // output work directly + if (tileCols == numCentroids) { + // Write into the final output + runL2SelectMin(distanceBufView, + *centroidNorms, + bitset, + outDistanceView, + outIndexView, + k, + streams[curStream]); + + if (!ignoreOutDistances) { + // expand (query id) to (query id, k) by duplicating along rows + // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k) + runSumAlongRows(queryNormNiew, + outDistanceView, + true, // L2 distances should not go below zero due + // to roundoff error + streams[curStream]); + } + } else { + auto centroidNormsView = centroidNorms->narrow(0, j, curCentroidSize); + + // Write into our intermediate output + runL2SelectMin(distanceBufView, + centroidNormsView, + bitset, + outDistanceBufColView, + outIndexBufColView, + k, + streams[curStream]); + + if (!ignoreOutDistances) { + // expand (query id) to (query id, k) by duplicating along rows + // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k) + runSumAlongRows(queryNormNiew, + outDistanceBufColView, + true, // L2 distances should not go below zero due + // to roundoff error + streams[curStream]); + } + } + } else { + // For IP, just k-select the output for this tile + if (tileCols == numCentroids) { + // Write into the final output + runBlockSelect(distanceBufView, + bitset, + outDistanceView, + outIndexView, + true, k, streams[curStream]); + } else { + // Write into the intermediate output + runBlockSelect(distanceBufView, + bitset, + outDistanceBufColView, + outIndexBufColView, + true, k, streams[curStream]); + } + } + } + + // As we're finished with processing a full set of centroids, perform the + // final k-selection + if (tileCols != numCentroids) { + // The indices are tile-relative; for each tile of k, we need to add + // tileCols to the index + runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]); + + runBlockSelectPair(outDistanceBufRowView, + outIndexBufRowView, + bitset, + outDistanceView, + outIndexView, + computeL2 ? false : true, k, streams[curStream]); + } + + curStream = (curStream + 1) % 2; + } + + // Have the desired ordering stream wait on the multi-stream + streamWait({defaultStream}, streams); + + if (interrupt) { + FAISS_THROW_MSG("interrupted"); + } +} + +template +void runL2Distance(GpuResources* resources, + Tensor& centroids, + bool centroidsRowMajor, + Tensor* centroidNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances = false) { + runDistance(true, // L2 + resources, + centroids, + centroidsRowMajor, + centroidNorms, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices, + ignoreOutDistances); +} + +template +void runIPDistance(GpuResources* resources, + Tensor& centroids, + bool centroidsRowMajor, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices) { + runDistance(false, // IP + resources, + centroids, + centroidsRowMajor, + nullptr, // no centroid norms provided + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices, + false); +} + +// +// Instantiations of the distance templates +// + +void +runIPDistance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices) { + runIPDistance(resources, + vectors, + vectorsRowMajor, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices); +} + +#ifdef FAISS_USE_FLOAT16 +void +runIPDistance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices) { + runIPDistance(resources, + vectors, + vectorsRowMajor, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices); +} +#endif + +void +runL2Distance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances) { + runL2Distance(resources, + vectors, + vectorsRowMajor, + vectorNorms, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices, + ignoreOutDistances); +} + +#ifdef FAISS_USE_FLOAT16 +void +runL2Distance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances) { + runL2Distance(resources, + vectors, + vectorsRowMajor, + vectorNorms, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices, + ignoreOutDistances); +} +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh new file mode 100644 index 0000000000..3430ddf87f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh @@ -0,0 +1,216 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +/// Calculates brute-force L2 distance between `vectors` and +/// `queries`, returning the k closest results seen +void runL2Distance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + // can be optionally pre-computed; nullptr if we + // have to compute it upon the call + Tensor* vectorNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + // Do we care about `outDistances`? If not, we can + // take shortcuts. + bool ignoreOutDistances = false); + +/// Calculates brute-force inner product distance between `vectors` +/// and `queries`, returning the k closest results seen +void runIPDistance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices); + +void runIPDistance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices); + + +void runL2Distance(GpuResources* resources, + Tensor& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances = false); + +// +// General distance implementation, assumes that all arguments are on the +// device. This is the top-level internal distance function to call to dispatch +// based on metric type. +// +template +void bfKnnOnDevice(GpuResources* resources, + int device, + cudaStream_t stream, + Tensor& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor& queries, + bool queriesRowMajor, + Tensor& bitset, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances) { + // We are guaranteed that all data arguments are resident on our preferred + // `device` here, and are ordered wrt `stream` + + // L2 and IP are specialized to use GEMM and an optimized L2 + selection or + // pure k-selection kernel. + if ((metric == faiss::MetricType::METRIC_L2) || + (metric == faiss::MetricType::METRIC_Lp && + metricArg == 2)) { + runL2Distance(resources, + vectors, + vectorsRowMajor, + vectorNorms, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_INNER_PRODUCT) { + runIPDistance(resources, + vectors, + vectorsRowMajor, + queries, + queriesRowMajor, + bitset, + k, + outDistances, + outIndices); + } else { + // + // General pairwise distance kernel + // + // The general distance kernel does not have specializations for + // transpositions (NN, NT, TN); instead, the transposition is just handled + // upon data load for now, which could result in poor data loading behavior + // for NT / TN. This can be fixed at a later date if desired, but efficiency + // is low versus GEMM anyways. + // + + Tensor tVectorsDimInnermost = + vectorsRowMajor ? + vectors.transposeInnermost(1) : + vectors.transposeInnermost(0); + Tensor tQueriesDimInnermost = + queriesRowMajor ? + queries.transposeInnermost(1) : + queries.transposeInnermost(0); + + if ((metric == faiss::MetricType::METRIC_L1) || + (metric == faiss::MetricType::METRIC_Lp && + metricArg == 1)) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + L1Distance(), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_Lp && + metricArg == -1) { + // A way to test L2 distance + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + L2Distance(), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_Lp) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + LpDistance(metricArg), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_Linf) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + LinfDistance(), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_Canberra) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + CanberraDistance(), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_BrayCurtis) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + BrayCurtisDistance(), + outDistances, + outIndices); + } else if (metric == faiss::MetricType::METRIC_JensenShannon) { + runGeneralDistance(resources, + tVectorsDimInnermost, + tQueriesDimInnermost, + bitset, + k, + JensenShannonDistance(), + outDistances, + outIndices); + } else { + FAISS_THROW_FMT("unsupported metric type %d", metric); + } + } +} + + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/DistanceUtils.cuh b/core/src/index/thirdparty/faiss/gpu/impl/DistanceUtils.cuh new file mode 100644 index 0000000000..42d815a5f3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/DistanceUtils.cuh @@ -0,0 +1,343 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +// +// Shared utilities for brute-force distance calculations +// + +namespace faiss { namespace gpu { + +struct IPDistance { + __host__ __device__ IPDistance() : dist(0) {} + + static constexpr bool kDirection = true; // maximize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = -std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + dist += a * b; + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const IPDistance& v) { + dist += v.dist; + } + + __host__ __device__ IPDistance zero() const { + return IPDistance(); + } + + float dist; +}; + +struct L1Distance { + __host__ __device__ L1Distance() : dist(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + dist += fabsf(a - b); + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const L1Distance& v) { + dist += v.dist; + } + + __host__ __device__ L1Distance zero() const { + return L1Distance(); + } + + float dist; +}; + +struct L2Distance { + __host__ __device__ L2Distance() : dist(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + float v = a - b; + dist += v * v; + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const L2Distance& v) { + dist += v.dist; + } + + __host__ __device__ L2Distance zero() const { + return L2Distance(); + } + + float dist; +}; + +struct LpDistance { + __host__ __device__ LpDistance() + : p(2), dist(0) {} + + __host__ __device__ LpDistance(float arg) + : p(arg), dist(0) {} + + __host__ __device__ LpDistance(const LpDistance& v) + : p(v.p), dist(v.dist) {} + + __host__ __device__ LpDistance& operator=(const LpDistance& v) { + p = v.p; + dist = v.dist; + return *this; + } + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + dist += powf(fabsf(a - b), p); + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const LpDistance& v) { + dist += v.dist; + } + + __host__ __device__ LpDistance zero() const { + return LpDistance(p); + } + + float p; + float dist; +}; + +struct LinfDistance { + __host__ __device__ LinfDistance() : dist(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + dist = fmaxf(dist, fabsf(a - b)); + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const LinfDistance& v) { + dist = fmaxf(dist, v.dist); + } + + __host__ __device__ LinfDistance zero() const { + return LinfDistance(); + } + + float dist; +}; + +struct CanberraDistance { + __host__ __device__ CanberraDistance() : dist(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + float denom = fabsf(a) + fabsf(b); + dist += fabsf(a - b) / denom; + } + + __host__ __device__ float reduce() { + return dist; + } + + __host__ __device__ void combine(const CanberraDistance& v) { + dist += v.dist; + } + + __host__ __device__ CanberraDistance zero() const { + return CanberraDistance(); + } + + float dist; +}; + +struct BrayCurtisDistance { + __host__ __device__ BrayCurtisDistance() + : numerator(0), denominator(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + numerator += fabsf(a - b); + denominator += fabsf(a + b); + } + + __host__ __device__ float reduce() { + return (numerator / denominator); + } + + __host__ __device__ void combine(const BrayCurtisDistance& v) { + numerator += v.numerator; + denominator += v.denominator; + } + + __host__ __device__ BrayCurtisDistance zero() const { + return BrayCurtisDistance(); + } + + float numerator; + float denominator; +}; + +struct JensenShannonDistance { + __host__ __device__ JensenShannonDistance() + : dist(0) {} + + static constexpr bool kDirection = false; // minimize + static constexpr float kIdentityData = 0; + static constexpr float kMaxDistance = std::numeric_limits::max(); + + __host__ __device__ void handle(float a, float b) { + float m = 0.5f * (a + b); + + float x = m / a; + float y = m / b; + + float kl1 = -a * log(x); + float kl2 = -b * log(y); + + dist += kl1 + kl2; + } + + __host__ __device__ float reduce() { + return 0.5 * dist; + } + + __host__ __device__ void combine(const JensenShannonDistance& v) { + dist += v.dist; + } + + __host__ __device__ JensenShannonDistance zero() const { + return JensenShannonDistance(); + } + + float dist; +}; + +template +Tensor sliceCentroids(Tensor& centroids, + bool centroidsRowMajor, + int startCentroid, + int num) { + // Row major is (num, dim) + // Col major is (dim, num) + if (startCentroid == 0 && + num == centroids.getSize(centroidsRowMajor ? 0 : 1)) { + return centroids; + } + + return centroids.narrow(centroidsRowMajor ? 0 : 1, startCentroid, num); +} + +// For each chunk of k indices, increment the index by chunk * increment +template +__global__ void incrementIndex(Tensor indices, + int k, + int increment) { + for (int i = threadIdx.x; i < k; i += blockDim.x) { + indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment; + } +} + +// Used to update result indices in distance computation where the number of +// centroids is high, and is tiled +template +void runIncrementIndex(Tensor& indices, + int k, + int increment, + cudaStream_t stream) { + dim3 grid(indices.getSize(1) / k, indices.getSize(0)); + int block = std::min(k, 512); + + // should be exact + FAISS_ASSERT(grid.x * k == indices.getSize(1)); + + incrementIndex<<>>(indices, k, increment); +} + +// If the inner size (dim) of the vectors is small, we want a larger query tile +// size, like 1024 +inline void chooseTileSize(int numQueries, + int numCentroids, + int dim, + int elementSize, + size_t tempMemAvailable, + int& tileRows, + int& tileCols) { + // The matrix multiplication should be large enough to be efficient, but if it + // is too large, we seem to lose efficiency as opposed to double-streaming. + // Each tile size here defines 1/2 of the memory use due to double streaming. + // We ignore available temporary memory, as that is adjusted independently by + // the user and can thus meet these requirements (or not). + // For <= 4 GB GPUs, prefer 512 MB of usage. + // For <= 8 GB GPUs, prefer 768 MB of usage. + // Otherwise, prefer 1 GB of usage. + auto totalMem = getCurrentDeviceProperties().totalGlobalMem; + + int targetUsage = 0; + + if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) { + targetUsage = 512 * 1024 * 1024; + } else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } else { + targetUsage = 1024 * 1024 * 1024; + } + + targetUsage /= 2 * elementSize; + + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + int preferredTileRows = 512; + if (dim <= 32) { + preferredTileRows = 1024; + } + + tileRows = std::min(preferredTileRows, numQueries); + + // tileCols is the remainder size + tileCols = std::min(targetUsage / preferredTileRows, numCentroids); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu new file mode 100644 index 0000000000..29480fa84f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu @@ -0,0 +1,388 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +FlatIndex::FlatIndex(GpuResources* res, + int dim, + bool useFloat16, + bool storeTransposed, + MemorySpace space) : + resources_(res), + dim_(dim), + useFloat16_(useFloat16), + storeTransposed_(storeTransposed), + space_(space), + num_(0), + rawData_(space) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16_); +#endif +} + +bool +FlatIndex::getUseFloat16() const { + return useFloat16_; +} + +/// Returns the number of vectors we contain +int FlatIndex::getSize() const { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + return vectorsHalf_.getSize(0); + } else { + return vectors_.getSize(0); + } +#else + return vectors_.getSize(0); +#endif +} + +int FlatIndex::getDim() const { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + return vectorsHalf_.getSize(1); + } else { + return vectors_.getSize(1); + } +#else + return vectors_.getSize(1); +#endif +} + +void +FlatIndex::reserve(size_t numVecs, cudaStream_t stream) { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + rawData_.reserve(numVecs * dim_ * sizeof(half), stream); + } else { + rawData_.reserve(numVecs * dim_ * sizeof(float), stream); + } +#else + rawData_.reserve(numVecs * dim_ * sizeof(float), stream); +#endif +} + +template <> +Tensor& +FlatIndex::getVectorsRef() { + // Should not call this unless we are in float32 mode + FAISS_ASSERT(!useFloat16_); + return getVectorsFloat32Ref(); +} + +#ifdef FAISS_USE_FLOAT16 +template <> +Tensor& +FlatIndex::getVectorsRef() { + // Should not call this unless we are in float16 mode + FAISS_ASSERT(useFloat16_); + return getVectorsFloat16Ref(); +} +#endif + +Tensor& +FlatIndex::getVectorsFloat32Ref() { + // Should not call this unless we are in float32 mode + FAISS_ASSERT(!useFloat16_); + + return vectors_; +} + +#ifdef FAISS_USE_FLOAT16 +Tensor& +FlatIndex::getVectorsFloat16Ref() { + // Should not call this unless we are in float16 mode + FAISS_ASSERT(useFloat16_); + + return vectorsHalf_; +} +#endif + +DeviceTensor +FlatIndex::getVectorsFloat32Copy(cudaStream_t stream) { + return getVectorsFloat32Copy(0, num_, stream); +} + +DeviceTensor +FlatIndex::getVectorsFloat32Copy(int from, int num, cudaStream_t stream) { + DeviceTensor vecFloat32({num, dim_}, space_); + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + auto halfNarrow = vectorsHalf_.narrowOutermost(from, num); + convertTensor(stream, halfNarrow, vecFloat32); + } else { + vectors_.copyTo(vecFloat32, stream); + } +#else + vectors_.copyTo(vecFloat32, stream); +#endif + + return vecFloat32; +} + +void +FlatIndex::query(Tensor& input, + Tensor& bitset, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + auto& mem = resources_->getMemoryManagerCurrentDevice(); + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + // We need to convert the input to float16 for comparison to ourselves + + auto inputHalf = + convertTensor(resources_, stream, input); + + query(inputHalf, bitset, k, metric, metricArg, + outDistances, outIndices, exactDistance); + + } else { + bfKnnOnDevice(resources_, + getCurrentDevice(), + stream, + storeTransposed_ ? vectorsTransposed_ : vectors_, + !storeTransposed_, // is vectors row major? + &norms_, + input, + true, // input is row major + bitset, + k, + metric, + metricArg, + outDistances, + outIndices, + !exactDistance); + } +#else + bfKnnOnDevice(resources_, + getCurrentDevice(), + stream, + storeTransposed_ ? vectorsTransposed_ : vectors_, + !storeTransposed_, // is vectors row major? + &norms_, + input, + true, // input is row major + bitset, + k, + metric, + metricArg, + outDistances, + outIndices, + !exactDistance); +#endif +} + +#ifdef FAISS_USE_FLOAT16 +void +FlatIndex::query(Tensor& input, + Tensor& bitset, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) { + FAISS_ASSERT(useFloat16_); + + bfKnnOnDevice(resources_, + getCurrentDevice(), + resources_->getDefaultStreamCurrentDevice(), + storeTransposed_ ? vectorsHalfTransposed_ : vectorsHalf_, + !storeTransposed_, // is vectors row major? + &norms_, + input, + true, // input is row major + bitset, + k, + metric, + metricArg, + outDistances, + outIndices, + !exactDistance); +} +#endif + +void +FlatIndex::computeResidual(Tensor& vecs, + Tensor& listIds, + Tensor& residuals) { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + runCalcResidual(vecs, + getVectorsFloat16Ref(), + listIds, + residuals, + resources_->getDefaultStreamCurrentDevice()); + } else { + runCalcResidual(vecs, + getVectorsFloat32Ref(), + listIds, + residuals, + resources_->getDefaultStreamCurrentDevice()); + } +#else + runCalcResidual(vecs, + getVectorsFloat32Ref(), + listIds, + residuals, + resources_->getDefaultStreamCurrentDevice()); +#endif +} + +void +FlatIndex::reconstruct(Tensor& listIds, + Tensor& vecs) { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + runReconstruct(listIds, + getVectorsFloat16Ref(), + vecs, + resources_->getDefaultStreamCurrentDevice()); + } else { + runReconstruct(listIds, + getVectorsFloat32Ref(), + vecs, + resources_->getDefaultStreamCurrentDevice()); + } +#else + runReconstruct(listIds, + getVectorsFloat32Ref(), + vecs, + resources_->getDefaultStreamCurrentDevice()); +#endif +} +void +FlatIndex::reconstruct(Tensor& listIds, + Tensor& vecs) { + auto listIds1 = listIds.downcastOuter<1>(); + auto vecs2 = vecs.downcastOuter<2>(); + + reconstruct(listIds1, vecs2); +} + +void +FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { + if (numVecs == 0) { + return; + } + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + // Make sure that `data` is on our device; we'll run the + // conversion on our device + auto devData = toDevice(resources_, + getCurrentDevice(), + (float*) data, + stream, + {numVecs, dim_}); + + auto devDataHalf = + convertTensor(resources_, stream, devData); + + rawData_.append((char*) devDataHalf.data(), + devDataHalf.getSizeInBytes(), + stream, + true /* reserve exactly */); + } else { + rawData_.append((char*) data, + (size_t) dim_ * numVecs * sizeof(float), + stream, + true /* reserve exactly */); + } + +#else + rawData_.append((char*) data, + (size_t) dim_ * numVecs * sizeof(float), + stream, + true /* reserve exactly */); +#endif + num_ += numVecs; + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + DeviceTensor vectorsHalf( + (half*) rawData_.data(), {(int) num_, dim_}, space_); + vectorsHalf_ = std::move(vectorsHalf); + } else { + DeviceTensor vectors( + (float*) rawData_.data(), {(int) num_, dim_}, space_); + vectors_ = std::move(vectors); + } +#else + DeviceTensor vectors( + (float*) rawData_.data(), {(int) num_, dim_}, space_); + vectors_ = std::move(vectors); +#endif + + if (storeTransposed_) { +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + vectorsHalfTransposed_ = + std::move(DeviceTensor({dim_, (int) num_}, space_)); + runTransposeAny(vectorsHalf_, 0, 1, vectorsHalfTransposed_, stream); + } else { + vectorsTransposed_ = + std::move(DeviceTensor({dim_, (int) num_}, space_)); + runTransposeAny(vectors_, 0, 1, vectorsTransposed_, stream); + } +#else + vectorsTransposed_ = + std::move(DeviceTensor({dim_, (int) num_}, space_)); + runTransposeAny(vectors_, 0, 1, vectorsTransposed_, stream); +#endif + } + + // Precompute L2 norms of our database +#ifdef FAISS_USE_FLOAT16 + if (useFloat16_) { + DeviceTensor norms({(int) num_}, space_); + runL2Norm(vectorsHalf_, true, norms, true, stream); + norms_ = std::move(norms); + } else { + DeviceTensor norms({(int) num_}, space_); + runL2Norm(vectors_, true, norms, true, stream); + norms_ = std::move(norms); + } +#else + DeviceTensor norms({(int) num_}, space_); + runL2Norm(vectors_, true, norms, true, stream); + norms_ = std::move(norms); +#endif +} + +void +FlatIndex::reset() { + rawData_.clear(); + vectors_ = std::move(DeviceTensor()); + vectorsTransposed_ = std::move(DeviceTensor()); +#ifdef FAISS_USE_FLOAT16 + vectorsHalf_ = std::move(DeviceTensor()); + vectorsHalfTransposed_ = std::move(DeviceTensor()); +#endif + norms_ = std::move(DeviceTensor()); + num_ = 0; +} + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh new file mode 100644 index 0000000000..eef07df24c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh @@ -0,0 +1,139 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +/// Holder of GPU resources for a particular flat index +class FlatIndex { + public: + FlatIndex(GpuResources* res, + int dim, + bool useFloat16, + bool storeTransposed, + MemorySpace space); + + /// Whether or not this flat index primarily stores data in float16 + bool getUseFloat16() const; + + /// Returns the number of vectors we contain + int getSize() const; + + /// Returns the dimensionality of the vectors + int getDim() const; + + /// Reserve storage that can contain at least this many vectors + void reserve(size_t numVecs, cudaStream_t stream); + + /// Returns the vectors based on the type desired; the FlatIndex must be of + /// the same type (float16 or float32) to not assert + template + Tensor& getVectorsRef(); + + /// Returns a reference to our vectors currently in use + Tensor& getVectorsFloat32Ref(); + + /// Returns a reference to our vectors currently in use (useFloat16 mode) +#ifdef FAISS_USE_FLOAT16 + Tensor& getVectorsFloat16Ref(); +#endif + + /// Performs a copy of the vectors on the given device, converting + /// as needed from float16 + DeviceTensor getVectorsFloat32Copy(cudaStream_t stream); + + /// Returns only a subset of the vectors + DeviceTensor getVectorsFloat32Copy(int from, + int num, + cudaStream_t stream); + + void query(Tensor& vecs, + Tensor& bitset, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance); + +#ifdef FAISS_USE_FLOAT16 + void query(Tensor& vecs, + Tensor& bitset, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance); +#endif + + /// Compute residual for set of vectors + void computeResidual(Tensor& vecs, + Tensor& listIds, + Tensor& residuals); + + /// Gather vectors given the set of IDs + void reconstruct(Tensor& listIds, + Tensor& vecs); + + void reconstruct(Tensor& listIds, + Tensor& vecs); + + /// Add vectors to ourselves; the pointer passed can be on the host + /// or the device + void add(const float* data, int numVecs, cudaStream_t stream); + + /// Free all storage + void reset(); + + private: + /// Collection of GPU resources that we use + GpuResources* resources_; + + /// Dimensionality of our vectors + const int dim_; + + /// Float16 data format + const bool useFloat16_; + + /// Store vectors in transposed layout for speed; makes addition to + /// the index slower + const bool storeTransposed_; + + /// Memory space for our allocations + MemorySpace space_; + + /// How many vectors we have + int num_; + + /// The underlying expandable storage + DeviceVector rawData_; + + /// Vectors currently in rawData_ + DeviceTensor vectors_; + DeviceTensor vectorsTransposed_; + + /// Vectors currently in rawData_, float16 form +#ifdef FAISS_USE_FLOAT16 + DeviceTensor vectorsHalf_; + DeviceTensor vectorsHalfTransposed_; +#endif + + /// Precomputed L2 norms + DeviceTensor norms_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/GeneralDistance.cuh b/core/src/index/thirdparty/faiss/gpu/impl/GeneralDistance.cuh new file mode 100644 index 0000000000..5dae58638c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/GeneralDistance.cuh @@ -0,0 +1,432 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// +// Kernels for non-L2 / inner product distances +// + +namespace faiss { namespace gpu { + +// Reduction tree operator +template +struct ReduceDistanceOp { + __device__ static DistanceOp reduce(DistanceOp ops[N]) { + DistanceOp vals[N/2]; +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + vals[i] = ops[i * 2]; + vals[i].combine(ops[i * 2 + 1]); + } + + return ReduceDistanceOp::reduce(vals); + } +}; + +template +struct ReduceDistanceOp { + __device__ static DistanceOp reduce(DistanceOp ops[1]) { + return ops[0]; + } +}; + +// Implements a pairwise reduction tree +template +inline __device__ DistanceOp +reduce(const DistanceOp& in, + const T queryTile[kWarpSize][DimMultiple * kWarpSize + 1], + const T vecTile[kWarpSize][DimMultiple * kWarpSize + 1]) { + DistanceOp accs[Unroll]; +#pragma unroll + for (int i = 0; i < Unroll; ++i) { + accs[i] = in.zero(); + } + + auto vecTileBase = vecTile[threadIdx.x]; + auto queryTileBase = queryTile[threadIdx.y]; + +#pragma unroll + for (int i = 0; i < Unroll; ++i) { +#pragma unroll + for (int j = 0; j < (kWarpSize * DimMultiple / Unroll); ++j) { + int idx = i * (kWarpSize * DimMultiple / Unroll) + j; + accs[i].handle(ConvertTo::to(queryTileBase[idx]), + ConvertTo::to(vecTileBase[idx])); + } + } + + return ReduceDistanceOp::reduce(accs); +} + +// Our general distance matrix "multiplication" kernel +template +__launch_bounds__(kWarpSize * kWarpSize) +__global__ void +generalDistance(Tensor query, // m x k + Tensor vec, // n x k + DistanceOp op, + Tensor out) { // m x n + constexpr int kDimMultiple = 1; + + __shared__ T queryTile[kWarpSize][kWarpSize * kDimMultiple + 1]; + __shared__ T vecTile[kWarpSize][kWarpSize * kDimMultiple + 1]; + + // block y -> query + // block x -> vector + + int queryBlock = blockIdx.y * kWarpSize; + int queryThread = queryBlock + threadIdx.y; + + int vecBlock = blockIdx.x * kWarpSize; + int vecThreadLoad = vecBlock + threadIdx.y; + int vecThreadSave = vecBlock + threadIdx.x; + + DistanceOp acc = op.zero(); + + auto queryTileBase = queryTile[threadIdx.y]; + auto vecTileBase = vecTile[threadIdx.y]; + + auto queryBase = query[queryThread]; + auto vecBase = vec[vecThreadLoad]; + + if ((blockIdx.x != (gridDim.x - 1)) && (blockIdx.y != (gridDim.y - 1))) { + // + // Interior tile + // + int limit = utils::roundDown(query.getSize(1), kWarpSize * kDimMultiple); + + for (int k = threadIdx.x; k < limit; k += kWarpSize * kDimMultiple) { + // Load query tile +#pragma unroll + for (int i = 0; i < kDimMultiple; ++i) { + queryTileBase[threadIdx.x + i * kWarpSize] = + queryBase[k + i * kWarpSize]; + vecTileBase[threadIdx.x + i * kWarpSize] = + vecBase[k + i * kWarpSize]; + } + + __syncthreads(); + + // thread (y, x) does (query y, vec x) + acc.combine( + reduce(op, queryTile, vecTile)); + + __syncthreads(); + } + + // Handle remainder + if (limit < query.getSize(1)) { +#pragma unroll + for (int i = 0; i < kDimMultiple; ++i) { + int k = limit + threadIdx.x + i * kWarpSize; + bool kInBounds = k < query.getSize(1); + + queryTileBase[threadIdx.x + i * kWarpSize] = + kInBounds ? + queryBase[k] : (T) 0; //DistanceOp::kIdentityData; + + vecTileBase[threadIdx.x + i * kWarpSize] = + kInBounds ? + vecBase[k] : (T) 0; // DistanceOp::kIdentityData; + } + + __syncthreads(); + + int remainder = query.getSize(1) - limit; + + // thread (y, x) does (query y, vec x) +#pragma unroll + for (int i = 0; i < remainder; ++i) { + acc.handle(ConvertTo::to(queryTileBase[i]), + ConvertTo::to(vecTile[threadIdx.x][i])); + } + } + + // Write out results + out[queryThread][vecThreadSave] = acc.reduce(); + } else { + // + // Otherwise, we're an exterior tile + // + + bool queryThreadInBounds = queryThread < query.getSize(0); + bool vecThreadInBoundsLoad = vecThreadLoad < vec.getSize(0); + bool vecThreadInBoundsSave = vecThreadSave < vec.getSize(0); + int limit = utils::roundDown(query.getSize(1), kWarpSize); + + for (int k = threadIdx.x; k < limit; k += kWarpSize) { + // Load query tile + queryTileBase[threadIdx.x] = + queryThreadInBounds ? + queryBase[k] : (T) 0; // DistanceOp::kIdentityData; + + vecTileBase[threadIdx.x] = + vecThreadInBoundsLoad ? + vecBase[k] : (T) 0; // DistanceOp::kIdentityData; + + __syncthreads(); + + // thread (y, x) does (query y, vec x) +#pragma unroll + for (int i = 0; i < kWarpSize; ++i) { + acc.handle(ConvertTo::to(queryTileBase[i]), + ConvertTo::to(vecTile[threadIdx.x][i])); + } + + __syncthreads(); + } + + // Handle remainder + if (limit < query.getSize(1)) { + int k = limit + threadIdx.x; + bool kInBounds = k < query.getSize(1); + + // Load query tile + queryTileBase[threadIdx.x] = + queryThreadInBounds && kInBounds ? + queryBase[k] : (T) 0; // DistanceOp::kIdentityData; + + vecTileBase[threadIdx.x] = + vecThreadInBoundsLoad && kInBounds ? + vecBase[k] : (T) 0; // DistanceOp::kIdentityData; + + __syncthreads(); + + int remainder = query.getSize(1) - limit; + + // thread (y, x) does (query y, vec x) + for (int i = 0; i < remainder; ++i) { + acc.handle(ConvertTo::to(queryTileBase[i]), + ConvertTo::to(vecTile[threadIdx.x][i])); + } + } + + // Write out results + if (queryThreadInBounds && vecThreadInBoundsSave) { + out[queryThread][vecThreadSave] = acc.reduce(); + } + } +} + + +template +void runGeneralDistanceKernel(Tensor& vecs, + Tensor& query, + Tensor& out, + const DistanceOp& op, + cudaStream_t stream) { + FAISS_ASSERT(vecs.getSize(1) == query.getSize(1)); + FAISS_ASSERT(out.getSize(0) == query.getSize(0)); + FAISS_ASSERT(out.getSize(1) == vecs.getSize(0)); + + dim3 grid(utils::divUp(vecs.getSize(0), kWarpSize), + utils::divUp(query.getSize(0), kWarpSize)); + dim3 block(kWarpSize, kWarpSize); + + generalDistance<<>>(query, vecs, op, out); +} + +template +void runGeneralDistance(GpuResources* resources, + Tensor& centroids, + Tensor& queries, + Tensor& bitset, + int k, + const DistanceOp& op, + Tensor& outDistances, + Tensor& outIndices) { + // The # of centroids in `centroids` based on memory layout + auto numCentroids = centroids.getSize(0); + + // The # of queries in `queries` based on memory layout + auto numQueries = queries.getSize(0); + + // The dimensions of the vectors to consider + auto dim = queries.getSize(1); + FAISS_ASSERT((numQueries == 0 || numCentroids == 0) || + dim == centroids.getSize(1)); + + FAISS_ASSERT(outDistances.getSize(0) == numQueries); + FAISS_ASSERT(outIndices.getSize(0) == numQueries); + FAISS_ASSERT(outDistances.getSize(1) == k); + FAISS_ASSERT(outIndices.getSize(1) == k); + + auto& mem = resources->getMemoryManagerCurrentDevice(); + auto defaultStream = resources->getDefaultStreamCurrentDevice(); + + // If we're quering against a 0 sized set, just return empty results + if (centroids.numElements() == 0) { + thrust::fill(thrust::cuda::par.on(defaultStream), + outDistances.data(), outDistances.end(), + Limits::getMax()); + + thrust::fill(thrust::cuda::par.on(defaultStream), + outIndices.data(), outIndices.end(), + -1); + + return; + } + + // By default, aim to use up to 512 MB of memory for the processing, with both + // number of queries and number of centroids being at least 512. + int tileRows = 0; + int tileCols = 0; + chooseTileSize(numQueries, + numCentroids, + dim, + sizeof(T), + mem.getSizeAvailable(), + tileRows, + tileCols); + + int numColTiles = utils::divUp(numCentroids, tileCols); + + // We can have any number of vectors to query against, even less than k, in + // which case we'll return -1 for the index + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); // select limitation + + // Temporary output memory space we'll use + DeviceTensor distanceBuf1( + mem, {tileRows, tileCols}, defaultStream); + DeviceTensor distanceBuf2( + mem, {tileRows, tileCols}, defaultStream); + DeviceTensor* distanceBufs[2] = + {&distanceBuf1, &distanceBuf2}; + + DeviceTensor outDistanceBuf1( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor outDistanceBuf2( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor* outDistanceBufs[2] = + {&outDistanceBuf1, &outDistanceBuf2}; + + DeviceTensor outIndexBuf1( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor outIndexBuf2( + mem, {tileRows, numColTiles * k}, defaultStream); + DeviceTensor* outIndexBufs[2] = + {&outIndexBuf1, &outIndexBuf2}; + + auto streams = resources->getAlternateStreamsCurrentDevice(); + streamWait(streams, {defaultStream}); + + int curStream = 0; + bool interrupt = false; + + // Tile over the input queries + for (int i = 0; i < numQueries; i += tileRows) { + if (interrupt || InterruptCallback::is_interrupted()) { + interrupt = true; + break; + } + + int curQuerySize = std::min(tileRows, numQueries - i); + + auto outDistanceView = + outDistances.narrow(0, i, curQuerySize); + auto outIndexView = + outIndices.narrow(0, i, curQuerySize); + + auto queryView = + queries.narrow(0, i, curQuerySize); + + auto outDistanceBufRowView = + outDistanceBufs[curStream]->narrow(0, 0, curQuerySize); + auto outIndexBufRowView = + outIndexBufs[curStream]->narrow(0, 0, curQuerySize); + + // Tile over the centroids + for (int j = 0; j < numCentroids; j += tileCols) { + if (InterruptCallback::is_interrupted()) { + interrupt = true; + break; + } + + int curCentroidSize = std::min(tileCols, numCentroids - j); + int curColTile = j / tileCols; + + auto centroidsView = + sliceCentroids(centroids, true, j, curCentroidSize); + + auto distanceBufView = distanceBufs[curStream]-> + narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize); + + auto outDistanceBufColView = + outDistanceBufRowView.narrow(1, k * curColTile, k); + auto outIndexBufColView = + outIndexBufRowView.narrow(1, k * curColTile, k); + + runGeneralDistanceKernel(centroidsView, + queryView, + distanceBufView, + op, + streams[curStream]); + + // For IP, just k-select the output for this tile + if (tileCols == numCentroids) { + // Write into the final output + runBlockSelect(distanceBufView, + bitset, + outDistanceView, + outIndexView, + DistanceOp::kDirection, k, streams[curStream]); + } else { + // Write into the intermediate output + runBlockSelect(distanceBufView, + bitset, + outDistanceBufColView, + outIndexBufColView, + DistanceOp::kDirection, k, streams[curStream]); + } + } + + // As we're finished with processing a full set of centroids, perform the + // final k-selection + if (tileCols != numCentroids) { + // The indices are tile-relative; for each tile of k, we need to add + // tileCols to the index + runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]); + + runBlockSelectPair(outDistanceBufRowView, + outIndexBufRowView, + bitset, + outDistanceView, + outIndexView, + DistanceOp::kDirection, k, streams[curStream]); + } + + curStream = (curStream + 1) % 2; + } + + // Have the desired ordering stream wait on the multi-stream + streamWait({defaultStream}, streams); + + if (interrupt) { + FAISS_THROW_MSG("interrupted"); + } + + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/GpuScalarQuantizer.cuh b/core/src/index/thirdparty/faiss/gpu/impl/GpuScalarQuantizer.cuh new file mode 100644 index 0000000000..32675a5a4e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/GpuScalarQuantizer.cuh @@ -0,0 +1,607 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +inline bool isSQSupported(QuantizerType qtype) { + switch (qtype) { + case QuantizerType::QT_8bit: + case QuantizerType::QT_8bit_uniform: + case QuantizerType::QT_8bit_direct: + case QuantizerType::QT_4bit: + case QuantizerType::QT_4bit_uniform: + case QuantizerType::QT_fp16: + return true; + default: + return false; + } +} + +// Wrapper around the CPU ScalarQuantizer that allows storage of parameters in +// GPU memory +struct GpuScalarQuantizer : public ScalarQuantizer { + GpuScalarQuantizer(const ScalarQuantizer& sq) + : ScalarQuantizer(sq), + gpuTrained(DeviceTensor({(int) sq.trained.size()})) { + HostTensor + cpuTrained((float*) sq.trained.data(), {(int) sq.trained.size()}); + + // Just use the default stream, as we're allocating memory above in any case + gpuTrained.copyFrom(cpuTrained, 0); + CUDA_VERIFY(cudaStreamSynchronize(0)); + } + + // ScalarQuantizer::trained copied to GPU memory + DeviceTensor gpuTrained; +}; + +// +// Quantizer codecs +// + +// QT is the quantizer type implemented +// DimMultiple is the minimum guaranteed dimension multiple of the vectors +// encoded (used for ensuring alignment for memory load/stores) +template +struct Codec { }; + +///// +// +// 32 bit encodings +// (does not use qtype) +// +///// + +struct CodecFloat { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 1; + + CodecFloat(int vecBytes) : bytesPerVec(vecBytes) { } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + float* p = (float*) &((uint8_t*) data)[vec * bytesPerVec]; + out[0] = p[d]; + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + // doesn't need implementing (kDimPerIter == 1) + return 0.0f; + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + float* p = (float*) &((uint8_t*) data)[vec * bytesPerVec]; + p[d] = v[0]; + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + // doesn't need implementing (kDimPerIter == 1) + } + + int bytesPerVec; +}; + +///// +// +// 16 bit encodings +// +///// + +// Arbitrary dimension fp16 +template <> +struct Codec<(int)QuantizerType::QT_fp16, 1> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 1; + + Codec(int vecBytes) : bytesPerVec(vecBytes) { } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + half* p = (half*) &((uint8_t*) data)[vec * bytesPerVec]; + out[0] = Convert()(p[d]); + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + // doesn't need implementing (kDimPerIter == 1) + return 0.0f; + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + half* p = (half*) &((uint8_t*) data)[vec * bytesPerVec]; + p[d] = Convert()(v[0]); + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + // doesn't need implementing (kDimPerIter == 1) + } + + int bytesPerVec; +}; + +// dim % 2 == 0, ensures uint32 alignment +template <> +struct Codec<(int)QuantizerType::QT_fp16, 2> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 2; + + Codec(int vecBytes) : bytesPerVec(vecBytes) { } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + half2* p = (half2*) &((uint8_t*) data)[vec * bytesPerVec]; + half2 pd = p[d]; + + out[0] = Convert()(__low2half(pd)); + out[1] = Convert()(__high2half(pd)); + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + // should not be called + assert(false); + return 0; + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + half2* p = (half2*) &((uint8_t*) data)[vec * bytesPerVec]; + half h0 = Convert()(v[0]); + half h1 = Convert()(v[1]); + + p[d] = __halves2half2(h0, h1); + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + // should not be called + assert(false); + } + + int bytesPerVec; +}; + +///// +// +// 8 bit encodings +// +///// + +template +struct Get8BitType { }; + +template <> +struct Get8BitType<1> { using T = uint8_t; }; + +template <> +struct Get8BitType<2> { using T = uint16_t; }; + +template <> +struct Get8BitType<4> { using T = uint32_t; }; + +// Uniform quantization across all dimensions +template +struct Codec<(int)QuantizerType::QT_8bit_uniform, DimMultiple> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = DimMultiple; + using MemT = typename Get8BitType::T; + + Codec(int vecBytes, float min, float diff) + : bytesPerVec(vecBytes), vmin(min), vdiff(diff) { + } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ float decodeHelper(uint8_t v) const { + float x = (((float) v) + 0.5f) / 255.0f; + return vmin + x * vdiff; + } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + MemT* p = (MemT*) &((uint8_t*) data)[vec * bytesPerVec]; + MemT pv = p[d]; + + uint8_t x[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + x[i] = (uint8_t) ((pv >> (i * 8)) & 0xffU); + } + + float xDec[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + xDec[i] = decodeHelper(x[i]); + } + + #pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + out[i] = xDec[i]; + } + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + if (DimMultiple > 1) { + // should not be called + assert(false); + } + + // otherwise does not need implementing + return 0; + } + + inline __device__ uint8_t encodeHelper(float v) const { + float x = (v - vmin) / vdiff; + x = fminf(1.0f, fmaxf(0.0f, x)); + return (uint8_t) (255 * x); + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + MemT* p = (MemT*) &((uint8_t*) data)[vec * bytesPerVec]; + + MemT x[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + x[i] = encodeHelper(v[i]); + } + + MemT out = 0; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + out |= (x[i] << (i * 8)); + } + + p[d] = out; + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + if (DimMultiple > 1) { + // should not be called + assert(false); + } + + // otherwise does not need implementing + } + + int bytesPerVec; + const float vmin; + const float vdiff; +}; + +// Uniform quantization per each dimension +template +struct Codec<(int)QuantizerType::QT_8bit, DimMultiple> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = DimMultiple; + using MemT = typename Get8BitType::T; + + Codec(int vecBytes, float* min, float* diff) + : bytesPerVec(vecBytes), vmin(min), vdiff(diff), + smemVmin(nullptr), + smemVdiff(nullptr) { + } + + size_t getSmemSize(int dim) { + return sizeof(float) * dim * 2; + } + + inline __device__ void setSmem(float* smem, int dim) { + smemVmin = smem; + smemVdiff = smem + dim; + + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + smemVmin[i] = vmin[i]; + smemVdiff[i] = vdiff[i]; + } + } + + inline __device__ float decodeHelper(uint8_t v, int realDim) const { + float x = (((float) v) + 0.5f) / 255.0f; + return smemVmin[realDim] + x * smemVdiff[realDim]; + } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + MemT* p = (MemT*) &((uint8_t*) data)[vec * bytesPerVec]; + MemT pv = p[d]; + int realDim = d * kDimPerIter; + + uint8_t x[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + x[i] = (uint8_t) ((pv >> (i * 8)) & 0xffU); + } + + float xDec[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + xDec[i] = decodeHelper(x[i], realDim + i); + } + + #pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + out[i] = xDec[i]; + } + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + if (DimMultiple > 1) { + // should not be called + assert(false); + } + + // otherwise does not need implementing + return 0; + } + + inline __device__ uint8_t encodeHelper(float v, int realDim) const { + float x = (v - vmin[realDim]) / vdiff[realDim]; + x = fminf(1.0f, fmaxf(0.0f, x)); + return (uint8_t) (255 * x); + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + MemT* p = (MemT*) &((uint8_t*) data)[vec * bytesPerVec]; + int realDim = d * kDimPerIter; + + MemT x[kDimPerIter]; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + x[i] = encodeHelper(v[i], realDim + i); + } + + MemT out = 0; +#pragma unroll + for (int i = 0; i < kDimPerIter; ++i) { + out |= (x[i] << (i * 8)); + } + + p[d] = out; + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + if (DimMultiple > 1) { + // should not be called + assert(false); + } + + // otherwise does not need implementing + } + + int bytesPerVec; + + // gmem pointers + const float* vmin; + const float* vdiff; + + // smem pointers (configured in the kernel) + float* smemVmin; + float* smemVdiff; +}; + +template <> +struct Codec<(int)QuantizerType::QT_8bit_direct, 1> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 1; + + Codec(int vecBytes) : bytesPerVec(vecBytes) { } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + out[0] = (float) p[d]; + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD) const { + // doesn't need implementing (kDimPerIter == 1) + return 0.0f; + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + p[d] = (uint8_t) v[0]; + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, + float v[kDimPerIter]) const { + // doesn't need implementing (kDimPerIter == 1) + } + + int bytesPerVec; +}; + +///// +// +// 4 bit encodings +// +///// + +// Uniform quantization across all dimensions +template <> +struct Codec<(int)QuantizerType::QT_4bit_uniform, 1> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 2; + + Codec(int vecBytes, float min, float diff) + : bytesPerVec(vecBytes), vmin(min), vdiff(diff) { + } + + size_t getSmemSize(int dim) { return 0; } + inline __device__ void setSmem(float* smem, int dim) { } + + inline __device__ float decodeHelper(uint8_t v) const { + float x = (((float) v) + 0.5f) / 15.0f; + return vmin + x * vdiff; + } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + uint8_t pv = p[d]; + + out[0] = decodeHelper(pv & 0xf); + out[1] = decodeHelper(pv >> 4); + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD /* unused */) const { + // We can only be called for a single input + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + uint8_t pv = p[d]; + + return decodeHelper(pv & 0xf); + } + + inline __device__ uint8_t encodeHelper(float v) const { + float x = (v - vmin) / vdiff; + x = fminf(1.0f, fmaxf(0.0f, x)); + return (uint8_t) (x * 15.0f); + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + p[d] = encodeHelper(v[0]) | (encodeHelper(v[1]) << 4); + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, /* unused */ + float v[kDimPerIter]) const { + // We can only be called for a single output + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + p[d] = encodeHelper(v[0]); + } + + int bytesPerVec; + const float vmin; + const float vdiff; +}; + +template <> +struct Codec<(int)QuantizerType::QT_4bit, 1> { + /// How many dimensions per iteration we are handling for encoding or decoding + static constexpr int kDimPerIter = 2; + + Codec(int vecBytes, float* min, float* diff) + : bytesPerVec(vecBytes), vmin(min), vdiff(diff), + smemVmin(nullptr), + smemVdiff(nullptr) { + } + + size_t getSmemSize(int dim) { + return sizeof(float) * dim * 2; + } + + inline __device__ void setSmem(float* smem, int dim) { + smemVmin = smem; + smemVdiff = smem + dim; + + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + smemVmin[i] = vmin[i]; + smemVdiff[i] = vdiff[i]; + } + } + + inline __device__ float decodeHelper(uint8_t v, int realDim) const { + float x = (((float) v) + 0.5f) / 15.0f; + return smemVmin[realDim] + x * smemVdiff[realDim]; + } + + inline __device__ void decode(void* data, int vec, int d, + float* out) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + uint8_t pv = p[d]; + int realDim = d * kDimPerIter; + + out[0] = decodeHelper(pv & 0xf, realDim); + out[1] = decodeHelper(pv >> 4, realDim + 1); + } + + inline __device__ float decodePartial(void* data, int vec, int d, + int subD /* unused */) const { + // We can only be called for a single input + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + uint8_t pv = p[d]; + int realDim = d * kDimPerIter; + + return decodeHelper(pv & 0xf, realDim); + } + + inline __device__ uint8_t encodeHelper(float v, int realDim) const { + float x = (v - vmin[realDim]) / vdiff[realDim]; + x = fminf(1.0f, fmaxf(0.0f, x)); + return (uint8_t) (x * 15.0f); + } + + inline __device__ void encode(void* data, int vec, int d, + float v[kDimPerIter]) const { + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + int realDim = d * kDimPerIter; + p[d] = encodeHelper(v[0], realDim) | (encodeHelper(v[1], realDim + 1) << 4); + } + + inline __device__ void encodePartial(void* data, int vec, int d, + int remaining, /* unused */ + float v[kDimPerIter]) const { + // We can only be called for a single output + uint8_t* p = &((uint8_t*) data)[vec * bytesPerVec]; + int realDim = d * kDimPerIter; + + p[d] = encodeHelper(v[0], realDim); + } + + int bytesPerVec; + + // gmem pointers + const float* vmin; + const float* vdiff; + + // smem pointers + float* smemVmin; + float* smemVdiff; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cu new file mode 100644 index 0000000000..ace37549b9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cu @@ -0,0 +1,369 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// +// IVF list length update +// + +__global__ void +runUpdateListPointers(Tensor listIds, + Tensor newListLength, + Tensor newCodePointers, + Tensor newIndexPointers, + int* listLengths, + void** listCodes, + void** listIndices) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < listIds.getSize(0)) { + int listId = listIds[i]; + listLengths[listId] = newListLength[i]; + listCodes[listId] = newCodePointers[i]; + listIndices[listId] = newIndexPointers[i]; + } +} + +void +runUpdateListPointers(Tensor& listIds, + Tensor& newListLength, + Tensor& newCodePointers, + Tensor& newIndexPointers, + thrust::device_vector& listLengths, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + cudaStream_t stream) { + int numThreads = std::min(listIds.getSize(0), getMaxThreadsCurrentDevice()); + int numBlocks = utils::divUp(listIds.getSize(0), numThreads); + + dim3 grid(numBlocks); + dim3 block(numThreads); + + runUpdateListPointers<<>>( + listIds, newListLength, newCodePointers, newIndexPointers, + listLengths.data().get(), + listCodes.data().get(), + listIndices.data().get()); + + CUDA_TEST_ERROR(); +} + +// +// IVF PQ append +// + +template +__global__ void +ivfpqInvertedListAppend(Tensor listIds, + Tensor listOffset, + Tensor encodings, + Tensor indices, + void** listCodes, + void** listIndices) { + int encodingToAdd = blockIdx.x * blockDim.x + threadIdx.x; + + if (encodingToAdd >= listIds.getSize(0)) { + return; + } + + int listId = listIds[encodingToAdd]; + int offset = listOffset[encodingToAdd]; + + // Add vector could be invalid (contains NaNs etc) + if (listId == -1 || offset == -1) { + return; + } + + auto encoding = encodings[encodingToAdd]; + long index = indices[encodingToAdd]; + + if (Opt == INDICES_32_BIT) { + // FIXME: there could be overflow here, but where should we check this? + ((int*) listIndices[listId])[offset] = (int) index; + } else if (Opt == INDICES_64_BIT) { + ((long*) listIndices[listId])[offset] = (long) index; + } else { + // INDICES_CPU or INDICES_IVF; no indices are being stored + } + + unsigned char* codeStart = + ((unsigned char*) listCodes[listId]) + offset * encodings.getSize(1); + + // FIXME: slow + for (int i = 0; i < encodings.getSize(1); ++i) { + codeStart[i] = (unsigned char) encoding[i]; + } +} + +void +runIVFPQInvertedListAppend(Tensor& listIds, + Tensor& listOffset, + Tensor& encodings, + Tensor& indices, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + cudaStream_t stream) { + int numThreads = std::min(listIds.getSize(0), getMaxThreadsCurrentDevice()); + int numBlocks = utils::divUp(listIds.getSize(0), numThreads); + + dim3 grid(numBlocks); + dim3 block(numThreads); + +#define RUN_APPEND(IND) \ + do { \ + ivfpqInvertedListAppend<<>>( \ + listIds, listOffset, encodings, indices, \ + listCodes.data().get(), \ + listIndices.data().get()); \ + } while (0) + + if ((indicesOptions == INDICES_CPU) || (indicesOptions == INDICES_IVF)) { + // no need to maintain indices on the GPU + RUN_APPEND(INDICES_IVF); + } else if (indicesOptions == INDICES_32_BIT) { + RUN_APPEND(INDICES_32_BIT); + } else if (indicesOptions == INDICES_64_BIT) { + RUN_APPEND(INDICES_64_BIT); + } else { + // unknown index storage type + FAISS_ASSERT(false); + } + + CUDA_TEST_ERROR(); + +#undef RUN_APPEND +} + +// +// IVF flat append +// + +__global__ void +ivfFlatIndicesAppend(Tensor listIds, + Tensor listOffset, + Tensor indices, + IndicesOptions opt, + void** listIndices) { + int vec = blockIdx.x * blockDim.x + threadIdx.x; + + if (vec >= listIds.getSize(0)) { + return; + } + + int listId = listIds[vec]; + int offset = listOffset[vec]; + + // Add vector could be invalid (contains NaNs etc) + if (listId == -1 || offset == -1) { + return; + } + + long index = indices[vec]; + + if (opt == INDICES_32_BIT) { + // FIXME: there could be overflow here, but where should we check this? + ((int*) listIndices[listId])[offset] = (int) index; + } else if (opt == INDICES_64_BIT) { + ((long*) listIndices[listId])[offset] = (long) index; + } +} + +template +__global__ void +ivfFlatInvertedListAppend(Tensor listIds, + Tensor listOffset, + Tensor vecs, + void** listData, + Codec codec) { + int vec = blockIdx.x; + + int listId = listIds[vec]; + int offset = listOffset[vec]; + + // Add vector could be invalid (contains NaNs etc) + if (listId == -1 || offset == -1) { + return; + } + + // Handle whole encoding (only thread 0 will handle the remainder) + int limit = utils::divDown(vecs.getSize(1), Codec::kDimPerIter); + + int i; + for (i = threadIdx.x; i < limit; i += blockDim.x) { + int realDim = i * Codec::kDimPerIter; + float toEncode[Codec::kDimPerIter]; + +#pragma unroll + for (int j = 0; j < Codec::kDimPerIter; ++j) { + toEncode[j] = vecs[vec][realDim + j]; + } + + codec.encode(listData[listId], offset, i, toEncode); + } + + // Handle remainder with a single thread, if any + if (Codec::kDimPerIter > 1) { + int realDim = limit * Codec::kDimPerIter; + + // Was there any remainder? + if (realDim < vecs.getSize(1)) { + if (threadIdx.x == 0) { + float toEncode[Codec::kDimPerIter]; + + // How many remaining that we need to encode + int remaining = vecs.getSize(1) - realDim; + +#pragma unroll + for (int j = 0; j < Codec::kDimPerIter; ++j) { + int idx = realDim + j; + toEncode[j] = idx < vecs.getSize(1) ? vecs[vec][idx] : 0.0f; + } + + codec.encodePartial(listData[listId], offset, i, remaining, toEncode); + } + } + } +} + +void +runIVFFlatInvertedListAppend(Tensor& listIds, + Tensor& listOffset, + Tensor& vecs, + Tensor& indices, + bool useResidual, + Tensor& residuals, + GpuScalarQuantizer* scalarQ, + thrust::device_vector& listData, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + cudaStream_t stream) { + int dim = vecs.getSize(1); + int maxThreads = getMaxThreadsCurrentDevice(); + + // First, append the indices that we're about to add, if any + if (indicesOptions != INDICES_CPU && indicesOptions != INDICES_IVF) { + int blocks = utils::divUp(vecs.getSize(0), maxThreads); + + ivfFlatIndicesAppend<<>>( + listIds, + listOffset, + indices, + indicesOptions, + listIndices.data().get()); + } + + // Each block will handle appending a single vector +#define RUN_APPEND \ + do { \ + dim3 grid(vecs.getSize(0)); \ + dim3 block(std::min(dim / codec.kDimPerIter, maxThreads)); \ + \ + ivfFlatInvertedListAppend \ + <<>>( \ + listIds, \ + listOffset, \ + useResidual ? residuals : vecs, \ + listData.data().get(), \ + codec); \ + } while (0) + + if (!scalarQ) { + CodecFloat codec(dim * sizeof(float)); + RUN_APPEND; + } else { + switch (scalarQ->qtype) { + case QuantizerType::QT_8bit: + { + if (false) { +// if (dim % 4 == 0) { + Codec<(int)QuantizerType::QT_8bit, 4> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + RUN_APPEND; + } else { + Codec<(int)QuantizerType::QT_8bit, 1> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + RUN_APPEND; + } + } + break; + case QuantizerType::QT_8bit_uniform: + { +// if (dim % 4 == 0) { + if (false) { + Codec<(int)QuantizerType::QT_8bit_uniform, 4> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + RUN_APPEND; + } else { + Codec<(int)QuantizerType::QT_8bit_uniform, 1> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + RUN_APPEND; + } + } + break; + case QuantizerType::QT_fp16: + { +// if (dim % 2 == 0) { + if (false) { + Codec<(int)QuantizerType::QT_fp16, 2> + codec(scalarQ->code_size); + RUN_APPEND; + } else { + Codec<(int)QuantizerType::QT_fp16, 1> + codec(scalarQ->code_size); + RUN_APPEND; + } + } + break; + case QuantizerType::QT_8bit_direct: + { + Codec<(int)QuantizerType::QT_8bit_direct, 1> + codec(scalarQ->code_size); + RUN_APPEND; + } + break; + case QuantizerType::QT_4bit: + { + Codec<(int)QuantizerType::QT_4bit, 1> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + RUN_APPEND; + } + break; + case QuantizerType::QT_4bit_uniform: + { + Codec<(int)QuantizerType::QT_4bit_uniform, 1> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + RUN_APPEND; + } + break; + default: + // unimplemented, should be handled at a higher level + FAISS_ASSERT(false); + } + } + + CUDA_TEST_ERROR(); + +#undef RUN_APPEND +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cuh new file mode 100644 index 0000000000..3d61248082 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFAppend.cuh @@ -0,0 +1,53 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Update device-side list pointers in a batch +void runUpdateListPointers(Tensor& listIds, + Tensor& newListLength, + Tensor& newCodePointers, + Tensor& newIndexPointers, + thrust::device_vector& listLengths, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + cudaStream_t stream); + +/// Actually append the new codes / vector indices to the individual lists + +/// IVFPQ +void runIVFPQInvertedListAppend(Tensor& listIds, + Tensor& listOffset, + Tensor& encodings, + Tensor& indices, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + cudaStream_t stream); + +/// IVF flat storage +void runIVFFlatInvertedListAppend(Tensor& listIds, + Tensor& listOffset, + Tensor& vecs, + Tensor& indices, + bool useResidual, + Tensor& residuals, + GpuScalarQuantizer* scalarQ, + thrust::device_vector& listData, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + cudaStream_t stream); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cu new file mode 100644 index 0000000000..48c362e36e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cu @@ -0,0 +1,379 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +IVFBase::IVFBase(GpuResources* resources, + faiss::MetricType metric, + float metricArg, + FlatIndex* quantizer, + int bytesPerVector, + IndicesOptions indicesOptions, + MemorySpace space) : + resources_(resources), + metric_(metric), + metricArg_(metricArg), + quantizer_(quantizer), + bytesPerVector_(bytesPerVector), + indicesOptions_(indicesOptions), + space_(space), + dim_(quantizer->getDim()), + numLists_(quantizer->getSize()), + maxListLength_(0) { + reset(); +} + +IVFBase::~IVFBase() { +} + +void +IVFBase::reserveMemory(size_t numVecs) { + size_t vecsPerList = numVecs / deviceListData_.size(); + if (vecsPerList < 1) { + return; + } + + auto stream = resources_->getDefaultStreamCurrentDevice(); + + size_t bytesPerDataList = vecsPerList * bytesPerVector_; + for (auto& list : deviceListData_) { + list->reserve(bytesPerDataList, stream); + } + + if ((indicesOptions_ == INDICES_32_BIT) || + (indicesOptions_ == INDICES_64_BIT)) { + // Reserve for index lists as well + size_t bytesPerIndexList = vecsPerList * + (indicesOptions_ == INDICES_32_BIT ? sizeof(int) : sizeof(long)); + + for (auto& list : deviceListIndices_) { + list->reserve(bytesPerIndexList, stream); + } + } + + // Update device info for all lists, since the base pointers may + // have changed + updateDeviceListInfo_(stream); +} + +void +IVFBase::reset() { + deviceListData_.clear(); + deviceListIndices_.clear(); + deviceListDataPointers_.clear(); + deviceListIndexPointers_.clear(); + deviceListLengths_.clear(); + listOffsetToUserIndex_.clear(); + + deviceListData_.reserve(numLists_); + deviceListIndices_.reserve(numLists_); + listOffsetToUserIndex_.resize(numLists_); + + for (size_t i = 0; i < numLists_; ++i) { + deviceListData_.emplace_back( + std::unique_ptr>( + new DeviceVector(space_))); + deviceListIndices_.emplace_back( + std::unique_ptr>( + new DeviceVector(space_))); + listOffsetToUserIndex_.emplace_back(std::vector()); + } + + deviceListDataPointers_.resize(numLists_, nullptr); + deviceListIndexPointers_.resize(numLists_, nullptr); + deviceListLengths_.resize(numLists_, 0); + maxListLength_ = 0; + + deviceData_.reset(new DeviceVector(space_)); + deviceIndices_.reset(new DeviceVector(space_)); + deviceTrained_.reset(new DeviceVector(space_)); +} + +int +IVFBase::getDim() const { + return dim_; +} + +size_t +IVFBase::reclaimMemory() { + // Reclaim all unused memory exactly + return reclaimMemory_(true); +} + +size_t +IVFBase::reclaimMemory_(bool exact) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + size_t totalReclaimed = 0; + + for (int i = 0; i < deviceListData_.size(); ++i) { + auto& data = deviceListData_[i]; + totalReclaimed += data->reclaim(exact, stream); + + deviceListDataPointers_[i] = data->data(); + } + + for (int i = 0; i < deviceListIndices_.size(); ++i) { + auto& indices = deviceListIndices_[i]; + totalReclaimed += indices->reclaim(exact, stream); + + deviceListIndexPointers_[i] = indices->data(); + } + + // Update device info for all lists, since the base pointers may + // have changed + updateDeviceListInfo_(stream); + + return totalReclaimed; +} + +void +IVFBase::updateDeviceListInfo_(cudaStream_t stream) { + std::vector listIds(deviceListData_.size()); + for (int i = 0; i < deviceListData_.size(); ++i) { + listIds[i] = i; + } + + updateDeviceListInfo_(listIds, stream); +} + +void +IVFBase::updateDeviceListInfo_(const std::vector& listIds, + cudaStream_t stream) { + auto& mem = resources_->getMemoryManagerCurrentDevice(); + + HostTensor + hostListsToUpdate({(int) listIds.size()}); + HostTensor + hostNewListLength({(int) listIds.size()}); + HostTensor + hostNewDataPointers({(int) listIds.size()}); + HostTensor + hostNewIndexPointers({(int) listIds.size()}); + + for (int i = 0; i < listIds.size(); ++i) { + auto listId = listIds[i]; + auto& data = deviceListData_[listId]; + auto& indices = deviceListIndices_[listId]; + + hostListsToUpdate[i] = listId; + hostNewListLength[i] = data->size() / bytesPerVector_; + hostNewDataPointers[i] = data->data(); + hostNewIndexPointers[i] = indices->data(); + } + + // Copy the above update sets to the GPU + DeviceTensor listsToUpdate( + mem, hostListsToUpdate, stream); + DeviceTensor newListLength( + mem, hostNewListLength, stream); + DeviceTensor newDataPointers( + mem, hostNewDataPointers, stream); + DeviceTensor newIndexPointers( + mem, hostNewIndexPointers, stream); + + // Update all pointers to the lists on the device that may have + // changed + runUpdateListPointers(listsToUpdate, + newListLength, + newDataPointers, + newIndexPointers, + deviceListLengths_, + deviceListDataPointers_, + deviceListIndexPointers_, + stream); +} + +size_t +IVFBase::getNumLists() const { + return numLists_; +} + +int +IVFBase::getListLength(int listId) const { + FAISS_ASSERT(listId < deviceListLengths_.size()); + + return deviceListLengths_[listId]; +} + +std::vector +IVFBase::getListIndices(int listId) const { + FAISS_ASSERT(listId < numLists_); + + if (indicesOptions_ == INDICES_32_BIT) { + FAISS_ASSERT(listId < deviceListIndices_.size()); + + auto intInd = deviceListIndices_[listId]->copyToHost( + resources_->getDefaultStreamCurrentDevice()); + + std::vector out(intInd.size()); + for (size_t i = 0; i < intInd.size(); ++i) { + out[i] = (long) intInd[i]; + } + + return out; + } else if (indicesOptions_ == INDICES_64_BIT) { + FAISS_ASSERT(listId < deviceListIndices_.size()); + + return deviceListIndices_[listId]->copyToHost( + resources_->getDefaultStreamCurrentDevice()); + } else if (indicesOptions_ == INDICES_CPU) { + FAISS_ASSERT(listId < deviceListData_.size()); + FAISS_ASSERT(listId < listOffsetToUserIndex_.size()); + + auto& userIds = listOffsetToUserIndex_[listId]; + FAISS_ASSERT(userIds.size() == + deviceListData_[listId]->size() / bytesPerVector_); + + // this will return a copy + return userIds; + } else { + // unhandled indices type (includes INDICES_IVF) + FAISS_ASSERT(false); + return std::vector(); + } +} + +std::vector +IVFBase::getListVectors(int listId) const { + FAISS_ASSERT(listId < deviceListData_.size()); + auto& list = *deviceListData_[listId]; + auto stream = resources_->getDefaultStreamCurrentDevice(); + + return list.copyToHost(stream); +} + +void +IVFBase::copyIndicesFromCpu_(const long* indices, + const std::vector& list_length) { + FAISS_ASSERT_FMT(list_length.size() == this->getNumLists(), "Expect list size %zu but %zu received!", + this->getNumLists(), list_length.size()); + auto numVecs = std::accumulate(list_length.begin(), list_length.end(), 0); + + auto stream = resources_->getDefaultStreamCurrentDevice(); + int bytesPerRecord; + + if (indicesOptions_ == INDICES_32_BIT) { + std::vector indices32(numVecs); + for (size_t i = 0; i < numVecs; ++i) { + auto ind = indices[i]; + FAISS_ASSERT(ind <= (long) std::numeric_limits::max()); + indices32[i] = (int) ind; + } + + bytesPerRecord = sizeof(int); + + deviceIndices_->append((unsigned char*) indices32.data(), + numVecs * bytesPerRecord, + stream, + true); + } else if (indicesOptions_ == INDICES_64_BIT) { + bytesPerRecord = sizeof(long); + deviceIndices_->append((unsigned char*) indices, + numVecs * bytesPerRecord, + stream, + true); + } else if (indicesOptions_ == INDICES_CPU) { + FAISS_ASSERT(false); + size_t listId = 0; + auto curr_indices = indices; + for (auto& userIndices : listOffsetToUserIndex_) { + userIndices.insert(userIndices.begin(), curr_indices, curr_indices + list_length[listId]); + curr_indices += list_length[listId]; + listId++; + } + } else { + // indices are not stored + FAISS_ASSERT(indicesOptions_ == INDICES_IVF); + } + + size_t listId = 0; + size_t pos = 0; + size_t size = 0; + + thrust::host_vector hostPointers(deviceListData_.size(), nullptr); + for (auto& device_indice : deviceListIndices_) { + auto data = deviceIndices_->data() + pos; + size = list_length[listId] * bytesPerRecord; + device_indice->reset(data, size, size); + hostPointers[listId] = device_indice->data(); + pos += size; + ++ listId; + } + + deviceListIndexPointers_ = hostPointers; +} + +void +IVFBase::addIndicesFromCpu_(int listId, + const long* indices, + size_t numVecs) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + auto& listIndices = deviceListIndices_[listId]; + auto prevIndicesData = listIndices->data(); + + if (indicesOptions_ == INDICES_32_BIT) { + // Make sure that all indices are in bounds + std::vector indices32(numVecs); + for (size_t i = 0; i < numVecs; ++i) { + auto ind = indices[i]; + FAISS_ASSERT(ind <= (long) std::numeric_limits::max()); + indices32[i] = (int) ind; + } + + listIndices->append((unsigned char*) indices32.data(), + numVecs * sizeof(int), + stream, + true /* exact reserved size */); + } else if (indicesOptions_ == INDICES_64_BIT) { + listIndices->append((unsigned char*) indices, + numVecs * sizeof(long), + stream, + true /* exact reserved size */); + } else if (indicesOptions_ == INDICES_CPU) { + // indices are stored on the CPU + FAISS_ASSERT(listId < listOffsetToUserIndex_.size()); + + auto& userIndices = listOffsetToUserIndex_[listId]; + userIndices.insert(userIndices.begin(), indices, indices + numVecs); + } else { + // indices are not stored + FAISS_ASSERT(indicesOptions_ == INDICES_IVF); + } + + if (prevIndicesData != listIndices->data()) { + deviceListIndexPointers_[listId] = listIndices->data(); + } +} + +void +IVFBase::addTrainedDataFromCpu_(const uint8_t* trained, + size_t numData) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + deviceTrained_->append((unsigned char*)trained, + numData, + stream, + true); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cuh new file mode 100644 index 0000000000..987439269d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFBase.cuh @@ -0,0 +1,151 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; +struct FlatIndex; + +/// Base inverted list functionality for IVFFlat and IVFPQ +class IVFBase { + public: + IVFBase(GpuResources* resources, + faiss::MetricType metric, + float metricArg, + /// We do not own this reference + FlatIndex* quantizer, + int bytesPerVector, + IndicesOptions indicesOptions, + MemorySpace space); + + virtual ~IVFBase(); + + /// Reserve GPU memory in our inverted lists for this number of vectors + void reserveMemory(size_t numVecs); + + /// Clear out all inverted lists, but retain the coarse quantizer + /// and the product quantizer info + void reset(); + + /// Return the number of dimensions we are indexing + int getDim() const; + + /// After adding vectors, one can call this to reclaim device memory + /// to exactly the amount needed. Returns space reclaimed in bytes + size_t reclaimMemory(); + + /// Returns the number of inverted lists + size_t getNumLists() const; + + /// For debugging purposes, return the list length of a particular + /// list + int getListLength(int listId) const; + + /// Return the list indices of a particular list back to the CPU + std::vector getListIndices(int listId) const; + + DeviceVector* getTrainedData() { return deviceTrained_.get(); }; + + /// Return the encoded vectors of a particular list back to the CPU + std::vector getListVectors(int listId) const; + + protected: + /// Reclaim memory consumed on the device for our inverted lists + /// `exact` means we trim exactly to the memory needed + size_t reclaimMemory_(bool exact); + + /// Update all device-side list pointer and size information + void updateDeviceListInfo_(cudaStream_t stream); + + /// For a set of list IDs, update device-side list pointer and size + /// information + void updateDeviceListInfo_(const std::vector& listIds, + cudaStream_t stream); + + /// Shared function to copy indices from CPU to GPU + void addIndicesFromCpu_(int listId, + const long* indices, + size_t numVecs); + + void copyIndicesFromCpu_(const long* indices, + const std::vector& list_length); + + void addTrainedDataFromCpu_(const uint8_t* trained, size_t numData); + + protected: + /// Collection of GPU resources that we use + GpuResources* resources_; + + /// Metric type of the index + faiss::MetricType metric_; + + /// Metric arg + float metricArg_; + + /// Quantizer object + FlatIndex* quantizer_; + + /// Expected dimensionality of the vectors + const int dim_; + + /// Number of inverted lists we maintain + const int numLists_; + + /// Number of bytes per vector in the list + const int bytesPerVector_; + + /// How are user indices stored on the GPU? + const IndicesOptions indicesOptions_; + + /// What memory space our inverted list storage is in + const MemorySpace space_; + + /// Device representation of all inverted list data + /// id -> data + thrust::device_vector deviceListDataPointers_; + + /// Device representation of all inverted list index pointers + /// id -> data + thrust::device_vector deviceListIndexPointers_; + + /// Device representation of all inverted list lengths + /// id -> length + thrust::device_vector deviceListLengths_; + + /// Maximum list length seen + int maxListLength_; + + /// Device memory for each separate list, as managed by the host. + /// Device memory as stored in DeviceVector is stored as unique_ptr + /// since deviceListSummary_ pointers must remain valid despite + /// resizing of deviceLists_ + std::vector>> deviceListData_; + std::vector>> deviceListIndices_; + + std::unique_ptr> deviceData_; + std::unique_ptr> deviceIndices_; + std::unique_ptr> deviceTrained_; + + /// If we are storing indices on the CPU (indicesOptions_ is + /// INDICES_CPU), then this maintains a CPU-side map of what + /// (inverted list id, offset) maps to which user index + std::vector> listOffsetToUserIndex_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cu new file mode 100644 index 0000000000..acebb5799e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cu @@ -0,0 +1,413 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +IVFFlat::IVFFlat(GpuResources* resources, + FlatIndex* quantizer, + faiss::MetricType metric, + float metricArg, + bool useResidual, + faiss::ScalarQuantizer* scalarQ, + IndicesOptions indicesOptions, + MemorySpace space) : + IVFBase(resources, + metric, + metricArg, + quantizer, + scalarQ ? scalarQ->code_size : + sizeof(float) * quantizer->getDim(), + indicesOptions, + space), + useResidual_(useResidual), + scalarQ_(scalarQ ? new GpuScalarQuantizer(*scalarQ) : nullptr) { +} + +IVFFlat::~IVFFlat() { +} + +void +IVFFlat::copyCodeVectorsFromCpu(const float* vecs, + const long* indices, + const std::vector& list_length) { + FAISS_ASSERT_FMT(list_length.size() == this->getNumLists(), "Expect list size %zu but %zu received!", + this->getNumLists(), list_length.size()); + int64_t numVecs = std::accumulate(list_length.begin(), list_length.end(), 0); + if (numVecs == 0) { + return; + } + + auto stream = resources_->getDefaultStreamCurrentDevice(); + + deviceListLengths_ = list_length; + + int64_t lengthInBytes = numVecs * bytesPerVector_; + + // We only have int32 length representations on the GPU per each + // list; the length is in sizeof(char) + FAISS_ASSERT(deviceData_->size() + lengthInBytes <= std::numeric_limits::max()); + + deviceData_->append((unsigned char*) vecs, + lengthInBytes, + stream, + true /* exact reserved size */); + copyIndicesFromCpu_(indices, list_length); + maxListLength_ = 0; + + size_t listId = 0; + size_t pos = 0; + size_t size = 0; + thrust::host_vector hostPointers(deviceListData_.size(), nullptr); + + for (auto& device_data : deviceListData_) { + auto data = deviceData_->data() + pos; + + size = list_length[listId] * bytesPerVector_; + + device_data->reset(data, size, size); + hostPointers[listId] = device_data->data(); + maxListLength_ = std::max(maxListLength_, (int)list_length[listId]); + pos += size; + ++ listId; + } + + deviceListDataPointers_ = hostPointers; + + // device_vector add is potentially happening on a different stream + // than our default stream + if (stream != 0) { + streamWait({stream}, {0}); + } +} + +void +IVFFlat::addCodeVectorsFromCpu(int listId, + const unsigned char* vecs, + const long* indices, + size_t numVecs) { + // This list must already exist + FAISS_ASSERT(listId < deviceListData_.size()); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // If there's nothing to add, then there's nothing we have to do + if (numVecs == 0) { + return; + } + + size_t lengthInBytes = numVecs * bytesPerVector_; + + auto& listData = deviceListData_[listId]; + auto prevData = listData->data(); + + // We only have int32 length representations on the GPU per each + // list; the length is in sizeof(char) + FAISS_ASSERT(listData->size() + lengthInBytes <= + (size_t) std::numeric_limits::max()); + + listData->append(vecs, + lengthInBytes, + stream, + true /* exact reserved size */); + + // Handle the indices as well + addIndicesFromCpu_(listId, indices, numVecs); + + // This list address may have changed due to vector resizing, but + // only bother updating it on the device if it has changed + if (prevData != listData->data()) { + deviceListDataPointers_[listId] = listData->data(); + } + + // And our size has changed too + int listLength = listData->size() / bytesPerVector_; + deviceListLengths_[listId] = listLength; + + // We update this as well, since the multi-pass algorithm uses it + maxListLength_ = std::max(maxListLength_, listLength); + + // device_vector add is potentially happening on a different stream + // than our default stream + if (stream != 0) { + streamWait({stream}, {0}); + } +} + +int +IVFFlat::classifyAndAddVectors(Tensor& vecs, + Tensor& indices) { + FAISS_ASSERT(vecs.getSize(0) == indices.getSize(0)); + FAISS_ASSERT(vecs.getSize(1) == dim_); + + auto& mem = resources_->getMemoryManagerCurrentDevice(); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // Number of valid vectors that we actually add; we return this + int numAdded = 0; + + DeviceTensor + listDistance2d(mem, {vecs.getSize(0), 1}, stream); + + DeviceTensor + listIds2d(mem, {vecs.getSize(0), 1}, stream); + auto listIds = listIds2d.view<1>({vecs.getSize(0)}); + + /* pseudo bitset */ + DeviceTensor bitset(mem, {0}, stream); + quantizer_->query(vecs, bitset, 1, metric_, metricArg_, + listDistance2d, listIds2d, false); + + // Calculate residuals for these vectors, if needed + DeviceTensor + residuals(mem, {vecs.getSize(0), dim_}, stream); + + if (useResidual_) { + quantizer_->computeResidual(vecs, listIds, residuals); + } + + // Copy the lists that we wish to append to back to the CPU + // FIXME: really this can be into pinned memory and a true async + // copy on a different stream; we can start the copy early, but it's + // tiny + HostTensor listIdsHost(listIds, stream); + + // Now we add the encoded vectors to the individual lists + // First, make sure that there is space available for adding the new + // encoded vectors and indices + + // list id -> # being added + std::unordered_map assignCounts; + + // vector id -> offset in list + // (we already have vector id -> list id in listIds) + HostTensor listOffsetHost({listIdsHost.getSize(0)}); + + for (int i = 0; i < listIds.getSize(0); ++i) { + int listId = listIdsHost[i]; + + // Add vector could be invalid (contains NaNs etc) + if (listId < 0) { + listOffsetHost[i] = -1; + continue; + } + + FAISS_ASSERT(listId < numLists_); + ++numAdded; + + int offset = deviceListData_[listId]->size() / bytesPerVector_; + + auto it = assignCounts.find(listId); + if (it != assignCounts.end()) { + offset += it->second; + it->second++; + } else { + assignCounts[listId] = 1; + } + + listOffsetHost[i] = offset; + } + + // If we didn't add anything (all invalid vectors), no need to + // continue + if (numAdded == 0) { + return 0; + } + + // We need to resize the data structures for the inverted lists on + // the GPUs, which means that they might need reallocation, which + // means that their base address may change. Figure out the new base + // addresses, and update those in a batch on the device + { + for (auto& counts : assignCounts) { + auto& data = deviceListData_[counts.first]; + data->resize(data->size() + counts.second * bytesPerVector_, + stream); + int newNumVecs = (int) (data->size() / bytesPerVector_); + + auto& indices = deviceListIndices_[counts.first]; + if ((indicesOptions_ == INDICES_32_BIT) || + (indicesOptions_ == INDICES_64_BIT)) { + size_t indexSize = + (indicesOptions_ == INDICES_32_BIT) ? sizeof(int) : sizeof(long); + + indices->resize(indices->size() + counts.second * indexSize, stream); + } else if (indicesOptions_ == INDICES_CPU) { + // indices are stored on the CPU side + FAISS_ASSERT(counts.first < listOffsetToUserIndex_.size()); + + auto& userIndices = listOffsetToUserIndex_[counts.first]; + userIndices.resize(newNumVecs); + } else { + // indices are not stored on the GPU or CPU side + FAISS_ASSERT(indicesOptions_ == INDICES_IVF); + } + + // This is used by the multi-pass query to decide how much scratch + // space to allocate for intermediate results + maxListLength_ = std::max(maxListLength_, newNumVecs); + } + + // Update all pointers to the lists on the device that may have + // changed + { + std::vector listIds(assignCounts.size()); + int i = 0; + for (auto& counts : assignCounts) { + listIds[i++] = counts.first; + } + + updateDeviceListInfo_(listIds, stream); + } + } + + // If we're maintaining the indices on the CPU side, update our + // map. We already resized our map above. + if (indicesOptions_ == INDICES_CPU) { + // We need to maintain the indices on the CPU side + HostTensor hostIndices(indices, stream); + + for (int i = 0; i < hostIndices.getSize(0); ++i) { + int listId = listIdsHost[i]; + + // Add vector could be invalid (contains NaNs etc) + if (listId < 0) { + continue; + } + + int offset = listOffsetHost[i]; + + FAISS_ASSERT(listId < listOffsetToUserIndex_.size()); + auto& userIndices = listOffsetToUserIndex_[listId]; + + FAISS_ASSERT(offset < userIndices.size()); + userIndices[offset] = hostIndices[i]; + } + } + + // We similarly need to actually append the new vectors + { + DeviceTensor listOffset(mem, listOffsetHost, stream); + + // Now, for each list to which a vector is being assigned, write it + runIVFFlatInvertedListAppend(listIds, + listOffset, + vecs, + indices, + useResidual_, + residuals, + scalarQ_.get(), + deviceListDataPointers_, + deviceListIndexPointers_, + indicesOptions_, + stream); + } + + return numAdded; +} + +void +IVFFlat::query(Tensor& queries, + Tensor& bitset, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices) { + auto& mem = resources_->getMemoryManagerCurrentDevice(); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // These are caught at a higher level + FAISS_ASSERT(nprobe <= GPU_MAX_SELECTION_K); + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + nprobe = std::min(nprobe, quantizer_->getSize()); + + FAISS_ASSERT(queries.getSize(1) == dim_); + + FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0)); + FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0)); + + // Reserve space for the quantized information + DeviceTensor + coarseDistances(mem, {queries.getSize(0), nprobe}, stream); + DeviceTensor + coarseIndices(mem, {queries.getSize(0), nprobe}, stream); + + DeviceTensor coarseBitset(mem, {0}, stream); + // Find the `nprobe` closest lists; we can use int indices both + // internally and externally + quantizer_->query(queries, + coarseBitset, + nprobe, + metric_, + metricArg_, + coarseDistances, + coarseIndices, + false); + + DeviceTensor + residualBase(mem, {queries.getSize(0), nprobe, dim_}, stream); + + if (useResidual_) { + // Reconstruct vectors from the quantizer + quantizer_->reconstruct(coarseIndices, residualBase); + } + + runIVFFlatScan(queries, + coarseIndices, + bitset, + deviceListDataPointers_, + deviceListIndexPointers_, + indicesOptions_, + deviceListLengths_, + maxListLength_, + k, + metric_, + useResidual_, + residualBase, + scalarQ_.get(), + outDistances, + outIndices, + resources_); + + // If the GPU isn't storing indices (they are on the CPU side), we + // need to perform the re-mapping here + // FIXME: we might ultimately be calling this function with inputs + // from the CPU, these are unnecessary copies + if (indicesOptions_ == INDICES_CPU) { + HostTensor hostOutIndices(outIndices, stream); + + ivfOffsetToUserIndex(hostOutIndices.data(), + numLists_, + hostOutIndices.getSize(0), + hostOutIndices.getSize(1), + listOffsetToUserIndex_); + + // Copy back to GPU, since the input to this function is on the + // GPU + outIndices.copyFrom(hostOutIndices, stream); + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cuh new file mode 100644 index 0000000000..6b29419121 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlat.cuh @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +class IVFFlat : public IVFBase { + public: + /// Construct from a quantizer that has elemen + IVFFlat(GpuResources* resources, + /// We do not own this reference + FlatIndex* quantizer, + faiss::MetricType metric, + float metricArg, + bool useResidual, + /// Optional ScalarQuantizer + faiss::ScalarQuantizer* scalarQ, + IndicesOptions indicesOptions, + MemorySpace space); + + ~IVFFlat() override; + + /// Add vectors to a specific list; the input data can be on the + /// host or on our current device + void addCodeVectorsFromCpu(int listId, + const unsigned char* vecs, + const long* indices, + size_t numVecs); + + void copyCodeVectorsFromCpu(const float* vecs, + const long* indices, + const std::vector& list_length); + + /// Adds the given vectors to this index. + /// The input data must be on our current device. + /// Returns the number of vectors successfully added. Vectors may + /// not be able to be added because they contain NaNs. + int classifyAndAddVectors(Tensor& vecs, + Tensor& indices); + + /// Find the approximate k nearest neigbors for `queries` against + /// our database + void query(Tensor& queries, + Tensor& bitset, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices); + + private: + /// Returns the size of our stored vectors, in bytes + size_t getVectorMemorySize() const; + + private: + /// Do we encode the residual from a coarse quantizer or not? + bool useResidual_; + + /// Scalar quantizer for encoded vectors, if any + std::unique_ptr scalarQ_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cu new file mode 100644 index 0000000000..2b76e0a09b --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cu @@ -0,0 +1,542 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +namespace { + +/// Sort direction per each metric +inline bool metricToSortDirection(MetricType mt) { + switch (mt) { + case MetricType::METRIC_INNER_PRODUCT: + // highest + return true; + case MetricType::METRIC_L2: + // lowest + return false; + default: + // unhandled metric + FAISS_ASSERT(false); + return false; + } +} + +} + +// Number of warps we create per block of IVFFlatScan +constexpr int kIVFFlatScanWarps = 4; + +// Works for any dimension size +template +struct IVFFlatScan { + static __device__ void scan(float* query, + bool useResidual, + float* residualBaseSlice, + void* vecData, + const Codec& codec, + const Metric& metric, + int numVecs, + int dim, + float* distanceOut) { + // How many separate loading points are there for the decoder? + int limit = utils::divDown(dim, Codec::kDimPerIter); + + // Each warp handles a separate chunk of vectors + int warpId = threadIdx.x / kWarpSize; + // FIXME: why does getLaneId() not work when we write out below!?!?! + int laneId = threadIdx.x % kWarpSize; // getLaneId(); + + // Divide the set of vectors among the warps + int vecsPerWarp = utils::divUp(numVecs, kIVFFlatScanWarps); + + int vecStart = vecsPerWarp * warpId; + int vecEnd = min(vecsPerWarp * (warpId + 1), numVecs); + + // Walk the list of vectors for this warp + for (int vec = vecStart; vec < vecEnd; ++vec) { + Metric dist = metric.zero(); + + // Scan the dimensions availabe that have whole units for the decoder, + // as the decoder may handle more than one dimension at once (leaving the + // remainder to be handled separately) + for (int d = laneId; d < limit; d += kWarpSize) { + int realDim = d * Codec::kDimPerIter; + float vecVal[Codec::kDimPerIter]; + + // Decode the kDimPerIter dimensions + codec.decode(vecData, vec, d, vecVal); + +#pragma unroll + for (int j = 0; j < Codec::kDimPerIter; ++j) { + vecVal[j] += useResidual ? residualBaseSlice[realDim + j] : 0.0f; + } + +#pragma unroll + for (int j = 0; j < Codec::kDimPerIter; ++j) { + dist.handle(query[realDim + j], vecVal[j]); + } + } + + // Handle remainder by a single thread, if any + // Not needed if we decode 1 dim per time + if (Codec::kDimPerIter > 1) { + int realDim = limit * Codec::kDimPerIter; + + // Was there any remainder? + if (realDim < dim) { + // Let the first threads in the block sequentially perform it + int remainderDim = realDim + laneId; + + if (remainderDim < dim) { + float vecVal = + codec.decodePartial(vecData, vec, limit, laneId); + vecVal += useResidual ? residualBaseSlice[remainderDim] : 0.0f; + dist.handle(query[remainderDim], vecVal); + } + } + } + + // Reduce distance within warp + auto warpDist = warpReduceAllSum(dist.reduce()); + + if (laneId == 0) { + distanceOut[vec] = warpDist; + } + } + } +}; + +template +__global__ void +ivfFlatScan(Tensor queries, + bool useResidual, + Tensor residualBase, + Tensor listIds, + void** allListData, + int* listLengths, + Codec codec, + Metric metric, + Tensor prefixSumOffsets, + Tensor distance) { + extern __shared__ float smem[]; + + auto queryId = blockIdx.y; + auto probeId = blockIdx.x; + + // This is where we start writing out data + // We ensure that before the array (at offset -1), there is a 0 value + int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1); + + auto listId = listIds[queryId][probeId]; + // Safety guard in case NaNs in input cause no list ID to be generated + if (listId == -1) { + return; + } + + auto query = queries[queryId].data(); + auto vecs = allListData[listId]; + auto numVecs = listLengths[listId]; + auto dim = queries.getSize(1); + auto distanceOut = distance[outBase].data(); + + auto residualBaseSlice = residualBase[queryId][probeId].data(); + + codec.setSmem(smem, dim); + + IVFFlatScan::scan(query, + useResidual, + residualBaseSlice, + vecs, + codec, + metric, + numVecs, + dim, + distanceOut); +} + +void +runIVFFlatScanTile(Tensor& queries, + Tensor& listIds, + Tensor& bitset, + thrust::device_vector& listData, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + Tensor& thrustMem, + Tensor& prefixSumOffsets, + Tensor& allDistances, + Tensor& heapDistances, + Tensor& heapIndices, + int k, + faiss::MetricType metricType, + bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream) { + int dim = queries.getSize(1); + + // Check the amount of shared memory per block available based on our type is + // sufficient + if (scalarQ && + (scalarQ->qtype == QuantizerType::QT_8bit || + scalarQ->qtype == QuantizerType::QT_4bit)) { + int maxDim = getMaxSharedMemPerBlockCurrentDevice() / + (sizeof(float) * 2); + + FAISS_THROW_IF_NOT_FMT(dim < maxDim, + "Insufficient shared memory available on the GPU " + "for QT_8bit or QT_4bit with %d dimensions; " + "maximum dimensions possible is %d", dim, maxDim); + } + + + // Calculate offset lengths, so we know where to write out + // intermediate results + runCalcListOffsets(listIds, listLengths, prefixSumOffsets, thrustMem, stream); + + auto grid = dim3(listIds.getSize(1), listIds.getSize(0)); + auto block = dim3(kWarpSize * kIVFFlatScanWarps); + +#define RUN_IVF_FLAT \ + do { \ + ivfFlatScan \ + <<>>( \ + queries, \ + useResidual, \ + residualBase, \ + listIds, \ + listData.data().get(), \ + listLengths.data().get(), \ + codec, \ + metric, \ + prefixSumOffsets, \ + allDistances); \ + } while (0) + +#define HANDLE_METRICS \ + do { \ + if (metricType == MetricType::METRIC_L2) { \ + L2Distance metric; RUN_IVF_FLAT; \ + } else { \ + IPDistance metric; RUN_IVF_FLAT; \ + } \ + } while (0) + + if (!scalarQ) { + CodecFloat codec(dim * sizeof(float)); + HANDLE_METRICS; + } else { + switch (scalarQ->qtype) { + case QuantizerType::QT_8bit: + { + // FIXME: investigate 32 bit load perf issues +// if (dim % 4 == 0) { + if (false) { + Codec<(int)QuantizerType::QT_8bit, 4> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + HANDLE_METRICS; + } else { + Codec<(int)QuantizerType::QT_8bit, 1> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + HANDLE_METRICS; + } + } + break; + case QuantizerType::QT_8bit_uniform: + { + // FIXME: investigate 32 bit load perf issues + if (false) { +// if (dim % 4 == 0) { + Codec<(int)QuantizerType::QT_8bit_uniform, 4> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + HANDLE_METRICS; + } else { + Codec<(int)QuantizerType::QT_8bit_uniform, 1> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + HANDLE_METRICS; + } + } + break; + case QuantizerType::QT_fp16: + { + if (false) { + // FIXME: investigate 32 bit load perf issues +// if (dim % 2 == 0) { + Codec<(int)QuantizerType::QT_fp16, 2> + codec(scalarQ->code_size); + HANDLE_METRICS; + } else { + Codec<(int)QuantizerType::QT_fp16, 1> + codec(scalarQ->code_size); + HANDLE_METRICS; + } + } + break; + case QuantizerType::QT_8bit_direct: + { + Codec<(int)QuantizerType::QT_8bit_direct, 1> + codec(scalarQ->code_size); + HANDLE_METRICS; + } + break; + case QuantizerType::QT_4bit: + { + Codec<(int)QuantizerType::QT_4bit, 1> + codec(scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + HANDLE_METRICS; + } + break; + case QuantizerType::QT_4bit_uniform: + { + Codec<(int)QuantizerType::QT_4bit_uniform, 1> + codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); + HANDLE_METRICS; + } + break; + default: + // unimplemented, should be handled at a higher level + FAISS_ASSERT(false); + } + } + + CUDA_TEST_ERROR(); + +#undef HANDLE_METRICS +#undef RUN_IVF_FLAT + + // k-select the output in chunks, to increase parallelism + runPass1SelectLists(listIndices, + indicesOptions, + prefixSumOffsets, + listIds, + bitset, + allDistances, + listIds.getSize(1), + k, + metricToSortDirection(metricType), + heapDistances, + heapIndices, + stream); + + // k-select final output + auto flatHeapDistances = heapDistances.downcastInner<2>(); + auto flatHeapIndices = heapIndices.downcastInner<2>(); + + runPass2SelectLists(flatHeapDistances, + flatHeapIndices, + listIndices, + indicesOptions, + prefixSumOffsets, + listIds, + k, + metricToSortDirection(metricType), + outDistances, + outIndices, + stream); +} + +void +runIVFFlatScan(Tensor& queries, + Tensor& listIds, + Tensor& bitset, + thrust::device_vector& listData, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + faiss::MetricType metric, + bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res) { + constexpr int kMinQueryTileSize = 8; + constexpr int kMaxQueryTileSize = 128; + constexpr int kThrustMemSize = 16384; + + int nprobe = listIds.getSize(1); + + auto& mem = res->getMemoryManagerCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // Make a reservation for Thrust to do its dirty work (global memory + // cross-block reduction space); hopefully this is large enough. + DeviceTensor thrustMem1( + mem, {kThrustMemSize}, stream); + DeviceTensor thrustMem2( + mem, {kThrustMemSize}, stream); + DeviceTensor* thrustMem[2] = + {&thrustMem1, &thrustMem2}; + + // How much temporary storage is available? + // If possible, we'd like to fit within the space available. + size_t sizeAvailable = mem.getSizeAvailable(); + + // We run two passes of heap selection + // This is the size of the first-level heap passes + constexpr int kNProbeSplit = 8; + int pass2Chunks = std::min(nprobe, kNProbeSplit); + + size_t sizeForFirstSelectPass = + pass2Chunks * k * (sizeof(float) + sizeof(int)); + + // How much temporary storage we need per each query + size_t sizePerQuery = + 2 * // # streams + ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets + nprobe * maxListLength * sizeof(float) + // allDistances + sizeForFirstSelectPass); + + int queryTileSize = (int) (sizeAvailable / sizePerQuery); + + if (queryTileSize < kMinQueryTileSize) { + queryTileSize = kMinQueryTileSize; + } else if (queryTileSize > kMaxQueryTileSize) { + queryTileSize = kMaxQueryTileSize; + } + + // FIXME: we should adjust queryTileSize to deal with this, since + // indexing is in int32 + FAISS_ASSERT(queryTileSize * nprobe * maxListLength < + std::numeric_limits::max()); + + // Temporary memory buffers + // Make sure there is space prior to the start which will be 0, and + // will handle the boundary condition without branches + DeviceTensor prefixSumOffsetSpace1( + mem, {queryTileSize * nprobe + 1}, stream); + DeviceTensor prefixSumOffsetSpace2( + mem, {queryTileSize * nprobe + 1}, stream); + + DeviceTensor prefixSumOffsets1( + prefixSumOffsetSpace1[1].data(), + {queryTileSize, nprobe}); + DeviceTensor prefixSumOffsets2( + prefixSumOffsetSpace2[1].data(), + {queryTileSize, nprobe}); + DeviceTensor* prefixSumOffsets[2] = + {&prefixSumOffsets1, &prefixSumOffsets2}; + + // Make sure the element before prefixSumOffsets is 0, since we + // depend upon simple, boundary-less indexing to get proper results + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(), + 0, + sizeof(int), + stream)); + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(), + 0, + sizeof(int), + stream)); + + DeviceTensor allDistances1( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor allDistances2( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor* allDistances[2] = + {&allDistances1, &allDistances2}; + + DeviceTensor heapDistances1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapDistances2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapDistances[2] = + {&heapDistances1, &heapDistances2}; + + DeviceTensor heapIndices1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapIndices2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapIndices[2] = + {&heapIndices1, &heapIndices2}; + + auto streams = res->getAlternateStreamsCurrentDevice(); + streamWait(streams, {stream}); + + int curStream = 0; + + for (int query = 0; query < queries.getSize(0); query += queryTileSize) { + int numQueriesInTile = + std::min(queryTileSize, queries.getSize(0) - query); + + auto prefixSumOffsetsView = + prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile); + + auto listIdsView = + listIds.narrowOutermost(query, numQueriesInTile); + auto queryView = + queries.narrowOutermost(query, numQueriesInTile); + auto residualBaseView = + residualBase.narrowOutermost(query, numQueriesInTile); + + auto heapDistancesView = + heapDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto heapIndicesView = + heapIndices[curStream]->narrowOutermost(0, numQueriesInTile); + + auto outDistanceView = + outDistances.narrowOutermost(query, numQueriesInTile); + auto outIndicesView = + outIndices.narrowOutermost(query, numQueriesInTile); + + runIVFFlatScanTile(queryView, + listIdsView, + bitset, + listData, + listIndices, + indicesOptions, + listLengths, + *thrustMem[curStream], + prefixSumOffsetsView, + *allDistances[curStream], + heapDistancesView, + heapIndicesView, + k, + metric, + useResidual, + residualBaseView, + scalarQ, + outDistanceView, + outIndicesView, + streams[curStream]); + + curStream = (curStream + 1) % 2; + } + + streamWait({stream}, streams); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cuh new file mode 100644 index 0000000000..2b67cba06f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFFlatScan.cuh @@ -0,0 +1,40 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +void runIVFFlatScan(Tensor& queries, + Tensor& listIds, + Tensor& bitset, + thrust::device_vector& listData, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + faiss::MetricType metric, + bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu new file mode 100644 index 0000000000..48254c1f5b --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu @@ -0,0 +1,811 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +IVFPQ::IVFPQ(GpuResources* resources, + faiss::MetricType metric, + float metricArg, + FlatIndex* quantizer, + int numSubQuantizers, + int bitsPerSubQuantizer, + float* pqCentroidData, + IndicesOptions indicesOptions, + bool useFloat16LookupTables, + MemorySpace space) : + IVFBase(resources, + metric, + metricArg, + quantizer, + numSubQuantizers, + indicesOptions, + space), + numSubQuantizers_(numSubQuantizers), + bitsPerSubQuantizer_(bitsPerSubQuantizer), + numSubQuantizerCodes_(utils::pow2(bitsPerSubQuantizer_)), + dimPerSubQuantizer_(dim_ / numSubQuantizers), + precomputedCodes_(false), + useFloat16LookupTables_(useFloat16LookupTables) { + FAISS_ASSERT(pqCentroidData); + + FAISS_ASSERT(bitsPerSubQuantizer_ <= 8); + FAISS_ASSERT(dim_ % numSubQuantizers_ == 0); + FAISS_ASSERT(isSupportedPQCodeLength(bytesPerVector_)); + +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16LookupTables_); +#endif + + setPQCentroids_(pqCentroidData); +} + +IVFPQ::~IVFPQ() { +} + + +bool +IVFPQ::isSupportedPQCodeLength(int size) { + switch (size) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 12: + case 16: + case 20: + case 24: + case 28: + case 32: + case 40: + case 48: + case 56: // only supported with float16 + case 64: // only supported with float16 + case 96: // only supported with float16 + return true; + default: + return false; + } +} + +bool +IVFPQ::isSupportedNoPrecomputedSubDimSize(int dims) { + return faiss::gpu::isSupportedNoPrecomputedSubDimSize(dims); +} + +void +IVFPQ::setPrecomputedCodes(bool enable) { + if (enable && metric_ == MetricType::METRIC_INNER_PRODUCT) { + FAISS_THROW_MSG("Precomputed codes are not needed for GpuIndexIVFPQ " + "with METRIC_INNER_PRODUCT"); + } + + if (precomputedCodes_ != enable) { + precomputedCodes_ = enable; + + if (precomputedCodes_) { + precomputeCodes_(); + } else { + // Clear out old precomputed code data + precomputedCode_ = std::move(DeviceTensor()); +#ifdef FAISS_USE_FLOAT16 + precomputedCodeHalf_ = std::move(DeviceTensor()); +#endif + } + } +} + +int +IVFPQ::classifyAndAddVectors(Tensor& vecs, + Tensor& indices) { + FAISS_ASSERT(vecs.getSize(0) == indices.getSize(0)); + FAISS_ASSERT(vecs.getSize(1) == dim_); + + auto& mem = resources_->getMemoryManagerCurrentDevice(); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // Number of valid vectors that we actually add; we return this + int numAdded = 0; + + // We don't actually need this + DeviceTensor listDistance(mem, {vecs.getSize(0), 1}, stream); + // We use this + DeviceTensor listIds2d(mem, {vecs.getSize(0), 1}, stream); + auto listIds = listIds2d.view<1>({vecs.getSize(0)}); + + /* pseudo bitset */ + DeviceTensor bitset(mem, {0}, stream); + quantizer_->query(vecs, + bitset, + 1, + metric_, + metricArg_, + listDistance, + listIds2d, + false); + + // Copy the lists that we wish to append to back to the CPU + // FIXME: really this can be into pinned memory and a true async + // copy on a different stream; we can start the copy early, but it's + // tiny + HostTensor listIdsHost(listIds, stream); + + // Calculate the residual for each closest centroid + DeviceTensor residuals( + mem, {vecs.getSize(0), vecs.getSize(1)}, stream); + +#ifdef FAISS_USE_FLOAT16 + if (quantizer_->getUseFloat16()) { + auto& coarseCentroids = quantizer_->getVectorsFloat16Ref(); + runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); + } else { + auto& coarseCentroids = quantizer_->getVectorsFloat32Ref(); + runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); + } +#else + auto& coarseCentroids = quantizer_->getVectorsFloat32Ref(); + runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); +#endif + + // Residuals are in the form + // (vec x numSubQuantizer x dimPerSubQuantizer) + // transpose to + // (numSubQuantizer x vec x dimPerSubQuantizer) + auto residualsView = residuals.view<3>( + {residuals.getSize(0), numSubQuantizers_, dimPerSubQuantizer_}); + + DeviceTensor residualsTranspose( + mem, + {numSubQuantizers_, residuals.getSize(0), dimPerSubQuantizer_}, + stream); + + runTransposeAny(residualsView, 0, 1, residualsTranspose, stream); + + // Get the product quantizer centroids in the form + // (numSubQuantizer x numSubQuantizerCodes x dimPerSubQuantizer) + // which is pqCentroidsMiddleCode_ + + // We now have a batch operation to find the top-1 distances: + // batch size: numSubQuantizer + // centroids: (numSubQuantizerCodes x dimPerSubQuantizer) + // residuals: (vec x dimPerSubQuantizer) + // => (numSubQuantizer x vec x 1) + + DeviceTensor closestSubQDistance( + mem, {numSubQuantizers_, residuals.getSize(0), 1}, stream); + DeviceTensor closestSubQIndex( + mem, {numSubQuantizers_, residuals.getSize(0), 1}, stream); + + for (int subQ = 0; subQ < numSubQuantizers_; ++subQ) { + auto closestSubQDistanceView = closestSubQDistance[subQ].view(); + auto closestSubQIndexView = closestSubQIndex[subQ].view(); + + auto pqCentroidsMiddleCodeView = pqCentroidsMiddleCode_[subQ].view(); + auto residualsTransposeView = residualsTranspose[subQ].view(); + + runL2Distance(resources_, + pqCentroidsMiddleCodeView, + true, // pqCentroidsMiddleCodeView is row major + nullptr, // no precomputed norms + residualsTransposeView, + true, // residualsTransposeView is row major + bitset, + 1, + closestSubQDistanceView, + closestSubQIndexView, + // We don't care about distances + true); + } + + // Now, we have the nearest sub-q centroid for each slice of the + // residual vector. + auto closestSubQIndexView = closestSubQIndex.view<2>( + {numSubQuantizers_, residuals.getSize(0)}); + + // Transpose this for easy use + DeviceTensor encodings( + mem, {residuals.getSize(0), numSubQuantizers_}, stream); + + runTransposeAny(closestSubQIndexView, 0, 1, encodings, stream); + + // Now we add the encoded vectors to the individual lists + // First, make sure that there is space available for adding the new + // encoded vectors and indices + + // list id -> # being added + std::unordered_map assignCounts; + + // vector id -> offset in list + // (we already have vector id -> list id in listIds) + HostTensor listOffsetHost({listIdsHost.getSize(0)}); + + for (int i = 0; i < listIdsHost.getSize(0); ++i) { + int listId = listIdsHost[i]; + + // Add vector could be invalid (contains NaNs etc) + if (listId < 0) { + listOffsetHost[i] = -1; + continue; + } + + FAISS_ASSERT(listId < numLists_); + ++numAdded; + + int offset = deviceListData_[listId]->size() / bytesPerVector_; + + auto it = assignCounts.find(listId); + if (it != assignCounts.end()) { + offset += it->second; + it->second++; + } else { + assignCounts[listId] = 1; + } + + listOffsetHost[i] = offset; + } + + // If we didn't add anything (all invalid vectors), no need to + // continue + if (numAdded == 0) { + return 0; + } + + // We need to resize the data structures for the inverted lists on + // the GPUs, which means that they might need reallocation, which + // means that their base address may change. Figure out the new base + // addresses, and update those in a batch on the device + { + // Resize all of the lists that we are appending to + for (auto& counts : assignCounts) { + auto& codes = deviceListData_[counts.first]; + codes->resize(codes->size() + counts.second * bytesPerVector_, + stream); + int newNumVecs = (int) (codes->size() / bytesPerVector_); + + auto& indices = deviceListIndices_[counts.first]; + if ((indicesOptions_ == INDICES_32_BIT) || + (indicesOptions_ == INDICES_64_BIT)) { + size_t indexSize = + (indicesOptions_ == INDICES_32_BIT) ? sizeof(int) : sizeof(long); + + indices->resize(indices->size() + counts.second * indexSize, stream); + } else if (indicesOptions_ == INDICES_CPU) { + // indices are stored on the CPU side + FAISS_ASSERT(counts.first < listOffsetToUserIndex_.size()); + + auto& userIndices = listOffsetToUserIndex_[counts.first]; + userIndices.resize(newNumVecs); + } else { + // indices are not stored on the GPU or CPU side + FAISS_ASSERT(indicesOptions_ == INDICES_IVF); + } + + // This is used by the multi-pass query to decide how much scratch + // space to allocate for intermediate results + maxListLength_ = std::max(maxListLength_, newNumVecs); + } + + // Update all pointers and sizes on the device for lists that we + // appended to + { + std::vector listIds(assignCounts.size()); + int i = 0; + for (auto& counts : assignCounts) { + listIds[i++] = counts.first; + } + + updateDeviceListInfo_(listIds, stream); + } + } + + // If we're maintaining the indices on the CPU side, update our + // map. We already resized our map above. + if (indicesOptions_ == INDICES_CPU) { + // We need to maintain the indices on the CPU side + HostTensor hostIndices(indices, stream); + + for (int i = 0; i < hostIndices.getSize(0); ++i) { + int listId = listIdsHost[i]; + + // Add vector could be invalid (contains NaNs etc) + if (listId < 0) { + continue; + } + + int offset = listOffsetHost[i]; + + FAISS_ASSERT(listId < listOffsetToUserIndex_.size()); + auto& userIndices = listOffsetToUserIndex_[listId]; + + FAISS_ASSERT(offset < userIndices.size()); + userIndices[offset] = hostIndices[i]; + } + } + + // We similarly need to actually append the new encoded vectors + { + DeviceTensor listOffset(mem, listOffsetHost, stream); + + // This kernel will handle appending each encoded vector + index to + // the appropriate list + runIVFPQInvertedListAppend(listIds, + listOffset, + encodings, + indices, + deviceListDataPointers_, + deviceListIndexPointers_, + indicesOptions_, + stream); + } + + return numAdded; +} + +void +IVFPQ::addCodeVectorsFromCpu(int listId, + const void* codes, + const long* indices, + size_t numVecs) { + // This list must already exist + FAISS_ASSERT(listId < deviceListData_.size()); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // If there's nothing to add, then there's nothing we have to do + if (numVecs == 0) { + return; + } + + size_t lengthInBytes = numVecs * bytesPerVector_; + + auto& listCodes = deviceListData_[listId]; + auto prevCodeData = listCodes->data(); + + // We only have int32 length representations on the GPU per each + // list; the length is in sizeof(char) + FAISS_ASSERT(listCodes->size() % bytesPerVector_ == 0); + FAISS_ASSERT(listCodes->size() + lengthInBytes <= + (size_t) std::numeric_limits::max()); + + listCodes->append((unsigned char*) codes, + lengthInBytes, + stream, + true /* exact reserved size */); + + // Handle the indices as well + addIndicesFromCpu_(listId, indices, numVecs); + + // This list address may have changed due to vector resizing, but + // only bother updating it on the device if it has changed + if (prevCodeData != listCodes->data()) { + deviceListDataPointers_[listId] = listCodes->data(); + } + + // And our size has changed too + int listLength = listCodes->size() / bytesPerVector_; + deviceListLengths_[listId] = listLength; + + // We update this as well, since the multi-pass algorithm uses it + maxListLength_ = std::max(maxListLength_, listLength); + + // device_vector add is potentially happening on a different stream + // than our default stream + if (resources_->getDefaultStreamCurrentDevice() != 0) { + streamWait({stream}, {0}); + } +} + +void +IVFPQ::setPQCentroids_(float* data) { + size_t pqSize = + numSubQuantizers_ * numSubQuantizerCodes_ * dimPerSubQuantizer_; + + // Make sure the data is on the host + // FIXME: why are we doing this? + thrust::host_vector hostMemory; + hostMemory.insert(hostMemory.end(), data, data + pqSize); + + HostTensor pqHost( + hostMemory.data(), + {numSubQuantizers_, numSubQuantizerCodes_, dimPerSubQuantizer_}); + DeviceTensor pqDevice( + pqHost, + resources_->getDefaultStreamCurrentDevice()); + + DeviceTensor pqDeviceTranspose( + {numSubQuantizers_, dimPerSubQuantizer_, numSubQuantizerCodes_}); + runTransposeAny(pqDevice, 1, 2, pqDeviceTranspose, + resources_->getDefaultStreamCurrentDevice()); + + pqCentroidsInnermostCode_ = std::move(pqDeviceTranspose); + + // Also maintain the PQ centroids in the form + // (sub q)(code id)(sub dim) + DeviceTensor pqCentroidsMiddleCode( + {numSubQuantizers_, numSubQuantizerCodes_, dimPerSubQuantizer_}); + runTransposeAny(pqCentroidsInnermostCode_, 1, 2, pqCentroidsMiddleCode, + resources_->getDefaultStreamCurrentDevice()); + + pqCentroidsMiddleCode_ = std::move(pqCentroidsMiddleCode); +} + +template +void +IVFPQ::precomputeCodesT_() { + FAISS_ASSERT(metric_ == MetricType::METRIC_L2); + + // + // d = || x - y_C ||^2 + || y_R ||^2 + 2 * (y_C|y_R) - 2 * (x|y_R) + // --------------- --------------------------- ------- + // term 1 term 2 term 3 + // + + // Terms 1 and 3 are available only at query time. We compute term 2 + // here. + + // Compute ||y_R||^2 by treating + // (sub q)(code id)(sub dim) as (sub q * code id)(sub dim) + auto pqCentroidsMiddleCodeView = + pqCentroidsMiddleCode_.view<2>( + {numSubQuantizers_ * numSubQuantizerCodes_, dimPerSubQuantizer_}); + DeviceTensor subQuantizerNorms( + {numSubQuantizers_ * numSubQuantizerCodes_}); + + runL2Norm(pqCentroidsMiddleCodeView, true, + subQuantizerNorms, true, + resources_->getDefaultStreamCurrentDevice()); + + // Compute 2 * (y_C|y_R) via batch matrix multiplication + // batch size (sub q) x {(centroid id)(sub dim) x (code id)(sub dim)'} + // => (sub q) x {(centroid id)(code id)} + // => (sub q)(centroid id)(code id) + + // View (centroid id)(dim) as + // (centroid id)(sub q)(dim) + // Transpose (centroid id)(sub q)(sub dim) to + // (sub q)(centroid id)(sub dim) + auto& coarseCentroids = quantizer_->template getVectorsRef(); + auto centroidView = coarseCentroids.template view<3>( + {coarseCentroids.getSize(0), numSubQuantizers_, dimPerSubQuantizer_}); + DeviceTensor centroidsTransposed( + {numSubQuantizers_, coarseCentroids.getSize(0), dimPerSubQuantizer_}); + + runTransposeAny(centroidView, 0, 1, centroidsTransposed, + resources_->getDefaultStreamCurrentDevice()); + + DeviceTensor coarsePQProduct( + {numSubQuantizers_, coarseCentroids.getSize(0), numSubQuantizerCodes_}); + + runIteratedMatrixMult(coarsePQProduct, false, + centroidsTransposed, false, + pqCentroidsMiddleCode_, true, + 2.0f, 0.0f, + resources_->getBlasHandleCurrentDevice(), + resources_->getDefaultStreamCurrentDevice()); + + // Transpose (sub q)(centroid id)(code id) to + // (centroid id)(sub q)(code id) + DeviceTensor coarsePQProductTransposed( + {coarseCentroids.getSize(0), numSubQuantizers_, numSubQuantizerCodes_}); + runTransposeAny(coarsePQProduct, 0, 1, coarsePQProductTransposed, + resources_->getDefaultStreamCurrentDevice()); + + // View (centroid id)(sub q)(code id) as + // (centroid id)(sub q * code id) + auto coarsePQProductTransposedView = coarsePQProductTransposed.view<2>( + {coarseCentroids.getSize(0), numSubQuantizers_ * numSubQuantizerCodes_}); + + // Sum || y_R ||^2 + 2 * (y_C|y_R) + // i.e., add norms (sub q * code id) + // along columns of inner product (centroid id)(sub q * code id) + runSumAlongColumns(subQuantizerNorms, coarsePQProductTransposedView, + resources_->getDefaultStreamCurrentDevice()); + + // We added into the view, so `coarsePQProductTransposed` is now our + // precomputed term 2. +#ifdef FAISS_USE_FLOAT16 + if (useFloat16LookupTables_) { + precomputedCodeHalf_ = + convertTensor(resources_, + resources_->getDefaultStreamCurrentDevice(), + coarsePQProductTransposed); + } else { + precomputedCode_ = std::move(coarsePQProductTransposed); + } +#else + precomputedCode_ = std::move(coarsePQProductTransposed); +#endif + +} + +void +IVFPQ::precomputeCodes_() { +#ifdef FAISS_USE_FLOAT16 + if (quantizer_->getUseFloat16()) { + precomputeCodesT_(); + } else { + precomputeCodesT_(); + } +#else + precomputeCodesT_(); +#endif +} + +void +IVFPQ::query(Tensor& queries, + Tensor& bitset, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices) { + // These are caught at a higher level + FAISS_ASSERT(nprobe <= GPU_MAX_SELECTION_K); + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + auto& mem = resources_->getMemoryManagerCurrentDevice(); + auto stream = resources_->getDefaultStreamCurrentDevice(); + nprobe = std::min(nprobe, quantizer_->getSize()); + + FAISS_ASSERT(queries.getSize(1) == dim_); + FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0)); + FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0)); + + // Reserve space for the closest coarse centroids + DeviceTensor + coarseDistances(mem, {queries.getSize(0), nprobe}, stream); + DeviceTensor + coarseIndices(mem, {queries.getSize(0), nprobe}, stream); + + DeviceTensor coarseBitset(mem, {0}, stream); + // Find the `nprobe` closest coarse centroids; we can use int + // indices both internally and externally + quantizer_->query(queries, + coarseBitset, + nprobe, + metric_, + metricArg_, + coarseDistances, + coarseIndices, + true); + + if (precomputedCodes_) { + FAISS_ASSERT(metric_ == MetricType::METRIC_L2); + + runPQPrecomputedCodes_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); + } else { + runPQNoPrecomputedCodes_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); + } + + // If the GPU isn't storing indices (they are on the CPU side), we + // need to perform the re-mapping here + // FIXME: we might ultimately be calling this function with inputs + // from the CPU, these are unnecessary copies + if (indicesOptions_ == INDICES_CPU) { + HostTensor hostOutIndices(outIndices, stream); + + ivfOffsetToUserIndex(hostOutIndices.data(), + numLists_, + hostOutIndices.getSize(0), + hostOutIndices.getSize(1), + listOffsetToUserIndex_); + + // Copy back to GPU, since the input to this function is on the + // GPU + outIndices.copyFrom(hostOutIndices, stream); + } +} + +std::vector +IVFPQ::getListCodes(int listId) const { + FAISS_ASSERT(listId < deviceListData_.size()); + + return deviceListData_[listId]->copyToHost( + resources_->getDefaultStreamCurrentDevice()); +} + +Tensor +IVFPQ::getPQCentroids() { + return pqCentroidsMiddleCode_; +} + +void +IVFPQ::runPQPrecomputedCodes_( + Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices) { + FAISS_ASSERT(metric_ == MetricType::METRIC_L2); + + auto& mem = resources_->getMemoryManagerCurrentDevice(); + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // Compute precomputed code term 3, - 2 * (x|y_R) + // This is done via batch MM + // {sub q} x {(query id)(sub dim) * (code id)(sub dim)'} => + // {sub q} x {(query id)(code id)} + DeviceTensor term3Transposed( + mem, + {queries.getSize(0), numSubQuantizers_, numSubQuantizerCodes_}, + stream); + + // These allocations within are only temporary, so release them when + // we're done to maximize free space + { + auto querySubQuantizerView = queries.view<3>( + {queries.getSize(0), numSubQuantizers_, dimPerSubQuantizer_}); + DeviceTensor queriesTransposed( + mem, + {numSubQuantizers_, queries.getSize(0), dimPerSubQuantizer_}, + stream); + runTransposeAny(querySubQuantizerView, 0, 1, queriesTransposed, stream); + + DeviceTensor term3( + mem, + {numSubQuantizers_, queries.getSize(0), numSubQuantizerCodes_}, + stream); + + runIteratedMatrixMult(term3, false, + queriesTransposed, false, + pqCentroidsMiddleCode_, true, + -2.0f, 0.0f, + resources_->getBlasHandleCurrentDevice(), + stream); + + runTransposeAny(term3, 0, 1, term3Transposed, stream); + } + + NoTypeTensor<3, true> term2; + NoTypeTensor<3, true> term3; +#ifdef FAISS_USE_FLOAT16 + DeviceTensor term3Half; + + if (useFloat16LookupTables_) { + term3Half = + convertTensor(resources_, stream, term3Transposed); + + term2 = NoTypeTensor<3, true>(precomputedCodeHalf_); + term3 = NoTypeTensor<3, true>(term3Half); + } +#endif + + if (!useFloat16LookupTables_) { + term2 = NoTypeTensor<3, true>(precomputedCode_); + term3 = NoTypeTensor<3, true>(term3Transposed); + } + + runPQScanMultiPassPrecomputed(queries, + coarseDistances, // term 1 + term2, // term 2 + term3, // term 3 + coarseIndices, + bitset, + useFloat16LookupTables_, + bytesPerVector_, + numSubQuantizers_, + numSubQuantizerCodes_, + deviceListDataPointers_, + deviceListIndexPointers_, + indicesOptions_, + deviceListLengths_, + maxListLength_, + k, + outDistances, + outIndices, + resources_); +} + +template +void +IVFPQ::runPQNoPrecomputedCodesT_( + Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices) { + auto& coarseCentroids = quantizer_->template getVectorsRef(); + + runPQScanMultiPassNoPrecomputed(queries, + coarseCentroids, + pqCentroidsInnermostCode_, + coarseIndices, + bitset, + useFloat16LookupTables_, + bytesPerVector_, + numSubQuantizers_, + numSubQuantizerCodes_, + deviceListDataPointers_, + deviceListIndexPointers_, + indicesOptions_, + deviceListLengths_, + maxListLength_, + k, + metric_, + outDistances, + outIndices, + resources_); +} + +void +IVFPQ::runPQNoPrecomputedCodes_( + Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices) { +#ifdef FAISS_USE_FLOAT16 + if (quantizer_->getUseFloat16()) { + runPQNoPrecomputedCodesT_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); + } else { + runPQNoPrecomputedCodesT_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); + } +#else + runPQNoPrecomputedCodesT_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); +#endif + +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh new file mode 100644 index 0000000000..ad03fb4f89 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh @@ -0,0 +1,161 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Implementing class for IVFPQ on the GPU +class IVFPQ : public IVFBase { + public: + IVFPQ(GpuResources* resources, + faiss::MetricType metric, + float metricArg, + /// We do not own this reference + FlatIndex* quantizer, + int numSubQuantizers, + int bitsPerSubQuantizer, + float* pqCentroidData, + IndicesOptions indicesOptions, + bool useFloat16LookupTables, + MemorySpace space); + + /// Returns true if we support PQ in this size + static bool isSupportedPQCodeLength(int size); + + /// For no precomputed codes, is this a supported sub-dimension + /// size? + /// FIXME: get MM implementation working again + static bool isSupportedNoPrecomputedSubDimSize(int dims); + + ~IVFPQ() override; + + /// Enable or disable pre-computed codes + void setPrecomputedCodes(bool enable); + + /// Adds a set of codes and indices to a list; the data can be + /// resident on either the host or the device + void addCodeVectorsFromCpu(int listId, + const void* codes, + const long* indices, + size_t numVecs); + + /// Calcuates the residual and quantizes the vectors, adding them to + /// this index + /// The input data must be on our current device. + /// Returns the number of vectors successfully added. Vectors may + /// not be able to be added because they contain NaNs. + int classifyAndAddVectors(Tensor& vecs, + Tensor& indices); + + /// Find the approximate k nearest neigbors for `queries` against + /// our database + void query(Tensor& queries, + Tensor& bitset, + int nprobe, + int k, + Tensor& outDistances, + Tensor& outIndices); + + /// Return the list codes of a particular list back to the CPU + std::vector getListCodes(int listId) const; + + /// Returns our set of sub-quantizers of the form + /// (sub q)(code id)(sub dim) + Tensor getPQCentroids(); + + private: + /// Sets the current product quantizer centroids; the data can be + /// resident on either the host or the device. It will be transposed + /// into our preferred data layout + /// Data must be a row-major, 3-d array of size + /// (numSubQuantizers, numSubQuantizerCodes, dim / numSubQuantizers) + void setPQCentroids_(float* data); + + /// Calculate precomputed residual distance information + void precomputeCodes_(); + + /// Calculate precomputed residual distance information (for different coarse + /// centroid type) + template + void precomputeCodesT_(); + + /// Runs kernels for scanning inverted lists with precomputed codes + void runPQPrecomputedCodes_(Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices); + + /// Runs kernels for scanning inverted lists without precomputed codes + void runPQNoPrecomputedCodes_(Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices); + + /// Runs kernels for scanning inverted lists without precomputed codes (for + /// different coarse centroid type) + template + void runPQNoPrecomputedCodesT_(Tensor& queries, + Tensor& bitset, + DeviceTensor& coarseDistances, + DeviceTensor& coarseIndices, + int k, + Tensor& outDistances, + Tensor& outIndices); + + private: + /// Number of sub-quantizers per vector + const int numSubQuantizers_; + + /// Number of bits per sub-quantizer + const int bitsPerSubQuantizer_; + + /// Number of per sub-quantizer codes (2^bits) + const int numSubQuantizerCodes_; + + /// Number of dimensions per each sub-quantizer + const int dimPerSubQuantizer_; + + /// Do we maintain precomputed terms and lookup tables in float16 + /// form? + const bool useFloat16LookupTables_; + + /// On the GPU, we prefer different PQ centroid data layouts for + /// different purposes. + /// + /// (sub q)(sub dim)(code id) + DeviceTensor pqCentroidsInnermostCode_; + + /// (sub q)(code id)(sub dim) + DeviceTensor pqCentroidsMiddleCode_; + + /// Are precomputed codes enabled? (additional factoring and + /// precomputation of the residual distance, to reduce query-time work) + bool precomputedCodes_; + + /// Precomputed term 2 in float form + /// (centroid id)(sub q)(code id) + DeviceTensor precomputedCode_; + + /// Precomputed term 2 in half form +#ifdef FAISS_USE_FLOAT16 + DeviceTensor precomputedCodeHalf_; +#endif +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cu new file mode 100644 index 0000000000..fda439fea2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cu @@ -0,0 +1,78 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Calculates the total number of intermediate distances to consider +// for all queries +__global__ void +getResultLengths(Tensor topQueryToCentroid, + int* listLengths, + int totalSize, + Tensor length) { + int linearThreadId = blockIdx.x * blockDim.x + threadIdx.x; + if (linearThreadId >= totalSize) { + return; + } + + int nprobe = topQueryToCentroid.getSize(1); + int queryId = linearThreadId / nprobe; + int listId = linearThreadId % nprobe; + + int centroidId = topQueryToCentroid[queryId][listId]; + + // Safety guard in case NaNs in input cause no list ID to be generated + length[queryId][listId] = (centroidId != -1) ? listLengths[centroidId] : 0; +} + +void runCalcListOffsets(Tensor& topQueryToCentroid, + thrust::device_vector& listLengths, + Tensor& prefixSumOffsets, + Tensor& thrustMem, + cudaStream_t stream) { + FAISS_ASSERT(topQueryToCentroid.getSize(0) == prefixSumOffsets.getSize(0)); + FAISS_ASSERT(topQueryToCentroid.getSize(1) == prefixSumOffsets.getSize(1)); + + int totalSize = topQueryToCentroid.numElements(); + + int numThreads = std::min(totalSize, getMaxThreadsCurrentDevice()); + int numBlocks = utils::divUp(totalSize, numThreads); + + auto grid = dim3(numBlocks); + auto block = dim3(numThreads); + + getResultLengths<<>>( + topQueryToCentroid, + listLengths.data().get(), + totalSize, + prefixSumOffsets); + CUDA_TEST_ERROR(); + + // Prefix sum of the indices, so we know where the intermediate + // results should be maintained + // Thrust wants a place for its temporary allocations, so provide + // one, so it won't call cudaMalloc/Free + GpuResourcesThrustAllocator alloc(thrustMem.data(), + thrustMem.getSizeInBytes()); + + thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), + prefixSumOffsets.data(), + prefixSumOffsets.data() + totalSize, + prefixSumOffsets.data()); + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cuh new file mode 100644 index 0000000000..3eb226568d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtils.cuh @@ -0,0 +1,119 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +// A collection of utility functions for IVFPQ and IVFFlat, for +// post-processing and k-selecting the results +namespace faiss { namespace gpu { + +// This is warp divergence central, but this is really a final step +// and happening a small number of times +inline __device__ int binarySearchForBucket(int* prefixSumOffsets, + int size, + int val) { + int start = 0; + int end = size; + + while (end - start > 0) { + int mid = start + (end - start) / 2; + + int midVal = prefixSumOffsets[mid]; + + // Find the first bucket that we are <= + if (midVal <= val) { + start = mid + 1; + } else { + end = mid; + } + } + + // We must find the bucket that it is in + assert(start != size); + + return start; +} + +inline __device__ long +getListIndex(int queryId, + int offset, + void** listIndices, + Tensor& prefixSumOffsets, + Tensor& topQueryToCentroid, + IndicesOptions opt) { + long index = -1; + + // In order to determine the actual user index, we need to first + // determine what list it was in. + // We do this by binary search in the prefix sum list. + int probe = binarySearchForBucket(prefixSumOffsets[queryId].data(), + prefixSumOffsets.getSize(1), + offset); + + // This is then the probe for the query; we can find the actual + // list ID from this + int listId = topQueryToCentroid[queryId][probe]; + + // Now, we need to know the offset within the list + // We ensure that before the array (at offset -1), there is a 0 value + int listStart = *(prefixSumOffsets[queryId][probe].data() - 1); + int listOffset = offset - listStart; + + // This gives us our final index + if (opt == INDICES_32_BIT) { + index = (long) ((int*) listIndices[listId])[listOffset]; + } else if (opt == INDICES_64_BIT) { + index = ((long*) listIndices[listId])[listOffset]; + } else { + index = ((long) listId << 32 | (long) listOffset); + } + + return index; +} + +/// Function for multi-pass scanning that collects the length of +/// intermediate results for all (query, probe) pair +void runCalcListOffsets(Tensor& topQueryToCentroid, + thrust::device_vector& listLengths, + Tensor& prefixSumOffsets, + Tensor& thrustMem, + cudaStream_t stream); + +/// Performs a first pass of k-selection on the results +void runPass1SelectLists(thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + Tensor& prefixSumOffsets, + Tensor& topQueryToCentroid, + Tensor& bitset, + Tensor& distance, + int nprobe, + int k, + bool chooseLargest, + Tensor& heapDistances, + Tensor& heapIndices, + cudaStream_t stream); + +/// Performs a final pass of k-selection on the results, producing the +/// final indices +void runPass2SelectLists(Tensor& heapDistances, + Tensor& heapIndices, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + Tensor& prefixSumOffsets, + Tensor& topQueryToCentroid, + int k, + bool chooseLargest, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect1.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect1.cu new file mode 100644 index 0000000000..b575d3c0a4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect1.cu @@ -0,0 +1,199 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +// +// This kernel is split into a separate compilation unit to cut down +// on compile time +// + +namespace faiss { namespace gpu { + +template +__global__ void +pass1SelectLists(void** listIndices, + Tensor prefixSumOffsets, + Tensor topQueryToCentroid, + Tensor bitset, + Tensor distance, + int nprobe, + int k, + IndicesOptions opt, + Tensor heapDistances, + Tensor heapIndices) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ float smemK[kNumWarps * NumWarpQ]; + __shared__ int smemV[kNumWarps * NumWarpQ]; + + constexpr auto kInit = Dir ? kFloatMin : kFloatMax; + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(kInit, -1, smemK, smemV, k); + + auto queryId = blockIdx.y; + auto sliceId = blockIdx.x; + auto numSlices = gridDim.x; + + int sliceSize = (nprobe / numSlices); + int sliceStart = sliceSize * sliceId; + int sliceEnd = sliceId == (numSlices - 1) ? nprobe : + sliceStart + sliceSize; + auto offsets = prefixSumOffsets[queryId].data(); + + // We ensure that before the array (at offset -1), there is a 0 value + int start = *(&offsets[sliceStart] - 1); + int end = offsets[sliceEnd - 1]; + + int num = end - start; + int limit = utils::roundDown(num, kWarpSize); + + int i = threadIdx.x; + auto distanceStart = distance[start].data(); + bool bitsetEmpty = (bitset.getSize(0) == 0); + long index = -1; + + // BlockSelect add cannot be used in a warp divergent circumstance; we + // handle the remainder warp below + for (; i < limit; i += blockDim.x) { + index = getListIndex(queryId, + start + i, + listIndices, + prefixSumOffsets, + topQueryToCentroid, + opt); + if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) { + heap.addThreadQ(distanceStart[i], start + i); + } + heap.checkThreadQ(); + } + + // Handle warp divergence separately + if (i < num) { + index = getListIndex(queryId, + start + i, + listIndices, + prefixSumOffsets, + topQueryToCentroid, + opt); + if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) { + heap.addThreadQ(distanceStart[i], start + i); + } + } + + // Merge all final results + heap.reduce(); + + // Write out the final k-selected values; they should be all + // together + for (int i = threadIdx.x; i < k; i += blockDim.x) { + heapDistances[queryId][sliceId][i] = smemK[i]; + heapIndices[queryId][sliceId][i] = smemV[i]; + } +} + +void +runPass1SelectLists(thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + Tensor& prefixSumOffsets, + Tensor& topQueryToCentroid, + Tensor& bitset, + Tensor& distance, + int nprobe, + int k, + bool chooseLargest, + Tensor& heapDistances, + Tensor& heapIndices, + cudaStream_t stream) { + // This is caught at a higher level + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0)); + +#define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \ + do { \ + pass1SelectLists \ + <<>>(listIndices.data().get(), \ + prefixSumOffsets, \ + topQueryToCentroid, \ + bitset, \ + distance, \ + nprobe, \ + k, \ + indicesOptions, \ + heapDistances, \ + heapIndices); \ + CUDA_TEST_ERROR(); \ + return; /* success */ \ + } while (0) + +#if GPU_MAX_SELECTION_K >= 2048 + + // block size 128 for k <= 1024, 64 for k = 2048 +#define RUN_PASS_DIR(DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(128, 1, 1, DIR); \ + } else if (k <= 32) { \ + RUN_PASS(128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(128, 1024, 8, DIR); \ + } else if (k <= 2048) { \ + RUN_PASS(64, 2048, 8, DIR); \ + } \ + } while (0) + +#else + +#define RUN_PASS_DIR(DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(128, 1, 1, DIR); \ + } else if (k <= 32) { \ + RUN_PASS(128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(128, 1024, 8, DIR); \ + } \ + } while (0) + +#endif // GPU_MAX_SELECTION_K + + if (chooseLargest) { + RUN_PASS_DIR(true); + } else { + RUN_PASS_DIR(false); + } + +#undef RUN_PASS_DIR +#undef RUN_PASS +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect2.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect2.cu new file mode 100644 index 0000000000..8c6b9eb3b8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFUtilsSelect2.cu @@ -0,0 +1,218 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +// +// This kernel is split into a separate compilation unit to cut down +// on compile time +// + +namespace faiss { namespace gpu { + +// This is warp divergence central, but this is really a final step +// and happening a small number of times +//inline __device__ int binarySearchForBucket(int* prefixSumOffsets, +// int size, +// int val) { +// int start = 0; +// int end = size; +// +// while (end - start > 0) { +// int mid = start + (end - start) / 2; +// +// int midVal = prefixSumOffsets[mid]; +// +// // Find the first bucket that we are <= +// if (midVal <= val) { +// start = mid + 1; +// } else { +// end = mid; +// } +// } +// +// // We must find the bucket that it is in +// assert(start != size); +// +// return start; +//} + +template +__global__ void +pass2SelectLists(Tensor heapDistances, + Tensor heapIndices, + void** listIndices, + Tensor prefixSumOffsets, + Tensor topQueryToCentroid, + int k, + IndicesOptions opt, + Tensor outDistances, + Tensor outIndices) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ float smemK[kNumWarps * NumWarpQ]; + __shared__ int smemV[kNumWarps * NumWarpQ]; + + constexpr auto kInit = Dir ? kFloatMin : kFloatMax; + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(kInit, -1, smemK, smemV, k); + + auto queryId = blockIdx.x; + int num = heapDistances.getSize(1); + int limit = utils::roundDown(num, kWarpSize); + + int i = threadIdx.x; + auto heapDistanceStart = heapDistances[queryId]; + + // BlockSelect add cannot be used in a warp divergent circumstance; we + // handle the remainder warp below + for (; i < limit; i += blockDim.x) { + heap.add(heapDistanceStart[i], i); + } + + // Handle warp divergence separately + if (i < num) { + heap.addThreadQ(heapDistanceStart[i], i); + } + + // Merge all final results + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[queryId][i] = smemK[i]; + + // `v` is the index in `heapIndices` + // We need to translate this into an original user index. The + // reason why we don't maintain intermediate results in terms of + // user indices is to substantially reduce temporary memory + // requirements and global memory write traffic for the list + // scanning. + // This code is highly divergent, but it's probably ok, since this + // is the very last step and it is happening a small number of + // times (#queries x k). + int v = smemV[i]; + long index = -1; + + if (v != -1) { + // `offset` is the offset of the intermediate result, as + // calculated by the original scan. + int offset = heapIndices[queryId][v]; + + index = getListIndex(queryId, + offset, + listIndices, + prefixSumOffsets, + topQueryToCentroid, + opt); + } + + outIndices[queryId][i] = index; + } +} + +void +runPass2SelectLists(Tensor& heapDistances, + Tensor& heapIndices, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + Tensor& prefixSumOffsets, + Tensor& topQueryToCentroid, + int k, + bool chooseLargest, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream) { + auto grid = dim3(topQueryToCentroid.getSize(0)); + +#define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \ + do { \ + pass2SelectLists \ + <<>>(heapDistances, \ + heapIndices, \ + listIndices.data().get(), \ + prefixSumOffsets, \ + topQueryToCentroid, \ + k, \ + indicesOptions, \ + outDistances, \ + outIndices); \ + CUDA_TEST_ERROR(); \ + return; /* success */ \ + } while (0) + +#if GPU_MAX_SELECTION_K >= 2048 + + // block size 128 for k <= 1024, 64 for k = 2048 +#define RUN_PASS_DIR(DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(128, 1, 1, DIR); \ + } else if (k <= 32) { \ + RUN_PASS(128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(128, 1024, 8, DIR); \ + } else if (k <= 2048) { \ + RUN_PASS(64, 2048, 8, DIR); \ + } \ + } while (0) + +#else + +#define RUN_PASS_DIR(DIR) \ + do { \ + if (k == 1) { \ + RUN_PASS(128, 1, 1, DIR); \ + } else if (k <= 32) { \ + RUN_PASS(128, 32, 2, DIR); \ + } else if (k <= 64) { \ + RUN_PASS(128, 64, 3, DIR); \ + } else if (k <= 128) { \ + RUN_PASS(128, 128, 3, DIR); \ + } else if (k <= 256) { \ + RUN_PASS(128, 256, 4, DIR); \ + } else if (k <= 512) { \ + RUN_PASS(128, 512, 8, DIR); \ + } else if (k <= 1024) { \ + RUN_PASS(128, 1024, 8, DIR); \ + } \ + } while (0) + +#endif // GPU_MAX_SELECTION_K + + if (chooseLargest) { + RUN_PASS_DIR(true); + } else { + RUN_PASS_DIR(false); + } + + // unimplemented / too many resources + FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k); + +#undef RUN_PASS_DIR +#undef RUN_PASS +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu new file mode 100644 index 0000000000..bdf812524e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu @@ -0,0 +1,331 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Input: (batch x dim) +// Output: (batch norm) +// Done under the presumption that the dimension size is not too large +// (<10k or so), since there wouldn't be enough parallelism applying a +// single block to the problem. Also that each vector is large enough +// (>64), since a single block works on multiple rows' norms at the +// same time. +// T: the type we are doing the math in (e.g., float, half) +// TVec: the potentially vectorized type we are loading in (e.g., +// float4, half2) +template +__global__ void +l2NormRowMajor(Tensor input, + Tensor output) { + extern __shared__ char smemByte[]; // #warps * RowTileSize elements + float* smem = (float*) smemByte; + + IndexType numWarps = utils::divUp(blockDim.x, kWarpSize); + IndexType laneId = getLaneId(); + IndexType warpId = threadIdx.x / kWarpSize; + + bool lastRowTile = (blockIdx.x == (gridDim.x - 1)); + IndexType rowStart = RowTileSize * blockIdx.x; + // accumulate in f32 + float rowNorm[RowTileSize]; + + if (lastRowTile) { + // We are handling the very end of the input matrix rows + for (IndexType row = 0; row < input.getSize(0) - rowStart; ++row) { + if (NormLoop) { + rowNorm[0] = 0; + + for (IndexType col = threadIdx.x; + col < input.getSize(1); col += blockDim.x) { + TVec val = input[rowStart + row][col]; + val = Math::mul(val, val); + rowNorm[0] = rowNorm[0] + Math::reduceAdd(val); + } + } else { + TVec val = input[rowStart + row][threadIdx.x]; + val = Math::mul(val, val); + rowNorm[0] = Math::reduceAdd(val); + } + + rowNorm[0] = warpReduceAllSum(rowNorm[0]); + if (laneId == 0) { + smem[row * numWarps + warpId] = rowNorm[0]; + } + } + } else { + // We are guaranteed that all RowTileSize rows are available in + // [rowStart, rowStart + RowTileSize) + + if (NormLoop) { + // A single block of threads is not big enough to span each + // vector + TVec tmp[RowTileSize]; + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = 0; + } + + for (IndexType col = threadIdx.x; + col < input.getSize(1); col += blockDim.x) { +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + tmp[row] = input[rowStart + row][col]; + } + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + tmp[row] = Math::mul(tmp[row], tmp[row]); + } + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = rowNorm[row] + + Math::reduceAdd(tmp[row]); + } + } + } else { + TVec tmp[RowTileSize]; + + // A block of threads is the exact size of the vector +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + tmp[row] = input[rowStart + row][threadIdx.x]; + } + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + tmp[row] = Math::mul(tmp[row], tmp[row]); + } + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = Math::reduceAdd(tmp[row]); + } + } + + // Sum up all parts in each warp +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = warpReduceAllSum(rowNorm[row]); + } + + if (laneId == 0) { +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + smem[row * numWarps + warpId] = rowNorm[row]; + } + } + } + + __syncthreads(); + + // Sum across warps + if (warpId == 0) { +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = laneId < numWarps ? smem[row * numWarps + laneId] : 0; + } + +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + rowNorm[row] = warpReduceAllSum(rowNorm[row]); + } + + // Write out answer + if (laneId == 0) { +#pragma unroll + for (int row = 0; row < RowTileSize; ++row) { + int outCol = rowStart + row; + + if (lastRowTile) { + if (outCol < output.getSize(0)) { + output[outCol] = + NormSquared ? ConvertTo::to(rowNorm[row]) : + sqrtf(ConvertTo::to(rowNorm[row])); + } + } else { + output[outCol] = + NormSquared ? ConvertTo::to(rowNorm[row]) : + sqrtf(ConvertTo::to(rowNorm[row])); + } + } + } + } +} + +// Input: (dim x batch) +// Output: (batch norm) +// Handles the case where `input` is column major. A single thread calculates +// the norm of each vector instead of a block-wide reduction. +template +__global__ void +l2NormColMajor(Tensor input, + Tensor output) { + // grid-stride loop to handle all batch elements + for (IndexType batch = blockIdx.x * blockDim.x + threadIdx.x; + batch < input.getSize(1); + batch += gridDim.x * blockDim.x) { + float sum = 0; + + // This is still a coalesced load from the memory + for (IndexType dim = 0; dim < input.getSize(0); ++dim) { + // Just do the math in float32, even if the input is float16 + float v = ConvertTo::to(input[dim][batch]); + sum += v * v; + } + + if (!NormSquared) { + sum = sqrtf(sum); + } + + output[batch] = ConvertTo::to(sum); + } +} + +template +void runL2Norm(Tensor& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream) { + IndexType maxThreads = (IndexType) getMaxThreadsCurrentDevice(); + constexpr int rowTileSize = 8; + +#define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \ + do { \ + if (normLoop) { \ + if (normSquared) { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } else { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } \ + } else { \ + if (normSquared) { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } else { \ + l2NormRowMajor \ + <<>>(INPUT, output); \ + } \ + } \ + } while (0) + + if (inputRowMajor) { + // + // Row-major kernel + /// + + if (input.template canCastResize()) { + // Can load using the vectorized type + auto inputV = input.template castResize(); + + auto dim = inputV.getSize(1); + bool normLoop = dim > maxThreads; + auto numThreads = min(dim, maxThreads); + + auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize)); + auto block = dim3(numThreads); + + auto smem = sizeof(float) * rowTileSize * utils::divUp(numThreads, kWarpSize); + + RUN_L2_ROW_MAJOR(T, TVec, inputV); + } else { + // Can't load using the vectorized type + + auto dim = input.getSize(1); + bool normLoop = dim > maxThreads; + auto numThreads = min(dim, maxThreads); + + auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize)); + auto block = dim3(numThreads); + + auto smem = sizeof(float) * rowTileSize * utils::divUp(numThreads, kWarpSize); + + RUN_L2_ROW_MAJOR(T, T, input); + } + } else { + // + // Column-major kernel + // + + // Just use a fixed-sized block, since the kernel threads are fully + // independent + auto block = 128; + + // Cap the grid size at 2^16 since there is a grid-stride loop to handle + // processing everything + auto grid = (int) + std::min(utils::divUp(input.getSize(1), (IndexType) block), + (IndexType) 65536); + + if (normSquared) { + l2NormColMajor<<>>( + input, output); + } else { + l2NormColMajor<<>>( + input, output); + } + } + +#undef RUN_L2 + + CUDA_TEST_ERROR(); +} + +void runL2Norm(Tensor& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream) { + if (input.canUseIndexType()) { + runL2Norm( + input, inputRowMajor, output, normSquared, stream); + } else { + auto inputCast = input.castIndexType(); + auto outputCast = output.castIndexType(); + + runL2Norm( + inputCast, inputRowMajor, outputCast, normSquared, stream); + } +} + +#ifdef FAISS_USE_FLOAT16 +void runL2Norm(Tensor& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream) { + if (input.canUseIndexType()) { + runL2Norm( + input, inputRowMajor, output, normSquared, stream); + } else { + auto inputCast = input.castIndexType(); + auto outputCast = output.castIndexType(); + + runL2Norm( + inputCast, inputRowMajor, outputCast, normSquared, stream); + } +} +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh new file mode 100644 index 0000000000..6df3dcea58 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh @@ -0,0 +1,29 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +void runL2Norm(Tensor& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runL2Norm(Tensor& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cu b/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cu new file mode 100644 index 0000000000..9ea70ec651 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cu @@ -0,0 +1,264 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// L2 + select kernel for k == 1, implements re-use of ||c||^2 +template +__global__ void l2SelectMin1(Tensor productDistances, + Tensor centroidDistances, + Tensor bitset, + Tensor outDistances, + Tensor outIndices) { + // Each block handles kRowsPerBlock rows of the distances (results) + Pair threadMin[kRowsPerBlock]; + __shared__ Pair blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)]; + + T distance[kRowsPerBlock]; + +#pragma unroll + for (int i = 0; i < kRowsPerBlock; ++i) { + threadMin[i].k = Limits::getMax(); + threadMin[i].v = -1; + } + + // blockIdx.x: which chunk of rows we are responsible for updating + int rowStart = blockIdx.x * kRowsPerBlock; + + // FIXME: if we have exact multiples, don't need this + bool endRow = (blockIdx.x == gridDim.x - 1); + + bool bitsetEmpty = (bitset.getSize(0) == 0); + + if (endRow) { + if (productDistances.getSize(0) % kRowsPerBlock == 0) { + endRow = false; + } + } + + if (endRow) { + for (int row = rowStart; row < productDistances.getSize(0); ++row) { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + if (bitsetEmpty || (!(bitset[col >> 3] & (0x1 << (col & 0x7))))) { + distance[0] = Math::add(centroidDistances[col], + productDistances[row][col]); + } else { + distance[0] = (T)(1.0 / 0.0); + } + + if (Math::lt(distance[0], threadMin[0].k)) { + threadMin[0].k = distance[0]; + threadMin[0].v = col; + } + } + + // Reduce within the block + threadMin[0] = + blockReduceAll, Min>, false, false>( + threadMin[0], Min>(), blockMin); + + if (threadIdx.x == 0) { + outDistances[row][0] = threadMin[0].k; + outIndices[row][0] = threadMin[0].v; + } + + // so we can use the shared memory again + __syncthreads(); + + threadMin[0].k = Limits::getMax(); + threadMin[0].v = -1; + } + } else { + for (int col = threadIdx.x; col < productDistances.getSize(1); + col += blockDim.x) { + T centroidDistance = centroidDistances[col]; + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + distance[row] = productDistances[rowStart + row][col]; + } + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + distance[row] = Math::add(distance[row], centroidDistance); + } + +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + if (Math::lt(distance[row], threadMin[row].k)) { + threadMin[row].k = distance[row]; + threadMin[row].v = col; + } + } + } + + // Reduce within the block + blockReduceAll, + Min >, + false, + false>(threadMin, + Min >(), + blockMin); + + if (threadIdx.x == 0) { +#pragma unroll + for (int row = 0; row < kRowsPerBlock; ++row) { + outDistances[rowStart + row][0] = threadMin[row].k; + outIndices[rowStart + row][0] = threadMin[row].v; + } + } + } +} + +// With bitset included +// L2 + select kernel for k > 1, no re-use of ||c||^2 +template +__global__ void l2SelectMinK(Tensor productDistances, + Tensor centroidDistances, + Tensor bitset, + Tensor outDistances, + Tensor outIndices, + int k, T initK) { + // Each block handles a single row of the distances (results) + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ T smemK[kNumWarps * NumWarpQ]; + __shared__ int smemV[kNumWarps * NumWarpQ]; + + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, -1, smemK, smemV, k); + + int row = blockIdx.x; + + // Whole warps must participate in the selection + int limit = utils::roundDown(productDistances.getSize(1), kWarpSize); + int i = threadIdx.x; + + bool bitsetEmpty = (bitset.getSize(0) == 0); + T v; + + for (; i < limit; i += blockDim.x) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + v = Math::add(centroidDistances[i], + productDistances[row][i]); + heap.addThreadQ(v, i); + } + heap.checkThreadQ(); + } + + if (i < productDistances.getSize(1)) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + v = Math::add(centroidDistances[i], + productDistances[row][i]); + heap.addThreadQ(v, i); + } + } + + heap.reduce(); + for (int i = threadIdx.x; i < k; i += blockDim.x) { + outDistances[row][i] = smemK[i]; + outIndices[row][i] = smemV[i]; + } +} + +template +void runL2SelectMin(Tensor& productDistances, + Tensor& centroidDistances, + Tensor& bitset, + Tensor& outDistances, + Tensor& outIndices, + int k, + cudaStream_t stream) { + FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0)); + FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0)); + FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1)); + FAISS_ASSERT(outDistances.getSize(1) == k); + FAISS_ASSERT(outIndices.getSize(1) == k); + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + if (k == 1) { + constexpr int kThreadsPerBlock = 256; + constexpr int kRowsPerBlock = 8; + + auto block = dim3(kThreadsPerBlock); + auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock)); + + l2SelectMin1 + <<>>(productDistances, centroidDistances, bitset, + outDistances, outIndices); + } else { + auto grid = dim3(outDistances.getSize(0)); + +#define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \ + do { \ + l2SelectMinK \ + <<>>(productDistances, centroidDistances, bitset, \ + outDistances, outIndices, \ + k, Limits::getMax()); \ + } while (0) + + // block size 128 for everything <= 1024 + if (k <= 32) { + RUN_L2_SELECT(128, 32, 2); + } else if (k <= 64) { + RUN_L2_SELECT(128, 64, 3); + } else if (k <= 128) { + RUN_L2_SELECT(128, 128, 3); + } else if (k <= 256) { + RUN_L2_SELECT(128, 256, 4); + } else if (k <= 512) { + RUN_L2_SELECT(128, 512, 8); + } else if (k <= 1024) { + RUN_L2_SELECT(128, 1024, 8); + +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + // smaller block for less shared memory + RUN_L2_SELECT(64, 2048, 8); +#endif + + } else { + FAISS_ASSERT(false); + } + } + + CUDA_TEST_ERROR(); +} + +void runL2SelectMin(Tensor& productDistances, + Tensor& centroidDistances, + Tensor& bitset, + Tensor& outDistances, + Tensor& outIndices, + int k, + cudaStream_t stream) { + runL2SelectMin(productDistances, + centroidDistances, + bitset, + outDistances, + outIndices, + k, + stream); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cuh b/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cuh new file mode 100644 index 0000000000..b29552d786 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Select.cuh @@ -0,0 +1,23 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +void runL2SelectMin(Tensor& productDistances, + Tensor& centroidDistances, + Tensor& bitset, + Tensor& outDistances, + Tensor& outIndices, + int k, + cudaStream_t stream); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/Metrics.cuh b/core/src/index/thirdparty/faiss/gpu/impl/Metrics.cuh new file mode 100644 index 0000000000..5b9feac3ee --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/Metrics.cuh @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace faiss { namespace gpu { + +/// List of supported metrics +inline bool isMetricSupported(MetricType mt) { + switch (mt) { + case MetricType::METRIC_INNER_PRODUCT: + case MetricType::METRIC_L2: + return true; + default: + return false; + } +} + +/// Sort direction per each metric +inline bool metricToSortDirection(MetricType mt) { + switch (mt) { + case MetricType::METRIC_INNER_PRODUCT: + // highest + return true; + case MetricType::METRIC_L2: + // lowest + return false; + default: + // unhandled metric + FAISS_ASSERT(false); + return false; + } +} + +struct L2Metric { + static inline __device__ float distance(float a, float b) { + float d = a - b; + return d * d; + } +}; + +struct IPMetric { + static inline __device__ float distance(float a, float b) { + return a * b; + } +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh new file mode 100644 index 0000000000..520a8bcafb --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh @@ -0,0 +1,574 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Kernel responsible for calculating distance from residual vector to +// each product quantizer code centroid +template +__global__ void +__launch_bounds__(288, 4) +pqCodeDistances(Tensor queries, + int queriesPerBlock, + Tensor coarseCentroids, + Tensor pqCentroids, + Tensor topQueryToCentroid, + // (query id)(coarse)(subquantizer)(code) -> dist + Tensor outCodeDistances) { + const auto numSubQuantizers = pqCentroids.getSize(0); + const auto dimsPerSubQuantizer = pqCentroids.getSize(1); + assert(DimsPerSubQuantizer == dimsPerSubQuantizer); + const auto codesPerSubQuantizer = pqCentroids.getSize(2); + + bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer; + int loadingThreadId = threadIdx.x - codesPerSubQuantizer; + + extern __shared__ float smem[]; + + // Each thread calculates a single code + float subQuantizerData[DimsPerSubQuantizer]; + + auto code = threadIdx.x; + auto subQuantizer = blockIdx.y; + + // Each thread will load the pq centroid data for the code that it + // is processing +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer; ++i) { + subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg(); + } + + // Where we store our query vector + float* smemQuery = smem; + + // Where we store our residual vector; this is double buffered so we + // can be loading the next one while processing the current one + float* smemResidual1 = &smemQuery[DimsPerSubQuantizer]; + float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer]; + + // Where we pre-load the coarse centroid IDs + int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer]; + + // Each thread is calculating the distance for a single code, + // performing the reductions locally + + // Handle multiple queries per block + auto startQueryId = blockIdx.x * queriesPerBlock; + auto numQueries = queries.getSize(0) - startQueryId; + if (numQueries > queriesPerBlock) { + numQueries = queriesPerBlock; + } + + for (int query = 0; query < numQueries; ++query) { + auto queryId = startQueryId + query; + + auto querySubQuantizer = + queries[queryId][subQuantizer * DimsPerSubQuantizer].data(); + + // Load current query vector + for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) { + smemQuery[i] = querySubQuantizer[i]; + } + + // Load list of coarse centroids found + for (int i = threadIdx.x; + i < topQueryToCentroid.getSize(1); i += blockDim.x) { + coarseIds[i] = topQueryToCentroid[queryId][i]; + } + + // We need coarseIds below + // FIXME: investigate loading separately, so we don't need this + __syncthreads(); + + // Preload first buffer of residual data + if (isLoadingThread) { + for (int i = loadingThreadId; + i < DimsPerSubQuantizer; + i += blockDim.x - codesPerSubQuantizer) { + auto coarseId = coarseIds[0]; + // In case NaNs were in the original query data + coarseId = coarseId == -1 ? 0 : coarseId; + auto coarseCentroidSubQuantizer = + coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data(); + + if (L2Distance) { + smemResidual1[i] = smemQuery[i] - + ConvertTo::to(coarseCentroidSubQuantizer[i]); + } else { + smemResidual1[i] = + ConvertTo::to(coarseCentroidSubQuantizer[i]); + } + } + } + + // The block walks the list for a single query + for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) { + // Wait for smemResidual1 to be loaded + __syncthreads(); + + if (isLoadingThread) { + // Preload second buffer of residual data + for (int i = loadingThreadId; + i < DimsPerSubQuantizer; + i += blockDim.x - codesPerSubQuantizer) { + // FIXME: try always making this centroid id 0 so we can + // terminate + if (coarse != (topQueryToCentroid.getSize(1) - 1)) { + auto coarseId = coarseIds[coarse + 1]; + // In case NaNs were in the original query data + coarseId = coarseId == -1 ? 0 : coarseId; + + auto coarseCentroidSubQuantizer = + coarseCentroids[coarseId] + [subQuantizer * dimsPerSubQuantizer].data(); + + if (L2Distance) { + smemResidual2[i] = smemQuery[i] - + ConvertTo::to(coarseCentroidSubQuantizer[i]); + } else { + smemResidual2[i] = + ConvertTo::to(coarseCentroidSubQuantizer[i]); + } + } + } + } else { + // These are the processing threads + float dist = 0.0f; + + constexpr int kUnroll = 4; + constexpr int kRemainder = DimsPerSubQuantizer % kUnroll; + constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder; + float vals[kUnroll]; + + // Calculate residual - pqCentroid for each dim that we're + // processing + + // Unrolled loop + if (L2Distance) { +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = smemResidual1[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] -= subQuantizerData[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] *= vals[j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + dist += vals[j]; + } + } + } else { + // Inner product: query slice against the reconstructed sub-quantizer + // for this coarse cell (query o (centroid + subQCentroid)) +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = smemResidual1[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] += subQuantizerData[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] *= smemQuery[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + dist += vals[j]; + } + } + } + + // Remainder loop + if (L2Distance) { +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] = smemResidual1[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] -= subQuantizerData[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] *= vals[j]; + } + } else { + // Inner product + // Inner product: query slice against the reconstructed sub-quantizer + // for this coarse cell (query o (centroid + subQCentroid)) +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] = smemResidual1[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] += subQuantizerData[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] *= smemQuery[kRemainderBase + j]; + } + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + dist += vals[j]; + } + + // We have the distance for our code; write it out + outCodeDistances[queryId][coarse][subQuantizer][code] = + ConvertTo::to(dist); + } // !isLoadingThread + + // Swap residual buffers + float* tmp = smemResidual1; + smemResidual1 = smemResidual2; + smemResidual2 = tmp; + } + } +} + +template +__global__ void +residualVector(Tensor queries, + Tensor coarseCentroids, + Tensor topQueryToCentroid, + int numSubDim, + // output is transposed: + // (sub q)(query id)(centroid id)(sub dim) + Tensor residual) { + // block x is query id + // block y is centroid id + // thread x is dim + auto queryId = blockIdx.x; + auto centroidId = blockIdx.y; + + int realCentroidId = topQueryToCentroid[queryId][centroidId]; + + for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) { + float q = queries[queryId][dim]; + float c = ConvertTo::to(coarseCentroids[realCentroidId][dim]); + + residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = q - c; + } +} + +template +void +runResidualVector(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + Tensor& residual, + cudaStream_t stream) { + auto grid = + dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1)); + auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice())); + + residualVector<<>>( + queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1), + residual); + + CUDA_TEST_ERROR(); +} + +template +void +runPQCodeDistancesMM(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool useFloat16Lookup, + DeviceMemory& mem, + cublasHandle_t handle, + cudaStream_t stream) { + // Calculate (q - c) residual vector + // (sub q)(query id)(centroid id)(sub dim) + DeviceTensor residual( + mem, + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0), + topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}, + stream); + + runResidualVector(pqCentroids, queries, + coarseCentroids, topQueryToCentroid, + residual, stream); + + // Calculate ||q - c||^2 + DeviceTensor residualNorms( + mem, + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1)}, + stream); + + auto residualView2 = residual.view<2>( + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}); + + runL2Norm(residualView2, true, residualNorms, true, stream); + + // Perform a batch MM: + // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} => + // (sub q) x {(q * c)(code)} + auto residualView3 = residual.view<3>( + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}); + + DeviceTensor residualDistance( + mem, + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + pqCentroids.getSize(2)}, + stream); + + runIteratedMatrixMult(residualDistance, false, + residualView3, false, + pqCentroids, false, + -2.0f, 0.0f, + handle, + stream); + + // Sum ||q - c||^2 along rows + auto residualDistanceView2 = residualDistance.view<2>( + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1), + pqCentroids.getSize(2)}); + + runSumAlongRows(residualNorms, residualDistanceView2, false, stream); + + Tensor outCodeDistancesF; + DeviceTensor outCodeDistancesFloatMem; + + if (useFloat16Lookup) { + outCodeDistancesFloatMem = DeviceTensor( + mem, {outCodeDistances.getSize(0), + outCodeDistances.getSize(1), + outCodeDistances.getSize(2), + outCodeDistances.getSize(3)}, + stream); + + outCodeDistancesF = outCodeDistancesFloatMem; + } else { + outCodeDistancesF = outCodeDistances.toTensor(); + } + + // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which + // is where we build our output distances) + auto outCodeDistancesView = outCodeDistancesF.view<3>( + {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + outCodeDistances.getSize(2), + outCodeDistances.getSize(3)}); + + runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream); + + // Calculate code norms per each sub-dim + // (sub q)(sub dim)(code) is pqCentroids + // transpose to (sub q)(code)(sub dim) + DeviceTensor pqCentroidsTranspose( + mem, + {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)}, + stream); + + runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream); + + auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>( + {pqCentroids.getSize(0) * pqCentroids.getSize(2), + pqCentroids.getSize(1)}); + + DeviceTensor pqCentroidsNorm( + mem, + {pqCentroids.getSize(0) * pqCentroids.getSize(2)}, + stream); + + runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream); + + // View output as (q * c)(sub q * code), and add centroid norm to + // each row + auto outDistancesCodeViewCols = outCodeDistancesView.view<2>( + {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + outCodeDistances.getSize(2) * outCodeDistances.getSize(3)}); + + runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream); + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + // Need to convert back + auto outCodeDistancesH = outCodeDistances.toTensor(); + convertTensor(stream, + outCodeDistancesF, + outCodeDistancesH); + } +#endif +} + +template +void +runPQCodeDistances(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool l2Distance, + bool useFloat16Lookup, + cudaStream_t stream) { + const auto numSubQuantizers = pqCentroids.getSize(0); + const auto dimsPerSubQuantizer = pqCentroids.getSize(1); + const auto codesPerSubQuantizer = pqCentroids.getSize(2); + + // FIXME: tune + // Reuse of pq centroid data is based on both # of queries * nprobe, + // and we should really be tiling in both dimensions + constexpr int kQueriesPerBlock = 8; + + auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock), + numSubQuantizers); + + // Reserve one block of threads for double buffering + // FIXME: probably impractical for large # of dims? + auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize); + auto block = dim3(codesPerSubQuantizer + loadingThreads); + + auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + + topQueryToCentroid.getSize(1) * sizeof(int); + +#ifdef FAISS_USE_FLOAT16 +#define RUN_CODE(DIMS, L2) \ + do { \ + if (useFloat16Lookup) { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } else { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } \ + } while (0) +#else +#define RUN_CODE(DIMS, L2) \ + do { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } while (0) +#endif + +#define CODE_L2(DIMS) \ + do { \ + if (l2Distance) { \ + RUN_CODE(DIMS, true); \ + } else { \ + RUN_CODE(DIMS, false); \ + } \ + } while (0) + + switch (dimsPerSubQuantizer) { + case 1: + CODE_L2(1); + break; + case 2: + CODE_L2(2); + break; + case 3: + CODE_L2(3); + break; + case 4: + CODE_L2(4); + break; + case 6: + CODE_L2(6); + break; + case 8: + CODE_L2(8); + break; + case 10: + CODE_L2(10); + break; + case 12: + CODE_L2(12); + break; + case 16: + CODE_L2(16); + break; + case 20: + CODE_L2(20); + break; + case 24: + CODE_L2(24); + break; + case 28: + CODE_L2(28); + break; + case 32: + CODE_L2(32); + break; + // FIXME: larger sizes require too many registers - we need the + // MM implementation working + default: + FAISS_THROW_MSG("Too many dimensions (>32) per subquantizer " + "not currently supported"); + } + +#undef RUN_CODE +#undef CODE_L2 + + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu new file mode 100644 index 0000000000..eec8852310 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu @@ -0,0 +1,589 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +struct Converter { +}; + +#ifdef FAISS_USE_FLOAT16 +template <> +struct Converter { + inline static __device__ half to(float v) { return __float2half(v); } +}; +#endif + +template <> +struct Converter { + inline static __device__ float to(float v) { return v; } +}; + +// Kernel responsible for calculating distance from residual vector to +// each product quantizer code centroid +template +__global__ void +__launch_bounds__(288, 4) +pqCodeDistances(Tensor queries, + int queriesPerBlock, + Tensor coarseCentroids, + Tensor pqCentroids, + Tensor topQueryToCentroid, + // (query id)(coarse)(subquantizer)(code) -> dist + Tensor outCodeDistances) { + const auto numSubQuantizers = pqCentroids.getSize(0); + const auto dimsPerSubQuantizer = pqCentroids.getSize(1); + assert(DimsPerSubQuantizer == dimsPerSubQuantizer); + const auto codesPerSubQuantizer = pqCentroids.getSize(2); + + bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer; + int loadingThreadId = threadIdx.x - codesPerSubQuantizer; + + extern __shared__ float smem[]; + + // Each thread calculates a single code + float subQuantizerData[DimsPerSubQuantizer]; + + auto code = threadIdx.x; + auto subQuantizer = blockIdx.y; + + // Each thread will load the pq centroid data for the code that it + // is processing +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer; ++i) { + subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg(); + } + + // Where we store our query vector + float* smemQuery = smem; + + // Where we store our residual vector; this is double buffered so we + // can be loading the next one while processing the current one + float* smemResidual1 = &smemQuery[DimsPerSubQuantizer]; + float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer]; + + // Where we pre-load the coarse centroid IDs + int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer]; + + // Each thread is calculating the distance for a single code, + // performing the reductions locally + + // Handle multiple queries per block + auto startQueryId = blockIdx.x * queriesPerBlock; + auto numQueries = queries.getSize(0) - startQueryId; + if (numQueries > queriesPerBlock) { + numQueries = queriesPerBlock; + } + + for (int query = 0; query < numQueries; ++query) { + auto queryId = startQueryId + query; + + auto querySubQuantizer = + queries[queryId][subQuantizer * DimsPerSubQuantizer].data(); + + // Load current query vector + for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) { + smemQuery[i] = querySubQuantizer[i]; + } + + // Load list of coarse centroids found + for (int i = threadIdx.x; + i < topQueryToCentroid.getSize(1); i += blockDim.x) { + coarseIds[i] = topQueryToCentroid[queryId][i]; + } + + // We need coarseIds below + // FIXME: investigate loading separately, so we don't need this + __syncthreads(); + + // Preload first buffer of residual data + if (isLoadingThread) { + for (int i = loadingThreadId; + i < DimsPerSubQuantizer; + i += blockDim.x - codesPerSubQuantizer) { + auto coarseId = coarseIds[0]; + // In case NaNs were in the original query data + coarseId = coarseId == -1 ? 0 : coarseId; + auto coarseCentroidSubQuantizer = + coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data(); + + if (L2Distance) { + smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i]; + } else { + smemResidual1[i] = coarseCentroidSubQuantizer[i]; + } + } + } + + // The block walks the list for a single query + for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) { + // Wait for smemResidual1 to be loaded + __syncthreads(); + + if (isLoadingThread) { + // Preload second buffer of residual data + for (int i = loadingThreadId; + i < DimsPerSubQuantizer; + i += blockDim.x - codesPerSubQuantizer) { + // FIXME: try always making this centroid id 0 so we can + // terminate + if (coarse != (topQueryToCentroid.getSize(1) - 1)) { + auto coarseId = coarseIds[coarse + 1]; + // In case NaNs were in the original query data + coarseId = coarseId == -1 ? 0 : coarseId; + + auto coarseCentroidSubQuantizer = + coarseCentroids[coarseId] + [subQuantizer * dimsPerSubQuantizer].data(); + + if (L2Distance) { + smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i]; + } else { + smemResidual2[i] = coarseCentroidSubQuantizer[i]; + } + } + } + } else { + // These are the processing threads + float dist = 0.0f; + + constexpr int kUnroll = 4; + constexpr int kRemainder = DimsPerSubQuantizer % kUnroll; + constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder; + float vals[kUnroll]; + + // Calculate residual - pqCentroid for each dim that we're + // processing + + // Unrolled loop + if (L2Distance) { +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = smemResidual1[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] -= subQuantizerData[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] *= vals[j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + dist += vals[j]; + } + } + } else { + // Inner product: query slice against the reconstructed sub-quantizer + // for this coarse cell (query o (centroid + subQCentroid)) +#pragma unroll + for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = smemResidual1[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] += subQuantizerData[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] *= smemQuery[i * kUnroll + j]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + dist += vals[j]; + } + } + } + + // Remainder loop + if (L2Distance) { +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] = smemResidual1[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] -= subQuantizerData[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] *= vals[j]; + } + } else { + // Inner product + // Inner product: query slice against the reconstructed sub-quantizer + // for this coarse cell (query o (centroid + subQCentroid)) +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] = smemResidual1[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] += subQuantizerData[kRemainderBase + j]; + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + vals[j] *= smemQuery[kRemainderBase + j]; + } + } + +#pragma unroll + for (int j = 0; j < kRemainder; ++j) { + dist += vals[j]; + } + + // We have the distance for our code; write it out + outCodeDistances[queryId][coarse][subQuantizer][code] = + Converter::to(dist); + } // !isLoadingThread + + // Swap residual buffers + float* tmp = smemResidual1; + smemResidual1 = smemResidual2; + smemResidual2 = tmp; + } + } +} + +__global__ void +residualVector(Tensor queries, + Tensor coarseCentroids, + Tensor topQueryToCentroid, + int numSubDim, + // output is transposed: + // (sub q)(query id)(centroid id)(sub dim) + Tensor residual) { + // block x is query id + // block y is centroid id + // thread x is dim + auto queryId = blockIdx.x; + auto centroidId = blockIdx.y; + + int realCentroidId = topQueryToCentroid[queryId][centroidId]; + + for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) { + float q = queries[queryId][dim]; + float c = coarseCentroids[realCentroidId][dim]; + + residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = + q - c; + } +} + +void +runResidualVector(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + Tensor& residual, + cudaStream_t stream) { + auto grid = + dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1)); + auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice())); + + residualVector<<>>( + queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1), + residual); + + CUDA_TEST_ERROR(); +} + +void +runPQCodeDistancesMM(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool useFloat16Lookup, + DeviceMemory& mem, + cublasHandle_t handle, + cudaStream_t stream) { + // Calculate (q - c) residual vector + // (sub q)(query id)(centroid id)(sub dim) + DeviceTensor residual( + mem, + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0), + topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}, + stream); + + runResidualVector(pqCentroids, queries, + coarseCentroids, topQueryToCentroid, + residual, stream); + + // Calculate ||q - c||^2 + DeviceTensor residualNorms( + mem, + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1)}, + stream); + + auto residualView2 = residual.view<2>( + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}); + + runL2Norm(residualView2, true, residualNorms, true, stream); + + // Perform a batch MM: + // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} => + // (sub q) x {(q * c)(code)} + auto residualView3 = residual.view<3>( + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + pqCentroids.getSize(1)}); + + DeviceTensor residualDistance( + mem, + {pqCentroids.getSize(0), + topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + pqCentroids.getSize(2)}, + stream); + + runIteratedMatrixMult(residualDistance, false, + residualView3, false, + pqCentroids, false, + -2.0f, 0.0f, + handle, + stream); + + // Sum ||q - c||^2 along rows + auto residualDistanceView2 = residualDistance.view<2>( + {pqCentroids.getSize(0) * + topQueryToCentroid.getSize(0) * + topQueryToCentroid.getSize(1), + pqCentroids.getSize(2)}); + + runSumAlongRows(residualNorms, residualDistanceView2, false, stream); + + Tensor outCodeDistancesF; + DeviceTensor outCodeDistancesFloatMem; + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + outCodeDistancesFloatMem = DeviceTensor( + mem, {outCodeDistances.getSize(0), + outCodeDistances.getSize(1), + outCodeDistances.getSize(2), + outCodeDistances.getSize(3)}, + stream); + + outCodeDistancesF = outCodeDistancesFloatMem; + } else { + outCodeDistancesF = outCodeDistances.toTensor(); + } +#else + outCodeDistancesF = outCodeDistances.toTensor(); +#endif + + // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which + // is where we build our output distances) + auto outCodeDistancesView = outCodeDistancesF.view<3>( + {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + outCodeDistances.getSize(2), + outCodeDistances.getSize(3)}); + + runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream); + + // Calculate code norms per each sub-dim + // (sub q)(sub dim)(code) is pqCentroids + // transpose to (sub q)(code)(sub dim) + DeviceTensor pqCentroidsTranspose( + mem, + {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)}, + stream); + + runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream); + + auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>( + {pqCentroids.getSize(0) * pqCentroids.getSize(2), + pqCentroids.getSize(1)}); + + DeviceTensor pqCentroidsNorm( + mem, + {pqCentroids.getSize(0) * pqCentroids.getSize(2)}, + stream); + + runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream); + + // View output as (q * c)(sub q * code), and add centroid norm to + // each row + auto outDistancesCodeViewCols = outCodeDistancesView.view<2>( + {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1), + outCodeDistances.getSize(2) * outCodeDistances.getSize(3)}); + + runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream); + +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + // Need to convert back + auto outCodeDistancesH = outCodeDistances.toTensor(); + convertTensor(stream, + outCodeDistancesF, + outCodeDistancesH); + } +#endif +} + +void +runPQCodeDistances(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool l2Distance, + bool useFloat16Lookup, + cudaStream_t stream) { + const auto numSubQuantizers = pqCentroids.getSize(0); + const auto dimsPerSubQuantizer = pqCentroids.getSize(1); + const auto codesPerSubQuantizer = pqCentroids.getSize(2); + + // FIXME: tune + // Reuse of pq centroid data is based on both # of queries * nprobe, + // and we should really be tiling in both dimensions + constexpr int kQueriesPerBlock = 8; + + auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock), + numSubQuantizers); + + // Reserve one block of threads for double buffering + // FIXME: probably impractical for large # of dims? + auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize); + auto block = dim3(codesPerSubQuantizer + loadingThreads); + + auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + + topQueryToCentroid.getSize(1) * sizeof(int); + +#ifdef FAISS_USE_FLOAT16 +#define RUN_CODE(DIMS, L2) \ + do { \ + if (useFloat16Lookup) { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } else { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } \ + } while (0) +#else +#define RUN_CODE(DIMS, L2) \ + do { \ + if(!useFloat16Lookup){ \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } \ + } while (0) +#endif + +#define CODE_L2(DIMS) \ + do { \ + if (l2Distance) { \ + RUN_CODE(DIMS, true); \ + } else { \ + RUN_CODE(DIMS, false); \ + } \ + } while (0) + + switch (dimsPerSubQuantizer) { + case 1: + CODE_L2(1); + break; + case 2: + CODE_L2(2); + break; + case 3: + CODE_L2(3); + break; + case 4: + CODE_L2(4); + break; + case 6: + CODE_L2(6); + break; + case 8: + CODE_L2(8); + break; + case 10: + CODE_L2(10); + break; + case 12: + CODE_L2(12); + break; + case 16: + CODE_L2(16); + break; + case 20: + CODE_L2(20); + break; + case 24: + CODE_L2(24); + break; + case 28: + CODE_L2(28); + break; + case 32: + CODE_L2(32); + break; + // FIXME: larger sizes require too many registers - we need the + // MM implementation working + default: + FAISS_THROW_MSG("Too many dimensions (>32) per subquantizer " + "not currently supported"); + } + +#undef RUN_CODE +#undef CODE_L2 + + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cuh new file mode 100644 index 0000000000..0add947f2c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cuh @@ -0,0 +1,46 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +class DeviceMemory; + +/// pqCentroids is of the form (sub q)(sub dim)(code id) +/// Calculates the distance from the (query - centroid) residual to +/// each sub-code vector, for the given list of query results in +/// topQueryToCentroid +template +void runPQCodeDistances(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool l2Distance, + bool useFloat16Lookup, + cudaStream_t stream); + +template +void runPQCodeDistancesMM(Tensor& pqCentroids, + Tensor& queries, + Tensor& coarseCentroids, + Tensor& topQueryToCentroid, + NoTypeTensor<4, true>& outCodeDistances, + bool useFloat16Lookup, + DeviceMemory& mem, + cublasHandle_t handle, + cudaStream_t stream); + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeLoad.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeLoad.cuh new file mode 100644 index 0000000000..da933b1d00 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeLoad.cuh @@ -0,0 +1,357 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +#if __CUDA_ARCH__ >= 350 +// Use the CC 3.5+ read-only texture cache (nc) +#define LD_NC_V1 "ld.global.cs.nc.u32" +#define LD_NC_V2 "ld.global.cs.nc.v2.u32" +#define LD_NC_V4 "ld.global.cs.nc.v4.u32" +#else +// Read normally +#define LD_NC_V1 "ld.global.cs.u32" +#define LD_NC_V2 "ld.global.cs.v2.u32" +#define LD_NC_V4 "ld.global.cs.v4.u32" +#endif // __CUDA_ARCH__ + +/// +/// This file contains loader functions for PQ codes of various byte +/// length. +/// + +// Type-specific wrappers around the PTX bfe.* instruction, for +// quantization code extraction +inline __device__ unsigned int getByte(unsigned char v, + int pos, + int width) { + return v; +} + +inline __device__ unsigned int getByte(unsigned short v, + int pos, + int width) { + return getBitfield((unsigned int) v, pos, width); +} + +inline __device__ unsigned int getByte(unsigned int v, + int pos, + int width) { + return getBitfield(v, pos, width); +} + +inline __device__ unsigned int getByte(unsigned long v, + int pos, + int width) { + return getBitfield(v, pos, width); +} + +template +struct LoadCode32 {}; + +template<> +struct LoadCode32<1> { + static inline __device__ void load(unsigned int code32[1], + unsigned char* p, + int offset) { + p += offset * 1; + asm("ld.global.cs.u8 {%0}, [%1];" : + "=r"(code32[0]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<2> { + static inline __device__ void load(unsigned int code32[1], + unsigned char* p, + int offset) { + p += offset * 2; + asm("ld.global.cs.u16 {%0}, [%1];" : + "=r"(code32[0]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<3> { + static inline __device__ void load(unsigned int code32[1], + unsigned char* p, + int offset) { + p += offset * 3; + unsigned int a; + unsigned int b; + unsigned int c; + + // FIXME: this is a non-coalesced, unaligned, non-vectorized load + // unfortunately need to reorganize memory layout by warp + asm("ld.global.cs.u8 {%0}, [%1 + 0];" : + "=r"(a) : "l"(p)); + asm("ld.global.cs.u8 {%0}, [%1 + 1];" : + "=r"(b) : "l"(p)); + asm("ld.global.cs.u8 {%0}, [%1 + 2];" : + "=r"(c) : "l"(p)); + + // FIXME: this is also slow, since we have to recover the + // individual bytes loaded + code32[0] = (c << 16) | (b << 8) | a; + } +}; + +template<> +struct LoadCode32<4> { + static inline __device__ void load(unsigned int code32[1], + unsigned char* p, + int offset) { + p += offset * 4; + asm("ld.global.cs.u32 {%0}, [%1];" : + "=r"(code32[0]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<8> { + static inline __device__ void load(unsigned int code32[2], + unsigned char* p, + int offset) { + p += offset * 8; + asm("ld.global.cs.v2.u32 {%0, %1}, [%2];" : + "=r"(code32[0]), "=r"(code32[1]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<12> { + static inline __device__ void load(unsigned int code32[3], + unsigned char* p, + int offset) { + p += offset * 12; + // FIXME: this is a non-coalesced, unaligned, non-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V1 " {%0}, [%1 + 0];" : + "=r"(code32[0]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 4];" : + "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 8];" : + "=r"(code32[2]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<16> { + static inline __device__ void load(unsigned int code32[4], + unsigned char* p, + int offset) { + p += offset * 16; + asm("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];" : + "=r"(code32[0]), "=r"(code32[1]), + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<20> { + static inline __device__ void load(unsigned int code32[5], + unsigned char* p, + int offset) { + p += offset * 20; + // FIXME: this is a non-coalesced, unaligned, non-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V1 " {%0}, [%1 + 0];" : + "=r"(code32[0]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 4];" : + "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 8];" : + "=r"(code32[2]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 12];" : + "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 16];" : + "=r"(code32[4]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<24> { + static inline __device__ void load(unsigned int code32[6], + unsigned char* p, + int offset) { + p += offset * 24; + // FIXME: this is a non-coalesced, unaligned, 2-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V2 " {%0, %1}, [%2 + 0];" : + "=r"(code32[0]), "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 8];" : + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 16];" : + "=r"(code32[4]), "=r"(code32[5]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<28> { + static inline __device__ void load(unsigned int code32[7], + unsigned char* p, + int offset) { + p += offset * 28; + // FIXME: this is a non-coalesced, unaligned, non-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V1 " {%0}, [%1 + 0];" : + "=r"(code32[0]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 4];" : + "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 8];" : + "=r"(code32[2]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 12];" : + "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 16];" : + "=r"(code32[4]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 20];" : + "=r"(code32[5]) : "l"(p)); + asm(LD_NC_V1 " {%0}, [%1 + 24];" : + "=r"(code32[6]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<32> { + static inline __device__ void load(unsigned int code32[8], + unsigned char* p, + int offset) { + p += offset * 32; + // FIXME: this is a non-coalesced load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4];" : + "=r"(code32[0]), "=r"(code32[1]), + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 16];" : + "=r"(code32[4]), "=r"(code32[5]), + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<40> { + static inline __device__ void load(unsigned int code32[10], + unsigned char* p, + int offset) { + p += offset * 40; + // FIXME: this is a non-coalesced, unaligned, 2-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V2 " {%0, %1}, [%2 + 0];" : + "=r"(code32[0]), "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 8];" : + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 16];" : + "=r"(code32[4]), "=r"(code32[5]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 24];" : + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 32];" : + "=r"(code32[8]), "=r"(code32[9]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<48> { + static inline __device__ void load(unsigned int code32[12], + unsigned char* p, + int offset) { + p += offset * 48; + // FIXME: this is a non-coalesced load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4];" : + "=r"(code32[0]), "=r"(code32[1]), + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 16];" : + "=r"(code32[4]), "=r"(code32[5]), + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 32];" : + "=r"(code32[8]), "=r"(code32[9]), + "=r"(code32[10]), "=r"(code32[11]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<56> { + static inline __device__ void load(unsigned int code32[14], + unsigned char* p, + int offset) { + p += offset * 56; + // FIXME: this is a non-coalesced, unaligned, 2-vectorized load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V2 " {%0, %1}, [%2 + 0];" : + "=r"(code32[0]), "=r"(code32[1]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 8];" : + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 16];" : + "=r"(code32[4]), "=r"(code32[5]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 24];" : + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 32];" : + "=r"(code32[8]), "=r"(code32[9]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 40];" : + "=r"(code32[10]), "=r"(code32[11]) : "l"(p)); + asm(LD_NC_V2 " {%0, %1}, [%2 + 48];" : + "=r"(code32[12]), "=r"(code32[13]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<64> { + static inline __device__ void load(unsigned int code32[16], + unsigned char* p, + int offset) { + p += offset * 64; + // FIXME: this is a non-coalesced load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4];" : + "=r"(code32[0]), "=r"(code32[1]), + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 16];" : + "=r"(code32[4]), "=r"(code32[5]), + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 32];" : + "=r"(code32[8]), "=r"(code32[9]), + "=r"(code32[10]), "=r"(code32[11]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 48];" : + "=r"(code32[12]), "=r"(code32[13]), + "=r"(code32[14]), "=r"(code32[15]) : "l"(p)); + } +}; + +template<> +struct LoadCode32<96> { + static inline __device__ void load(unsigned int code32[24], + unsigned char* p, + int offset) { + p += offset * 96; + // FIXME: this is a non-coalesced load + // unfortunately need to reorganize memory layout by warp + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4];" : + "=r"(code32[0]), "=r"(code32[1]), + "=r"(code32[2]), "=r"(code32[3]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 16];" : + "=r"(code32[4]), "=r"(code32[5]), + "=r"(code32[6]), "=r"(code32[7]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 32];" : + "=r"(code32[8]), "=r"(code32[9]), + "=r"(code32[10]), "=r"(code32[11]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 48];" : + "=r"(code32[12]), "=r"(code32[13]), + "=r"(code32[14]), "=r"(code32[15]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 64];" : + "=r"(code32[16]), "=r"(code32[17]), + "=r"(code32[18]), "=r"(code32[19]) : "l"(p)); + asm(LD_NC_V4 " {%0, %1, %2, %3}, [%4 + 80];" : + "=r"(code32[20]), "=r"(code32[21]), + "=r"(code32[22]), "=r"(code32[23]) : "l"(p)); + } +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh new file mode 100644 index 0000000000..a77e783d09 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh @@ -0,0 +1,623 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace faiss { namespace gpu { + +// This must be kept in sync with PQCodeDistances.cu +inline bool isSupportedNoPrecomputedSubDimSize(int dims) { + switch (dims) { + case 1: + case 2: + case 3: + case 4: + case 6: + case 8: + case 10: + case 12: + case 16: + case 20: + case 24: + case 28: + case 32: + return true; + default: + // FIXME: larger sizes require too many registers - we need the + // MM implementation working + return false; + } +} + +template +struct LoadCodeDistances { + static inline __device__ void load(LookupT* smem, + LookupT* codes, + int numCodes) { + constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT); + + // We can only use the vector type if the data is guaranteed to be + // aligned. The codes are innermost, so if it is evenly divisible, + // then any slice will be aligned. + if (numCodes % kWordSize == 0) { + // Load the data by float4 for efficiency, and then handle any remainder + // limitVec is the number of whole vec words we can load, in terms + // of whole blocks performing the load + constexpr int kUnroll = 2; + int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x); + limitVec *= kUnroll * blockDim.x; + + LookupVecT* smemV = (LookupVecT*) smem; + LookupVecT* codesV = (LookupVecT*) codes; + + for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) { + LookupVecT vals[kUnroll]; + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = + LoadStore::load(&codesV[i + j * blockDim.x]); + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + LoadStore::store(&smemV[i + j * blockDim.x], vals[j]); + } + } + + // This is where we start loading the remainder that does not evenly + // fit into kUnroll x blockDim.x + int remainder = limitVec * kWordSize; + + for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) { + smem[i] = codes[i]; + } + } else { + // Potential unaligned load + constexpr int kUnroll = 4; + + int limit = utils::roundDown(numCodes, kUnroll * blockDim.x); + + int i = threadIdx.x; + for (; i < limit; i += kUnroll * blockDim.x) { + LookupT vals[kUnroll]; + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = codes[i + j * blockDim.x]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + smem[i + j * blockDim.x] = vals[j]; + } + } + + for (; i < numCodes; i += blockDim.x) { + smem[i] = codes[i]; + } + } + } +}; + +template +__global__ void +pqScanNoPrecomputedMultiPass(Tensor queries, + Tensor pqCentroids, + Tensor topQueryToCentroid, + Tensor codeDistances, + void** listCodes, + int* listLengths, + Tensor prefixSumOffsets, + Tensor distance) { + const auto codesPerSubQuantizer = pqCentroids.getSize(2); + + // Where the pq code -> residual distance is stored + extern __shared__ char smemCodeDistances[]; + LookupT* codeDist = (LookupT*) smemCodeDistances; + + // Each block handles a single query + auto queryId = blockIdx.y; + auto probeId = blockIdx.x; + + // This is where we start writing out data + // We ensure that before the array (at offset -1), there is a 0 value + int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1); + float* distanceOut = distance[outBase].data(); + + auto listId = topQueryToCentroid[queryId][probeId]; + // Safety guard in case NaNs in input cause no list ID to be generated + if (listId == -1) { + return; + } + + unsigned char* codeList = (unsigned char*) listCodes[listId]; + int limit = listLengths[listId]; + + constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 : + (NumSubQuantizers / 4); + unsigned int code32[kNumCode32]; + unsigned int nextCode32[kNumCode32]; + + // We double-buffer the code loading, which improves memory utilization + if (threadIdx.x < limit) { + LoadCode32::load(code32, codeList, threadIdx.x); + } + + LoadCodeDistances::load( + codeDist, + codeDistances[queryId][probeId].data(), + codeDistances.getSize(2) * codeDistances.getSize(3)); + + // Prevent WAR dependencies + __syncthreads(); + + // Each thread handles one code element in the list, with a + // block-wide stride + for (int codeIndex = threadIdx.x; + codeIndex < limit; + codeIndex += blockDim.x) { + // Prefetch next codes + if (codeIndex + blockDim.x < limit) { + LoadCode32::load( + nextCode32, codeList, codeIndex + blockDim.x); + } + + float dist = 0.0f; + +#pragma unroll + for (int word = 0; word < kNumCode32; ++word) { + constexpr int kBytesPerCode32 = + NumSubQuantizers < 4 ? NumSubQuantizers : 4; + + if (kBytesPerCode32 == 1) { + auto code = code32[0]; + dist = ConvertTo::to(codeDist[code]); + + } else { +#pragma unroll + for (int byte = 0; byte < kBytesPerCode32; ++byte) { + auto code = getByte(code32[word], byte * 8, 8); + + auto offset = + codesPerSubQuantizer * (word * kBytesPerCode32 + byte); + + dist += ConvertTo::to(codeDist[offset + code]); + } + } + } + + // Write out intermediate distance result + // We do not maintain indices here, in order to reduce global + // memory traffic. Those are recovered in the final selection step. + distanceOut[codeIndex] = dist; + + // Rotate buffers +#pragma unroll + for (int word = 0; word < kNumCode32; ++word) { + code32[word] = nextCode32[word]; + } + } +} + +template +void +runMultiPassTile(Tensor& queries, + Tensor& centroids, + Tensor& pqCentroidsInnermostCode, + NoTypeTensor<4, true>& codeDistances, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + Tensor& thrustMem, + Tensor& prefixSumOffsets, + Tensor& allDistances, + Tensor& heapDistances, + Tensor& heapIndices, + int k, + faiss::MetricType metric, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream) { + // We only support two metrics at the moment + FAISS_ASSERT(metric == MetricType::METRIC_INNER_PRODUCT || + metric == MetricType::METRIC_L2); + + bool l2Distance = metric == MetricType::METRIC_L2; + + // Calculate offset lengths, so we know where to write out + // intermediate results + runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets, + thrustMem, stream); + + // Calculate residual code distances, since this is without + // precomputed codes + runPQCodeDistances(pqCentroidsInnermostCode, + queries, + centroids, + topQueryToCentroid, + codeDistances, + l2Distance, + useFloat16Lookup, + stream); + + // Convert all codes to a distance, and write out (distance, + // index) values for all intermediate results + { + auto kThreadsPerBlock = 256; + + auto grid = dim3(topQueryToCentroid.getSize(1), + topQueryToCentroid.getSize(0)); + auto block = dim3(kThreadsPerBlock); + + // pq centroid distances + +#ifdef FAISS_USE_FLOAT16 + auto smem = (sizeof(float)== useFloat16Lookup) ? sizeof(half) : sizeof(float); +#else + auto smem = sizeof(float); +#endif + + smem *= numSubQuantizers * numSubQuantizerCodes; + FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); + +#define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \ + do { \ + auto codeDistancesT = codeDistances.toTensor(); \ + \ + pqScanNoPrecomputedMultiPass \ + <<>>( \ + queries, \ + pqCentroidsInnermostCode, \ + topQueryToCentroid, \ + codeDistancesT, \ + listCodes.data().get(), \ + listLengths.data().get(), \ + prefixSumOffsets, \ + allDistances); \ + } while (0) + +#ifdef FAISS_USE_FLOAT16 +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + if (useFloat16Lookup) { \ + RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \ + } else { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } \ + } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif + + switch (bytesPerCode) { + case 1: + RUN_PQ(1); + break; + case 2: + RUN_PQ(2); + break; + case 3: + RUN_PQ(3); + break; + case 4: + RUN_PQ(4); + break; + case 8: + RUN_PQ(8); + break; + case 12: + RUN_PQ(12); + break; + case 16: + RUN_PQ(16); + break; + case 20: + RUN_PQ(20); + break; + case 24: + RUN_PQ(24); + break; + case 28: + RUN_PQ(28); + break; + case 32: + RUN_PQ(32); + break; + case 40: + RUN_PQ(40); + break; + case 48: + RUN_PQ(48); + break; + case 56: + RUN_PQ(56); + break; + case 64: + RUN_PQ(64); + break; + case 96: + RUN_PQ(96); + break; + default: + FAISS_ASSERT(false); + break; + } + +#undef RUN_PQ +#undef RUN_PQ_OPT + } + + CUDA_TEST_ERROR(); + + // k-select the output in chunks, to increase parallelism + runPass1SelectLists(listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + bitset, + allDistances, + topQueryToCentroid.getSize(1), + k, + !l2Distance, // L2 distance chooses smallest + heapDistances, + heapIndices, + stream); + + // k-select final output + auto flatHeapDistances = heapDistances.downcastInner<2>(); + auto flatHeapIndices = heapIndices.downcastInner<2>(); + + runPass2SelectLists(flatHeapDistances, + flatHeapIndices, + listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + k, + !l2Distance, // L2 distance chooses smallest + outDistances, + outIndices, + stream); +} + +template +void +runPQScanMultiPassNoPrecomputed(Tensor& queries, + Tensor& centroids, + Tensor& pqCentroidsInnermostCode, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + faiss::MetricType metric, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res) { + constexpr int kMinQueryTileSize = 8; + constexpr int kMaxQueryTileSize = 128; + constexpr int kThrustMemSize = 16384; + + int nprobe = topQueryToCentroid.getSize(1); + + auto& mem = res->getMemoryManagerCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // Make a reservation for Thrust to do its dirty work (global memory + // cross-block reduction space); hopefully this is large enough. + DeviceTensor thrustMem1( + mem, {kThrustMemSize}, stream); + DeviceTensor thrustMem2( + mem, {kThrustMemSize}, stream); + DeviceTensor* thrustMem[2] = + {&thrustMem1, &thrustMem2}; + + // How much temporary storage is available? + // If possible, we'd like to fit within the space available. + size_t sizeAvailable = mem.getSizeAvailable(); + + // We run two passes of heap selection + // This is the size of the first-level heap passes + constexpr int kNProbeSplit = 8; + int pass2Chunks = std::min(nprobe, kNProbeSplit); + + size_t sizeForFirstSelectPass = + pass2Chunks * k * (sizeof(float) + sizeof(int)); + + // How much temporary storage we need per each query + size_t sizePerQuery = + 2 * // streams + ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets + nprobe * maxListLength * sizeof(float) + // allDistances + // residual distances + nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) + + sizeForFirstSelectPass); + + int queryTileSize = (int) (sizeAvailable / sizePerQuery); + + if (queryTileSize < kMinQueryTileSize) { + queryTileSize = kMinQueryTileSize; + } else if (queryTileSize > kMaxQueryTileSize) { + queryTileSize = kMaxQueryTileSize; + } + + // FIXME: we should adjust queryTileSize to deal with this, since + // indexing is in int32 + FAISS_ASSERT(queryTileSize * nprobe * maxListLength < + std::numeric_limits::max()); + + // Temporary memory buffers + // Make sure there is space prior to the start which will be 0, and + // will handle the boundary condition without branches + DeviceTensor prefixSumOffsetSpace1( + mem, {queryTileSize * nprobe + 1}, stream); + DeviceTensor prefixSumOffsetSpace2( + mem, {queryTileSize * nprobe + 1}, stream); + + DeviceTensor prefixSumOffsets1( + prefixSumOffsetSpace1[1].data(), + {queryTileSize, nprobe}); + DeviceTensor prefixSumOffsets2( + prefixSumOffsetSpace2[1].data(), + {queryTileSize, nprobe}); + DeviceTensor* prefixSumOffsets[2] = + {&prefixSumOffsets1, &prefixSumOffsets2}; + + // Make sure the element before prefixSumOffsets is 0, since we + // depend upon simple, boundary-less indexing to get proper results + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(), + 0, + sizeof(int), + stream)); + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(), + 0, + sizeof(int), + stream)); + + int codeDistanceTypeSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + codeDistanceTypeSize = sizeof(half); + } +#endif + + int totalCodeDistancesSize = + queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes * + codeDistanceTypeSize; + + DeviceTensor codeDistances1Mem( + mem, {totalCodeDistancesSize}, stream); + NoTypeTensor<4, true> codeDistances1( + codeDistances1Mem.data(), + codeDistanceTypeSize, + {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes}); + + DeviceTensor codeDistances2Mem( + mem, {totalCodeDistancesSize}, stream); + NoTypeTensor<4, true> codeDistances2( + codeDistances2Mem.data(), + codeDistanceTypeSize, + {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes}); + + NoTypeTensor<4, true>* codeDistances[2] = + {&codeDistances1, &codeDistances2}; + + DeviceTensor allDistances1( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor allDistances2( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor* allDistances[2] = + {&allDistances1, &allDistances2}; + + DeviceTensor heapDistances1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapDistances2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapDistances[2] = + {&heapDistances1, &heapDistances2}; + + DeviceTensor heapIndices1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapIndices2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapIndices[2] = + {&heapIndices1, &heapIndices2}; + + auto streams = res->getAlternateStreamsCurrentDevice(); + streamWait(streams, {stream}); + + int curStream = 0; + + for (int query = 0; query < queries.getSize(0); query += queryTileSize) { + int numQueriesInTile = + std::min(queryTileSize, queries.getSize(0) - query); + + auto prefixSumOffsetsView = + prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile); + + auto codeDistancesView = + codeDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto coarseIndicesView = + topQueryToCentroid.narrowOutermost(query, numQueriesInTile); + auto queryView = + queries.narrowOutermost(query, numQueriesInTile); + + auto heapDistancesView = + heapDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto heapIndicesView = + heapIndices[curStream]->narrowOutermost(0, numQueriesInTile); + + auto outDistanceView = + outDistances.narrowOutermost(query, numQueriesInTile); + auto outIndicesView = + outIndices.narrowOutermost(query, numQueriesInTile); + + runMultiPassTile(queryView, + centroids, + pqCentroidsInnermostCode, + codeDistancesView, + coarseIndicesView, + bitset, + useFloat16Lookup, + bytesPerCode, + numSubQuantizers, + numSubQuantizerCodes, + listCodes, + listIndices, + indicesOptions, + listLengths, + *thrustMem[curStream], + prefixSumOffsetsView, + *allDistances[curStream], + heapDistancesView, + heapIndicesView, + k, + metric, + outDistanceView, + outIndicesView, + streams[curStream]); + + curStream = (curStream + 1) % 2; + } + + streamWait({stream}, streams); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu new file mode 100644 index 0000000000..b4934382cb --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu @@ -0,0 +1,627 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace faiss { namespace gpu { + +//// This must be kept in sync with PQCodeDistances.cu +//bool isSupportedNoPrecomputedSubDimSize(int dims) { +// switch (dims) { +// case 1: +// case 2: +// case 3: +// case 4: +// case 6: +// case 8: +// case 10: +// case 12: +// case 16: +// case 20: +// case 24: +// case 28: +// case 32: +// return true; +// default: +// // FIXME: larger sizes require too many registers - we need the +// // MM implementation working +// return false; +// } +//} +// +//template +//struct LoadCodeDistances { +// static inline __device__ void load(LookupT* smem, +// LookupT* codes, +// int numCodes) { +// constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT); +// +// // We can only use the vector type if the data is guaranteed to be +// // aligned. The codes are innermost, so if it is evenly divisible, +// // then any slice will be aligned. +// if (numCodes % kWordSize == 0) { +// // Load the data by float4 for efficiency, and then handle any remainder +// // limitVec is the number of whole vec words we can load, in terms +// // of whole blocks performing the load +// constexpr int kUnroll = 2; +// int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x); +// limitVec *= kUnroll * blockDim.x; +// +// LookupVecT* smemV = (LookupVecT*) smem; +// LookupVecT* codesV = (LookupVecT*) codes; +// +// for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) { +// LookupVecT vals[kUnroll]; +// +//#pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// vals[j] = +// LoadStore::load(&codesV[i + j * blockDim.x]); +// } +// +//#pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// LoadStore::store(&smemV[i + j * blockDim.x], vals[j]); +// } +// } +// +// // This is where we start loading the remainder that does not evenly +// // fit into kUnroll x blockDim.x +// int remainder = limitVec * kWordSize; +// +// for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) { +// smem[i] = codes[i]; +// } +// } else { +// // Potential unaligned load +// constexpr int kUnroll = 4; +// +// int limit = utils::roundDown(numCodes, kUnroll * blockDim.x); +// +// int i = threadIdx.x; +// for (; i < limit; i += kUnroll * blockDim.x) { +// LookupT vals[kUnroll]; +// +//#pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// vals[j] = codes[i + j * blockDim.x]; +// } +// +//#pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// smem[i + j * blockDim.x] = vals[j]; +// } +// } +// +// for (; i < numCodes; i += blockDim.x) { +// smem[i] = codes[i]; +// } +// } +// } +//}; +// +//template +//__global__ void +//pqScanNoPrecomputedMultiPass(Tensor queries, +// Tensor pqCentroids, +// Tensor topQueryToCentroid, +// Tensor codeDistances, +// void** listCodes, +// int* listLengths, +// Tensor prefixSumOffsets, +// Tensor distance) { +// const auto codesPerSubQuantizer = pqCentroids.getSize(2); +// +// // Where the pq code -> residual distance is stored +// extern __shared__ char smemCodeDistances[]; +// LookupT* codeDist = (LookupT*) smemCodeDistances; +// +// // Each block handles a single query +// auto queryId = blockIdx.y; +// auto probeId = blockIdx.x; +// +// // This is where we start writing out data +// // We ensure that before the array (at offset -1), there is a 0 value +// int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1); +// float* distanceOut = distance[outBase].data(); +// +// auto listId = topQueryToCentroid[queryId][probeId]; +// // Safety guard in case NaNs in input cause no list ID to be generated +// if (listId == -1) { +// return; +// } +// +// unsigned char* codeList = (unsigned char*) listCodes[listId]; +// int limit = listLengths[listId]; +// +// constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 : +// (NumSubQuantizers / 4); +// unsigned int code32[kNumCode32]; +// unsigned int nextCode32[kNumCode32]; +// +// // We double-buffer the code loading, which improves memory utilization +// if (threadIdx.x < limit) { +// LoadCode32::load(code32, codeList, threadIdx.x); +// } +// +// LoadCodeDistances::load( +// codeDist, +// codeDistances[queryId][probeId].data(), +// codeDistances.getSize(2) * codeDistances.getSize(3)); +// +// // Prevent WAR dependencies +// __syncthreads(); +// +// // Each thread handles one code element in the list, with a +// // block-wide stride +// for (int codeIndex = threadIdx.x; +// codeIndex < limit; +// codeIndex += blockDim.x) { +// // Prefetch next codes +// if (codeIndex + blockDim.x < limit) { +// LoadCode32::load( +// nextCode32, codeList, codeIndex + blockDim.x); +// } +// +// float dist = 0.0f; +// +//#pragma unroll +// for (int word = 0; word < kNumCode32; ++word) { +// constexpr int kBytesPerCode32 = +// NumSubQuantizers < 4 ? NumSubQuantizers : 4; +// +// if (kBytesPerCode32 == 1) { +// auto code = code32[0]; +// dist = ConvertTo::to(codeDist[code]); +// +// } else { +//#pragma unroll +// for (int byte = 0; byte < kBytesPerCode32; ++byte) { +// auto code = getByte(code32[word], byte * 8, 8); +// +// auto offset = +// codesPerSubQuantizer * (word * kBytesPerCode32 + byte); +// +// dist += ConvertTo::to(codeDist[offset + code]); +// } +// } +// } +// +// // Write out intermediate distance result +// // We do not maintain indices here, in order to reduce global +// // memory traffic. Those are recovered in the final selection step. +// distanceOut[codeIndex] = dist; +// +// // Rotate buffers +//#pragma unroll +// for (int word = 0; word < kNumCode32; ++word) { +// code32[word] = nextCode32[word]; +// } +// } +//} + +void +runMultiPassTile(Tensor& queries, + Tensor& centroids, + Tensor& pqCentroidsInnermostCode, + NoTypeTensor<4, true>& codeDistances, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + Tensor& thrustMem, + Tensor& prefixSumOffsets, + Tensor& allDistances, + Tensor& heapDistances, + Tensor& heapIndices, + int k, + faiss::MetricType metric, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream) { + // We only support two metrics at the moment + FAISS_ASSERT(metric == MetricType::METRIC_INNER_PRODUCT || + metric == MetricType::METRIC_L2); + + bool l2Distance = metric == MetricType::METRIC_L2; + // Calculate offset lengths, so we know where to write out +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16Lookup); +#endif + // Calculate offset lengths, so we know where to write out + // intermediate results + runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets, + thrustMem, stream); + + // Calculate residual code distances, since this is without + // precomputed codes + runPQCodeDistances(pqCentroidsInnermostCode, + queries, + centroids, + topQueryToCentroid, + codeDistances, + l2Distance, + useFloat16Lookup, + stream); + + // Convert all codes to a distance, and write out (distance, + // index) values for all intermediate results + { + auto kThreadsPerBlock = 256; + + auto grid = dim3(topQueryToCentroid.getSize(1), + topQueryToCentroid.getSize(0)); + auto block = dim3(kThreadsPerBlock); + + // pq centroid distances + //auto smem = useFloat16Lookup ? sizeof(half) : sizeof(float); + auto smem = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + smem = sizeof(half); + } +#endif + + smem *= numSubQuantizers * numSubQuantizerCodes; + FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); + +#define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \ + do { \ + auto codeDistancesT = codeDistances.toTensor(); \ + \ + pqScanNoPrecomputedMultiPass \ + <<>>( \ + queries, \ + pqCentroidsInnermostCode, \ + topQueryToCentroid, \ + codeDistancesT, \ + listCodes.data().get(), \ + listLengths.data().get(), \ + prefixSumOffsets, \ + allDistances); \ + } while (0) + +#ifdef FAISS_USE_FLOAT16 +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + if (useFloat16Lookup) { \ + RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \ + } else { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } \ + } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif // FAISS_USE_FLOAT16 + + switch (bytesPerCode) { + case 1: + RUN_PQ(1); + break; + case 2: + RUN_PQ(2); + break; + case 3: + RUN_PQ(3); + break; + case 4: + RUN_PQ(4); + break; + case 8: + RUN_PQ(8); + break; + case 12: + RUN_PQ(12); + break; + case 16: + RUN_PQ(16); + break; + case 20: + RUN_PQ(20); + break; + case 24: + RUN_PQ(24); + break; + case 28: + RUN_PQ(28); + break; + case 32: + RUN_PQ(32); + break; + case 40: + RUN_PQ(40); + break; + case 48: + RUN_PQ(48); + break; + case 56: + RUN_PQ(56); + break; + case 64: + RUN_PQ(64); + break; + case 96: + RUN_PQ(96); + break; + default: + FAISS_ASSERT(false); + break; + } + +#undef RUN_PQ +#undef RUN_PQ_OPT + } + + CUDA_TEST_ERROR(); + + // k-select the output in chunks, to increase parallelism + runPass1SelectLists(listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + bitset, + allDistances, + topQueryToCentroid.getSize(1), + k, + !l2Distance, // L2 distance chooses smallest + heapDistances, + heapIndices, + stream); + + // k-select final output + auto flatHeapDistances = heapDistances.downcastInner<2>(); + auto flatHeapIndices = heapIndices.downcastInner<2>(); + + runPass2SelectLists(flatHeapDistances, + flatHeapIndices, + listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + k, + !l2Distance, // L2 distance chooses smallest + outDistances, + outIndices, + stream); +} + +void runPQScanMultiPassNoPrecomputed(Tensor& queries, + Tensor& centroids, + Tensor& pqCentroidsInnermostCode, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + faiss::MetricType metric, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res) { + constexpr int kMinQueryTileSize = 8; + constexpr int kMaxQueryTileSize = 128; + constexpr int kThrustMemSize = 16384; + + int nprobe = topQueryToCentroid.getSize(1); + + auto& mem = res->getMemoryManagerCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // Make a reservation for Thrust to do its dirty work (global memory + // cross-block reduction space); hopefully this is large enough. + DeviceTensor thrustMem1( + mem, {kThrustMemSize}, stream); + DeviceTensor thrustMem2( + mem, {kThrustMemSize}, stream); + DeviceTensor* thrustMem[2] = + {&thrustMem1, &thrustMem2}; + + // How much temporary storage is available? + // If possible, we'd like to fit within the space available. + size_t sizeAvailable = mem.getSizeAvailable(); + + // We run two passes of heap selection + // This is the size of the first-level heap passes + constexpr int kNProbeSplit = 8; + int pass2Chunks = std::min(nprobe, kNProbeSplit); + + size_t sizeForFirstSelectPass = + pass2Chunks * k * (sizeof(float) + sizeof(int)); + + // How much temporary storage we need per each query + size_t sizePerQuery = + 2 * // streams + ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets + nprobe * maxListLength * sizeof(float) + // allDistances + // residual distances + nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) + + sizeForFirstSelectPass); + + int queryTileSize = (int) (sizeAvailable / sizePerQuery); + + if (queryTileSize < kMinQueryTileSize) { + queryTileSize = kMinQueryTileSize; + } else if (queryTileSize > kMaxQueryTileSize) { + queryTileSize = kMaxQueryTileSize; + } + + // FIXME: we should adjust queryTileSize to deal with this, since + // indexing is in int32 + FAISS_ASSERT(queryTileSize * nprobe * maxListLength < + std::numeric_limits::max()); + + // Temporary memory buffers + // Make sure there is space prior to the start which will be 0, and + // will handle the boundary condition without branches + DeviceTensor prefixSumOffsetSpace1( + mem, {queryTileSize * nprobe + 1}, stream); + DeviceTensor prefixSumOffsetSpace2( + mem, {queryTileSize * nprobe + 1}, stream); + + DeviceTensor prefixSumOffsets1( + prefixSumOffsetSpace1[1].data(), + {queryTileSize, nprobe}); + DeviceTensor prefixSumOffsets2( + prefixSumOffsetSpace2[1].data(), + {queryTileSize, nprobe}); + DeviceTensor* prefixSumOffsets[2] = + {&prefixSumOffsets1, &prefixSumOffsets2}; + + // Make sure the element before prefixSumOffsets is 0, since we + // depend upon simple, boundary-less indexing to get proper results + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(), + 0, + sizeof(int), + stream)); + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(), + 0, + sizeof(int), + stream)); + + int codeDistanceTypeSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + codeDistanceTypeSize = sizeof(half); + } +#else + FAISS_ASSERT(!useFloat16Lookup); +#endif + + int totalCodeDistancesSize = + queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes * + codeDistanceTypeSize; + + DeviceTensor codeDistances1Mem( + mem, {totalCodeDistancesSize}, stream); + NoTypeTensor<4, true> codeDistances1( + codeDistances1Mem.data(), + codeDistanceTypeSize, + {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes}); + + DeviceTensor codeDistances2Mem( + mem, {totalCodeDistancesSize}, stream); + NoTypeTensor<4, true> codeDistances2( + codeDistances2Mem.data(), + codeDistanceTypeSize, + {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes}); + + NoTypeTensor<4, true>* codeDistances[2] = + {&codeDistances1, &codeDistances2}; + + DeviceTensor allDistances1( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor allDistances2( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor* allDistances[2] = + {&allDistances1, &allDistances2}; + + DeviceTensor heapDistances1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapDistances2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapDistances[2] = + {&heapDistances1, &heapDistances2}; + + DeviceTensor heapIndices1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapIndices2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapIndices[2] = + {&heapIndices1, &heapIndices2}; + + auto streams = res->getAlternateStreamsCurrentDevice(); + streamWait(streams, {stream}); + + int curStream = 0; + + for (int query = 0; query < queries.getSize(0); query += queryTileSize) { + int numQueriesInTile = + std::min(queryTileSize, queries.getSize(0) - query); + + auto prefixSumOffsetsView = + prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile); + + auto codeDistancesView = + codeDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto coarseIndicesView = + topQueryToCentroid.narrowOutermost(query, numQueriesInTile); + auto queryView = + queries.narrowOutermost(query, numQueriesInTile); + + auto heapDistancesView = + heapDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto heapIndicesView = + heapIndices[curStream]->narrowOutermost(0, numQueriesInTile); + + auto outDistanceView = + outDistances.narrowOutermost(query, numQueriesInTile); + auto outIndicesView = + outIndices.narrowOutermost(query, numQueriesInTile); + + runMultiPassTile(queryView, + centroids, + pqCentroidsInnermostCode, + codeDistancesView, + coarseIndicesView, + bitset, + useFloat16Lookup, + bytesPerCode, + numSubQuantizers, + numSubQuantizerCodes, + listCodes, + listIndices, + indicesOptions, + listLengths, + *thrustMem[curStream], + prefixSumOffsetsView, + *allDistances[curStream], + heapDistancesView, + heapIndicesView, + k, + metric, + outDistanceView, + outIndicesView, + streams[curStream]); + + curStream = (curStream + 1) % 2; + } + + streamWait({stream}, streams); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cuh new file mode 100644 index 0000000000..d3c0cc53d5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cuh @@ -0,0 +1,49 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +/// For no precomputed codes, is this a supported number of dimensions +/// per subquantizer? +bool isSupportedNoPrecomputedSubDimSize(int dims); + +template +void runPQScanMultiPassNoPrecomputed(Tensor& queries, + Tensor& centroids, + Tensor& pqCentroidsInnermostCode, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + faiss::MetricType metric, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res); + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu new file mode 100644 index 0000000000..02e65ff32a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu @@ -0,0 +1,573 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// For precomputed codes, this calculates and loads code distances +// into smem +template +inline __device__ void +loadPrecomputedTerm(LookupT* smem, + LookupT* term2Start, + LookupT* term3Start, + int numCodes) { + constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT); + + // We can only use vector loads if the data is guaranteed to be + // aligned. The codes are innermost, so if it is evenly divisible, + // then any slice will be aligned. + if (numCodes % kWordSize == 0) { + constexpr int kUnroll = 2; + + // Load the data by float4 for efficiency, and then handle any remainder + // limitVec is the number of whole vec words we can load, in terms + // of whole blocks performing the load + int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x); + limitVec *= kUnroll * blockDim.x; + + LookupVecT* smemV = (LookupVecT*) smem; + LookupVecT* term2StartV = (LookupVecT*) term2Start; + LookupVecT* term3StartV = (LookupVecT*) term3Start; + + for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) { + LookupVecT vals[kUnroll]; + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = + LoadStore::load(&term2StartV[i + j * blockDim.x]); + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + LookupVecT q = + LoadStore::load(&term3StartV[i + j * blockDim.x]); + + vals[j] = Math::add(vals[j], q); + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + LoadStore::store(&smemV[i + j * blockDim.x], vals[j]); + } + } + + // This is where we start loading the remainder that does not evenly + // fit into kUnroll x blockDim.x + int remainder = limitVec * kWordSize; + + for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) { + smem[i] = Math::add(term2Start[i], term3Start[i]); + } + } else { + // Potential unaligned load + constexpr int kUnroll = 4; + + int limit = utils::roundDown(numCodes, kUnroll * blockDim.x); + + int i = threadIdx.x; + for (; i < limit; i += kUnroll * blockDim.x) { + LookupT vals[kUnroll]; + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = term2Start[i + j * blockDim.x]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + vals[j] = Math::add(vals[j], term3Start[i + j * blockDim.x]); + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + smem[i + j * blockDim.x] = vals[j]; + } + } + + for (; i < numCodes; i += blockDim.x) { + smem[i] = Math::add(term2Start[i], term3Start[i]); + } + } +} + +template +__global__ void +pqScanPrecomputedMultiPass(Tensor queries, + Tensor precompTerm1, + Tensor precompTerm2, + Tensor precompTerm3, + Tensor topQueryToCentroid, + void** listCodes, + int* listLengths, + Tensor prefixSumOffsets, + Tensor distance) { + // precomputed term 2 + 3 storage + // (sub q)(code id) + extern __shared__ char smemTerm23[]; + LookupT* term23 = (LookupT*) smemTerm23; + + // Each block handles a single query + auto queryId = blockIdx.y; + auto probeId = blockIdx.x; + auto codesPerSubQuantizer = precompTerm2.getSize(2); + auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer; + + // This is where we start writing out data + // We ensure that before the array (at offset -1), there is a 0 value + int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1); + float* distanceOut = distance[outBase].data(); + + auto listId = topQueryToCentroid[queryId][probeId]; + // Safety guard in case NaNs in input cause no list ID to be generated + if (listId == -1) { + return; + } + + unsigned char* codeList = (unsigned char*) listCodes[listId]; + int limit = listLengths[listId]; + + constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 : + (NumSubQuantizers / 4); + unsigned int code32[kNumCode32]; + unsigned int nextCode32[kNumCode32]; + + // We double-buffer the code loading, which improves memory utilization + if (threadIdx.x < limit) { + LoadCode32::load(code32, codeList, threadIdx.x); + } + + // Load precomputed terms 1, 2, 3 + float term1 = precompTerm1[queryId][probeId]; + loadPrecomputedTerm(term23, + precompTerm2[listId].data(), + precompTerm3[queryId].data(), + precompTermSize); + + // Prevent WAR dependencies + __syncthreads(); + + // Each thread handles one code element in the list, with a + // block-wide stride + for (int codeIndex = threadIdx.x; + codeIndex < limit; + codeIndex += blockDim.x) { + // Prefetch next codes + if (codeIndex + blockDim.x < limit) { + LoadCode32::load( + nextCode32, codeList, codeIndex + blockDim.x); + } + + float dist = term1; + +#pragma unroll + for (int word = 0; word < kNumCode32; ++word) { + constexpr int kBytesPerCode32 = + NumSubQuantizers < 4 ? NumSubQuantizers : 4; + + if (kBytesPerCode32 == 1) { + auto code = code32[0]; + dist = ConvertTo::to(term23[code]); + + } else { +#pragma unroll + for (int byte = 0; byte < kBytesPerCode32; ++byte) { + auto code = getByte(code32[word], byte * 8, 8); + + auto offset = + codesPerSubQuantizer * (word * kBytesPerCode32 + byte); + + dist += ConvertTo::to(term23[offset + code]); + } + } + } + + // Write out intermediate distance result + // We do not maintain indices here, in order to reduce global + // memory traffic. Those are recovered in the final selection step. + distanceOut[codeIndex] = dist; + + // Rotate buffers +#pragma unroll + for (int word = 0; word < kNumCode32; ++word) { + code32[word] = nextCode32[word]; + } + } +} + +void +runMultiPassTile(Tensor& queries, + Tensor& precompTerm1, + NoTypeTensor<3, true>& precompTerm2, + NoTypeTensor<3, true>& precompTerm3, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + Tensor& thrustMem, + Tensor& prefixSumOffsets, + Tensor& allDistances, + Tensor& heapDistances, + Tensor& heapIndices, + int k, + Tensor& outDistances, + Tensor& outIndices, + cudaStream_t stream) { + // Calculate offset lengths, so we know where to write out + // intermediate results + runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets, + thrustMem, stream); + + // Convert all codes to a distance, and write out (distance, + // index) values for all intermediate results + { + auto kThreadsPerBlock = 256; + + auto grid = dim3(topQueryToCentroid.getSize(1), + topQueryToCentroid.getSize(0)); + auto block = dim3(kThreadsPerBlock); + + // pq precomputed terms (2 + 3) + auto smem = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + smem = sizeof(half); + } +#endif + + smem *= numSubQuantizers * numSubQuantizerCodes; + FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); + +#define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \ + do { \ + auto precompTerm2T = precompTerm2.toTensor(); \ + auto precompTerm3T = precompTerm3.toTensor(); \ + \ + pqScanPrecomputedMultiPass \ + <<>>( \ + queries, \ + precompTerm1, \ + precompTerm2T, \ + precompTerm3T, \ + topQueryToCentroid, \ + listCodes.data().get(), \ + listLengths.data().get(), \ + prefixSumOffsets, \ + allDistances); \ + } while (0) + +#ifdef FAISS_USE_FLOAT16 +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + if (useFloat16Lookup) { \ + RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \ + } else { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } \ + } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif + + switch (bytesPerCode) { + case 1: + RUN_PQ(1); + break; + case 2: + RUN_PQ(2); + break; + case 3: + RUN_PQ(3); + break; + case 4: + RUN_PQ(4); + break; + case 8: + RUN_PQ(8); + break; + case 12: + RUN_PQ(12); + break; + case 16: + RUN_PQ(16); + break; + case 20: + RUN_PQ(20); + break; + case 24: + RUN_PQ(24); + break; + case 28: + RUN_PQ(28); + break; + case 32: + RUN_PQ(32); + break; + case 40: + RUN_PQ(40); + break; + case 48: + RUN_PQ(48); + break; + case 56: + RUN_PQ(56); + break; + case 64: + RUN_PQ(64); + break; + case 96: + RUN_PQ(96); + break; + default: + FAISS_ASSERT(false); + break; + } + + CUDA_TEST_ERROR(); + +#undef RUN_PQ +#undef RUN_PQ_OPT + } + + // k-select the output in chunks, to increase parallelism + runPass1SelectLists(listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + bitset, + allDistances, + topQueryToCentroid.getSize(1), + k, + false, // L2 distance chooses smallest + heapDistances, + heapIndices, + stream); + + // k-select final output + auto flatHeapDistances = heapDistances.downcastInner<2>(); + auto flatHeapIndices = heapIndices.downcastInner<2>(); + + runPass2SelectLists(flatHeapDistances, + flatHeapIndices, + listIndices, + indicesOptions, + prefixSumOffsets, + topQueryToCentroid, + k, + false, // L2 distance chooses smallest + outDistances, + outIndices, + stream); + + CUDA_TEST_ERROR(); +} + +void runPQScanMultiPassPrecomputed(Tensor& queries, + Tensor& precompTerm1, + NoTypeTensor<3, true>& precompTerm2, + NoTypeTensor<3, true>& precompTerm3, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res) { + constexpr int kMinQueryTileSize = 8; + constexpr int kMaxQueryTileSize = 128; + constexpr int kThrustMemSize = 16384; + + int nprobe = topQueryToCentroid.getSize(1); + + auto& mem = res->getMemoryManagerCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + // Make a reservation for Thrust to do its dirty work (global memory + // cross-block reduction space); hopefully this is large enough. + DeviceTensor thrustMem1( + mem, {kThrustMemSize}, stream); + DeviceTensor thrustMem2( + mem, {kThrustMemSize}, stream); + DeviceTensor* thrustMem[2] = + {&thrustMem1, &thrustMem2}; + + // How much temporary storage is available? + // If possible, we'd like to fit within the space available. + size_t sizeAvailable = mem.getSizeAvailable(); + + // We run two passes of heap selection + // This is the size of the first-level heap passes + constexpr int kNProbeSplit = 8; + int pass2Chunks = std::min(nprobe, kNProbeSplit); + + size_t sizeForFirstSelectPass = + pass2Chunks * k * (sizeof(float) + sizeof(int)); + + // How much temporary storage we need per each query + size_t sizePerQuery = + 2 * // # streams + ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets + nprobe * maxListLength * sizeof(float) + // allDistances + sizeForFirstSelectPass); + + int queryTileSize = (int) (sizeAvailable / sizePerQuery); + + if (queryTileSize < kMinQueryTileSize) { + queryTileSize = kMinQueryTileSize; + } else if (queryTileSize > kMaxQueryTileSize) { + queryTileSize = kMaxQueryTileSize; + } + + // FIXME: we should adjust queryTileSize to deal with this, since + // indexing is in int32 + FAISS_ASSERT(queryTileSize * nprobe * maxListLength <= + std::numeric_limits::max()); + + // Temporary memory buffers + // Make sure there is space prior to the start which will be 0, and + // will handle the boundary condition without branches + DeviceTensor prefixSumOffsetSpace1( + mem, {queryTileSize * nprobe + 1}, stream); + DeviceTensor prefixSumOffsetSpace2( + mem, {queryTileSize * nprobe + 1}, stream); + + DeviceTensor prefixSumOffsets1( + prefixSumOffsetSpace1[1].data(), + {queryTileSize, nprobe}); + DeviceTensor prefixSumOffsets2( + prefixSumOffsetSpace2[1].data(), + {queryTileSize, nprobe}); + DeviceTensor* prefixSumOffsets[2] = + {&prefixSumOffsets1, &prefixSumOffsets2}; + + // Make sure the element before prefixSumOffsets is 0, since we + // depend upon simple, boundary-less indexing to get proper results + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(), + 0, + sizeof(int), + stream)); + CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(), + 0, + sizeof(int), + stream)); + + DeviceTensor allDistances1( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor allDistances2( + mem, {queryTileSize * nprobe * maxListLength}, stream); + DeviceTensor* allDistances[2] = + {&allDistances1, &allDistances2}; + + DeviceTensor heapDistances1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapDistances2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapDistances[2] = + {&heapDistances1, &heapDistances2}; + + DeviceTensor heapIndices1( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor heapIndices2( + mem, {queryTileSize, pass2Chunks, k}, stream); + DeviceTensor* heapIndices[2] = + {&heapIndices1, &heapIndices2}; + + auto streams = res->getAlternateStreamsCurrentDevice(); + streamWait(streams, {stream}); + + int curStream = 0; + + for (int query = 0; query < queries.getSize(0); query += queryTileSize) { + int numQueriesInTile = + std::min(queryTileSize, queries.getSize(0) - query); + + auto prefixSumOffsetsView = + prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile); + + auto coarseIndicesView = + topQueryToCentroid.narrowOutermost(query, numQueriesInTile); + auto queryView = + queries.narrowOutermost(query, numQueriesInTile); + auto term1View = + precompTerm1.narrowOutermost(query, numQueriesInTile); + auto term3View = + precompTerm3.narrowOutermost(query, numQueriesInTile); + + auto heapDistancesView = + heapDistances[curStream]->narrowOutermost(0, numQueriesInTile); + auto heapIndicesView = + heapIndices[curStream]->narrowOutermost(0, numQueriesInTile); + + auto outDistanceView = + outDistances.narrowOutermost(query, numQueriesInTile); + auto outIndicesView = + outIndices.narrowOutermost(query, numQueriesInTile); + + runMultiPassTile(queryView, + term1View, + precompTerm2, + term3View, + coarseIndicesView, + bitset, + useFloat16Lookup, + bytesPerCode, + numSubQuantizers, + numSubQuantizerCodes, + listCodes, + listIndices, + indicesOptions, + listLengths, + *thrustMem[curStream], + prefixSumOffsetsView, + *allDistances[curStream], + heapDistancesView, + heapIndicesView, + k, + outDistanceView, + outIndicesView, + streams[curStream]); + + curStream = (curStream + 1) % 2; + } + + streamWait({stream}, streams); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cuh new file mode 100644 index 0000000000..644ba7d99d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cuh @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class GpuResources; + +void runPQScanMultiPassPrecomputed(Tensor& queries, + Tensor& precompTerm1, + NoTypeTensor<3, true>& precompTerm2, + NoTypeTensor<3, true>& precompTerm3, + Tensor& topQueryToCentroid, + Tensor& bitset, + bool useFloat16Lookup, + int bytesPerCode, + int numSubQuantizers, + int numSubQuantizerCodes, + thrust::device_vector& listCodes, + thrust::device_vector& listIndices, + IndicesOptions indicesOptions, + thrust::device_vector& listLengths, + int maxListLength, + int k, + // output + Tensor& outDistances, + // output + Tensor& outIndices, + GpuResources* res); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.cpp b/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.cpp new file mode 100644 index 0000000000..a3df65c91c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.cpp @@ -0,0 +1,43 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +namespace faiss { namespace gpu { + +// Utility function to translate (list id, offset) to a user index on +// the CPU. In a cpp in order to use OpenMP +void ivfOffsetToUserIndex( + long* indices, + int numLists, + int queries, + int k, + const std::vector>& listOffsetToUserIndex) { + FAISS_ASSERT(numLists == listOffsetToUserIndex.size()); + +#pragma omp parallel for + for (int q = 0; q < queries; ++q) { + for (int r = 0; r < k; ++r) { + long offsetIndex = indices[q * k + r]; + + if (offsetIndex < 0) continue; + + int listId = (int) (offsetIndex >> 32); + int listOffset = (int) (offsetIndex & 0xffffffff); + + FAISS_ASSERT(listId < numLists); + auto& listIndices = listOffsetToUserIndex[listId]; + + FAISS_ASSERT(listOffset < listIndices.size()); + indices[q * k + r] = listIndices[listOffset]; + } + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.h b/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.h new file mode 100644 index 0000000000..234148451f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/RemapIndices.h @@ -0,0 +1,24 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +/// Utility function to translate (list id, offset) to a user index on +/// the CPU. In a cpp in order to use OpenMP. +void ivfOffsetToUserIndex( + long* indices, + int numLists, + int queries, + int k, + const std::vector>& listOffsetToUserIndex); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu new file mode 100644 index 0000000000..980b3c3979 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu @@ -0,0 +1,148 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include // in CUDA SDK, for CUDART_NAN_F + +namespace faiss { namespace gpu { + +template +__global__ void calcResidual(Tensor vecs, + Tensor centroids, + Tensor vecToCentroid, + Tensor residuals) { + auto vec = vecs[blockIdx.x]; + auto residual = residuals[blockIdx.x]; + + int centroidId = vecToCentroid[blockIdx.x]; + // Vector could be invalid (containing NaNs), so -1 was the + // classified centroid + if (centroidId == -1) { + if (LargeDim) { + for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) { + residual[i] = CUDART_NAN_F; + } + } else { + residual[threadIdx.x] = CUDART_NAN_F; + } + + return; + } + + auto centroid = centroids[centroidId]; + + if (LargeDim) { + for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) { + residual[i] = vec[i] - ConvertTo::to(centroid[i]); + } + } else { + residual[threadIdx.x] = vec[threadIdx.x] - + ConvertTo::to(centroid[threadIdx.x]); + } +} + +template +__global__ void gatherReconstruct(Tensor listIds, + Tensor vecs, + Tensor out) { + auto id = listIds[blockIdx.x]; + auto vec = vecs[id]; + auto outVec = out[blockIdx.x]; + + Convert conv; + + for (int i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) { + outVec[i] = id == -1 ? 0.0f : conv(vec[i]); + } +} + +template +void calcResidual(Tensor& vecs, + Tensor& centroids, + Tensor& vecToCentroid, + Tensor& residuals, + cudaStream_t stream) { + FAISS_ASSERT(vecs.getSize(1) == centroids.getSize(1)); + FAISS_ASSERT(vecs.getSize(1) == residuals.getSize(1)); + FAISS_ASSERT(vecs.getSize(0) == vecToCentroid.getSize(0)); + FAISS_ASSERT(vecs.getSize(0) == residuals.getSize(0)); + + dim3 grid(vecs.getSize(0)); + + int maxThreads = getMaxThreadsCurrentDevice(); + bool largeDim = vecs.getSize(1) > maxThreads; + dim3 block(std::min(vecs.getSize(1), maxThreads)); + + if (largeDim) { + calcResidual<<>>( + vecs, centroids, vecToCentroid, residuals); + } else { + calcResidual<<>>( + vecs, centroids, vecToCentroid, residuals); + } + + CUDA_TEST_ERROR(); +} + +template +void gatherReconstruct(Tensor& listIds, + Tensor& vecs, + Tensor& out, + cudaStream_t stream) { + FAISS_ASSERT(listIds.getSize(0) == out.getSize(0)); + FAISS_ASSERT(vecs.getSize(1) == out.getSize(1)); + + dim3 grid(listIds.getSize(0)); + + int maxThreads = getMaxThreadsCurrentDevice(); + dim3 block(std::min(vecs.getSize(1), maxThreads)); + + gatherReconstruct<<>>(listIds, vecs, out); + + CUDA_TEST_ERROR(); +} + +void runCalcResidual(Tensor& vecs, + Tensor& centroids, + Tensor& vecToCentroid, + Tensor& residuals, + cudaStream_t stream) { + calcResidual(vecs, centroids, vecToCentroid, residuals, stream); +} + +#ifdef FAISS_USE_FLOAT16 +void runCalcResidual(Tensor& vecs, + Tensor& centroids, + Tensor& vecToCentroid, + Tensor& residuals, + cudaStream_t stream) { + calcResidual(vecs, centroids, vecToCentroid, residuals, stream); +} +#endif + +void runReconstruct(Tensor& listIds, + Tensor& vecs, + Tensor& out, + cudaStream_t stream) { + gatherReconstruct(listIds, vecs, out, stream); +} + +#ifdef FAISS_USE_FLOAT16 +void runReconstruct(Tensor& listIds, + Tensor& vecs, + Tensor& out, + cudaStream_t stream) { + gatherReconstruct(listIds, vecs, out, stream); +} +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh new file mode 100644 index 0000000000..8e8cd2e756 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +// Calculates residual v_i - c_j for all v_i in vecs where j = vecToCentroid[i] +void runCalcResidual(Tensor& vecs, + Tensor& centroids, + Tensor& vecToCentroid, + Tensor& residuals, + cudaStream_t stream); + +void runCalcResidual(Tensor& vecs, + Tensor& centroids, + Tensor& vecToCentroid, + Tensor& residuals, + cudaStream_t stream); + +// Gather vectors +void runReconstruct(Tensor& listIds, + Tensor& vecs, + Tensor& out, + cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runReconstruct(Tensor& listIds, + Tensor& vecs, + Tensor& out, + cudaStream_t stream); +# endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper-inl.h b/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper-inl.h new file mode 100644 index 0000000000..90eb629509 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper-inl.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include + +namespace faiss { namespace gpu { + +template +IndexWrapper::IndexWrapper( + int numGpus, + std::function(GpuResources*, int)> init) { + FAISS_ASSERT(numGpus <= faiss::gpu::getNumDevices()); + for (int i = 0; i < numGpus; ++i) { + auto res = std::unique_ptr( + new StandardGpuResources); + + subIndex.emplace_back(init(res.get(), i)); + resources.emplace_back(std::move(res)); + } + + if (numGpus > 1) { + // create proxy + replicaIndex = + std::unique_ptr(new faiss::IndexReplicas); + + for (auto& index : subIndex) { + replicaIndex->addIndex(index.get()); + } + } +} + +template +faiss::Index* +IndexWrapper::getIndex() { + if ((bool) replicaIndex) { + return replicaIndex.get(); + } else { + FAISS_ASSERT(!subIndex.empty()); + return subIndex.front().get(); + } +} + +template +void +IndexWrapper::runOnIndices(std::function f) { + + if ((bool) replicaIndex) { + replicaIndex->runOnIndex( + [f](int, faiss::Index* index) { + f(dynamic_cast(index)); + }); + } else { + FAISS_ASSERT(!subIndex.empty()); + f(subIndex.front().get()); + } +} + +template +void +IndexWrapper::setNumProbes(int nprobe) { + runOnIndices([nprobe](GpuIndex* index) { + index->setNumProbes(nprobe); + }); +} + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper.h b/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper.h new file mode 100644 index 0000000000..df36255a26 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/IndexWrapper.h @@ -0,0 +1,39 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// If we want to run multi-GPU, create a proxy to wrap the indices. +// If we don't want multi-GPU, don't involve the proxy, so it doesn't +// affect the timings. +template +struct IndexWrapper { + std::vector> resources; + std::vector> subIndex; + std::unique_ptr replicaIndex; + + IndexWrapper( + int numGpus, + std::function(GpuResources*, int)> init); + faiss::Index* getIndex(); + + void runOnIndices(std::function f); + void setNumProbes(int nprobe); +}; + +} } + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfBinaryFlat.cu b/core/src/index/thirdparty/faiss/gpu/perf/PerfBinaryFlat.cu new file mode 100644 index 0000000000..3e921c50da --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfBinaryFlat.cu @@ -0,0 +1,125 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +DEFINE_int32(k, 3, "final number of closest results returned"); +DEFINE_int32(num, 128, "# of vecs"); +DEFINE_int32(dim, 128, "# of dimensions"); +DEFINE_int32(num_queries, 3, "number of query vectors"); +DEFINE_int64(seed, -1, "specify random seed"); +DEFINE_int64(pinned_mem, 0, "pinned memory allocation to use"); +DEFINE_bool(cpu, true, "run the CPU code for timing and comparison"); +DEFINE_bool(use_unified_mem, false, "use Pascal unified memory for the index"); + +using namespace faiss::gpu; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + cudaProfilerStop(); + + auto seed = FLAGS_seed != -1L ? FLAGS_seed : time(nullptr); + printf("using seed %ld\n", seed); + + auto numQueries = FLAGS_num_queries; + + auto index = std::unique_ptr( + new faiss::IndexBinaryFlat(FLAGS_dim)); + + HostTensor vecs({FLAGS_num, FLAGS_dim / 8}); + faiss::byte_rand(vecs.data(), vecs.numElements(), seed); + + index->add(FLAGS_num, vecs.data()); + + printf("Database: dim %d num vecs %d\n", FLAGS_dim, FLAGS_num); + printf("Hamming lookup: %d queries, total k %d\n", + numQueries, FLAGS_k); + + // Convert to GPU index + printf("Copying index to GPU...\n"); + + GpuIndexBinaryFlatConfig config; + config.memorySpace = FLAGS_use_unified_mem ? + MemorySpace::Unified : MemorySpace::Device; + + faiss::gpu::StandardGpuResources res; + + faiss::gpu::GpuIndexBinaryFlat gpuIndex(&res, + index.get(), + config); + printf("copy done\n"); + + // Build query vectors + HostTensor cpuQuery({numQueries, FLAGS_dim / 8}); + faiss::byte_rand(cpuQuery.data(), cpuQuery.numElements(), seed); + + // Time faiss CPU + HostTensor + cpuDistances({numQueries, FLAGS_k}); + HostTensor + cpuIndices({numQueries, FLAGS_k}); + + if (FLAGS_cpu) { + float cpuTime = 0.0f; + + CpuTimer timer; + index->search(numQueries, + cpuQuery.data(), + FLAGS_k, + cpuDistances.data(), + cpuIndices.data()); + + cpuTime = timer.elapsedMilliseconds(); + printf("CPU time %.3f ms\n", cpuTime); + } + + HostTensor gpuDistances({numQueries, FLAGS_k}); + HostTensor gpuIndices({numQueries, FLAGS_k}); + + CUDA_VERIFY(cudaProfilerStart()); + faiss::gpu::synchronizeAllDevices(); + + float gpuTime = 0.0f; + + // Time GPU + { + CpuTimer timer; + + gpuIndex.search(cpuQuery.getSize(0), + cpuQuery.data(), + FLAGS_k, + gpuDistances.data(), + gpuIndices.data()); + + // There is a device -> host copy above, so no need to time + // additional synchronization with the GPU + gpuTime = timer.elapsedMilliseconds(); + } + + CUDA_VERIFY(cudaProfilerStop()); + printf("GPU time %.3f ms\n", gpuTime); + + CUDA_VERIFY(cudaDeviceSynchronize()); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfClustering.cpp b/core/src/index/thirdparty/faiss/gpu/perf/PerfClustering.cpp new file mode 100644 index 0000000000..6171e77926 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfClustering.cpp @@ -0,0 +1,115 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +DEFINE_int32(num, 10000, "# of vecs"); +DEFINE_int32(k, 100, "# of clusters"); +DEFINE_int32(dim, 128, "# of dimensions"); +DEFINE_int32(niter, 10, "# of iterations"); +DEFINE_bool(L2_metric, true, "If true, use L2 metric. If false, use IP metric"); +DEFINE_bool(use_float16, false, "use float16 vectors and math"); +DEFINE_bool(transposed, false, "transposed vector storage"); +DEFINE_bool(verbose, false, "turn on clustering logging"); +DEFINE_int64(seed, -1, "specify random seed"); +DEFINE_int32(num_gpus, 1, "number of gpus to use"); +DEFINE_int64(min_paging_size, -1, "minimum size to use CPU -> GPU paged copies"); +DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use"); +DEFINE_int32(max_points, -1, "max points per centroid"); + +using namespace faiss::gpu; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + cudaProfilerStop(); + + auto seed = FLAGS_seed != -1L ? FLAGS_seed : time(nullptr); + printf("using seed %ld\n", seed); + + std::vector vecs((size_t) FLAGS_num * FLAGS_dim); + faiss::float_rand(vecs.data(), vecs.size(), seed); + + printf("K-means metric %s dim %d centroids %d num train %d niter %d\n", + FLAGS_L2_metric ? "L2" : "IP", + FLAGS_dim, FLAGS_k, FLAGS_num, FLAGS_niter); + printf("float16 math %s\n", FLAGS_use_float16 ? "enabled" : "disabled"); + printf("transposed storage %s\n", FLAGS_transposed ? "enabled" : "disabled"); + printf("verbose %s\n", FLAGS_verbose ? "enabled" : "disabled"); + + auto initFn = [](faiss::gpu::GpuResources* res, int dev) -> + std::unique_ptr { + if (FLAGS_pinned_mem >= 0) { + ((faiss::gpu::StandardGpuResources*) res)->setPinnedMemory( + FLAGS_pinned_mem); + } + + GpuIndexFlatConfig config; + config.device = dev; + config.useFloat16 = FLAGS_use_float16; + config.storeTransposed = FLAGS_transposed; + + auto p = std::unique_ptr( + FLAGS_L2_metric ? + (faiss::gpu::GpuIndexFlat*) + new faiss::gpu::GpuIndexFlatL2(res, FLAGS_dim, config) : + (faiss::gpu::GpuIndexFlat*) + new faiss::gpu::GpuIndexFlatIP(res, FLAGS_dim, config)); + + if (FLAGS_min_paging_size >= 0) { + p->setMinPagingSize(FLAGS_min_paging_size); + } + return p; + }; + + IndexWrapper gpuIndex(FLAGS_num_gpus, initFn); + + CUDA_VERIFY(cudaProfilerStart()); + faiss::gpu::synchronizeAllDevices(); + + float gpuTime = 0.0f; + + faiss::ClusteringParameters cp; + cp.niter = FLAGS_niter; + cp.verbose = FLAGS_verbose; + + if (FLAGS_max_points > 0) { + cp.max_points_per_centroid = FLAGS_max_points; + } + + faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp); + + // Time k-means + { + CpuTimer timer; + + kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex())); + + // There is a device -> host copy above, so no need to time + // additional synchronization with the GPU + gpuTime = timer.elapsedMilliseconds(); + } + + CUDA_VERIFY(cudaProfilerStop()); + printf("k-means time %.3f ms\n", gpuTime); + + CUDA_VERIFY(cudaDeviceSynchronize()); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfFlat.cu b/core/src/index/thirdparty/faiss/gpu/perf/PerfFlat.cu new file mode 100644 index 0000000000..20a16382f1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfFlat.cu @@ -0,0 +1,149 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +DEFINE_bool(l2, true, "L2 or inner product"); +DEFINE_int32(k, 3, "final number of closest results returned"); +DEFINE_int32(num, 128, "# of vecs"); +DEFINE_int32(dim, 128, "# of dimensions"); +DEFINE_int32(num_queries, 3, "number of query vectors"); +DEFINE_bool(diff, true, "show exact distance + index output discrepancies"); +DEFINE_bool(use_float16, false, "use encodings in float16"); +DEFINE_bool(use_float16_math, false, "perform math in float16"); +DEFINE_bool(transposed, false, "store vectors transposed"); +DEFINE_int64(seed, -1, "specify random seed"); +DEFINE_int32(num_gpus, 1, "number of gpus to use"); +DEFINE_int64(pinned_mem, 0, "pinned memory allocation to use"); +DEFINE_bool(cpu, true, "run the CPU code for timing and comparison"); +DEFINE_bool(use_unified_mem, false, "use Pascal unified memory for the index"); + +using namespace faiss::gpu; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + cudaProfilerStop(); + + auto seed = FLAGS_seed != -1L ? FLAGS_seed : time(nullptr); + printf("using seed %ld\n", seed); + + auto numQueries = FLAGS_num_queries; + + auto index = std::unique_ptr( + new faiss::IndexFlat(FLAGS_dim, FLAGS_l2 ? + faiss::METRIC_L2 : faiss::METRIC_INNER_PRODUCT)); + + HostTensor vecs({FLAGS_num, FLAGS_dim}); + faiss::float_rand(vecs.data(), vecs.numElements(), seed); + + index->add(FLAGS_num, vecs.data()); + + printf("Database: dim %d num vecs %d\n", FLAGS_dim, FLAGS_num); + printf("%s lookup: %d queries, total k %d\n", + FLAGS_l2 ? "L2" : "IP", + numQueries, FLAGS_k); + printf("float16 encoding %s\n", FLAGS_use_float16 ? "enabled" : "disabled"); + printf("transposed storage %s\n", FLAGS_transposed ? "enabled" : "disabled"); + + // Convert to GPU index + printf("Copying index to %d GPU(s)...\n", FLAGS_num_gpus); + + auto initFn = [&index](faiss::gpu::GpuResources* res, int dev) -> + std::unique_ptr { + ((faiss::gpu::StandardGpuResources*) res)->setPinnedMemory( + FLAGS_pinned_mem); + + GpuIndexFlatConfig config; + config.device = dev; + config.useFloat16 = FLAGS_use_float16; + config.storeTransposed = FLAGS_transposed; + config.memorySpace = FLAGS_use_unified_mem ? + MemorySpace::Unified : MemorySpace::Device; + + auto p = std::unique_ptr( + new faiss::gpu::GpuIndexFlat(res, index.get(), config)); + return p; + }; + + IndexWrapper gpuIndex(FLAGS_num_gpus, initFn); + printf("copy done\n"); + + // Build query vectors + HostTensor cpuQuery({numQueries, FLAGS_dim}); + faiss::float_rand(cpuQuery.data(), cpuQuery.numElements(), seed); + + // Time faiss CPU + HostTensor cpuDistances({numQueries, FLAGS_k}); + HostTensor cpuIndices({numQueries, FLAGS_k}); + + if (FLAGS_cpu) { + float cpuTime = 0.0f; + + CpuTimer timer; + index->search(numQueries, + cpuQuery.data(), + FLAGS_k, + cpuDistances.data(), + cpuIndices.data()); + + cpuTime = timer.elapsedMilliseconds(); + printf("CPU time %.3f ms\n", cpuTime); + } + + HostTensor gpuDistances({numQueries, FLAGS_k}); + HostTensor gpuIndices({numQueries, FLAGS_k}); + + CUDA_VERIFY(cudaProfilerStart()); + faiss::gpu::synchronizeAllDevices(); + + float gpuTime = 0.0f; + + // Time GPU + { + CpuTimer timer; + + gpuIndex.getIndex()->search(cpuQuery.getSize(0), + cpuQuery.data(), + FLAGS_k, + gpuDistances.data(), + gpuIndices.data()); + + // There is a device -> host copy above, so no need to time + // additional synchronization with the GPU + gpuTime = timer.elapsedMilliseconds(); + } + + CUDA_VERIFY(cudaProfilerStop()); + printf("GPU time %.3f ms\n", gpuTime); + + if (FLAGS_cpu) { + compareLists(cpuDistances.data(), cpuIndices.data(), + gpuDistances.data(), gpuIndices.data(), + numQueries, FLAGS_k, + "", true, FLAGS_diff, false); + } + + CUDA_VERIFY(cudaDeviceSynchronize()); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFFlat.cu b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFFlat.cu new file mode 100644 index 0000000000..8b51b90ecf --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFFlat.cu @@ -0,0 +1,146 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +DEFINE_int32(nprobe, 5, "number of coarse centroids to probe"); +DEFINE_int32(k, 3, "final number of closest results returned"); +DEFINE_int32(num_queries, 3, "number of query vectors"); +DEFINE_string(in, "/home/jhj/local/index.out", "index file for input"); +DEFINE_bool(diff, true, "show exact distance + index output discrepancies"); +DEFINE_bool(use_float16_coarse, false, "coarse quantizer in float16"); +DEFINE_int64(seed, -1, "specify random seed"); +DEFINE_int32(num_gpus, 1, "number of gpus to use"); +DEFINE_int32(index, 2, "0 = no indices on GPU; 1 = 32 bit, 2 = 64 bit on GPU"); + +using namespace faiss::gpu; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + cudaProfilerStop(); + + auto seed = FLAGS_seed != -1L ? FLAGS_seed : time(nullptr); + printf("using seed %ld\n", seed); + + auto numQueries = FLAGS_num_queries; + + auto index = std::unique_ptr( + dynamic_cast(faiss::read_index(FLAGS_in.c_str()))); + FAISS_ASSERT((bool) index); + index->nprobe = FLAGS_nprobe; + + auto dim = index->d; + + printf("Database: dim %d num vecs %ld\n", dim, index->ntotal); + printf("Coarse centroids: %ld\n", index->quantizer->ntotal); + printf("L2 lookup: %d queries, nprobe %d, total k %d\n", + numQueries, FLAGS_nprobe, FLAGS_k); + printf("float16 coarse quantizer %s\n", + FLAGS_use_float16_coarse ? "enabled" : "disabled"); + + // Convert to GPU index + printf("Copying index to %d GPU(s)...\n", FLAGS_num_gpus); + + auto initFn = [&index](faiss::gpu::GpuResources* res, int dev) -> + std::unique_ptr { + GpuIndexIVFFlatConfig config; + config.device = dev; + config.indicesOptions = (faiss::gpu::IndicesOptions) FLAGS_index; + config.flatConfig.useFloat16 = FLAGS_use_float16_coarse; + + auto p = std::unique_ptr( + new faiss::gpu::GpuIndexIVFFlat(res, + index->d, + index->nlist, + index->metric_type, + config)); + p->copyFrom(index.get()); + return p; + }; + + IndexWrapper gpuIndex(FLAGS_num_gpus, initFn); + gpuIndex.setNumProbes(FLAGS_nprobe); + printf("copy done\n"); + + // Build query vectors + HostTensor cpuQuery({numQueries, dim}); + faiss::float_rand(cpuQuery.data(), cpuQuery.numElements(), seed); + + // Time faiss CPU + HostTensor cpuDistances({numQueries, FLAGS_k}); + HostTensor cpuIndices({numQueries, FLAGS_k}); + + float cpuTime = 0.0f; + + { + CpuTimer timer; + index->search(numQueries, + cpuQuery.data(), + FLAGS_k, + cpuDistances.data(), + cpuIndices.data()); + + cpuTime = timer.elapsedMilliseconds(); + } + + printf("CPU time %.3f ms\n", cpuTime); + + HostTensor gpuDistances({numQueries, FLAGS_k}); + HostTensor gpuIndices({numQueries, FLAGS_k}); + + CUDA_VERIFY(cudaProfilerStart()); + faiss::gpu::synchronizeAllDevices(); + + float gpuTime = 0.0f; + + // Time GPU + { + CpuTimer timer; + + gpuIndex.getIndex()->search(cpuQuery.getSize(0), + cpuQuery.data(), + FLAGS_k, + gpuDistances.data(), + gpuIndices.data()); + + // There is a device -> host copy above, so no need to time + // additional synchronization with the GPU + gpuTime = timer.elapsedMilliseconds(); + } + + CUDA_VERIFY(cudaProfilerStop()); + printf("GPU time %.3f ms\n", gpuTime); + + compareLists(cpuDistances.data(), cpuIndices.data(), + gpuDistances.data(), gpuIndices.data(), + numQueries, FLAGS_k, + "", true, FLAGS_diff, false); + + CUDA_VERIFY(cudaDeviceSynchronize()); + // printf("\ncudaMalloc usage %zd\n", + // resources.getMemoryManager().getHighWaterCudaMalloc()); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQ.cu b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQ.cu new file mode 100644 index 0000000000..82eb648a1f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQ.cu @@ -0,0 +1,157 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +DEFINE_int32(nprobe, 5, "number of coarse centroids to probe"); +DEFINE_int32(k, 3, "final number of closest results returned"); +DEFINE_int32(num_queries, 3, "number of query vectors"); +DEFINE_string(in, "/home/jhj/local/index.out", "index file for input"); +DEFINE_bool(diff, true, "show exact distance + index output discrepancies"); +DEFINE_bool(use_precomputed, true, "enable or disable precomputed codes"); +DEFINE_bool(float16_lookup, false, "use float16 residual distance tables"); +DEFINE_int64(seed, -1, "specify random seed"); +DEFINE_int32(num_gpus, 1, "number of gpus to use"); +DEFINE_int32(index, 2, "0 = no indices on GPU; 1 = 32 bit, 2 = 64 bit on GPU"); + +using namespace faiss::gpu; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + CUDA_VERIFY(cudaProfilerStop()); + + auto seed = FLAGS_seed != -1L ? FLAGS_seed : time(nullptr); + printf("using seed %ld\n", seed); + + auto numQueries = FLAGS_num_queries; + + auto index = std::unique_ptr( + dynamic_cast(faiss::read_index(FLAGS_in.c_str()))); + FAISS_ASSERT((bool) index); + index->nprobe = FLAGS_nprobe; + + if (!FLAGS_use_precomputed) { + index->use_precomputed_table = 0; + } + + auto dim = index->d; + auto codes = index->pq.M; + auto bitsPerCode = index->pq.nbits; + + printf("Database: dim %d num vecs %ld\n", dim, index->ntotal); + printf("Coarse centroids: %ld\n", index->quantizer->ntotal); + printf("PQ centroids: codes %ld bits per code %ld\n", codes, bitsPerCode); + printf("L2 lookup: %d queries, nprobe %d, total k %d, " + "precomputed codes %d\n\n", + numQueries, FLAGS_nprobe, FLAGS_k, + FLAGS_use_precomputed); + + // Convert to GPU index + printf("Copying index to %d GPU(s)...\n", FLAGS_num_gpus); + + auto precomp = FLAGS_use_precomputed; + auto indicesOpt = (faiss::gpu::IndicesOptions) FLAGS_index; + auto useFloat16Lookup = FLAGS_float16_lookup; + + auto initFn = [precomp, indicesOpt, useFloat16Lookup, &index] + (faiss::gpu::GpuResources* res, int dev) -> + std::unique_ptr { + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = dev; + config.usePrecomputedTables = precomp; + config.indicesOptions = indicesOpt; + config.useFloat16LookupTables = useFloat16Lookup; + + auto p = std::unique_ptr( + new faiss::gpu::GpuIndexIVFPQ(res, index.get(), config)); + + return p; + }; + + IndexWrapper gpuIndex(FLAGS_num_gpus, initFn); + gpuIndex.setNumProbes(FLAGS_nprobe); + printf("copy done\n"); + + // Build query vectors + HostTensor cpuQuery({numQueries, dim}); + faiss::float_rand(cpuQuery.data(), cpuQuery.numElements(), seed); + + // Time faiss CPU + HostTensor cpuDistances({numQueries, FLAGS_k}); + HostTensor cpuIndices({numQueries, FLAGS_k}); + + float cpuTime = 0.0f; + + { + CpuTimer timer; + index->search(numQueries, + cpuQuery.data(), + FLAGS_k, + cpuDistances.data(), + cpuIndices.data()); + + cpuTime = timer.elapsedMilliseconds(); + } + + printf("CPU time %.3f ms\n", cpuTime); + + HostTensor gpuDistances({numQueries, FLAGS_k}); + HostTensor gpuIndices({numQueries, FLAGS_k}); + + CUDA_VERIFY(cudaProfilerStart()); + faiss::gpu::synchronizeAllDevices(); + + float gpuTime = 0.0f; + + // Time GPU + { + CpuTimer timer; + + gpuIndex.getIndex()->search(cpuQuery.getSize(0), + cpuQuery.data(), + FLAGS_k, + gpuDistances.data(), + gpuIndices.data()); + + // There is a device -> host copy above, so no need to time + // additional synchronization with the GPU + gpuTime = timer.elapsedMilliseconds(); + } + + CUDA_VERIFY(cudaProfilerStop()); + printf("GPU time %.3f ms\n", gpuTime); + + compareLists(cpuDistances.data(), cpuIndices.data(), + gpuDistances.data(), gpuIndices.data(), + numQueries, FLAGS_k, + "", true, FLAGS_diff, false); + + CUDA_VERIFY(cudaDeviceSynchronize()); + // printf("\ncudaMalloc usage %zd\n", + // resources.getMemoryManager().getHighWaterCudaMalloc()); + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQAdd.cpp b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQAdd.cpp new file mode 100644 index 0000000000..1e45d635a5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfIVFPQAdd.cpp @@ -0,0 +1,139 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DEFINE_int32(batches, 10, "number of batches of vectors to add"); +DEFINE_int32(batch_size, 10000, "number of vectors in each batch"); +DEFINE_int32(dim, 256, "dimension of vectors"); +DEFINE_int32(centroids, 4096, "num coarse centroids to use"); +DEFINE_int32(bytes_per_vec, 32, "bytes per encoded vector"); +DEFINE_int32(bits_per_code, 8, "bits per PQ code"); +DEFINE_int32(index, 2, "0 = no indices on GPU; 1 = 32 bit, 2 = 64 bit on GPU"); +DEFINE_bool(time_gpu, true, "time add to GPU"); +DEFINE_bool(time_cpu, false, "time add to CPU"); +DEFINE_bool(per_batch_time, false, "print per-batch times"); +DEFINE_bool(reserve_memory, false, "whether or not to pre-reserve memory"); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + cudaProfilerStop(); + + int dim = FLAGS_dim; + int numCentroids = FLAGS_centroids; + int bytesPerVec = FLAGS_bytes_per_vec; + int bitsPerCode = FLAGS_bits_per_code; + + faiss::gpu::StandardGpuResources res; + + // IndexIVFPQ will complain, but just give us enough to get through this + int numTrain = 4 * numCentroids; + std::vector trainVecs = faiss::gpu::randVecs(numTrain, dim); + + faiss::IndexFlatL2 coarseQuantizer(dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, dim, numCentroids, + bytesPerVec, bitsPerCode); + if (FLAGS_time_cpu) { + cpuIndex.train(numTrain, trainVecs.data()); + } + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = 0; + config.indicesOptions = (faiss::gpu::IndicesOptions) FLAGS_index; + + faiss::gpu::GpuIndexIVFPQ gpuIndex( + &res, dim, numCentroids, bytesPerVec, bitsPerCode, + faiss::METRIC_L2, config); + + if (FLAGS_time_gpu) { + gpuIndex.train(numTrain, trainVecs.data()); + if (FLAGS_reserve_memory) { + size_t numVecs = (size_t) FLAGS_batches * (size_t) FLAGS_batch_size; + gpuIndex.reserveMemory(numVecs); + } + } + + cudaDeviceSynchronize(); + CUDA_VERIFY(cudaProfilerStart()); + + float totalGpuTime = 0.0f; + float totalCpuTime = 0.0f; + + for (int i = 0; i < FLAGS_batches; ++i) { + if (!FLAGS_per_batch_time) { + if (i % 10 == 0) { + printf("Adding batch %d\n", i + 1); + } + } + + auto addVecs = faiss::gpu::randVecs(FLAGS_batch_size, dim); + + if (FLAGS_time_gpu) { + faiss::gpu::CpuTimer timer; + gpuIndex.add(FLAGS_batch_size, addVecs.data()); + CUDA_VERIFY(cudaDeviceSynchronize()); + auto time = timer.elapsedMilliseconds(); + + totalGpuTime += time; + + if (FLAGS_per_batch_time) { + printf("Batch %d | GPU time to add %d vecs: %.3f ms (%.5f ms per)\n", + i + 1, FLAGS_batch_size, time, time / (float) FLAGS_batch_size); + } + } + + if (FLAGS_time_cpu) { + faiss::gpu::CpuTimer timer; + cpuIndex.add(FLAGS_batch_size, addVecs.data()); + auto time = timer.elapsedMilliseconds(); + + totalCpuTime += time; + + if (FLAGS_per_batch_time) { + printf("Batch %d | CPU time to add %d vecs: %.3f ms (%.5f ms per)\n", + i + 1, FLAGS_batch_size, time, time / (float) FLAGS_batch_size); + } + } + } + + CUDA_VERIFY(cudaProfilerStop()); + + int total = FLAGS_batch_size * FLAGS_batches; + + if (FLAGS_time_gpu) { + printf("%d dim, %d centroids, %d x %d encoding\n" + "GPU time to add %d vectors (%d batches, %d per batch): " + "%.3f ms (%.3f us per)\n", + dim, numCentroids, bytesPerVec, bitsPerCode, + total, FLAGS_batches, FLAGS_batch_size, + totalGpuTime, totalGpuTime * 1000.0f / (float) total); + } + + if (FLAGS_time_cpu) { + printf("%d dim, %d centroids, %d x %d encoding\n" + "CPU time to add %d vectors (%d batches, %d per batch): " + "%.3f ms (%.3f us per)\n", + dim, numCentroids, bytesPerVec, bitsPerCode, + total, FLAGS_batches, FLAGS_batch_size, + totalCpuTime, totalCpuTime * 1000.0f / (float) total); + } + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/PerfSelect.cu b/core/src/index/thirdparty/faiss/gpu/perf/PerfSelect.cu new file mode 100644 index 0000000000..5e2eb49f13 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/PerfSelect.cu @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DEFINE_int32(rows, 10000, "rows in matrix"); +DEFINE_int32(cols, 40000, "cols in matrix"); +DEFINE_int32(k, 100, "k"); +DEFINE_bool(dir, false, "direction of sort"); +DEFINE_bool(warp, false, "warp select"); +DEFINE_int32(iter, 5, "iterations to run"); +DEFINE_bool(k_powers, false, "test k powers of 2 from 1 -> max k"); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::vector v = faiss::gpu::randVecs(FLAGS_rows, FLAGS_cols); + faiss::gpu::HostTensor hostVal({FLAGS_rows, FLAGS_cols}); + + for (int r = 0; r < FLAGS_rows; ++r) { + for (int c = 0; c < FLAGS_cols; ++c) { + hostVal[r][c] = v[r * FLAGS_cols + c]; + } + } + + // Select top-k on GPU + faiss::gpu::DeviceTensor gpuVal(hostVal, 0); + + int startK = FLAGS_k; + int limitK = FLAGS_k; + + if (FLAGS_k_powers) { + startK = 1; + limitK = GPU_MAX_SELECTION_K; + } + + faiss::gpu::DeviceTensor bitset(nullptr, {0}); + for (int k = startK; k <= limitK; k *= 2) { + faiss::gpu::DeviceTensor gpuOutVal({FLAGS_rows, k}); + faiss::gpu::DeviceTensor gpuOutInd({FLAGS_rows, k}); + + for (int i = 0; i < FLAGS_iter; ++i) { + if (FLAGS_warp) { + faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd, + FLAGS_dir, k, 0); + } else { + faiss::gpu::runBlockSelect(gpuVal, bitset, gpuOutVal, gpuOutInd, + FLAGS_dir, k, 0); + } + } + } + + cudaDeviceSynchronize(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/WriteIndex.cpp b/core/src/index/thirdparty/faiss/gpu/perf/WriteIndex.cpp new file mode 100644 index 0000000000..af363787a9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/WriteIndex.cpp @@ -0,0 +1,102 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +// For IVFPQ: +DEFINE_bool(ivfpq, false, "use IVFPQ encoding"); +DEFINE_int32(codes, 4, "number of PQ codes per vector"); +DEFINE_int32(bits_per_code, 8, "number of bits per PQ code"); + +// For IVFFlat: +DEFINE_bool(l2, true, "use L2 metric (versus IP metric)"); +DEFINE_bool(ivfflat, false, "use IVF flat encoding"); + +// For both: +DEFINE_string(out, "/home/jhj/local/index.out", "index file for output"); +DEFINE_int32(dim, 128, "vector dimension"); +DEFINE_int32(num_coarse, 100, "number of coarse centroids"); +DEFINE_int32(num, 100000, "total database size"); +DEFINE_int32(num_train, -1, "number of database vecs to train on"); + +template +void fillAndSave(T& index, int numTrain, int num, int dim) { + auto trainVecs = faiss::gpu::randVecs(numTrain, dim); + index.train(numTrain, trainVecs.data()); + + constexpr int kAddChunk = 1000000; + + for (int i = 0; i < num; i += kAddChunk) { + int numRemaining = (num - i) < kAddChunk ? (num - i) : kAddChunk; + auto vecs = faiss::gpu::randVecs(numRemaining, dim); + + printf("adding at %d: %d\n", i, numRemaining); + index.add(numRemaining, vecs.data()); + } + + faiss::write_index(&index, FLAGS_out.c_str()); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Either ivfpq or ivfflat must be set + if ((FLAGS_ivfpq && FLAGS_ivfflat) || + (!FLAGS_ivfpq && !FLAGS_ivfflat)) { + printf("must specify either ivfpq or ivfflat\n"); + return 1; + } + + auto dim = FLAGS_dim; + auto numCentroids = FLAGS_num_coarse; + auto num = FLAGS_num; + auto numTrain = FLAGS_num_train; + numTrain = numTrain == -1 ? std::max((num / 4), 1) : numTrain; + numTrain = std::min(num, numTrain); + + if (FLAGS_ivfpq) { + faiss::IndexFlatL2 quantizer(dim); + faiss::IndexIVFPQ index(&quantizer, dim, numCentroids, + FLAGS_codes, FLAGS_bits_per_code); + index.verbose = true; + + printf("IVFPQ: codes %d bits per code %d\n", + FLAGS_codes, FLAGS_bits_per_code); + printf("Lists: %d\n", numCentroids); + printf("Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain); + printf("output file: %s\n", FLAGS_out.c_str()); + + fillAndSave(index, numTrain, num, dim); + } else if (FLAGS_ivfflat) { + faiss::IndexFlatL2 quantizerL2(dim); + faiss::IndexFlatIP quantizerIP(dim); + + faiss::IndexFlat* quantizer = FLAGS_l2 ? + (faiss::IndexFlat*) &quantizerL2 : + (faiss::IndexFlat*) &quantizerIP; + + faiss::IndexIVFFlat index(quantizer, dim, numCentroids, + FLAGS_l2 ? faiss::METRIC_L2 : + faiss::METRIC_INNER_PRODUCT); + + printf("IVFFlat: metric %s\n", FLAGS_l2 ? "L2" : "IP"); + printf("Lists: %d\n", numCentroids); + printf("Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain); + printf("output file: %s\n", FLAGS_out.c_str()); + + fillAndSave(index, numTrain, num, dim); + } + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/perf/slow.py b/core/src/index/thirdparty/faiss/gpu/perf/slow.py new file mode 100644 index 0000000000..a096311c4e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/perf/slow.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python3 +# this is a slow computation to test whether ctrl-C handling works +import faiss +import numpy as np + +def test_slow(): + d = 256 + index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), + 0, faiss.IndexFlatL2(d)) + x = np.random.rand(10 ** 6, d).astype('float32') + print('add') + index.add(x) + print('search') + index.search(x, 10) + print('done') + + +if __name__ == '__main__': + test_slow() diff --git a/core/src/index/thirdparty/faiss/gpu/test/Makefile b/core/src/index/thirdparty/faiss/gpu/test/Makefile new file mode 100644 index 0000000000..6836314810 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/Makefile @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include ../../makefile.inc + +TESTS_SRC = TestGpuIndexFlat.cpp TestGpuIndexIVFPQ.cpp \ +TestGpuIndexBinaryFlat.cpp TestGpuIndexIVFFlat.cpp TestGpuMemoryException.cpp +CUDA_TESTS_SRC = TestGpuSelect.cu + +TESTS_OBJ = $(TESTS_SRC:.cpp=.o) +CUDA_TESTS_OBJ = $(CUDA_TESTS_SRC:.cu=.o) + +TESTS_BIN = $(TESTS_OBJ:.o=) $(CUDA_TESTS_OBJ:.o=) + + +# test_gpu_index.py test_pytorch_faiss.py + +run: $(TESTS_BIN) $(CUDA_TESTS_BIN) + for t in $(TESTS_BIN) $(CUDA_TESTS_BIN); do ./$$t || exit; done + +$(CUDA_TESTS_OBJ): %.o: %.cu gtest + $(NVCC) $(NVCCFLAGS) -g -O3 -o $@ -c $< -Igtest/include + +$(TESTS_OBJ): %.o: %.cpp gtest + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -o $@ -c $< -Igtest/include + +$(TESTS_BIN): %: %.o TestUtils.o ../../libfaiss.a gtest/make/gtest.a + $(CXX) -o $@ $^ $(LDFLAGS) $(LIBS) + +demo_ivfpq_indexing_gpu: demo_ivfpq_indexing_gpu.o ../../libfaiss.a + $(CXX) -o $@ $^ $(LDFLAGS) $(LIBS) + +demo_ivfpq_indexing_gpu.o: demo_ivfpq_indexing_gpu.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -o $@ -c $^ + +gtest/make/gtest.a: gtest + $(MAKE) -C gtest/make CXX="$(CXX)" CXXFLAGS="$(CXXFLAGS)" gtest.a + +gtest: + curl -L https://github.com/google/googletest/archive/release-1.8.0.tar.gz | tar xz && \ + mv googletest-release-1.8.0/googletest gtest && \ + rm -rf googletest-release-1.8.0 + +clean: + rm -f *.o $(TESTS_BIN) + rm -rf gtest + rm -f demo_ivfpq_indexing_gpu + +.PHONY: clean run diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuDistance.cu b/core/src/index/thirdparty/faiss/gpu/test/TestGpuDistance.cu new file mode 100644 index 0000000000..f188a1b7d3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuDistance.cu @@ -0,0 +1,180 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void testTransposition(bool colMajorVecs, + bool colMajorQueries, + faiss::MetricType metric, + float metricArg = 0) { + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + int dim = faiss::gpu::randVal(20, 150); + int numVecs = faiss::gpu::randVal(10, 30000); + int numQuery = faiss::gpu::randVal(1, 1024); + int k = std::min(numVecs, faiss::gpu::randVal(20, 70)); + + // Input data for CPU + std::vector vecs = faiss::gpu::randVecs(numVecs, dim); + std::vector queries = faiss::gpu::randVecs(numQuery, dim); + + if (metric == faiss::MetricType::METRIC_JensenShannon) { + // make values positive + for (auto& v : vecs) { + v = std::abs(v); + if (v == 0) { + v = 1e-6; + } + } + + for (auto& q : queries) { + q = std::abs(q); + if (q == 0) { + q = 1e-6; + } + } + } + + // The CPU index is our reference for the results + faiss::IndexFlat cpuIndex(dim, metric); + cpuIndex.metric_arg = metricArg; + cpuIndex.add(numVecs, vecs.data()); + + std::vector cpuDistance(numQuery * k, 0); + std::vector cpuIndices(numQuery * k, -1); + + cpuIndex.search(numQuery, queries.data(), k, + cpuDistance.data(), cpuIndices.data()); + + // The transpose and distance code assumes the desired device is already set + faiss::gpu::DeviceScope scope(device); + auto stream = res.getDefaultStream(device); + + // Copy input data to GPU, and pre-transpose both vectors and queries for + // passing + auto gpuVecs = faiss::gpu::toDevice( + nullptr, device, vecs.data(), stream, {numVecs, dim}); + auto gpuQueries = faiss::gpu::toDevice( + nullptr, device, queries.data(), stream, {numQuery, dim}); + + faiss::gpu::DeviceTensor vecsT({dim, numVecs}); + faiss::gpu::runTransposeAny(gpuVecs, 0, 1, vecsT, stream); + + faiss::gpu::DeviceTensor queriesT({dim, numQuery}); + faiss::gpu::runTransposeAny(gpuQueries, 0, 1, queriesT, stream); + + std::vector gpuDistance(numQuery * k, 0); + std::vector gpuIndices(numQuery * k, -1); + + faiss::gpu::GpuDistanceParams args; + args.metric = metric; + args.metricArg = metricArg; + args.k = k; + args.dims = dim; + args.vectors = colMajorVecs ? vecsT.data() : gpuVecs.data(); + args.vectorsRowMajor = !colMajorVecs; + args.numVectors = numVecs; + args.queries = colMajorQueries ? queriesT.data() : gpuQueries.data(); + args.queriesRowMajor = !colMajorQueries; + args.numQueries = numQuery; + args.outDistances = gpuDistance.data(); + args.outIndices = gpuIndices.data(); + + faiss::gpu::bfKnn(&res, args); + + std::stringstream str; + str << "metric " << metric + << " colMajorVecs " << colMajorVecs + << " colMajorQueries " << colMajorQueries; + + faiss::gpu::compareLists(cpuDistance.data(), + cpuIndices.data(), + gpuDistance.data(), + gpuIndices.data(), + numQuery, k, + str.str(), + false, false, true, + 6e-3f, 0.1f, 0.015f); +} + +// Test different memory layouts for brute-force k-NN +TEST(TestGpuDistance, Transposition_RR) { + testTransposition(false, false, faiss::MetricType::METRIC_L2); + testTransposition(false, false, faiss::MetricType::METRIC_INNER_PRODUCT); +} + +TEST(TestGpuDistance, Transposition_RC) { + testTransposition(false, true, faiss::MetricType::METRIC_L2); +} + +TEST(TestGpuDistance, Transposition_CR) { + testTransposition(true, false, faiss::MetricType::METRIC_L2); +} + +TEST(TestGpuDistance, Transposition_CC) { + testTransposition(true, true, faiss::MetricType::METRIC_L2); +} + +TEST(TestGpuDistance, L1) { + testTransposition(false, false, faiss::MetricType::METRIC_L1); +} + +// Test other transpositions with the general distance kernel +TEST(TestGpuDistance, L1_RC) { + testTransposition(false, true, faiss::MetricType::METRIC_L1); +} + +TEST(TestGpuDistance, L1_CR) { + testTransposition(true, false, faiss::MetricType::METRIC_L1); +} + +TEST(TestGpuDistance, L1_CC) { + testTransposition(true, true, faiss::MetricType::METRIC_L1); +} + +// Test remainder of metric types +TEST(TestGpuDistance, Linf) { + testTransposition(false, false, faiss::MetricType::METRIC_Linf); +} + +TEST(TestGpuDistance, Lp) { + testTransposition(false, false, faiss::MetricType::METRIC_Lp, 3); +} + +TEST(TestGpuDistance, Canberra) { + testTransposition(false, false, faiss::MetricType::METRIC_Canberra); +} + +TEST(TestGpuDistance, BrayCurtis) { + testTransposition(false, false, faiss::MetricType::METRIC_BrayCurtis); +} + +TEST(TestGpuDistance, JensenShannon) { + testTransposition(false, false, faiss::MetricType::METRIC_JensenShannon); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp new file mode 100644 index 0000000000..14c28c155a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp @@ -0,0 +1,130 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void compareBinaryDist(const std::vector& cpuDist, + const std::vector& cpuLabels, + const std::vector& gpuDist, + const std::vector& gpuLabels, + int numQuery, + int k) { + for (int i = 0; i < numQuery; ++i) { + // The index order can be permuted within a group that has the same + // distance, since this is based on the order in which the algorithm + // encounters the values. The last set of equivalent distances seen in the + // min-k might be truncated, so we can't check that set, but all others we + // can check. + std::set cpuLabelSet; + std::set gpuLabelSet; + + int curDist = -1; + + for (int j = 0; j < k; ++j) { + int idx = i * k + j; + + if (curDist == -1) { + curDist = cpuDist[idx]; + } + + if (curDist != cpuDist[idx]) { + // Distances must be monotonically increasing + EXPECT_LT(curDist, cpuDist[idx]); + + // This is a new set of distances + EXPECT_EQ(cpuLabelSet, gpuLabelSet); + curDist = cpuDist[idx]; + cpuLabelSet.clear(); + gpuLabelSet.clear(); + } + + cpuLabelSet.insert(cpuLabels[idx]); + gpuLabelSet.insert(gpuLabels[idx]); + + // Because the distances are reproducible, they must be exactly the same + EXPECT_EQ(cpuDist[idx], gpuDist[idx]); + } + } +} + +template +void testGpuIndexBinaryFlat(int kOverride = -1) { + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexBinaryFlatConfig config; + config.device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + // multiples of 8 and multiples of 32 use different implementations + int dims = faiss::gpu::randVal(1, 20) * DimMultiple; + faiss::gpu::GpuIndexBinaryFlat gpuIndex(&res, dims, config); + + faiss::IndexBinaryFlat cpuIndex(dims); + + int k = kOverride > 0 ? + kOverride : faiss::gpu::randVal(1, faiss::gpu::getMaxKSelection()); + int numVecs = faiss::gpu::randVal(k + 1, 20000); + int numQuery = faiss::gpu::randVal(1, 1000); + + auto data = faiss::gpu::randBinaryVecs(numVecs, dims); + gpuIndex.add(numVecs, data.data()); + cpuIndex.add(numVecs, data.data()); + + auto query = faiss::gpu::randBinaryVecs(numQuery, dims); + + std::vector cpuDist(numQuery * k); + std::vector cpuLabels(numQuery * k); + + cpuIndex.search(numQuery, + query.data(), + k, + cpuDist.data(), + cpuLabels.data()); + + std::vector gpuDist(numQuery * k); + std::vector gpuLabels(numQuery * k); + + gpuIndex.search(numQuery, + query.data(), + k, + gpuDist.data(), + gpuLabels.data()); + + compareBinaryDist(cpuDist, cpuLabels, + gpuDist, gpuLabels, + numQuery, k); +} + +TEST(TestGpuIndexBinaryFlat, Test8) { + for (int tries = 0; tries < 4; ++tries) { + testGpuIndexBinaryFlat<8>(); + } +} + +TEST(TestGpuIndexBinaryFlat, Test32) { + for (int tries = 0; tries < 4; ++tries) { + testGpuIndexBinaryFlat<32>(); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexFlat.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexFlat.cpp new file mode 100644 index 0000000000..73cfe20542 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexFlat.cpp @@ -0,0 +1,393 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +// FIXME: figure out a better way to test fp16 +constexpr float kF16MaxRelErr = 0.07f; +constexpr float kF32MaxRelErr = 6e-3f; + +struct TestFlatOptions { + TestFlatOptions() + : metric(faiss::MetricType::METRIC_L2), + metricArg(0), + useFloat16(false), + useTransposed(false), + numVecsOverride(-1), + numQueriesOverride(-1), + kOverride(-1), + dimOverride(-1) { + } + + faiss::MetricType metric; + float metricArg; + + bool useFloat16; + bool useTransposed; + int numVecsOverride; + int numQueriesOverride; + int kOverride; + int dimOverride; +}; + +void testFlat(const TestFlatOptions& opt) { + int numVecs = opt.numVecsOverride > 0 ? + opt.numVecsOverride : faiss::gpu::randVal(1000, 5000); + int dim = opt.dimOverride > 0 ? + opt.dimOverride : faiss::gpu::randVal(50, 800); + int numQuery = opt.numQueriesOverride > 0 ? + opt.numQueriesOverride : faiss::gpu::randVal(1, 512); + + // Due to loss of precision in a float16 accumulator, for large k, + // the number of differences is pretty huge. Restrict ourselves to a + // fairly small `k` for float16 + int k = opt.useFloat16 ? + std::min(faiss::gpu::randVal(1, 50), numVecs) : + std::min(faiss::gpu::randVal(1, faiss::gpu::getMaxKSelection()), numVecs); + if (opt.kOverride > 0) { + k = opt.kOverride; + } + + faiss::IndexFlat cpuIndex(dim, opt.metric); + cpuIndex.metric_arg = opt.metricArg; + + // Construct on a random device to test multi-device, if we have + // multiple devices + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexFlatConfig config; + config.device = device; + config.useFloat16 = opt.useFloat16; + config.storeTransposed = opt.useTransposed; + + faiss::gpu::GpuIndexFlat gpuIndex(&res, dim, opt.metric, config); + gpuIndex.metric_arg = opt.metricArg; + + std::vector vecs = faiss::gpu::randVecs(numVecs, dim); + cpuIndex.add(numVecs, vecs.data()); + gpuIndex.add(numVecs, vecs.data()); + + std::stringstream str; + str << "metric " << opt.metric + << " marg " << opt.metricArg + << " numVecs " << numVecs + << " dim " << dim + << " useFloat16 " << opt.useFloat16 + << " transposed " << opt.useTransposed + << " numQuery " << numQuery + << " k " << k; + + // To some extent, we depend upon the relative error for the test + // for float16 + faiss::gpu::compareIndices(cpuIndex, gpuIndex, numQuery, dim, k, str.str(), + opt.useFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + // FIXME: the fp16 bounds are + // useless when math (the accumulator) is + // in fp16. Figure out another way to test + opt.useFloat16 ? 0.99f : 0.1f, + opt.useFloat16 ? 0.65f : 0.015f); +} + +TEST(TestGpuIndexFlat, IP_Float32) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_INNER_PRODUCT; + opt.useFloat16 = false; + opt.useTransposed = false; + + testFlat(opt); + + opt.useTransposed = true; + testFlat(opt); + } +} + +TEST(TestGpuIndexFlat, L1_Float32) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L1; + opt.useFloat16 = false; + opt.useTransposed = false; + + testFlat(opt); + + opt.useTransposed = true; + testFlat(opt); +} + +TEST(TestGpuIndexFlat, Lp_Float32) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_Lp; + opt.metricArg = 5; + opt.useFloat16 = false; + opt.useTransposed = false; + + testFlat(opt); + + // Don't bother testing the transposed version, the L1 test should be good + // enough for that +} + +TEST(TestGpuIndexFlat, L2_Float32) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + + opt.useFloat16 = false; + opt.useTransposed = false; + + testFlat(opt); + + opt.useTransposed = true; + testFlat(opt); + } +} + +// test specialized k == 1 codepath +TEST(TestGpuIndexFlat, L2_Float32_K1) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + opt.useFloat16 = false; + opt.useTransposed = false; + opt.kOverride = 1; + + testFlat(opt); + } +} + +TEST(TestGpuIndexFlat, IP_Float16) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_INNER_PRODUCT; + opt.useFloat16 = true; + opt.useTransposed = false; + + testFlat(opt); + + opt.useTransposed = true; + testFlat(opt); + } +} + +TEST(TestGpuIndexFlat, L2_Float16) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + opt.useFloat16 = true; + opt.useTransposed = false; + + testFlat(opt); + + opt.useTransposed = true; + testFlat(opt); + } +} + +// test specialized k == 1 codepath +TEST(TestGpuIndexFlat, L2_Float16_K1) { + for (int tries = 0; tries < 3; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + opt.useFloat16 = true; + opt.useTransposed = false; + opt.kOverride = 1; + + testFlat(opt); + } +} + +// test tiling along a huge vector set +TEST(TestGpuIndexFlat, L2_Tiling) { + for (int tries = 0; tries < 2; ++tries) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + opt.useFloat16 = false; + opt.useTransposed = false; + opt.numVecsOverride = 1000000; + + // keep the rest of the problem reasonably small + opt.numQueriesOverride = 4; + opt.dimOverride = 64; + opt.kOverride = 64; + + testFlat(opt); + } +} + +TEST(TestGpuIndexFlat, QueryEmpty) { + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexFlatConfig config; + config.device = 0; + config.useFloat16 = false; + config.storeTransposed = false; + + int dim = 128; + faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config); + + // Querying an empty index should not blow up, and just return + // (FLT_MAX, -1) + int numQuery = 10; + int k = 50; + std::vector queries(numQuery * dim, 1.0f); + + std::vector dist(numQuery * k, 0); + std::vector ind(numQuery * k); + + gpuIndex.search(numQuery, queries.data(), k, dist.data(), ind.data()); + + for (auto d : dist) { + EXPECT_EQ(d, std::numeric_limits::max()); + } + + for (auto i : ind) { + EXPECT_EQ(i, -1); + } +} + +TEST(TestGpuIndexFlat, CopyFrom) { + int numVecs = faiss::gpu::randVal(100, 200); + int dim = faiss::gpu::randVal(1, 1000); + + faiss::IndexFlatL2 cpuIndex(dim); + + std::vector vecs = faiss::gpu::randVecs(numVecs, dim); + cpuIndex.add(numVecs, vecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + // Fill with garbage values + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + faiss::gpu::GpuIndexFlatConfig config; + config.device = 0; + config.useFloat16 = false; + config.storeTransposed = false; + + faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, 2000, config); + gpuIndex.copyFrom(&cpuIndex); + + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, numVecs); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.d, dim); + + int idx = faiss::gpu::randVal(0, numVecs - 1); + + std::vector gpuVals(dim); + gpuIndex.reconstruct(idx, gpuVals.data()); + + std::vector cpuVals(dim); + cpuIndex.reconstruct(idx, cpuVals.data()); + + EXPECT_EQ(gpuVals, cpuVals); +} + +TEST(TestGpuIndexFlat, CopyTo) { + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + int numVecs = faiss::gpu::randVal(100, 200); + int dim = faiss::gpu::randVal(1, 1000); + + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + faiss::gpu::GpuIndexFlatConfig config; + config.device = device; + config.useFloat16 = false; + config.storeTransposed = false; + + faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config); + + std::vector vecs = faiss::gpu::randVecs(numVecs, dim); + gpuIndex.add(numVecs, vecs.data()); + + // Fill with garbage values + faiss::IndexFlatL2 cpuIndex(2000); + gpuIndex.copyTo(&cpuIndex); + + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, numVecs); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.d, dim); + + int idx = faiss::gpu::randVal(0, numVecs - 1); + + std::vector gpuVals(dim); + gpuIndex.reconstruct(idx, gpuVals.data()); + + std::vector cpuVals(dim); + cpuIndex.reconstruct(idx, cpuVals.data()); + + EXPECT_EQ(gpuVals, cpuVals); +} + +TEST(TestGpuIndexFlat, UnifiedMemory) { + // Construct on a random device to test multi-device, if we have + // multiple devices + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + if (!faiss::gpu::getFullUnifiedMemSupport(device)) { + return; + } + + int dim = 256; + + // FIXME: GpuIndexFlat doesn't support > 2^31 (vecs * dims) due to + // kernel indexing, so we can't test unified memory for memory + // oversubscription. + size_t numVecs = 50000; + int numQuery = 10; + int k = 10; + + faiss::IndexFlatL2 cpuIndexL2(dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexFlatConfig config; + config.device = device; + config.memorySpace = faiss::gpu::MemorySpace::Unified; + + faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config); + + std::vector vecs = faiss::gpu::randVecs(numVecs, dim); + cpuIndexL2.add(numVecs, vecs.data()); + gpuIndexL2.add(numVecs, vecs.data()); + + // To some extent, we depend upon the relative error for the test + // for float16 + faiss::gpu::compareIndices(cpuIndexL2, gpuIndexL2, + numQuery, dim, k, "Unified Memory", + kF32MaxRelErr, + 0.1f, + 0.015f); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFFlat.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFFlat.cpp new file mode 100644 index 0000000000..6304252e6b --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFFlat.cpp @@ -0,0 +1,550 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// FIXME: figure out a better way to test fp16 +constexpr float kF16MaxRelErr = 0.3f; +constexpr float kF32MaxRelErr = 0.03f; + + +struct Options { + Options() { + numAdd = 2 * faiss::gpu::randVal(2000, 5000); + dim = faiss::gpu::randVal(64, 200); + + numCentroids = std::sqrt((float) numAdd / 2); + numTrain = numCentroids * 40; + nprobe = faiss::gpu::randVal(std::min(10, numCentroids), numCentroids); + numQuery = faiss::gpu::randVal(32, 100); + + // Due to the approximate nature of the query and of floating point + // differences between GPU and CPU, to stay within our error bounds, only + // use a small k + k = std::min(faiss::gpu::randVal(10, 30), numAdd / 40); + indicesOpt = faiss::gpu::randSelect({ + faiss::gpu::INDICES_CPU, + faiss::gpu::INDICES_32_BIT, + faiss::gpu::INDICES_64_BIT}); + + device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + } + + std::string toString() const { + std::stringstream str; + str << "IVFFlat device " << device + << " numVecs " << numAdd + << " dim " << dim + << " numCentroids " << numCentroids + << " nprobe " << nprobe + << " numQuery " << numQuery + << " k " << k + << " indicesOpt " << indicesOpt; + + return str.str(); + } + + int numAdd; + int dim; + int numCentroids; + int numTrain; + int nprobe; + int numQuery; + int k; + int device; + faiss::gpu::IndicesOptions indicesOpt; +}; + +void queryTest(faiss::MetricType metricType, + bool useFloat16CoarseQuantizer, + int dimOverride = -1) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + opt.dim = dimOverride != -1 ? dimOverride : opt.dim; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 quantizerL2(opt.dim); + faiss::IndexFlatIP quantizerIP(opt.dim); + faiss::Index* quantizer = + metricType == faiss::METRIC_L2 ? + (faiss::Index*) &quantizerL2 : (faiss::Index*) &quantizerIP; + + faiss::IndexIVFFlat cpuIndex(quantizer, + opt.dim, opt.numCentroids, metricType); + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + cpuIndex.nprobe = opt.nprobe; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + cpuIndex.d, + cpuIndex.nlist, + cpuIndex.metric_type, + config); + gpuIndex.copyFrom(&cpuIndex); + gpuIndex.setNumProbes(opt.nprobe); + + bool compFloat16 = useFloat16CoarseQuantizer; + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + // FIXME: the fp16 bounds are + // useless when math (the accumulator) is + // in fp16. Figure out another way to test + compFloat16 ? 0.70f : 0.1f, + compFloat16 ? 0.65f : 0.015f); + } +} + +void addTest(faiss::MetricType metricType, + bool useFloat16CoarseQuantizer) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 quantizerL2(opt.dim); + faiss::IndexFlatIP quantizerIP(opt.dim); + faiss::Index* quantizer = + metricType == faiss::METRIC_L2 ? + (faiss::Index*) &quantizerL2 : (faiss::Index*) &quantizerIP; + + faiss::IndexIVFFlat cpuIndex(quantizer, + opt.dim, + opt.numCentroids, + metricType); + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.nprobe = opt.nprobe; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + cpuIndex.d, + cpuIndex.nlist, + cpuIndex.metric_type, + config); + gpuIndex.copyFrom(&cpuIndex); + gpuIndex.setNumProbes(opt.nprobe); + + cpuIndex.add(opt.numAdd, addVecs.data()); + gpuIndex.add(opt.numAdd, addVecs.data()); + + bool compFloat16 = useFloat16CoarseQuantizer; + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + compFloat16 ? 0.70f : 0.1f, + compFloat16 ? 0.30f : 0.015f); + } +} + +void copyToTest(bool useFloat16CoarseQuantizer) { + Options opt; + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + opt.dim, + opt.numCentroids, + faiss::METRIC_L2, + config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + gpuIndex.add(opt.numAdd, addVecs.data()); + gpuIndex.setNumProbes(opt.nprobe); + + // use garbage values to see if we overwrite then + faiss::IndexFlatL2 cpuQuantizer(1); + faiss::IndexIVFFlat cpuIndex(&cpuQuantizer, 1, 1, faiss::METRIC_L2); + cpuIndex.nprobe = 1; + + gpuIndex.copyTo(&cpuIndex); + + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, opt.numAdd); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.quantizer->d, gpuIndex.quantizer->d); + EXPECT_EQ(cpuIndex.d, opt.dim); + EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists()); + EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes()); + + // Query both objects; results should be equivalent + bool compFloat16 = useFloat16CoarseQuantizer; + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + compFloat16 ? 0.70f : 0.1f, + compFloat16 ? 0.30f : 0.015f); +} + +void copyFromTest(bool useFloat16CoarseQuantizer) { + Options opt; + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 cpuQuantizer(opt.dim); + faiss::IndexIVFFlat cpuIndex(&cpuQuantizer, + opt.dim, + opt.numCentroids, + faiss::METRIC_L2); + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + // use garbage values to see if we overwrite then + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + 1, + 1, + faiss::METRIC_L2, + config); + gpuIndex.setNumProbes(1); + + gpuIndex.copyFrom(&cpuIndex); + + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, opt.numAdd); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.d, opt.dim); + EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists()); + EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes()); + + // Query both objects; results should be equivalent + bool compFloat16 = useFloat16CoarseQuantizer; + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + compFloat16 ? 0.70f : 0.1f, + compFloat16 ? 0.30f : 0.015f); +} + +TEST(TestGpuIndexIVFFlat, Float32_32_Add_L2) { + addTest(faiss::METRIC_L2, false); +} + +TEST(TestGpuIndexIVFFlat, Float32_32_Add_IP) { + addTest(faiss::METRIC_INNER_PRODUCT, false); +} + +TEST(TestGpuIndexIVFFlat, Float16_32_Add_L2) { + addTest(faiss::METRIC_L2, true); +} + +TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) { + addTest(faiss::METRIC_INNER_PRODUCT, true); +} + +// +// General query tests +// + +TEST(TestGpuIndexIVFFlat, Float32_Query_L2) { + queryTest(faiss::METRIC_L2, false); +} + +TEST(TestGpuIndexIVFFlat, Float32_Query_IP) { + queryTest(faiss::METRIC_INNER_PRODUCT, false); +} + +// float16 coarse quantizer + +TEST(TestGpuIndexIVFFlat, Float16_32_Query_L2) { + queryTest(faiss::METRIC_L2, true); +} + +TEST(TestGpuIndexIVFFlat, Float16_32_Query_IP) { + queryTest(faiss::METRIC_INNER_PRODUCT, true); +} + +// +// There are IVF list scanning specializations for 64-d and 128-d that we +// make sure we explicitly test here +// + +TEST(TestGpuIndexIVFFlat, Float32_Query_L2_64) { + queryTest(faiss::METRIC_L2, false, 64); +} + +TEST(TestGpuIndexIVFFlat, Float32_Query_IP_64) { + queryTest(faiss::METRIC_INNER_PRODUCT, false, 64); +} + +TEST(TestGpuIndexIVFFlat, Float32_Query_L2_128) { + queryTest(faiss::METRIC_L2, false, 128); +} + +TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) { + queryTest(faiss::METRIC_INNER_PRODUCT, false, 128); +} + +// +// Copy tests +// + +TEST(TestGpuIndexIVFFlat, Float32_32_CopyTo) { + copyToTest(false); +} + +TEST(TestGpuIndexIVFFlat, Float32_32_CopyFrom) { + copyFromTest(false); +} + +TEST(TestGpuIndexIVFFlat, Float32_negative) { + Options opt; + + auto trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + auto addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + // Put all vecs on negative side + for (auto& f : trainVecs) { + f = std::abs(f) * -1.0f; + } + + for (auto& f : addVecs) { + f *= std::abs(f) * -1.0f; + } + + faiss::IndexFlatIP quantizerIP(opt.dim); + faiss::Index* quantizer = (faiss::Index*) &quantizerIP; + + faiss::IndexIVFFlat cpuIndex(quantizer, + opt.dim, opt.numCentroids, + faiss::METRIC_INNER_PRODUCT); + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + cpuIndex.nprobe = opt.nprobe; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + cpuIndex.d, + cpuIndex.nlist, + cpuIndex.metric_type, + config); + gpuIndex.copyFrom(&cpuIndex); + gpuIndex.setNumProbes(opt.nprobe); + + // Construct a positive test set + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + + // Put all vecs on positive size + for (auto& f : queryVecs) { + f = std::abs(f); + } + + bool compFloat16 = false; + faiss::gpu::compareIndices(queryVecs, + cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + compFloat16 ? kF16MaxRelErr : kF32MaxRelErr, + // FIXME: the fp16 bounds are + // useless when math (the accumulator) is + // in fp16. Figure out another way to test + compFloat16 ? 0.99f : 0.1f, + compFloat16 ? 0.65f : 0.015f); +} + +// +// NaN tests +// + +TEST(TestGpuIndexIVFFlat, QueryNaN) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = faiss::gpu::randBool(); + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + opt.dim, + opt.numCentroids, + faiss::METRIC_L2, + config); + gpuIndex.setNumProbes(opt.nprobe); + + gpuIndex.train(opt.numTrain, trainVecs.data()); + gpuIndex.add(opt.numAdd, addVecs.data()); + + int numQuery = 10; + std::vector nans(numQuery * opt.dim, + std::numeric_limits::quiet_NaN()); + + std::vector distances(numQuery * opt.k, 0); + std::vector indices(numQuery * opt.k, 0); + + gpuIndex.search(numQuery, + nans.data(), + opt.k, + distances.data(), + indices.data()); + + for (int q = 0; q < numQuery; ++q) { + for (int k = 0; k < opt.k; ++k) { + EXPECT_EQ(indices[q * opt.k + k], -1); + EXPECT_EQ(distances[q * opt.k + k], std::numeric_limits::max()); + } + } +} + +TEST(TestGpuIndexIVFFlat, AddNaN) { + Options opt; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = opt.device; + config.indicesOptions = opt.indicesOpt; + config.flatConfig.useFloat16 = faiss::gpu::randBool(); + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + opt.dim, + opt.numCentroids, + faiss::METRIC_L2, + config); + gpuIndex.setNumProbes(opt.nprobe); + + int numNans = 10; + std::vector nans(numNans * opt.dim, + std::numeric_limits::quiet_NaN()); + + // Make one vector valid, which should actually add + for (int i = 0; i < opt.dim; ++i) { + nans[i] = 0.0f; + } + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + // should not crash + EXPECT_EQ(gpuIndex.ntotal, 0); + gpuIndex.add(numNans, nans.data()); + + std::vector queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + std::vector distance(opt.numQuery * opt.k, 0); + std::vector indices(opt.numQuery * opt.k, 0); + + // should not crash + gpuIndex.search(opt.numQuery, queryVecs.data(), opt.k, + distance.data(), indices.data()); +} + +TEST(TestGpuIndexIVFFlat, UnifiedMemory) { + // Construct on a random device to test multi-device, if we have + // multiple devices + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + if (!faiss::gpu::getFullUnifiedMemSupport(device)) { + return; + } + + int dim = 128; + + int numCentroids = 256; + // Unfortunately it would take forever to add 24 GB in IVFPQ data, + // so just perform a small test with data allocated in the unified + // memory address space + size_t numAdd = 10000; + size_t numTrain = numCentroids * 40; + int numQuery = 10; + int k = 10; + int nprobe = 8; + + std::vector trainVecs = faiss::gpu::randVecs(numTrain, dim); + std::vector addVecs = faiss::gpu::randVecs(numAdd, dim); + + faiss::IndexFlatL2 quantizer(dim); + faiss::IndexIVFFlat cpuIndex(&quantizer, dim, numCentroids, faiss::METRIC_L2); + + cpuIndex.train(numTrain, trainVecs.data()); + cpuIndex.add(numAdd, addVecs.data()); + cpuIndex.nprobe = nprobe; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = device; + config.memorySpace = faiss::gpu::MemorySpace::Unified; + + faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, + dim, + numCentroids, + faiss::METRIC_L2, + config); + gpuIndex.copyFrom(&cpuIndex); + gpuIndex.setNumProbes(nprobe); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + numQuery, dim, k, "Unified Memory", + kF32MaxRelErr, + 0.1f, + 0.015f); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFPQ.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFPQ.cpp new file mode 100644 index 0000000000..1bee6b4bbf --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuIndexIVFPQ.cpp @@ -0,0 +1,558 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +void pickEncoding(int& codes, int& dim) { + std::vector codeSizes{ + 3, 4, 8, 12, 16, 20, 24, + 28, 32, 40, 48, 56, 64, 96 + }; + + // Above 32 doesn't work with no precomputed codes + std::vector dimSizes{4, 8, 10, 12, 16, 20, 24, 28, 32}; + + while (true) { + codes = codeSizes[faiss::gpu::randVal(0, codeSizes.size() - 1)]; + dim = codes * dimSizes[faiss::gpu::randVal(0, dimSizes.size() - 1)]; + + // for such a small test, super-low or high dim is more likely to + // generate comparison errors + if (dim < 256 && dim >= 64) { + return; + } + } +} + +struct Options { + Options() { + numAdd = faiss::gpu::randVal(2000, 5000); + numCentroids = std::sqrt((float) numAdd); + numTrain = numCentroids * 40; + + pickEncoding(codes, dim); + + // TODO: Change back to `faiss::gpu::randVal(3, 7)` when we officially + // support non-multiple of 8 subcodes for IVFPQ. + bitsPerCode = 8; + nprobe = std::min(faiss::gpu::randVal(40, 1000), numCentroids); + numQuery = faiss::gpu::randVal(1, 8); + + // Due to the approximate nature of the query and of floating point + // differences between GPU and CPU, to stay within our error bounds, only + // use a small k + k = std::min(faiss::gpu::randVal(5, 20), numAdd / 40); + usePrecomputed = faiss::gpu::randBool(); + indicesOpt = faiss::gpu::randSelect({ + faiss::gpu::INDICES_CPU, + faiss::gpu::INDICES_32_BIT, + faiss::gpu::INDICES_64_BIT}); + if (codes > 48) { + // large codes can only fit using float16 + useFloat16 = true; + } else { + useFloat16 = faiss::gpu::randBool(); + } + + device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + } + + std::string toString() const { + std::stringstream str; + str << "IVFPQ device " << device + << " numVecs " << numAdd + << " dim " << dim + << " numCentroids " << numCentroids + << " codes " << codes + << " bitsPerCode " << bitsPerCode + << " nprobe " << nprobe + << " numQuery " << numQuery + << " k " << k + << " usePrecomputed " << usePrecomputed + << " indicesOpt " << indicesOpt + << " useFloat16 " << useFloat16; + + return str.str(); + } + + float getCompareEpsilon() const { + return 0.03f; + } + + float getPctMaxDiff1() const { + return useFloat16 ? 0.30f : 0.10f; + } + + float getPctMaxDiffN() const { + return useFloat16 ? 0.05f : 0.02f; + } + + int numAdd; + int numCentroids; + int numTrain; + int codes; + int dim; + int bitsPerCode; + int nprobe; + int numQuery; + int k; + bool usePrecomputed; + faiss::gpu::IndicesOptions indicesOpt; + bool useFloat16; + int device; +}; + +TEST(TestGpuIndexIVFPQ, Query_L2) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config); + gpuIndex.setNumProbes(opt.nprobe); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); + } +} + +TEST(TestGpuIndexIVFPQ, Query_IP) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatIP coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.metric_type = faiss::MetricType::METRIC_INNER_PRODUCT; + + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = false; // not supported/required for IP + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config); + gpuIndex.setNumProbes(opt.nprobe); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); + } +} + +TEST(TestGpuIndexIVFPQ, Float16Coarse) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.flatConfig.useFloat16 = true; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config); + gpuIndex.setNumProbes(opt.nprobe); + + gpuIndex.add(opt.numAdd, addVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); +} + +TEST(TestGpuIndexIVFPQ, Add_L2) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config); + gpuIndex.setNumProbes(opt.nprobe); + + gpuIndex.add(opt.numAdd, addVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); + } +} + +TEST(TestGpuIndexIVFPQ, Add_IP) { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatIP coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.metric_type = faiss::MetricType::METRIC_INNER_PRODUCT; + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config); + gpuIndex.setNumProbes(opt.nprobe); + + gpuIndex.add(opt.numAdd, addVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); + } +} + +TEST(TestGpuIndexIVFPQ, CopyTo) { + Options opt; + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, + opt.dim, + opt.numCentroids, + opt.codes, + opt.bitsPerCode, + faiss::METRIC_L2, + config); + gpuIndex.setNumProbes(opt.nprobe); + gpuIndex.train(opt.numTrain, trainVecs.data()); + gpuIndex.add(opt.numAdd, addVecs.data()); + + // Use garbage values to see if we overwrite them + faiss::IndexFlatL2 cpuQuantizer(1); + faiss::IndexIVFPQ cpuIndex(&cpuQuantizer, 1, 1, 1, 1); + + gpuIndex.copyTo(&cpuIndex); + + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, opt.numAdd); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.d, opt.dim); + EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists()); + EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes()); + EXPECT_EQ(cpuIndex.pq.M, gpuIndex.getNumSubQuantizers()); + EXPECT_EQ(gpuIndex.getNumSubQuantizers(), opt.codes); + EXPECT_EQ(cpuIndex.pq.nbits, gpuIndex.getBitsPerCode()); + EXPECT_EQ(gpuIndex.getBitsPerCode(), opt.bitsPerCode); + + // Query both objects; results should be equivalent + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); +} + +TEST(TestGpuIndexIVFPQ, CopyFrom) { + Options opt; + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::IndexFlatL2 coarseQuantizer(opt.dim); + faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids, + opt.codes, opt.bitsPerCode); + cpuIndex.nprobe = opt.nprobe; + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numAdd, addVecs.data()); + + // Use garbage values to see if we overwrite them + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ + gpuIndex(&res, 1, 1, 1, 1, faiss::METRIC_L2, config); + gpuIndex.setNumProbes(1); + + gpuIndex.copyFrom(&cpuIndex); + + // Make sure we are equivalent + EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal); + EXPECT_EQ(gpuIndex.ntotal, opt.numAdd); + + EXPECT_EQ(cpuIndex.d, gpuIndex.d); + EXPECT_EQ(cpuIndex.d, opt.dim); + EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists()); + EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes()); + EXPECT_EQ(cpuIndex.pq.M, gpuIndex.getNumSubQuantizers()); + EXPECT_EQ(gpuIndex.getNumSubQuantizers(), opt.codes); + EXPECT_EQ(cpuIndex.pq.nbits, gpuIndex.getBitsPerCode()); + EXPECT_EQ(gpuIndex.getBitsPerCode(), opt.bitsPerCode); + + // Query both objects; results should be equivalent + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + opt.numQuery, opt.dim, opt.k, opt.toString(), + opt.getCompareEpsilon(), + opt.getPctMaxDiff1(), + opt.getPctMaxDiffN()); +} + +TEST(TestGpuIndexIVFPQ, QueryNaN) { + Options opt; + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, + opt.dim, + opt.numCentroids, + opt.codes, + opt.bitsPerCode, + faiss::METRIC_L2, + config); + + gpuIndex.setNumProbes(opt.nprobe); + + gpuIndex.train(opt.numTrain, trainVecs.data()); + gpuIndex.add(opt.numAdd, addVecs.data()); + + int numQuery = 5; + std::vector nans(numQuery * opt.dim, + std::numeric_limits::quiet_NaN()); + + std::vector distances(numQuery * opt.k, 0); + std::vector indices(numQuery * opt.k, 0); + + gpuIndex.search(numQuery, + nans.data(), + opt.k, + distances.data(), + indices.data()); + + for (int q = 0; q < numQuery; ++q) { + for (int k = 0; k < opt.k; ++k) { + EXPECT_EQ(indices[q * opt.k + k], -1); + EXPECT_EQ(distances[q * opt.k + k], std::numeric_limits::max()); + } + } +} + +TEST(TestGpuIndexIVFPQ, AddNaN) { + Options opt; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = opt.device; + config.usePrecomputedTables = opt.usePrecomputed; + config.indicesOptions = opt.indicesOpt; + config.useFloat16LookupTables = opt.useFloat16; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, + opt.dim, + opt.numCentroids, + opt.codes, + opt.bitsPerCode, + faiss::METRIC_L2, + config); + + gpuIndex.setNumProbes(opt.nprobe); + + int numNans = 10; + std::vector nans(numNans * opt.dim, + std::numeric_limits::quiet_NaN()); + + // Make one vector valid, which should actually add + for (int i = 0; i < opt.dim; ++i) { + nans[i] = 0.0f; + } + + std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + // should not crash + EXPECT_EQ(gpuIndex.ntotal, 0); + gpuIndex.add(numNans, nans.data()); + + std::vector queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + std::vector distance(opt.numQuery * opt.k, 0); + std::vector indices(opt.numQuery * opt.k, 0); + + // should not crash + gpuIndex.search(opt.numQuery, queryVecs.data(), opt.k, + distance.data(), indices.data()); +} + +TEST(TestGpuIndexIVFPQ, UnifiedMemory) { + // Construct on a random device to test multi-device, if we have + // multiple devices + int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + if (!faiss::gpu::getFullUnifiedMemSupport(device)) { + return; + } + + int dim = 128; + + int numCentroids = 256; + // Unfortunately it would take forever to add 24 GB in IVFPQ data, + // so just perform a small test with data allocated in the unified + // memory address space + size_t numAdd = 10000; + size_t numTrain = numCentroids * 40; + int numQuery = 10; + int k = 10; + int nprobe = 8; + int codes = 8; + int bitsPerCode = 8; + + std::vector trainVecs = faiss::gpu::randVecs(numTrain, dim); + std::vector addVecs = faiss::gpu::randVecs(numAdd, dim); + + faiss::IndexFlatL2 quantizer(dim); + faiss::IndexIVFPQ cpuIndex(&quantizer, dim, numCentroids, codes, bitsPerCode); + + cpuIndex.train(numTrain, trainVecs.data()); + cpuIndex.add(numAdd, addVecs.data()); + cpuIndex.nprobe = nprobe; + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = device; + config.memorySpace = faiss::gpu::MemorySpace::Unified; + + faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, + dim, + numCentroids, + codes, + bitsPerCode, + faiss::METRIC_L2, + config); + gpuIndex.copyFrom(&cpuIndex); + gpuIndex.setNumProbes(nprobe); + + faiss::gpu::compareIndices(cpuIndex, gpuIndex, + numQuery, dim, k, "Unified Memory", + 0.015f, + 0.1f, + 0.015f); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuMemoryException.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestGpuMemoryException.cpp new file mode 100644 index 0000000000..e3bca1d86a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuMemoryException.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include + +// Test to see if we can recover after attempting to allocate too much GPU +// memory +TEST(TestGpuMemoryException, AddException) { + size_t numBrokenAdd = std::numeric_limits::max(); + size_t numRealAdd = 10000; + size_t devFree = 0; + size_t devTotal = 0; + + CUDA_VERIFY(cudaMemGetInfo(&devFree, &devTotal)); + + // Figure out the dimensionality needed to get at least greater than devTotal + size_t brokenAddDims = ((devTotal / sizeof(float)) / numBrokenAdd) + 1; + size_t realAddDims = 128; + + faiss::gpu::StandardGpuResources res; + + faiss::gpu::GpuIndexFlatConfig config; + config.device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + + faiss::gpu::GpuIndexFlatL2 + gpuIndexL2Broken(&res, (int) brokenAddDims, config); + faiss::gpu::GpuIndexFlatL2 + gpuIndexL2(&res, (int) realAddDims, config); + faiss::IndexFlatL2 + cpuIndex((int) realAddDims); + + // Should throw on attempting to allocate too much data + { + // allocate memory without initialization + auto vecs = + std::unique_ptr(new float[numBrokenAdd * brokenAddDims]); + EXPECT_THROW(gpuIndexL2Broken.add(numBrokenAdd, vecs.get()), + faiss::FaissException); + } + + // Should be able to add a smaller set of data now + { + auto vecs = faiss::gpu::randVecs(numRealAdd, realAddDims); + EXPECT_NO_THROW(gpuIndexL2.add(numRealAdd, vecs.data())); + cpuIndex.add(numRealAdd, vecs.data()); + } + + // Should throw on attempting to allocate too much data + { + // allocate memory without initialization + auto vecs = + std::unique_ptr(new float[numBrokenAdd * brokenAddDims]); + EXPECT_THROW(gpuIndexL2Broken.add(numBrokenAdd, vecs.get()), + faiss::FaissException); + } + + // Should be able to query results from what we had before + { + size_t numQuery = 10; + auto vecs = faiss::gpu::randVecs(numQuery, realAddDims); + EXPECT_NO_THROW(compareIndices(vecs, cpuIndex, gpuIndexL2, + numQuery, realAddDims, 50, "", + 6e-3f, 0.1f, 0.015f)); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestGpuSelect.cu b/core/src/index/thirdparty/faiss/gpu/test/TestGpuSelect.cu new file mode 100644 index 0000000000..eec621bd5c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestGpuSelect.cu @@ -0,0 +1,193 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void testForSize(int rows, int cols, int k, bool dir, bool warp) { + std::vector v = faiss::gpu::randVecs(rows, cols); + faiss::gpu::HostTensor hostVal({rows, cols}); + + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + hostVal[r][c] = v[r * cols + c]; + } + } + + faiss::gpu::DeviceTensor bitset(nullptr, {0}); + + // row -> (val -> idx) + std::unordered_map>> hostOutValAndInd; + for (int r = 0; r < rows; ++r) { + std::vector> closest; + + for (int c = 0; c < cols; ++c) { + closest.emplace_back(c, (float) hostVal[r][c]); + } + + auto dirFalseFn = + [](std::pair& a, std::pair& b) { + return a.second < b.second; + }; + auto dirTrueFn = + [](std::pair& a, std::pair& b) { + return a.second > b.second; + }; + + std::sort(closest.begin(), closest.end(), dir ? dirTrueFn : dirFalseFn); + hostOutValAndInd.emplace(r, closest); + } + + // Select top-k on GPU + faiss::gpu::DeviceTensor gpuVal(hostVal, 0); + faiss::gpu::DeviceTensor gpuOutVal({rows, k}); + faiss::gpu::DeviceTensor gpuOutInd({rows, k}); + + if (warp) { + faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd, dir, k, 0); + } else { + + faiss::gpu::runBlockSelect(gpuVal, bitset, gpuOutVal, gpuOutInd, dir, k, 0); + } + + // Copy back to CPU + faiss::gpu::HostTensor outVal(gpuOutVal, 0); + faiss::gpu::HostTensor outInd(gpuOutInd, 0); + + for (int r = 0; r < rows; ++r) { + std::unordered_map seenIndices; + + for (int i = 0; i < k; ++i) { + float gpuV = outVal[r][i]; + float cpuV = hostOutValAndInd[r][i].second; + + EXPECT_EQ(gpuV, cpuV) << + "rows " << rows << " cols " << cols << " k " << k << " dir " << dir + << " row " << r << " ind " << i; + + // If there are identical elements in a row that should be + // within the top-k, then it is possible that the index can + // differ, because the order in which the GPU will see the + // equivalent values is different than the CPU (and will remain + // unspecified, since this is affected by the choice of + // k-selection algorithm that we use) + int gpuInd = outInd[r][i]; + int cpuInd = hostOutValAndInd[r][i].first; + + // We should never see duplicate indices, however + auto itSeenIndex = seenIndices.find(gpuInd); + + EXPECT_EQ(itSeenIndex, seenIndices.end()) << + "Row " << r << " user index " << gpuInd << " was seen at both " << + itSeenIndex->second << " and " << i; + + seenIndices[gpuInd] = i; + + if (gpuInd != cpuInd) { + // Gather the values from the original data via index; the + // values should be the same + float gpuGatherV = hostVal[r][gpuInd]; + float cpuGatherV = hostVal[r][cpuInd]; + + EXPECT_EQ(gpuGatherV, cpuGatherV) << + "rows " << rows << " cols " << cols << " k " << k << " dir " << dir + << " row " << r << " ind " << i << " source ind " + << gpuInd << " " << cpuInd; + } + } + } +} + +// General test +TEST(TestGpuSelect, test) { + for (int i = 0; i < 10; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, 30000); + int k = std::min(cols, faiss::gpu::randVal(1, GPU_MAX_SELECTION_K)); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, k, dir, false); + } +} + +// Test for k = 1 +TEST(TestGpuSelect, test1) { + for (int i = 0; i < 5; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, 30000); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, 1, dir, false); + } +} + +// Test for where k = #cols exactly (we are returning all the values, +// just sorted) +TEST(TestGpuSelect, testExact) { + for (int i = 0; i < 5; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, GPU_MAX_SELECTION_K); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, cols, dir, false); + } +} + +// General test +TEST(TestGpuSelect, testWarp) { + for (int i = 0; i < 10; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, 30000); + int k = std::min(cols, faiss::gpu::randVal(1, GPU_MAX_SELECTION_K)); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, k, dir, true); + } +} + +// Test for k = 1 +TEST(TestGpuSelect, test1Warp) { + for (int i = 0; i < 5; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, 30000); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, 1, dir, true); + } +} + +// Test for where k = #cols exactly (we are returning all the values, +// just sorted) +TEST(TestGpuSelect, testExactWarp) { + for (int i = 0; i < 5; ++i) { + int rows = faiss::gpu::randVal(10, 100); + int cols = faiss::gpu::randVal(1, GPU_MAX_SELECTION_K); + bool dir = faiss::gpu::randBool(); + + testForSize(rows, cols, cols, dir, true); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestUtils.cpp b/core/src/index/thirdparty/faiss/gpu/test/TestUtils.cpp new file mode 100644 index 0000000000..423d58b87d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestUtils.cpp @@ -0,0 +1,315 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +inline float relativeError(float a, float b) { + return std::abs(a - b) / (0.5f * (std::abs(a) + std::abs(b))); +} + +// This seed is also used for the faiss float_rand API; in a test it +// is all within a single thread, so it is ok +long s_seed = 1; + +void newTestSeed() { + struct timespec t; + clock_gettime(CLOCK_REALTIME, &t); + + setTestSeed(t.tv_nsec); +} + +void setTestSeed(long seed) { + printf("testing with random seed %ld\n", seed); + + srand48(seed); + s_seed = seed; +} + +int randVal(int a, int b) { + EXPECT_GE(a, 0); + EXPECT_LE(a, b); + + return a + (lrand48() % (b + 1 - a)); +} + +bool randBool() { + return randSelect({true, false}); +} + +std::vector randVecs(size_t num, size_t dim) { + std::vector v(num * dim); + + faiss::float_rand(v.data(), v.size(), s_seed); + // unfortunately we generate separate sets of vectors, and don't + // want the same values + ++s_seed; + + return v; +} + +std::vector randBinaryVecs(size_t num, size_t dim) { + std::vector v(num * (dim / 8)); + + faiss::byte_rand(v.data(), v.size(), s_seed); + // unfortunately we generate separate sets of vectors, and don't + // want the same values + ++s_seed; + + return v; +} + +void compareIndices( + const std::vector& queryVecs, + faiss::Index& refIndex, + faiss::Index& testIndex, + int numQuery, + int /*dim*/, + int k, + const std::string& configMsg, + float maxRelativeError, + float pctMaxDiff1, + float pctMaxDiffN) { + // Compare + std::vector refDistance(numQuery * k, 0); + std::vector refIndices(numQuery * k, -1); + refIndex.search(numQuery, queryVecs.data(), + k, refDistance.data(), refIndices.data()); + + std::vector testDistance(numQuery * k, 0); + std::vector testIndices(numQuery * k, -1); + testIndex.search(numQuery, queryVecs.data(), + k, testDistance.data(), testIndices.data()); + + faiss::gpu::compareLists(refDistance.data(), + refIndices.data(), + testDistance.data(), + testIndices.data(), + numQuery, k, + configMsg, + true, false, true, + maxRelativeError, pctMaxDiff1, pctMaxDiffN); +} + +void compareIndices(faiss::Index& refIndex, + faiss::Index& testIndex, + int numQuery, int dim, int k, + const std::string& configMsg, + float maxRelativeError, + float pctMaxDiff1, + float pctMaxDiffN) { + auto queryVecs = faiss::gpu::randVecs(numQuery, dim); + + compareIndices(queryVecs, + refIndex, + testIndex, + numQuery, dim, k, + configMsg, + maxRelativeError, + pctMaxDiff1, + pctMaxDiffN); +} + +template +inline T lookup(const T* p, int i, int j, int /*dim1*/, int dim2) { + return p[i * dim2 + j]; +} + +void compareLists(const float* refDist, + const faiss::Index::idx_t* refInd, + const float* testDist, + const faiss::Index::idx_t* testInd, + int dim1, int dim2, + const std::string& configMsg, + bool printBasicStats, bool printDiffs, bool assertOnErr, + float maxRelativeError, + float pctMaxDiff1, + float pctMaxDiffN) { + + float maxAbsErr = 0.0f; + for (int i = 0; i < dim1 * dim2; ++i) { + maxAbsErr = std::max(maxAbsErr, std::abs(refDist[i] - testDist[i])); + } + int numResults = dim1 * dim2; + + // query -> {index -> result position} + std::vector> refIndexMap; + + for (int query = 0; query < dim1; ++query) { + std::unordered_map indices; + + for (int result = 0; result < dim2; ++result) { + indices[lookup(refInd, query, result, dim1, dim2)] = result; + } + + refIndexMap.emplace_back(std::move(indices)); + } + + // See how far off the indices are + // Keep track of the difference for each entry + std::vector> indexDiffs; + + int diff1 = 0; // index differs by 1 + int diffN = 0; // index differs by >1 + int diffInf = 0; // index not found in the other + int nonUniqueIndices = 0; + + double avgDiff = 0.0; + int maxDiff = 0; + float maxRelErr = 0.0f; + + for (int query = 0; query < dim1; ++query) { + std::vector diffs; + std::set uniqueIndices; + + auto& indices = refIndexMap[query]; + + for (int result = 0; result < dim2; ++result) { + auto t = lookup(testInd, query, result, dim1, dim2); + + // All indices reported within a query should be unique; this is + // a serious error if is otherwise the case. + // If -1 is reported (no result due to IVF partitioning or not enough + // entries in the index), then duplicates are allowed, but both the + // reference and test must have -1 in the same position. + if (t == -1) { + EXPECT_EQ(lookup(refInd, query, result, dim1, dim2), t); + } else { + bool uniqueIndex = uniqueIndices.count(t) == 0; + if (assertOnErr) { + EXPECT_TRUE(uniqueIndex) << configMsg + << " " << query + << " " << result + << " " << t; + } + + if (!uniqueIndex) { + ++nonUniqueIndices; + } else { + uniqueIndices.insert(t); + } + + auto it = indices.find(t); + if (it != indices.end()) { + int diff = std::abs(result - it->second); + diffs.push_back(diff); + + if (diff == 1) { + ++diff1; + maxDiff = std::max(diff, maxDiff); + } else if (diff > 1) { + ++diffN; + maxDiff = std::max(diff, maxDiff); + } + + avgDiff += (double) diff; + } else { + ++diffInf; + diffs.push_back(-1); + // don't count this for maxDiff + } + } + + auto refD = lookup(refDist, query, result, dim1, dim2); + auto testD = lookup(testDist, query, result, dim1, dim2); + + float relErr = relativeError(refD, testD); + + if (assertOnErr) { + EXPECT_LE(relErr, maxRelativeError) << configMsg + << " (" << query << ", " << result + << ") refD: " << refD + << " testD: " << testD; + } + + maxRelErr = std::max(maxRelErr, relErr); + } + + indexDiffs.emplace_back(std::move(diffs)); + } + + if (assertOnErr) { + EXPECT_LE((float) (diff1 + diffN + diffInf), + (float) numResults * pctMaxDiff1) << configMsg; + + // Don't count diffInf because that could be diff1 as far as we + // know + EXPECT_LE((float) diffN, (float) numResults * pctMaxDiffN) << configMsg; + } + + avgDiff /= (double) numResults; + + if (printBasicStats) { + if (!configMsg.empty()) { + printf("Config\n" + "----------------------------\n" + "%s\n", + configMsg.c_str()); + } + + printf("Result error and differences\n" + "----------------------------\n" + "max abs diff %.7f rel diff %.7f\n" + "idx diff avg: %.5g max: %d\n" + "idx diff of 1: %d (%.3f%% of queries)\n" + "idx diff of >1: %d (%.3f%% of queries)\n" + "idx diff not found: %d (%.3f%% of queries)" + " [typically a last element inversion]\n" + "non-unique indices: %d (a serious error if >0)\n", + maxAbsErr, maxRelErr, + avgDiff, maxDiff, + diff1, 100.0f * (float) diff1 / (float) numResults, + diffN, 100.0f * (float) diffN / (float) numResults, + diffInf, 100.0f * (float) diffInf / (float) numResults, + nonUniqueIndices); + } + + if (printDiffs) { + printf("differences:\n"); + printf("==================\n"); + for (int query = 0; query < dim1; ++query) { + for (int result = 0; result < dim2; ++result) { + long refI = lookup(refInd, query, result, dim1, dim2); + long testI = lookup(testInd, query, result, dim1, dim2); + + if (refI != testI) { + float refD = lookup(refDist, query, result, dim1, dim2); + float testD = lookup(testDist, query, result, dim1, dim2); + + float maxDist = std::max(refD, testD); + float delta = std::abs(refD - testD); + + float relErr = delta / maxDist; + + if (refD == testD) { + printf("(%d, %d [%d]) (ref %ld tst %ld dist ==)\n", + query, result, + indexDiffs[query][result], + refI, testI); + } else { + printf("(%d, %d [%d]) (ref %ld tst %ld abs %.8f " + "rel %.8f ref %a tst %a)\n", + query, result, + indexDiffs[query][result], + refI, testI, delta, relErr, refD, testD); + } + } + } + } + } +} + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/test/TestUtils.h b/core/src/index/thirdparty/faiss/gpu/test/TestUtils.h new file mode 100644 index 0000000000..c59a4ab0ae --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/TestUtils.h @@ -0,0 +1,93 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Generates and displays a new seed for the test +void newTestSeed(); + +/// Uses an explicit seed for the test +void setTestSeed(long seed); + +/// Returns the relative error in difference between a and b +/// (|a - b| / (0.5 * (|a| + |b|)) +float relativeError(float a, float b); + +/// Generates a random integer in the range [a, b] +int randVal(int a, int b); + +/// Generates a random bool +bool randBool(); + +/// Select a random value from the given list of values provided as an +/// initializer_list +template +T randSelect(std::initializer_list vals) { + FAISS_ASSERT(vals.size() > 0); + int sel = randVal(0, vals.size()); + + int i = 0; + for (auto v : vals) { + if (i++ == sel) { + return v; + } + } + + // should not get here + return *vals.begin(); +} + +/// Generates a collection of random vectors in the range [0, 1] +std::vector randVecs(size_t num, size_t dim); + +/// Generates a collection of random bit vectors +std::vector randBinaryVecs(size_t num, size_t dim); + +/// Compare two indices via query for similarity, with a user-specified set of +/// query vectors +void compareIndices(const std::vector& queryVecs, + faiss::Index& refIndex, + faiss::Index& testIndex, + int numQuery, int dim, int k, + const std::string& configMsg, + float maxRelativeError = 6e-5f, + float pctMaxDiff1 = 0.1f, + float pctMaxDiffN = 0.005f); + +/// Compare two indices via query for similarity, generating random query +/// vectors +void compareIndices(faiss::Index& refIndex, + faiss::Index& testIndex, + int numQuery, int dim, int k, + const std::string& configMsg, + float maxRelativeError = 6e-5f, + float pctMaxDiff1 = 0.1f, + float pctMaxDiffN = 0.005f); + +/// Display specific differences in the two (distance, index) lists +void compareLists(const float* refDist, + const faiss::Index::idx_t* refInd, + const float* testDist, + const faiss::Index::idx_t* testInd, + int dim1, int dim2, + const std::string& configMsg, + bool printBasicStats, bool printDiffs, bool assertOnErr, + float maxRelativeError = 6e-5f, + float pctMaxDiff1 = 0.1f, + float pctMaxDiffN = 0.005f); + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp b/core/src/index/thirdparty/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp new file mode 100644 index 0000000000..852a43cbe9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp @@ -0,0 +1,159 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Copyright 2004-present Facebook. All Rights Reserved + + +#include +#include +#include + +#include + + +#include +#include + +#include +#include + +double elapsed () +{ + struct timeval tv; + gettimeofday (&tv, NULL); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + + +int main () +{ + + double t0 = elapsed(); + + // dimension of the vectors to index + int d = 128; + + // size of the database we plan to index + size_t nb = 200 * 1000; + + // make a set of nt training vectors in the unit cube + // (could be the database) + size_t nt = 100 * 1000; + + int dev_no = 0; + /* + printf ("[%.3f s] Begin d=%d nb=%ld nt=%nt dev_no=%d\n", + elapsed() - t0, d, nb, nt, dev_no); + */ + // a reasonable number of centroids to index nb vectors + int ncentroids = int (4 * sqrt (nb)); + + faiss::gpu::StandardGpuResources resources; + + + // the coarse quantizer should not be dealloced before the index + // 4 = nb of bytes per code (d must be a multiple of this) + // 8 = nb of bits per sub-code (almost always 8) + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = dev_no; + + faiss::gpu::GpuIndexIVFPQ index ( + &resources, d, ncentroids, 4, 8, faiss::METRIC_L2, config); + + { // training + printf ("[%.3f s] Generating %ld vectors in %dD for training\n", + elapsed() - t0, nt, d); + + std::vector trainvecs (nt * d); + for (size_t i = 0; i < nt * d; i++) { + trainvecs[i] = drand48(); + } + + printf ("[%.3f s] Training the index\n", + elapsed() - t0); + index.verbose = true; + + index.train (nt, trainvecs.data()); + } + + { // I/O demo + const char *outfilename = "/tmp/index_trained.faissindex"; + printf ("[%.3f s] storing the pre-trained index to %s\n", + elapsed() - t0, outfilename); + + faiss::Index * cpu_index = faiss::gpu::index_gpu_to_cpu (&index); + + write_index (cpu_index, outfilename); + + delete cpu_index; + } + + size_t nq; + std::vector queries; + + { // populating the database + printf ("[%.3f s] Building a dataset of %ld vectors to index\n", + elapsed() - t0, nb); + + std::vector database (nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + + printf ("[%.3f s] Adding the vectors to the index\n", + elapsed() - t0); + + index.add (nb, database.data()); + + printf ("[%.3f s] done\n", elapsed() - t0); + + // remember a few elements from the database as queries + int i0 = 1234; + int i1 = 1243; + + nq = i1 - i0; + queries.resize (nq * d); + for (int i = i0; i < i1; i++) { + for (int j = 0; j < d; j++) { + queries [(i - i0) * d + j] = database [i * d + j]; + } + } + + } + + { // searching the database + int k = 5; + printf ("[%.3f s] Searching the %d nearest neighbors " + "of %ld vectors in the index\n", + elapsed() - t0, k, nq); + + std::vector nns (k * nq); + std::vector dis (k * nq); + + index.search (nq, queries.data(), k, dis.data(), nns.data()); + + printf ("[%.3f s] Query results (vector ids, then distances):\n", + elapsed() - t0); + + for (int i = 0; i < nq; i++) { + printf ("query %2d: ", i); + for (int j = 0; j < k; j++) { + printf ("%7ld ", nns[j + i * k]); + } + printf ("\n dis: "); + for (int j = 0; j < k; j++) { + printf ("%7g ", dis[j + i * k]); + } + printf ("\n"); + } + + printf ("note that the nearest neighbor is not at " + "distance 0 due to quantization errors\n"); + } + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index.py b/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index.py new file mode 100644 index 0000000000..8b17b4801f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import time +import unittest +import numpy as np +import faiss + + +class EvalIVFPQAccuracy(unittest.TestCase): + + def get_dataset(self, small_one=False): + if not small_one: + d = 128 + nb = 100000 + nt = 15000 + nq = 2000 + else: + d = 32 + nb = 1000 + nt = 1000 + nq = 200 + np.random.seed(123) + + # generate points in a low-dim subspace to make the resutls + # look better :-) + d1 = 16 + q, r = np.linalg.qr(np.random.randn(d, d)) + qc = q[:d1, :] + def make_mat(n): + return np.dot( + np.random.random(size=(nb, d1)), qc).astype('float32') + + return (make_mat(nt), make_mat(nb), make_mat(nq)) + + + def test_mm(self): + # trouble with MKL+fbmake that appears only at runtime. Check it here + x = np.random.random(size=(100, 20)).astype('float32') + mat = faiss.PCAMatrix(20, 10) + mat.train(x) + mat.apply_py(x) + + def do_cpu_to_gpu(self, index_key): + ts = [] + ts.append(time.time()) + (xt, xb, xq) = self.get_dataset(small_one=True) + nb, d = xb.shape + + index = faiss.index_factory(d, index_key) + if index.__class__ == faiss.IndexIVFPQ: + # speed up test + index.pq.cp.niter = 2 + index.do_polysemous_training = False + ts.append(time.time()) + + index.train(xt) + ts.append(time.time()) + + # adding some ids because there was a bug in this case + index.add_with_ids(xb, np.arange(nb) * 3 + 12345) + ts.append(time.time()) + + index.nprobe = 4 + D, Iref = index.search(xq, 10) + ts.append(time.time()) + + res = faiss.StandardGpuResources() + gpu_index = faiss.index_cpu_to_gpu(res, 0, index) + ts.append(time.time()) + + gpu_index.setNumProbes(4) + + D, Inew = gpu_index.search(xq, 10) + ts.append(time.time()) + print('times:', [t - ts[0] for t in ts]) + + self.assertGreaterEqual((Iref == Inew).sum(), Iref.size) + + if faiss.get_num_gpus() == 1: + return + + for shard in False, True: + + # test on just 2 GPUs + res = [faiss.StandardGpuResources() for i in range(2)] + co = faiss.GpuMultipleClonerOptions() + co.shard = shard + + gpu_index = faiss.index_cpu_to_gpu_multiple_py(res, index, co) + + faiss.GpuParameterSpace().set_index_parameter( + gpu_index, 'nprobe', 4) + + D, Inew = gpu_index.search(xq, 10) + + # 0.99: allow some tolerance in results otherwise test + # fails occasionally (not reproducible) + self.assertGreaterEqual((Iref == Inew).sum(), Iref.size * 0.99) + + def test_cpu_to_gpu_IVFPQ(self): + self.do_cpu_to_gpu('IVF128,PQ4') + + def test_cpu_to_gpu_IVFFlat(self): + self.do_cpu_to_gpu('IVF128,Flat') + + def test_set_gpu_param(self): + index = faiss.index_factory(12, "PCAR8,IVF10,PQ4") + res = faiss.StandardGpuResources() + gpu_index = faiss.index_cpu_to_gpu(res, 0, index) + faiss.GpuParameterSpace().set_index_parameter(gpu_index, "nprobe", 3) + + +class ReferencedObject(unittest.TestCase): + + d = 16 + xb = np.random.rand(256, d).astype('float32') + nlist = 128 + + d_bin = 256 + xb_bin = np.random.randint(256, size=(10000, d_bin // 8)).astype('uint8') + xq_bin = np.random.randint(256, size=(1000, d_bin // 8)).astype('uint8') + + def test_proxy(self): + index = faiss.IndexReplicas() + for _i in range(3): + sub_index = faiss.IndexFlatL2(self.d) + sub_index.add(self.xb) + index.addIndex(sub_index) + assert index.d == self.d + index.search(self.xb, 10) + + def test_resources(self): + # this used to crash! + index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, + faiss.IndexFlatL2(self.d)) + index.add(self.xb) + + def test_flat(self): + index = faiss.GpuIndexFlat(faiss.StandardGpuResources(), + self.d, faiss.METRIC_L2) + index.add(self.xb) + + def test_ivfflat(self): + index = faiss.GpuIndexIVFFlat( + faiss.StandardGpuResources(), + self.d, self.nlist, faiss.METRIC_L2) + index.train(self.xb) + + def test_ivfpq(self): + index_cpu = faiss.IndexIVFPQ( + faiss.IndexFlatL2(self.d), + self.d, self.nlist, 2, 8) + # speed up test + index_cpu.pq.cp.niter = 2 + index_cpu.do_polysemous_training = False + index_cpu.train(self.xb) + + index = faiss.GpuIndexIVFPQ( + faiss.StandardGpuResources(), index_cpu) + index.add(self.xb) + + def test_binary_flat(self): + k = 10 + + index_ref = faiss.IndexBinaryFlat(self.d_bin) + index_ref.add(self.xb_bin) + D_ref, I_ref = index_ref.search(self.xq_bin, k) + + index = faiss.GpuIndexBinaryFlat(faiss.StandardGpuResources(), + self.d_bin) + index.add(self.xb_bin) + D, I = index.search(self.xq_bin, k) + + for d_ref, i_ref, d_new, i_new in zip(D_ref, I_ref, D, I): + # exclude max distance + assert d_ref.max() == d_new.max() + dmax = d_ref.max() + + # sort by (distance, id) pairs to be reproducible + ref = [(d, i) for d, i in zip(d_ref, i_ref) if d < dmax] + ref.sort() + + new = [(d, i) for d, i in zip(d_new, i_new) if d < dmax] + new.sort() + + assert ref == new + + def test_stress(self): + # a mixture of the above, from issue #631 + target = np.random.rand(50, 16).astype('float32') + + index = faiss.IndexReplicas() + size, dim = target.shape + num_gpu = 4 + for _i in range(num_gpu): + config = faiss.GpuIndexFlatConfig() + config.device = 0 # simulate on a single GPU + sub_index = faiss.GpuIndexFlatIP(faiss.StandardGpuResources(), dim, config) + index.addIndex(sub_index) + + index = faiss.IndexIDMap(index) + ids = np.arange(size) + index.add_with_ids(target, ids) + + + +class TestShardedFlat(unittest.TestCase): + + def test_sharded(self): + d = 32 + nb = 1000 + nq = 200 + k = 10 + rs = np.random.RandomState(123) + xb = rs.rand(nb, d).astype('float32') + xq = rs.rand(nq, d).astype('float32') + + index_cpu = faiss.IndexFlatL2(d) + + assert faiss.get_num_gpus() > 1 + + co = faiss.GpuMultipleClonerOptions() + co.shard = True + index = faiss.index_cpu_to_all_gpus(index_cpu, co, ngpu=2) + + index.add(xb) + D, I = index.search(xq, k) + + index_cpu.add(xb) + D_ref, I_ref = index_cpu.search(xq, k) + + assert np.all(I == I_ref) + + del index + index2 = faiss.index_cpu_to_all_gpus(index_cpu, co, ngpu=2) + D2, I2 = index2.search(xq, k) + + assert np.all(I2 == I_ref) + + try: + index2.add(xb) + except RuntimeError: + pass + else: + assert False, "this call should fail!" + + +class TestGPUKmeans(unittest.TestCase): + + def test_kmeans(self): + d = 32 + nb = 1000 + k = 10 + rs = np.random.RandomState(123) + xb = rs.rand(nb, d).astype('float32') + + km1 = faiss.Kmeans(d, k) + obj1 = km1.train(xb) + + km2 = faiss.Kmeans(d, k, gpu=True) + obj2 = km2.train(xb) + + print(obj1, obj2) + assert np.allclose(obj1, obj2) + + +class TestAlternativeDistances(unittest.TestCase): + + def do_test(self, metric, metric_arg=0): + res = faiss.StandardGpuResources() + d = 32 + nb = 1000 + nq = 100 + + rs = np.random.RandomState(123) + xb = rs.rand(nb, d).astype('float32') + xq = rs.rand(nq, d).astype('float32') + + index_ref = faiss.IndexFlat(d, metric) + index_ref.metric_arg = metric_arg + index_ref.add(xb) + Dref, Iref = index_ref.search(xq, 10) + + # build from other index + index = faiss.GpuIndexFlat(res, index_ref) + Dnew, Inew = index.search(xq, 10) + np.testing.assert_array_equal(Inew, Iref) + np.testing.assert_allclose(Dnew, Dref, rtol=1e-6) + + # build from scratch + index = faiss.GpuIndexFlat(res, d, metric) + index.metric_arg = metric_arg + index.add(xb) + + Dnew, Inew = index.search(xq, 10) + np.testing.assert_array_equal(Inew, Iref) + + def test_L1(self): + self.do_test(faiss.METRIC_L1) + + def test_Linf(self): + self.do_test(faiss.METRIC_Linf) + + def test_Lp(self): + self.do_test(faiss.METRIC_Lp, 0.7) + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index_ivfsq.py b/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index_ivfsq.py new file mode 100644 index 0000000000..6c312af3e6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/test_gpu_index_ivfsq.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python3 + +from __future__ import print_function +import unittest +import numpy as np +import faiss + +def make_t(num, d, clamp=False): + rs = np.random.RandomState(123) + x = rs.rand(num, d).astype('float32') + if clamp: + x = (x * 255).astype('uint8').astype('float32') + return x + +def make_indices_copy_from_cpu(nlist, d, qtype, by_residual, metric, clamp): + to_train = make_t(10000, d, clamp) + + quantizer_cp = faiss.IndexFlat(d, metric) + idx_cpu = faiss.IndexIVFScalarQuantizer(quantizer_cp, d, nlist, + qtype, metric, by_residual) + + idx_cpu.train(to_train) + idx_cpu.add(to_train) + + res = faiss.StandardGpuResources() + res.noTempMemory() + idx_gpu = faiss.GpuIndexIVFScalarQuantizer(res, idx_cpu) + + return idx_cpu, idx_gpu + + +def make_indices_copy_from_gpu(nlist, d, qtype, by_residual, metric, clamp): + to_train = make_t(10000, d, clamp) + + res = faiss.StandardGpuResources() + res.noTempMemory() + idx_gpu = faiss.GpuIndexIVFScalarQuantizer(res, d, nlist, + qtype, metric, by_residual) + idx_gpu.train(to_train) + idx_gpu.add(to_train) + + quantizer_cp = faiss.IndexFlat(d, metric) + idx_cpu = faiss.IndexIVFScalarQuantizer(quantizer_cp, d, nlist, + qtype, metric, by_residual) + idx_gpu.copyTo(idx_cpu) + + return idx_cpu, idx_gpu + + +def make_indices_train(nlist, d, qtype, by_residual, metric, clamp): + to_train = make_t(10000, d, clamp) + + quantizer_cp = faiss.IndexFlat(d, metric) + idx_cpu = faiss.IndexIVFScalarQuantizer(quantizer_cp, d, nlist, + qtype, metric, by_residual) + assert(by_residual == idx_cpu.by_residual) + + idx_cpu.train(to_train) + idx_cpu.add(to_train) + + res = faiss.StandardGpuResources() + res.noTempMemory() + idx_gpu = faiss.GpuIndexIVFScalarQuantizer(res, d, nlist, + qtype, metric, by_residual) + assert(by_residual == idx_gpu.by_residual) + + idx_gpu.train(to_train) + idx_gpu.add(to_train) + + return idx_cpu, idx_gpu + +# +# Testing functions +# + +def summarize_results(dist, idx): + valid = [] + invalid = [] + for query in range(dist.shape[0]): + valid_sub = {} + invalid_sub = [] + + for order, (d, i) in enumerate(zip(dist[query], idx[query])): + if i == -1: + invalid_sub.append(order) + else: + valid_sub[i] = [order, d] + + valid.append(valid_sub) + invalid.append(invalid_sub) + + return valid, invalid + +def compare_results(d1, i1, d2, i2): + # Count number of index differences + idx_diffs = {} + idx_diffs_inf = 0 + idx_invalid = 0 + + valid1, invalid1 = summarize_results(d1, i1) + valid2, invalid2 = summarize_results(d2, i2) + + # Invalid results should be the same for both + # (except if we happen to hit different centroids) + for inv1, inv2 in zip(invalid1, invalid2): + if (len(inv1) != len(inv2)): + print('mismatch ', len(inv1), len(inv2), inv2[0]) + + assert(len(inv1) == len(inv2)) + idx_invalid += len(inv2) + for x1, x2 in zip(inv1, inv2): + assert(x1 == x2) + + for _, (query1, query2) in enumerate(zip(valid1, valid2)): + for idx1, order_d1 in query1.items(): + order_d2 = query2.get(idx1, None) + if order_d2: + idx_diff = order_d1[0] - order_d2[0] + + if idx_diff not in idx_diffs: + idx_diffs[idx_diff] = 1 + else: + idx_diffs[idx_diff] += 1 + else: + idx_diffs_inf += 1 + + return idx_diffs, idx_diffs_inf, idx_invalid + +def check_diffs(total_num, in_window_thresh, diffs, diff_inf, invalid): + # We require a certain fraction of results to be within +/- diff_window + # index differences + diff_window = 4 + in_window = 0 + + for diff in sorted(diffs): + if abs(diff) <= diff_window: + in_window += diffs[diff] / total_num + + if (in_window < in_window_thresh): + print('error {} {}'.format(in_window, in_window_thresh)) + assert(in_window >= in_window_thresh) + +def do_test_with_index(ci, gi, nprobe, k, clamp, in_window_thresh): + num_query = 11 + to_query = make_t(num_query, ci.d, clamp) + + ci.nprobe = ci.nprobe + gi.nprobe = gi.nprobe + + total_num = num_query * k + check_diffs(total_num, in_window_thresh, + *compare_results(*ci.search(to_query, k), + *gi.search(to_query, k))) + +def do_test(nlist, d, qtype, by_residual, metric, nprobe, k): + clamp = (qtype == faiss.ScalarQuantizer.QT_8bit_direct) + ci, gi = make_indices_copy_from_cpu(nlist, d, qtype, + by_residual, metric, clamp) + # A direct copy should be much more closely in agreement + # (except for fp accumulation order differences) + do_test_with_index(ci, gi, nprobe, k, clamp, 0.99) + + ci, gi = make_indices_copy_from_gpu(nlist, d, qtype, + by_residual, metric, clamp) + # A direct copy should be much more closely in agreement + # (except for fp accumulation order differences) + do_test_with_index(ci, gi, nprobe, k, clamp, 0.99) + + ci, gi = make_indices_train(nlist, d, qtype, + by_residual, metric, clamp) + # Separate training can produce a slightly different coarse quantizer + # and residuals + do_test_with_index(ci, gi, nprobe, k, clamp, 0.8) + +def do_multi_test(qtype): + nlist = 100 + nprobe = 10 + k = 50 + + for d in [11, 64]: + if (qtype != faiss.ScalarQuantizer.QT_8bit_direct): + # residual doesn't make sense here + do_test(nlist, d, qtype, True, + faiss.METRIC_L2, nprobe, k) + do_test(nlist, d, qtype, True, + faiss.METRIC_INNER_PRODUCT, nprobe, k) + do_test(nlist, d, qtype, False, faiss.METRIC_L2, nprobe, k) + do_test(nlist, d, qtype, False, faiss.METRIC_INNER_PRODUCT, nprobe, k) + +# +# Test +# + +class TestSQ(unittest.TestCase): + def test_fp16(self): + do_multi_test(faiss.ScalarQuantizer.QT_fp16) + + def test_8bit(self): + do_multi_test(faiss.ScalarQuantizer.QT_8bit) + + def test_8bit_uniform(self): + do_multi_test(faiss.ScalarQuantizer.QT_8bit_uniform) + + def test_6bit(self): + try: + do_multi_test(faiss.ScalarQuantizer.QT_6bit) + # should not reach here; QT_6bit is unimplemented + except: + print('QT_6bit exception thrown (is expected)') + else: + assert(False) + + def test_4bit(self): + do_multi_test(faiss.ScalarQuantizer.QT_4bit) + + def test_4bit_uniform(self): + do_multi_test(faiss.ScalarQuantizer.QT_4bit_uniform) + + def test_8bit_direct(self): + do_multi_test(faiss.ScalarQuantizer.QT_8bit_direct) + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/gpu/test/test_pytorch_faiss.py b/core/src/index/thirdparty/faiss/gpu/test/test_pytorch_faiss.py new file mode 100644 index 0000000000..f59f711b82 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/test/test_pytorch_faiss.py @@ -0,0 +1,215 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import unittest +import faiss +import torch + +def swig_ptr_from_FloatTensor(x): + assert x.is_contiguous() + assert x.dtype == torch.float32 + return faiss.cast_integer_to_float_ptr( + x.storage().data_ptr() + x.storage_offset() * 4) + +def swig_ptr_from_LongTensor(x): + assert x.is_contiguous() + assert x.dtype == torch.int64, 'dtype=%s' % x.dtype + return faiss.cast_integer_to_long_ptr( + x.storage().data_ptr() + x.storage_offset() * 8) + + + +def search_index_pytorch(index, x, k, D=None, I=None): + """call the search function of an index with pytorch tensor I/O (CPU + and GPU supported)""" + assert x.is_contiguous() + n, d = x.size() + assert d == index.d + + if D is None: + D = torch.empty((n, k), dtype=torch.float32, device=x.device) + else: + assert D.size() == (n, k) + + if I is None: + I = torch.empty((n, k), dtype=torch.int64, device=x.device) + else: + assert I.size() == (n, k) + torch.cuda.synchronize() + xptr = swig_ptr_from_FloatTensor(x) + Iptr = swig_ptr_from_LongTensor(I) + Dptr = swig_ptr_from_FloatTensor(D) + index.search_c(n, xptr, + k, Dptr, Iptr) + torch.cuda.synchronize() + return D, I + + +def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, + metric=faiss.METRIC_L2): + assert xb.device == xq.device + + nq, d = xq.size() + if xq.is_contiguous(): + xq_row_major = True + elif xq.t().is_contiguous(): + xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) + xq_row_major = False + else: + raise TypeError('matrix should be row or column-major') + + xq_ptr = swig_ptr_from_FloatTensor(xq) + + nb, d2 = xb.size() + assert d2 == d + if xb.is_contiguous(): + xb_row_major = True + elif xb.t().is_contiguous(): + xb = xb.t() + xb_row_major = False + else: + raise TypeError('matrix should be row or column-major') + xb_ptr = swig_ptr_from_FloatTensor(xb) + + if D is None: + D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) + else: + assert D.shape == (nq, k) + assert D.device == xb.device + + if I is None: + I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) + else: + assert I.shape == (nq, k) + assert I.device == xb.device + + D_ptr = swig_ptr_from_FloatTensor(D) + I_ptr = swig_ptr_from_LongTensor(I) + + faiss.bruteForceKnn(res, metric, + xb_ptr, xb_row_major, nb, + xq_ptr, xq_row_major, nq, + d, k, D_ptr, I_ptr) + + return D, I + +def to_column_major(x): + if hasattr(torch, 'contiguous_format'): + return x.t().clone(memory_format=torch.contiguous_format).t() + else: + # was default setting before memory_format was introduced + return x.t().clone().t() + +class PytorchFaissInterop(unittest.TestCase): + + def test_interop(self): + + d = 16 + nq = 5 + nb = 20 + + xq = faiss.randn(nq * d, 1234).reshape(nq, d) + xb = faiss.randn(nb * d, 1235).reshape(nb, d) + + res = faiss.StandardGpuResources() + index = faiss.GpuIndexFlatIP(res, d) + index.add(xb) + + # reference CPU result + Dref, Iref = index.search(xq, 5) + + # query is pytorch tensor (CPU) + xq_torch = torch.FloatTensor(xq) + + D2, I2 = search_index_pytorch(index, xq_torch, 5) + + assert np.all(Iref == I2.numpy()) + + # query is pytorch tensor (GPU) + xq_torch = xq_torch.cuda() + # no need for a sync here + + D3, I3 = search_index_pytorch(index, xq_torch, 5) + + # D3 and I3 are on torch tensors on GPU as well. + # this does a sync, which is useful because faiss and + # pytorch use different Cuda streams. + res.syncDefaultStreamCurrentDevice() + + assert np.all(Iref == I3.cpu().numpy()) + + def test_raw_array_search(self): + d = 32 + nb = 1024 + nq = 128 + k = 10 + + # make GT on Faiss CPU + + xq = faiss.randn(nq * d, 1234).reshape(nq, d) + xb = faiss.randn(nb * d, 1235).reshape(nb, d) + + index = faiss.IndexFlatL2(d) + index.add(xb) + gt_D, gt_I = index.search(xq, k) + + # resource object, can be re-used over calls + res = faiss.StandardGpuResources() + # put on same stream as pytorch to avoid synchronizing streams + res.setDefaultNullStreamAllDevices() + + for xq_row_major in True, False: + for xb_row_major in True, False: + + # move to pytorch & GPU + xq_t = torch.from_numpy(xq).cuda() + xb_t = torch.from_numpy(xb).cuda() + + if not xq_row_major: + xq_t = to_column_major(xq_t) + assert not xq_t.is_contiguous() + + if not xb_row_major: + xb_t = to_column_major(xb_t) + assert not xb_t.is_contiguous() + + D, I = search_raw_array_pytorch(res, xb_t, xq_t, k) + + # back to CPU for verification + D = D.cpu().numpy() + I = I.cpu().numpy() + + assert np.all(I == gt_I) + assert np.all(np.abs(D - gt_D).max() < 1e-4) + + + + # test on subset + try: + D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k) + except TypeError: + if not xq_row_major: + # then it is expected + continue + # otherwise it is an error + raise + + # back to CPU for verification + D = D.cpu().numpy() + I = I.cpu().numpy() + + assert np.all(I == gt_I[60:80]) + assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4) + + + + + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectFloat.cu b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectFloat.cu new file mode 100644 index 0000000000..7f1febed3e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectFloat.cu @@ -0,0 +1,146 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +// warp Q to thread Q: +// 1, 1 +// 32, 2 +// 64, 3 +// 128, 3 +// 256, 4 +// 512, 8 +// 1024, 8 +// 2048, 8 + +BLOCK_SELECT_DECL(float, true, 1); +BLOCK_SELECT_DECL(float, true, 32); +BLOCK_SELECT_DECL(float, true, 64); +BLOCK_SELECT_DECL(float, true, 128); +BLOCK_SELECT_DECL(float, true, 256); +BLOCK_SELECT_DECL(float, true, 512); +BLOCK_SELECT_DECL(float, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_DECL(float, true, 2048); +#endif + +BLOCK_SELECT_DECL(float, false, 1); +BLOCK_SELECT_DECL(float, false, 32); +BLOCK_SELECT_DECL(float, false, 64); +BLOCK_SELECT_DECL(float, false, 128); +BLOCK_SELECT_DECL(float, false, 256); +BLOCK_SELECT_DECL(float, false, 512); +BLOCK_SELECT_DECL(float, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_DECL(float, false, 2048); +#endif + +void runBlockSelect(Tensor& in, + Tensor& bitset, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + if (dir) { + if (k == 1) { + BLOCK_SELECT_CALL(float, true, 1); + } else if (k <= 32) { + BLOCK_SELECT_CALL(float, true, 32); + } else if (k <= 64) { + BLOCK_SELECT_CALL(float, true, 64); + } else if (k <= 128) { + BLOCK_SELECT_CALL(float, true, 128); + } else if (k <= 256) { + BLOCK_SELECT_CALL(float, true, 256); + } else if (k <= 512) { + BLOCK_SELECT_CALL(float, true, 512); + } else if (k <= 1024) { + BLOCK_SELECT_CALL(float, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_CALL(float, true, 2048); +#endif + } + } else { + if (k == 1) { + BLOCK_SELECT_CALL(float, false, 1); + } else if (k <= 32) { + BLOCK_SELECT_CALL(float, false, 32); + } else if (k <= 64) { + BLOCK_SELECT_CALL(float, false, 64); + } else if (k <= 128) { + BLOCK_SELECT_CALL(float, false, 128); + } else if (k <= 256) { + BLOCK_SELECT_CALL(float, false, 256); + } else if (k <= 512) { + BLOCK_SELECT_CALL(float, false, 512); + } else if (k <= 1024) { + BLOCK_SELECT_CALL(float, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_CALL(float, false, 2048); +#endif + } + } +} + +void runBlockSelectPair(Tensor& inK, + Tensor& inV, + Tensor& bitset, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + if (dir) { + if (k == 1) { + BLOCK_SELECT_PAIR_CALL(float, true, 1); + } else if (k <= 32) { + BLOCK_SELECT_PAIR_CALL(float, true, 32); + } else if (k <= 64) { + BLOCK_SELECT_PAIR_CALL(float, true, 64); + } else if (k <= 128) { + BLOCK_SELECT_PAIR_CALL(float, true, 128); + } else if (k <= 256) { + BLOCK_SELECT_PAIR_CALL(float, true, 256); + } else if (k <= 512) { + BLOCK_SELECT_PAIR_CALL(float, true, 512); + } else if (k <= 1024) { + BLOCK_SELECT_PAIR_CALL(float, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_PAIR_CALL(float, true, 2048); +#endif + } + } else { + if (k == 1) { + BLOCK_SELECT_PAIR_CALL(float, false, 1); + } else if (k <= 32) { + BLOCK_SELECT_PAIR_CALL(float, false, 32); + } else if (k <= 64) { + BLOCK_SELECT_PAIR_CALL(float, false, 64); + } else if (k <= 128) { + BLOCK_SELECT_PAIR_CALL(float, false, 128); + } else if (k <= 256) { + BLOCK_SELECT_PAIR_CALL(float, false, 256); + } else if (k <= 512) { + BLOCK_SELECT_PAIR_CALL(float, false, 512); + } else if (k <= 1024) { + BLOCK_SELECT_PAIR_CALL(float, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_PAIR_CALL(float, false, 2048); +#endif + } + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu new file mode 100644 index 0000000000..f6989fc084 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu @@ -0,0 +1,150 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 + +// warp Q to thread Q: +// 1, 1 +// 32, 2 +// 64, 3 +// 128, 3 +// 256, 4 +// 512, 8 +// 1024, 8 +// 2048, 8 + +BLOCK_SELECT_DECL(half, true, 1); +BLOCK_SELECT_DECL(half, true, 32); +BLOCK_SELECT_DECL(half, true, 64); +BLOCK_SELECT_DECL(half, true, 128); +BLOCK_SELECT_DECL(half, true, 256); +BLOCK_SELECT_DECL(half, true, 512); +BLOCK_SELECT_DECL(half, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_DECL(half, true, 2048); +#endif + +BLOCK_SELECT_DECL(half, false, 1); +BLOCK_SELECT_DECL(half, false, 32); +BLOCK_SELECT_DECL(half, false, 64); +BLOCK_SELECT_DECL(half, false, 128); +BLOCK_SELECT_DECL(half, false, 256); +BLOCK_SELECT_DECL(half, false, 512); +BLOCK_SELECT_DECL(half, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_DECL(half, false, 2048); +#endif + +void runBlockSelect(Tensor& in, + Tensor& bitset, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + if (dir) { + if (k == 1) { + BLOCK_SELECT_CALL(half, true, 1); + } else if (k <= 32) { + BLOCK_SELECT_CALL(half, true, 32); + } else if (k <= 64) { + BLOCK_SELECT_CALL(half, true, 64); + } else if (k <= 128) { + BLOCK_SELECT_CALL(half, true, 128); + } else if (k <= 256) { + BLOCK_SELECT_CALL(half, true, 256); + } else if (k <= 512) { + BLOCK_SELECT_CALL(half, true, 512); + } else if (k <= 1024) { + BLOCK_SELECT_CALL(half, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_CALL(half, true, 2048); +#endif + } + } else { + if (k == 1) { + BLOCK_SELECT_CALL(half, false, 1); + } else if (k <= 32) { + BLOCK_SELECT_CALL(half, false, 32); + } else if (k <= 64) { + BLOCK_SELECT_CALL(half, false, 64); + } else if (k <= 128) { + BLOCK_SELECT_CALL(half, false, 128); + } else if (k <= 256) { + BLOCK_SELECT_CALL(half, false, 256); + } else if (k <= 512) { + BLOCK_SELECT_CALL(half, false, 512); + } else if (k <= 1024) { + BLOCK_SELECT_CALL(half, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_CALL(half, false, 2048); +#endif + } + } +} + +void runBlockSelectPair(Tensor& inK, + Tensor& inV, + Tensor& bitset, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + + if (dir) { + if (k == 1) { + BLOCK_SELECT_PAIR_CALL(half, true, 1); + } else if (k <= 32) { + BLOCK_SELECT_PAIR_CALL(half, true, 32); + } else if (k <= 64) { + BLOCK_SELECT_PAIR_CALL(half, true, 64); + } else if (k <= 128) { + BLOCK_SELECT_PAIR_CALL(half, true, 128); + } else if (k <= 256) { + BLOCK_SELECT_PAIR_CALL(half, true, 256); + } else if (k <= 512) { + BLOCK_SELECT_PAIR_CALL(half, true, 512); + } else if (k <= 1024) { + BLOCK_SELECT_PAIR_CALL(half, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_PAIR_CALL(half, true, 2048); +#endif + } + } else { + if (k == 1) { + BLOCK_SELECT_PAIR_CALL(half, false, 1); + } else if (k <= 32) { + BLOCK_SELECT_PAIR_CALL(half, false, 32); + } else if (k <= 64) { + BLOCK_SELECT_PAIR_CALL(half, false, 64); + } else if (k <= 128) { + BLOCK_SELECT_PAIR_CALL(half, false, 128); + } else if (k <= 256) { + BLOCK_SELECT_PAIR_CALL(half, false, 256); + } else if (k <= 512) { + BLOCK_SELECT_PAIR_CALL(half, false, 512); + } else if (k <= 1024) { + BLOCK_SELECT_PAIR_CALL(half, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + BLOCK_SELECT_PAIR_CALL(half, false, 2048); +#endif + } + } +} + +#endif // FAISS_USE_FLOAT16 + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectImpl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectImpl.cuh new file mode 100644 index 0000000000..4c32b75194 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectImpl.cuh @@ -0,0 +1,106 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \ + extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream); \ + \ + extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& inK, \ + Tensor& inV, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream); + +#define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \ + void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) { \ + FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \ + FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \ + FAISS_ASSERT(outK.getSize(1) == k); \ + FAISS_ASSERT(outV.getSize(1) == k); \ + \ + auto grid = dim3(in.getSize(0)); \ + \ + constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \ + auto block = dim3(kBlockSelectNumThreads); \ + \ + FAISS_ASSERT(k <= WARP_Q); \ + FAISS_ASSERT(dir == DIR); \ + \ + auto kInit = dir ? Limits::getMin() : Limits::getMax(); \ + auto vInit = -1; \ + \ + if (bitset.getSize(0) == 0) \ + blockSelect \ + <<>>(in, outK, outV, kInit, vInit, k); \ + else \ + blockSelect \ + <<>>(in, bitset, outK, outV, kInit, vInit, k); \ + CUDA_TEST_ERROR(); \ + } \ + \ + void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& inK, \ + Tensor& inV, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) { \ + FAISS_ASSERT(inK.isSameSize(inV)); \ + FAISS_ASSERT(outK.isSameSize(outV)); \ + \ + auto grid = dim3(inK.getSize(0)); \ + \ + constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \ + auto block = dim3(kBlockSelectNumThreads); \ + \ + FAISS_ASSERT(k <= WARP_Q); \ + FAISS_ASSERT(dir == DIR); \ + \ + auto kInit = dir ? Limits::getMin() : Limits::getMax(); \ + auto vInit = -1; \ + \ + if (bitset.getSize(0) == 0) \ + blockSelectPair \ + <<>>(inK, inV, outK, outV, kInit, vInit, k); \ + else \ + blockSelectPair \ + <<>>(inK, inV, bitset, outK, outV, kInit, vInit, k); \ + CUDA_TEST_ERROR(); \ + } + + +#define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \ + runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + in, bitset, outK, outV, dir, k, stream) + +#define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \ + runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + inK, inV, bitset, outK, outV, dir, k, stream) diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh new file mode 100644 index 0000000000..f787335cdf --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh @@ -0,0 +1,259 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { namespace gpu { + +template +__global__ void blockSelect(Tensor in, + Tensor outK, + Tensor outV, + K initK, + IndexType initV, + int k) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + + int i = threadIdx.x; + K* inStart = in[row][i].data(); + + // Whole warps must participate in the selection + int limit = utils::roundDown(in.getSize(1), kWarpSize); + + for (; i < limit; i += ThreadsPerBlock) { + heap.add(*inStart, (IndexType) i); + inStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < in.getSize(1)) { + heap.addThreadQ(*inStart, (IndexType) i); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } +} + +template +__global__ void blockSelectPair(Tensor inK, + Tensor inV, + Tensor outK, + Tensor outV, + K initK, + IndexType initV, + int k) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + + int i = threadIdx.x; + K* inKStart = inK[row][i].data(); + IndexType* inVStart = inV[row][i].data(); + + // Whole warps must participate in the selection + int limit = utils::roundDown(inK.getSize(1), kWarpSize); + + for (; i < limit; i += ThreadsPerBlock) { + heap.add(*inKStart, *inVStart); + inKStart += ThreadsPerBlock; + inVStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < inK.getSize(1)) { + heap.addThreadQ(*inKStart, *inVStart); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } +} + +// Bitset included +template +__global__ void blockSelect(Tensor in, + Tensor bitset, + Tensor outK, + Tensor outV, + K initK, + IndexType initV, + int k) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + + int i = threadIdx.x; + K* inStart = in[row][i].data(); + + // Whole warps must participate in the selection + int limit = utils::roundDown(in.getSize(1), kWarpSize); + + bool bitsetEmpty = (bitset.getSize(0) == 0); + + for (; i < limit; i += ThreadsPerBlock) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + heap.addThreadQ(*inStart, (IndexType) i); + } + heap.checkThreadQ(); + + inStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < in.getSize(1)) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + heap.addThreadQ(*inStart, (IndexType) i); + } + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } +} + +template +__global__ void blockSelectPair(Tensor inK, + Tensor inV, + Tensor bitset, + Tensor outK, + Tensor outV, + K initK, + IndexType initV, + int k) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __shared__ K smemK[kNumWarps * NumWarpQ]; + __shared__ IndexType smemV[kNumWarps * NumWarpQ]; + + BlockSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + + int i = threadIdx.x; + K* inKStart = inK[row][i].data(); + IndexType* inVStart = inV[row][i].data(); + + // Whole warps must participate in the selection + int limit = utils::roundDown(inK.getSize(1), kWarpSize); + + bool bitsetEmpty = (bitset.getSize(0) == 0); + + for (; i < limit; i += ThreadsPerBlock) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + heap.addThreadQ(*inKStart, *inVStart); + } + heap.checkThreadQ(); + + inKStart += ThreadsPerBlock; + inVStart += ThreadsPerBlock; + } + + // Handle last remainder fraction of a warp of elements + if (i < inK.getSize(1)) { + if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + heap.addThreadQ(*inKStart, *inVStart); + } + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) { + outK[row][i] = smemK[i]; + outV[row][i] = smemV[i]; + } +} + +void runBlockSelect(Tensor& in, + Tensor& bitset, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); + +void runBlockSelectPair(Tensor& inKeys, + Tensor& inIndices, + Tensor& bitset, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runBlockSelect(Tensor& in, + Tensor& bitset, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); + +void runBlockSelectPair(Tensor& inKeys, + Tensor& inIndices, + Tensor& bitset, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Comparators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Comparators.cuh new file mode 100644 index 0000000000..5abfab6af5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Comparators.cuh @@ -0,0 +1,46 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +template +struct Comparator { + __device__ static inline bool lt(T a, T b) { + return a < b; + } + + __device__ static inline bool gt(T a, T b) { + return a > b; + } +}; + +template <> +struct Comparator { + __device__ static inline bool lt(half a, half b) { +#if FAISS_USE_FULL_FLOAT16 + return __hlt(a, b); +#else + return __half2float(a) < __half2float(b); +#endif // FAISS_USE_FULL_FLOAT16 + } + + __device__ static inline bool gt(half a, half b) { +#if FAISS_USE_FULL_FLOAT16 + return __hgt(a, b); +#else + return __half2float(a) > __half2float(b); +#endif // FAISS_USE_FULL_FLOAT16 + } +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh new file mode 100644 index 0000000000..cf9b74c971 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh @@ -0,0 +1,138 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace faiss { namespace gpu { + +// +// Conversion utilities +// + +template +struct Convert { + inline __device__ To operator()(From v) const { + return (To) v; + } +}; + +#ifdef FAISS_USE_FLOAT16 +template <> +struct Convert { + inline __device__ half operator()(float v) const { + return __float2half(v); + } +}; + +template <> +struct Convert { + inline __device__ float operator()(half v) const { + return __half2float(v); + } +}; +#endif + +template +struct ConvertTo { +}; + +template <> +struct ConvertTo { + static inline __device__ float to(float v) { return v; } +#ifdef FAISS_USE_FLOAT16 + static inline __device__ float to(half v) { return __half2float(v); } +#endif +}; + +template <> +struct ConvertTo { + static inline __device__ float2 to(float2 v) { return v; } +#ifdef FAISS_USE_FLOAT16 + static inline __device__ float2 to(half2 v) { return __half22float2(v); } +#endif +}; + +template <> +struct ConvertTo { + static inline __device__ float4 to(float4 v) { return v; } +#ifdef FAISS_USE_FLOAT16 + static inline __device__ float4 to(Half4 v) { return half4ToFloat4(v); } +#endif +}; + +#ifdef FAISS_USE_FLOAT16 +template <> +struct ConvertTo { + static inline __device__ half to(float v) { return __float2half(v); } + static inline __device__ half to(half v) { return v; } +}; +#endif + +#ifdef FAISS_USE_FLOAT16 +template <> +struct ConvertTo { + static inline __device__ half2 to(float2 v) { return __float22half2_rn(v); } + static inline __device__ half2 to(half2 v) { return v; } +}; +#endif + +#ifdef FAISS_USE_FLOAT16 +template <> +struct ConvertTo { + static inline __device__ Half4 to(float4 v) { return float4ToHalf4(v); } + static inline __device__ Half4 to(Half4 v) { return v; } +}; +#endif + +// Tensor conversion +template +void runConvert(const From* in, + To* out, + size_t num, + cudaStream_t stream) { + thrust::transform(thrust::cuda::par.on(stream), + in, in + num, out, Convert()); +} + +template +void convertTensor(cudaStream_t stream, + Tensor& in, + Tensor& out) { + FAISS_ASSERT(in.numElements() == out.numElements()); + + runConvert(in.data(), out.data(), in.numElements(), stream); +} + +template +DeviceTensor convertTensor(GpuResources* res, + cudaStream_t stream, + Tensor& in) { + DeviceTensor out; + + if (res) { + out = std::move(DeviceTensor( + res->getMemoryManagerCurrentDevice(), + in.sizes(), + stream)); + } else { + out = std::move(DeviceTensor(in.sizes())); + } + + convertTensor(stream, in, out); + return out; +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/CopyUtils.cuh b/core/src/index/thirdparty/faiss/gpu/utils/CopyUtils.cuh new file mode 100644 index 0000000000..922ca4ed0e --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/CopyUtils.cuh @@ -0,0 +1,107 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +/// Ensure the memory at `p` is either on the given device, or copy it +/// to the device in a new allocation. +/// If `resources` is provided, then we will perform a temporary +/// memory allocation if needed. Otherwise, we will call cudaMalloc if +/// needed. +template +DeviceTensor toDevice(GpuResources* resources, + int dstDevice, + T* src, + cudaStream_t stream, + std::initializer_list sizes) { + int dev = getDeviceForAddress(src); + + if (dev == dstDevice) { + // On device we expect + return DeviceTensor(src, sizes); + } else { + // On different device or on host + DeviceScope scope(dstDevice); + + Tensor oldT(src, sizes); + + if (resources) { + DeviceTensor newT(resources->getMemoryManager(dstDevice), + sizes, + stream); + + newT.copyFrom(oldT, stream); + return newT; + } else { + DeviceTensor newT(sizes); + + newT.copyFrom(oldT, stream); + return newT; + } + } +} + +/// Copies data to the CPU, if it is not already on the CPU +template +HostTensor toHost(T* src, + cudaStream_t stream, + std::initializer_list sizes) { + int dev = getDeviceForAddress(src); + + if (dev == -1) { + // Already on the CPU, just wrap in a HostTensor that doesn't own this + // memory + return HostTensor(src, sizes); + } else { + HostTensor out(sizes); + Tensor devData(src, sizes); + out.copyFrom(devData, stream); + + return out; + } +} + +/// Copies a device array's allocation to an address, if necessary +template +inline void fromDevice(T* src, T* dst, size_t num, cudaStream_t stream) { + // It is possible that the array already represents memory at `p`, + // in which case no copy is needed + if (src == dst) { + return; + } + + int dev = getDeviceForAddress(dst); + + if (dev == -1) { + CUDA_VERIFY(cudaMemcpyAsync(dst, + src, + num * sizeof(T), + cudaMemcpyDeviceToHost, + stream)); + } else { + CUDA_VERIFY(cudaMemcpyAsync(dst, + src, + num * sizeof(T), + cudaMemcpyDeviceToDevice, + stream)); + } +} + +/// Copies a device array's allocation to an address, if necessary +template +void fromDevice(Tensor& src, T* dst, cudaStream_t stream) { + FAISS_ASSERT(src.isContiguous()); + fromDevice(src.data(), dst, src.numElements(), stream); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceDefs.cuh b/core/src/index/thirdparty/faiss/gpu/utils/DeviceDefs.cuh new file mode 100644 index 0000000000..89d3dda289 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceDefs.cuh @@ -0,0 +1,48 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ <= 750 +constexpr int kWarpSize = 32; +#else +#error Unknown __CUDA_ARCH__; please define parameters for compute capability +#endif // __CUDA_ARCH__ types +#endif // __CUDA_ARCH__ + +#ifndef __CUDA_ARCH__ +// dummy value for host compiler +constexpr int kWarpSize = 32; +#endif // !__CUDA_ARCH__ + +// This is a memory barrier for intra-warp writes to shared memory. +__forceinline__ __device__ void warpFence() { + +#if CUDA_VERSION >= 9000 + __syncwarp(); +#else + // For the time being, assume synchronicity. + // __threadfence_block(); +#endif +} + +#if CUDA_VERSION > 9000 +// Based on the CUDA version (we assume what version of nvcc/ptxas we were +// compiled with), the register allocation algorithm is much better, so only +// enable the 2048 selection code if we are above 9.0 (9.2 seems to be ok) +#define GPU_MAX_SELECTION_K 2048 +#else +#define GPU_MAX_SELECTION_K 1024 +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.cpp b/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.cpp new file mode 100644 index 0000000000..2ce721986a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.cpp @@ -0,0 +1,77 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +namespace faiss { namespace gpu { + +DeviceMemoryReservation::DeviceMemoryReservation() + : state_(nullptr), + device_(0), + data_(nullptr), + size_(0), + stream_(0) { +} + +DeviceMemoryReservation::DeviceMemoryReservation(DeviceMemory* state, + int device, + void* p, + size_t size, + cudaStream_t stream) + : state_(state), + device_(device), + data_(p), + size_(size), + stream_(stream) { +} + +DeviceMemoryReservation::DeviceMemoryReservation( + DeviceMemoryReservation&& m) noexcept { + + state_ = m.state_; + device_ = m.device_; + data_ = m.data_; + size_ = m.size_; + stream_ = m.stream_; + + m.data_ = nullptr; +} + +DeviceMemoryReservation::~DeviceMemoryReservation() { + if (data_) { + FAISS_ASSERT(state_); + state_->returnAllocation(*this); + } + + data_ = nullptr; +} + +DeviceMemoryReservation& +DeviceMemoryReservation::operator=(DeviceMemoryReservation&& m) { + if (data_) { + FAISS_ASSERT(state_); + state_->returnAllocation(*this); + } + + state_ = m.state_; + device_ = m.device_; + data_ = m.data_; + size_ = m.size_; + stream_ = m.stream_; + + m.data_ = nullptr; + + return *this; +} + +DeviceMemory::~DeviceMemory() { +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.h b/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.h new file mode 100644 index 0000000000..1bffdc00ac --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceMemory.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +class DeviceMemory; + +class DeviceMemoryReservation { + public: + DeviceMemoryReservation(); + DeviceMemoryReservation(DeviceMemory* state, + int device, void* p, size_t size, + cudaStream_t stream); + DeviceMemoryReservation(DeviceMemoryReservation&& m) noexcept; + ~DeviceMemoryReservation(); + + DeviceMemoryReservation& operator=(DeviceMemoryReservation&& m); + + int device() { return device_; } + void* get() { return data_; } + size_t size() { return size_; } + cudaStream_t stream() { return stream_; } + + private: + DeviceMemory* state_; + + int device_; + void* data_; + size_t size_; + cudaStream_t stream_; +}; + +/// Manages temporary memory allocations on a GPU device +class DeviceMemory { + public: + virtual ~DeviceMemory(); + + /// Returns the device we are managing memory for + virtual int getDevice() const = 0; + + /// Obtains a temporary memory allocation for our device, + /// whose usage is ordered with respect to the given stream. + virtual DeviceMemoryReservation getMemory(cudaStream_t stream, + size_t size) = 0; + + /// Returns the current size available without calling cudaMalloc + virtual size_t getSizeAvailable() const = 0; + + /// Returns a string containing our current memory manager state + virtual std::string toString() const = 0; + + /// Returns the high-water mark of cudaMalloc allocations for our + /// device + virtual size_t getHighWaterCudaMalloc() const = 0; + + protected: + friend class DeviceMemoryReservation; + virtual void returnAllocation(DeviceMemoryReservation& m) = 0; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor-inl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor-inl.cuh new file mode 100644 index 0000000000..cff5452989 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor-inl.cuh @@ -0,0 +1,228 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include // std::move + +namespace faiss { namespace gpu { + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor() : + Tensor(), + state_(AllocState::NotOwner), + space_(MemorySpace::Device) { +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DeviceTensor&& t) : + Tensor(), + state_(AllocState::NotOwner), + space_(MemorySpace::Device) { + this->operator=(std::move(t)); +} + +template class PtrTraits> +__host__ +DeviceTensor& +DeviceTensor::operator=( + DeviceTensor&& t) { + if (this->state_ == AllocState::Owner) { + CUDA_VERIFY(cudaFree(this->data_)); + } + + this->Tensor::operator=( + std::move(t)); + + this->state_ = t.state_; t.state_ = AllocState::NotOwner; + this->space_ = t.space_; + this->reservation_ = std::move(t.reservation_); + + return *this; +} + +template class PtrTraits> +__host__ +DeviceTensor::~DeviceTensor() { + if (state_ == AllocState::Owner) { + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); + CUDA_VERIFY(cudaFree(this->data_)); + this->data_ = nullptr; + } + + // Otherwise, if we have a temporary memory reservation, then its + // destructor will return the reservation +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + const IndexT sizes[Dim], + MemorySpace space) : + Tensor(nullptr, sizes), + state_(AllocState::Owner), + space_(space) { + + allocMemorySpace(space, &this->data_, this->getSizeInBytes()); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + std::initializer_list sizes, + MemorySpace space) : + Tensor(nullptr, sizes), + state_(AllocState::Owner), + space_(space) { + + allocMemorySpace(space, &this->data_, this->getSizeInBytes()); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); +} + +// memory reservation constructor +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DeviceMemory& m, + const IndexT sizes[Dim], + cudaStream_t stream, + MemorySpace space) : + Tensor(nullptr, sizes), + state_(AllocState::Reservation), + space_(space) { + + // FIXME: add MemorySpace to DeviceMemory + auto memory = m.getMemory(stream, this->getSizeInBytes()); + + this->data_ = (T*) memory.get(); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); + reservation_ = std::move(memory); +} + +// memory reservation constructor +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DeviceMemory& m, + std::initializer_list sizes, + cudaStream_t stream, + MemorySpace space) : + Tensor(nullptr, sizes), + state_(AllocState::Reservation), + space_(space) { + + // FIXME: add MemorySpace to DeviceMemory + auto memory = m.getMemory(stream, this->getSizeInBytes()); + + this->data_ = (T*) memory.get(); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); + reservation_ = std::move(memory); +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DataPtrType data, + const IndexT sizes[Dim], + MemorySpace space) : + Tensor(data, sizes), + state_(AllocState::NotOwner), + space_(space) { +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DataPtrType data, + std::initializer_list sizes, + MemorySpace space) : + Tensor(data, sizes), + state_(AllocState::NotOwner), + space_(space) { +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DataPtrType data, + const IndexT sizes[Dim], + const IndexT strides[Dim], + MemorySpace space) : + Tensor(data, sizes, strides), + state_(AllocState::NotOwner), + space_(space) { +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + Tensor& t, + cudaStream_t stream, + MemorySpace space) : + Tensor(nullptr, t.sizes(), t.strides()), + state_(AllocState::Owner), + space_(space) { + + allocMemorySpace(space_, &this->data_, this->getSizeInBytes()); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); + this->copyFrom(t, stream); +} + +template class PtrTraits> +__host__ +DeviceTensor::DeviceTensor( + DeviceMemory& m, + Tensor& t, + cudaStream_t stream, + MemorySpace space) : + Tensor(nullptr, t.sizes(), t.strides()), + state_(AllocState::Reservation), + space_(space) { + + // FIXME: add MemorySpace to DeviceMemory + auto memory = m.getMemory(stream, this->getSizeInBytes()); + + this->data_ = (T*) memory.get(); + FAISS_ASSERT(this->data_ || (this->getSizeInBytes() == 0)); + reservation_ = std::move(memory); + + this->copyFrom(t, stream); +} + +template class PtrTraits> +__host__ DeviceTensor& +DeviceTensor::zero( + cudaStream_t stream) { + if (this->data_) { + // Region must be contiguous + FAISS_ASSERT(this->isContiguous()); + + CUDA_VERIFY(cudaMemsetAsync( + this->data_, 0, this->getSizeInBytes(), stream)); + } + + return *this; +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor.cuh b/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor.cuh new file mode 100644 index 0000000000..78039969c5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceTensor.cuh @@ -0,0 +1,113 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +template class PtrTraits = traits::DefaultPtrTraits> +class DeviceTensor : public Tensor { + public: + typedef IndexT IndexType; + typedef typename PtrTraits::PtrType DataPtrType; + + /// Default constructor + __host__ DeviceTensor(); + + /// Destructor + __host__ ~DeviceTensor(); + + /// Move constructor + __host__ DeviceTensor(DeviceTensor&& t); + + /// Move assignment + __host__ DeviceTensor& + operator=(DeviceTensor&& t); + + /// Constructs a tensor of the given size, allocating memory for it + /// locally + __host__ DeviceTensor(const IndexT sizes[Dim], + MemorySpace space = MemorySpace::Device); + __host__ DeviceTensor(std::initializer_list sizes, + MemorySpace space = MemorySpace::Device); + + /// Constructs a tensor of the given size, reserving a temporary + /// memory reservation via a memory manager. + /// The memory reservation should be ordered with respect to the + /// given stream. + __host__ DeviceTensor(DeviceMemory& m, + const IndexT sizes[Dim], + cudaStream_t stream, + MemorySpace space = MemorySpace::Device); + __host__ DeviceTensor(DeviceMemory& m, + std::initializer_list sizes, + cudaStream_t stream, + MemorySpace space = MemorySpace::Device); + + /// Constructs a tensor of the given size and stride, referencing a + /// memory region we do not own + __host__ DeviceTensor(DataPtrType data, + const IndexT sizes[Dim], + MemorySpace space = MemorySpace::Device); + __host__ DeviceTensor(DataPtrType data, + std::initializer_list sizes, + MemorySpace space = MemorySpace::Device); + + /// Constructs a tensor of the given size and stride, referencing a + /// memory region we do not own + __host__ DeviceTensor(DataPtrType data, + const IndexT sizes[Dim], + const IndexT strides[Dim], + MemorySpace space = MemorySpace::Device); + + /// Copies a tensor into ourselves, allocating memory for it locally + __host__ DeviceTensor(Tensor& t, + cudaStream_t stream, + MemorySpace space = MemorySpace::Device); + + /// Copies a tensor into ourselves, reserving a temporary + /// memory reservation via a memory manager. + __host__ DeviceTensor(DeviceMemory& m, + Tensor& t, + cudaStream_t stream, + MemorySpace space = MemorySpace::Device); + + /// Call to zero out memory + __host__ DeviceTensor& + zero(cudaStream_t stream); + + private: + enum AllocState { + /// This tensor itself owns the memory, which must be freed via + /// cudaFree + Owner, + + /// This tensor itself is not an owner of the memory; there is + /// nothing to free + NotOwner, + + /// This tensor has the memory via a temporary memory reservation + Reservation + }; + + AllocState state_; + MemorySpace space_; + DeviceMemoryReservation reservation_; +}; + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.cu b/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.cu new file mode 100644 index 0000000000..a8195c9ca6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.cu @@ -0,0 +1,206 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +int getCurrentDevice() { + int dev = -1; + CUDA_VERIFY(cudaGetDevice(&dev)); + FAISS_ASSERT(dev != -1); + + return dev; +} + +void setCurrentDevice(int device) { + CUDA_VERIFY(cudaSetDevice(device)); +} + +int getNumDevices() { + int numDev = -1; + cudaError_t err = cudaGetDeviceCount(&numDev); + if (cudaErrorNoDevice == err) { + numDev = 0; + } else { + CUDA_VERIFY(err); + } + FAISS_ASSERT(numDev != -1); + + return numDev; +} + +void profilerStart() { + CUDA_VERIFY(cudaProfilerStart()); +} + +void profilerStop() { + CUDA_VERIFY(cudaProfilerStop()); +} + +void synchronizeAllDevices() { + for (int i = 0; i < getNumDevices(); ++i) { + DeviceScope scope(i); + + CUDA_VERIFY(cudaDeviceSynchronize()); + } +} + +const cudaDeviceProp& getDeviceProperties(int device) { + static std::mutex mutex; + static std::unordered_map properties; + + std::lock_guard guard(mutex); + + auto it = properties.find(device); + if (it == properties.end()) { + cudaDeviceProp prop; + CUDA_VERIFY(cudaGetDeviceProperties(&prop, device)); + + properties[device] = prop; + it = properties.find(device); + } + + return it->second; +} + +const cudaDeviceProp& getCurrentDeviceProperties() { + return getDeviceProperties(getCurrentDevice()); +} + +int getMaxThreads(int device) { + return getDeviceProperties(device).maxThreadsPerBlock; +} + +int getMaxThreadsCurrentDevice() { + return getMaxThreads(getCurrentDevice()); +} + +size_t getMaxSharedMemPerBlock(int device) { + return getDeviceProperties(device).sharedMemPerBlock; +} + +size_t getMaxSharedMemPerBlockCurrentDevice() { + return getMaxSharedMemPerBlock(getCurrentDevice()); +} + +int getDeviceForAddress(const void* p) { + if (!p) { + return -1; + } + + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + FAISS_ASSERT_FMT(err == cudaSuccess || + err == cudaErrorInvalidValue, + "unknown error %d", (int) err); + + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + FAISS_ASSERT_FMT(err == cudaErrorInvalidValue, + "unknown error %d", (int) err); + return -1; + } else if (att.memoryType == cudaMemoryTypeHost) { + return -1; + } else { + return att.device; + } +} + +bool getFullUnifiedMemSupport(int device) { + const auto& prop = getDeviceProperties(device); + return (prop.major >= 6); +} + +bool getFullUnifiedMemSupportCurrentDevice() { + return getFullUnifiedMemSupport(getCurrentDevice()); +} + +bool getTensorCoreSupport(int device) { + const auto& prop = getDeviceProperties(device); + return (prop.major >= 7); +} + +bool getTensorCoreSupportCurrentDevice() { + return getTensorCoreSupport(getCurrentDevice()); +} + +int getMaxKSelection() { + // Don't use the device at the moment, just base this based on the CUDA SDK + // that we were compiled with + return GPU_MAX_SELECTION_K; +} + +DeviceScope::DeviceScope(int device) { + prevDevice_ = getCurrentDevice(); + + if (prevDevice_ != device) { + setCurrentDevice(device); + } else { + prevDevice_ = -1; + } +} + +DeviceScope::~DeviceScope() { + if (prevDevice_ != -1) { + setCurrentDevice(prevDevice_); + } +} + +CublasHandleScope::CublasHandleScope() { + auto blasStatus = cublasCreate(&blasHandle_); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); +} + +CublasHandleScope::~CublasHandleScope() { + auto blasStatus = cublasDestroy(blasHandle_); + FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS); +} + +CudaEvent::CudaEvent(cudaStream_t stream) + : event_(0) { + CUDA_VERIFY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + CUDA_VERIFY(cudaEventRecord(event_, stream)); +} + +CudaEvent::CudaEvent(CudaEvent&& event) noexcept + : event_(std::move(event.event_)) { + event.event_ = 0; +} + +CudaEvent::~CudaEvent() { + if (event_) { + CUDA_VERIFY(cudaEventDestroy(event_)); + } +} + +CudaEvent& +CudaEvent::operator=(CudaEvent&& event) noexcept { + event_ = std::move(event.event_); + event.event_ = 0; + + return *this; +} + +void +CudaEvent::streamWaitOnEvent(cudaStream_t stream) { + CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0)); +} + +void +CudaEvent::cpuWaitOnEvent() { + CUDA_VERIFY(cudaEventSynchronize(event_)); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.h b/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.h new file mode 100644 index 0000000000..e9b5426ae4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceUtils.h @@ -0,0 +1,191 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Returns the current thread-local GPU device +int getCurrentDevice(); + +/// Sets the current thread-local GPU device +void setCurrentDevice(int device); + +/// Returns the number of available GPU devices +int getNumDevices(); + +/// Starts the CUDA profiler (exposed via SWIG) +void profilerStart(); + +/// Stops the CUDA profiler (exposed via SWIG) +void profilerStop(); + +/// Synchronizes the CPU against all devices (equivalent to +/// cudaDeviceSynchronize for each device) +void synchronizeAllDevices(); + +/// Returns a cached cudaDeviceProp for the given device +const cudaDeviceProp& getDeviceProperties(int device); + +/// Returns the cached cudaDeviceProp for the current device +const cudaDeviceProp& getCurrentDeviceProperties(); + +/// Returns the maximum number of threads available for the given GPU +/// device +int getMaxThreads(int device); + +/// Equivalent to getMaxThreads(getCurrentDevice()) +int getMaxThreadsCurrentDevice(); + +/// Returns the maximum smem available for the given GPU device +size_t getMaxSharedMemPerBlock(int device); + +/// Equivalent to getMaxSharedMemPerBlock(getCurrentDevice()) +size_t getMaxSharedMemPerBlockCurrentDevice(); + +/// For a given pointer, returns whether or not it is located on +/// a device (deviceId >= 0) or the host (-1). +int getDeviceForAddress(const void* p); + +/// Does the given device support full unified memory sharing host +/// memory? +bool getFullUnifiedMemSupport(int device); + +/// Equivalent to getFullUnifiedMemSupport(getCurrentDevice()) +bool getFullUnifiedMemSupportCurrentDevice(); + +/// Does the given device support tensor core operations? +bool getTensorCoreSupport(int device); + +/// Equivalent to getTensorCoreSupport(getCurrentDevice()) +bool getTensorCoreSupportCurrentDevice(); + +/// Returns the maximum k-selection value supported based on the CUDA SDK that +/// we were compiled with. .cu files can use DeviceDefs.cuh, but this is for +/// non-CUDA files +int getMaxKSelection(); + +/// RAII object to set the current device, and restore the previous +/// device upon destruction +class DeviceScope { + public: + explicit DeviceScope(int device); + ~DeviceScope(); + + private: + int prevDevice_; +}; + +/// RAII object to manage a cublasHandle_t +class CublasHandleScope { + public: + CublasHandleScope(); + ~CublasHandleScope(); + + cublasHandle_t get() { return blasHandle_; } + + private: + cublasHandle_t blasHandle_; +}; + +// RAII object to manage a cudaEvent_t +class CudaEvent { + public: + /// Creates an event and records it in this stream + explicit CudaEvent(cudaStream_t stream); + CudaEvent(const CudaEvent& event) = delete; + CudaEvent(CudaEvent&& event) noexcept; + ~CudaEvent(); + + inline cudaEvent_t get() { return event_; } + + /// Wait on this event in this stream + void streamWaitOnEvent(cudaStream_t stream); + + /// Have the CPU wait for the completion of this event + void cpuWaitOnEvent(); + + CudaEvent& operator=(CudaEvent&& event) noexcept; + CudaEvent& operator=(CudaEvent& event) = delete; + + private: + cudaEvent_t event_; +}; + +/// Wrapper to test return status of CUDA functions +#define CUDA_VERIFY(X) \ + do { \ + auto err__ = (X); \ + FAISS_ASSERT_FMT(err__ == cudaSuccess, "CUDA error %d %s", \ + (int) err__, cudaGetErrorString(err__)); \ + } while (0) + +/// Wrapper to synchronously probe for CUDA errors +// #define FAISS_GPU_SYNC_ERROR 1 + +#ifdef FAISS_GPU_SYNC_ERROR +#define CUDA_TEST_ERROR() \ + do { \ + CUDA_VERIFY(cudaDeviceSynchronize()); \ + } while (0) +#else +#define CUDA_TEST_ERROR() \ + do { \ + CUDA_VERIFY(cudaGetLastError()); \ + } while (0) +#endif + +/// Call for a collection of streams to wait on +template +void streamWaitBase(const L1& listWaiting, const L2& listWaitOn) { + // For all the streams we are waiting on, create an event + std::vector events; + for (auto& stream : listWaitOn) { + cudaEvent_t event; + CUDA_VERIFY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + CUDA_VERIFY(cudaEventRecord(event, stream)); + events.push_back(event); + } + + // For all the streams that are waiting, issue a wait + for (auto& stream : listWaiting) { + for (auto& event : events) { + CUDA_VERIFY(cudaStreamWaitEvent(stream, event, 0)); + } + } + + for (auto& event : events) { + CUDA_VERIFY(cudaEventDestroy(event)); + } +} + +/// These versions allow usage of initializer_list as arguments, since +/// otherwise {...} doesn't have a type +template +void streamWait(const L1& a, + const std::initializer_list& b) { + streamWaitBase(a, b); +} + +template +void streamWait(const std::initializer_list& a, + const L2& b) { + streamWaitBase(a, b); +} + +inline void streamWait(const std::initializer_list& a, + const std::initializer_list& b) { + streamWaitBase(a, b); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/DeviceVector.cuh b/core/src/index/thirdparty/faiss/gpu/utils/DeviceVector.cuh new file mode 100644 index 0000000000..dac73679fd --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/DeviceVector.cuh @@ -0,0 +1,191 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// A simple version of thrust::device_vector, but has more control +/// over whether resize() initializes new space with T() (which we +/// don't want), and control on how much the reserved space grows by +/// upon resize/reserve. It is also meant for POD types only. +template +class DeviceVector { + public: + DeviceVector(MemorySpace space = MemorySpace::Device) + : data_(nullptr), + num_(0), + capacity_(0), + owner(true), + space_(space) { + } + + ~DeviceVector() { + clear(); + } + + void reset(T* data, size_t num, size_t capacity, MemorySpace space = MemorySpace::Device) { + FAISS_ASSERT(data != nullptr); + FAISS_ASSERT(capacity >= num); + clear(); + owner = false; + data_ = data; + num_ = num; + capacity_ = capacity_; + } + + // Clear all allocated memory; reset to zero size + void clear() { + if (owner) { + freeMemorySpace(space_, data_); + } + data_ = nullptr; + num_ = 0; + capacity_ = 0; + owner = true; + } + + size_t size() const { return num_; } + size_t capacity() const { return capacity_; } + T* data() { return data_; } + const T* data() const { return data_; } + + template + std::vector copyToHost(cudaStream_t stream) const { + FAISS_ASSERT(num_ * sizeof(T) % sizeof(OutT) == 0); + + std::vector out((num_ * sizeof(T)) / sizeof(OutT)); + CUDA_VERIFY(cudaMemcpyAsync(out.data(), data_, num_ * sizeof(T), + cudaMemcpyDeviceToHost, stream)); + + return out; + } + + // Returns true if we actually reallocated memory + // If `reserveExact` is true, then we reserve only the memory that + // we need for what we're appending + bool append(const T* d, + size_t n, + cudaStream_t stream, + bool reserveExact = false) { + bool mem = false; + + if (n > 0) { + size_t reserveSize = num_ + n; + if (!reserveExact) { + reserveSize = getNewCapacity_(reserveSize); + } + + mem = reserve(reserveSize, stream); + + int dev = getDeviceForAddress(d); + if (dev == -1) { + CUDA_VERIFY(cudaMemcpyAsync(data_ + num_, d, n * sizeof(T), + cudaMemcpyHostToDevice, stream)); + } else { + CUDA_VERIFY(cudaMemcpyAsync(data_ + num_, d, n * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + num_ += n; + } + + return mem; + } + + // Returns true if we actually reallocated memory + bool resize(size_t newSize, cudaStream_t stream) { + bool mem = false; + + if (num_ < newSize) { + mem = reserve(getNewCapacity_(newSize), stream); + } + + // Don't bother zero initializing the newly accessible memory + // (unlike thrust::device_vector) + num_ = newSize; + + return mem; + } + + // Clean up after oversized allocations, while leaving some space to + // remain for subsequent allocations (if `exact` false) or to + // exactly the space we need (if `exact` true); returns space + // reclaimed in bytes + size_t reclaim(bool exact, cudaStream_t stream) { + size_t free = capacity_ - num_; + + if (exact) { + realloc_(num_, stream); + return free * sizeof(T); + } + + // If more than 1/4th of the space is free, then we want to + // truncate to only having 1/8th of the space free; this still + // preserves some space for new elements, but won't force us to + // double our size right away + if (free > (capacity_ / 4)) { + size_t newFree = capacity_ / 8; + size_t newCapacity = num_ + newFree; + + size_t oldCapacity = capacity_; + FAISS_ASSERT(newCapacity < oldCapacity); + + realloc_(newCapacity, stream); + + return (oldCapacity - newCapacity) * sizeof(T); + } + + return 0; + } + + // Returns true if we actually reallocated memory + bool reserve(size_t newCapacity, cudaStream_t stream) { + if (newCapacity <= capacity_) { + return false; + } + + // Otherwise, we need new space. + realloc_(newCapacity, stream); + return true; + } + + private: + void realloc_(size_t newCapacity, cudaStream_t stream) { + FAISS_ASSERT(num_ <= newCapacity); + FAISS_ASSERT_MSG(owner, "Cannot realloc due to no ownership of mem"); + + T* newData = nullptr; + allocMemorySpace(space_, &newData, newCapacity * sizeof(T)); + CUDA_VERIFY(cudaMemcpyAsync(newData, data_, num_ * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + freeMemorySpace(space_, data_); + + data_ = newData; + capacity_ = newCapacity; + } + + size_t getNewCapacity_(size_t preferredSize) { + return utils::nextHighestPowerOf2(preferredSize); + } + + T* data_; + size_t num_; + size_t capacity_; + MemorySpace space_; + bool owner = true; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu new file mode 100644 index 0000000000..e1f5c09b9f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu @@ -0,0 +1,42 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +//#include +#include +#include +#include + +#ifdef FAISS_USE_FLOAT16 + +namespace faiss { namespace gpu { + +bool getDeviceSupportsFloat16Math(int device) { + const auto& prop = getDeviceProperties(device); + + return (prop.major >= 6 || + (prop.major == 5 && prop.minor >= 3)); +} + +__half hostFloat2Half(float a) { +#if CUDA_VERSION >= 9000 + __half_raw raw; + //raw.x = cpu_float2half_rn(a).x; + FAISS_ASSERT_FMT(false, "%s", "cpu_float2half_rn() not support"); + return __half(raw); +#else + __half h; + //h.x = cpu_float2half_rn(a).x; + FAISS_ASSERT_FMT(false, "%s", "cpu_float2half_rn() not support"); + return h; +#endif +} + +} } // namespace + +#endif // FAISS_USE_FLOAT16 diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh new file mode 100644 index 0000000000..0af798ba80 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh @@ -0,0 +1,81 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +// We require at least CUDA 8.0 for compilation +#if CUDA_VERSION < 8000 +#error "CUDA >= 8.0 is required" +#endif + +// Some compute capabilities have full float16 ALUs. +#if __CUDA_ARCH__ >= 530 +#define FAISS_USE_FULL_FLOAT16 1 +#endif // __CUDA_ARCH__ types + +#ifdef FAISS_USE_FLOAT16 +#include +#endif + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 + +// 64 bytes containing 4 half (float16) values +struct Half4 { + half2 a; + half2 b; +}; + +inline __device__ float4 half4ToFloat4(Half4 v) { + float2 a = __half22float2(v.a); + float2 b = __half22float2(v.b); + + float4 out; + out.x = a.x; + out.y = a.y; + out.z = b.x; + out.w = b.y; + + return out; +} + +inline __device__ Half4 float4ToHalf4(float4 v) { + float2 a; + a.x = v.x; + a.y = v.y; + + float2 b; + b.x = v.z; + b.y = v.w; + + Half4 out; + out.a = __float22half2_rn(a); + out.b = __float22half2_rn(b); + + return out; +} + +// 128 bytes containing 8 half (float16) values +struct Half8 { + Half4 a; + Half4 b; +}; + +/// Returns true if the given device supports native float16 math +bool getDeviceSupportsFloat16Math(int device); + +__half hostFloat2Half(float v); + +#endif // FAISS_USE_FLOAT16 + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/HostTensor-inl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/HostTensor-inl.cuh new file mode 100644 index 0000000000..37149fc936 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/HostTensor-inl.cuh @@ -0,0 +1,180 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +namespace faiss { namespace gpu { + +template class PtrTraits> +__host__ +HostTensor::HostTensor() : + Tensor(), + state_(AllocState::NotOwner) { +} + +template class PtrTraits> +__host__ +HostTensor::~HostTensor() { + if (state_ == AllocState::Owner) { + FAISS_ASSERT(this->data_ != nullptr); + delete[] this->data_; + this->data_ = nullptr; + } +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + HostTensor&& t) : + Tensor(), + state_(AllocState::NotOwner) { + this->operator=(std::move(t)); +} + +template class PtrTraits> +__host__ +HostTensor& +HostTensor::operator=( + HostTensor&& t) { + if (this->state_ == AllocState::Owner) { + FAISS_ASSERT(this->data_ != nullptr); + delete[] this->data_; + this->data_ = nullptr; + } + + this->Tensor::operator=( + std::move(t)); + + this->state_ = t.state_; t.state_ = AllocState::NotOwner; + + return *this; +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + const IndexT sizes[Dim]) : + Tensor(nullptr, sizes), + state_(AllocState::Owner) { + + this->data_ = new T[this->numElements()]; + FAISS_ASSERT(this->data_ != nullptr); +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + std::initializer_list sizes) : + Tensor(nullptr, sizes), + state_(AllocState::Owner) { + this->data_ = new T[this->numElements()]; + FAISS_ASSERT(this->data_ != nullptr); +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + DataPtrType data, + const IndexT sizes[Dim]) : + Tensor(data, sizes), + state_(AllocState::NotOwner) { +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + DataPtrType data, + std::initializer_list sizes) : + Tensor(data, sizes), + state_(AllocState::NotOwner) { +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + DataPtrType data, + const IndexT sizes[Dim], + const IndexT strides[Dim]) : + Tensor(data, sizes, strides), + state_(AllocState::NotOwner) { +} + +template class PtrTraits> +__host__ +HostTensor::HostTensor( + Tensor& t, + cudaStream_t stream) : + Tensor(nullptr, t.sizes(), t.strides()), + state_(AllocState::Owner) { + // Only contiguous arrays handled for now + FAISS_ASSERT(t.isContiguous()); + + this->data_ = new T[t.numElements()]; + this->copyFrom(t, stream); +} + +/// Call to zero out memory +template class PtrTraits> +__host__ HostTensor& +HostTensor::zero() { + // Region must be contiguous + FAISS_ASSERT(this->isContiguous()); + + if (this->data_ != nullptr) { + memset(this->data_, 0, this->getSizeInBytes()); + } + + return *this; +} + +template class PtrTraits> +__host__ T +HostTensor::maxDiff( + const HostTensor& t) const { + auto size = this->numElements(); + + FAISS_ASSERT(size == t.numElements()); + FAISS_ASSERT(size > 0); + + if (InnerContig) { + auto a = this->data(); + auto b = t.data(); + + T maxDiff = a[0] - b[0]; + // FIXME: type-specific abs() + maxDiff = maxDiff < 0 ? maxDiff * (T) -1 : maxDiff; + + for (IndexT i = 1; i < size; ++i) { + auto diff = a[i] - b[i]; + // FIXME: type-specific abs + diff = diff < 0 ? diff * (T) -1 : diff; + if (diff > maxDiff) { + maxDiff = diff; + } + } + + return maxDiff; + } else { + // non-contiguous + // FIXME + FAISS_ASSERT(false); + return (T) 0; + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/HostTensor.cuh b/core/src/index/thirdparty/faiss/gpu/utils/HostTensor.cuh new file mode 100644 index 0000000000..5b8758a8ce --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/HostTensor.cuh @@ -0,0 +1,91 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +template class PtrTraits = traits::DefaultPtrTraits> +class HostTensor : public Tensor { + public: + typedef IndexT IndexType; + typedef typename PtrTraits::PtrType DataPtrType; + + /// Default constructor + __host__ HostTensor(); + + /// Destructor + __host__ ~HostTensor(); + + /// Move constructor + __host__ HostTensor(HostTensor&& t); + + /// Move assignment + __host__ HostTensor& + operator=(HostTensor&& t); + + /// Constructs a tensor of the given size, allocating memory for it + /// locally + __host__ HostTensor(const IndexT sizes[Dim]); + __host__ HostTensor(std::initializer_list sizes); + + /// Constructs a tensor of the given size and stride, referencing a + /// memory region we do not own + __host__ HostTensor(DataPtrType data, + const IndexT sizes[Dim]); + __host__ HostTensor(DataPtrType data, + std::initializer_list sizes); + + /// Constructs a tensor of the given size and stride, referencing a + /// memory region we do not own + __host__ HostTensor(DataPtrType data, + const IndexT sizes[Dim], + const IndexT strides[Dim]); + + /// Copies a tensor into ourselves, allocating memory for it + /// locally. If the tensor is on the GPU, then we will copy it to + /// ourselves wrt the given stream. + __host__ HostTensor(Tensor& t, + cudaStream_t stream); + + /// Call to zero out memory + __host__ HostTensor& zero(); + + /// Returns the maximum difference seen between two tensors + __host__ T + maxDiff(const HostTensor& t) const; + + /// Are the two tensors exactly equal? + __host__ bool + equal(const HostTensor& t) const { + return (maxDiff(t) == (T) 0); + } + + private: + enum AllocState { + /// This tensor itself owns the memory, which must be freed via + /// cudaFree + Owner, + + /// This tensor itself is not an owner of the memory; there is + /// nothing to free + NotOwner, + }; + + AllocState state_; +}; + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Limits.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Limits.cuh new file mode 100644 index 0000000000..7dfaa2e2ce --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Limits.cuh @@ -0,0 +1,82 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +template +struct Limits { +}; + +// Unfortunately we can't use constexpr because there is no +// constexpr constructor for half +// FIXME: faiss CPU uses +/-FLT_MAX instead of +/-infinity +constexpr float kFloatMax = std::numeric_limits::max(); +constexpr float kFloatMin = std::numeric_limits::lowest(); + +template <> +struct Limits { + static __device__ __host__ inline float getMin() { + return kFloatMin; + } + static __device__ __host__ inline float getMax() { + return kFloatMax; + } +}; + +inline __device__ __host__ half kGetHalf(unsigned short v) { +#if CUDA_VERSION >= 9000 + __half_raw h; + h.x = v; + return __half(h); +#else + half h; + h.x = v; + return h; +#endif +} + +template <> +struct Limits { + static __device__ __host__ inline half getMin() { + return kGetHalf(0xfbffU); + } + static __device__ __host__ inline half getMax() { + return kGetHalf(0x7bffU); + } +}; + +constexpr int kIntMax = std::numeric_limits::max(); +constexpr int kIntMin = std::numeric_limits::lowest(); + +template <> +struct Limits { + static __device__ __host__ inline int getMin() { + return kIntMin; + } + static __device__ __host__ inline int getMax() { + return kIntMax; + } +}; + +template +struct Limits> { + static __device__ __host__ inline Pair getMin() { + return Pair(Limits::getMin(), Limits::getMin()); + } + + static __device__ __host__ inline Pair getMax() { + return Pair(Limits::getMax(), Limits::getMax()); + } +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh new file mode 100644 index 0000000000..b49d634461 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh @@ -0,0 +1,94 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +#ifndef __HALF2_TO_UI +// cuda_fp16.hpp doesn't export this +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#endif + + +// +// Templated wrappers to express load/store for different scalar and vector +// types, so kernels can have the same written form but can operate +// over half and float, and on vector types transparently +// + +namespace faiss { namespace gpu { + +template +struct LoadStore { + static inline __device__ T load(void* p) { + return *((T*) p); + } + + static inline __device__ void store(void* p, const T& v) { + *((T*) p) = v; + } +}; + +#ifdef FAISS_USE_FLOAT16 + +template <> +struct LoadStore { + static inline __device__ Half4 load(void* p) { + Half4 out; +#if CUDA_VERSION >= 9000 + asm("ld.global.v2.u32 {%0, %1}, [%2];" : + "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b)) : "l"(p)); +#else + asm("ld.global.v2.u32 {%0, %1}, [%2];" : + "=r"(out.a.x), "=r"(out.b.x) : "l"(p)); +#endif + return out; + } + + static inline __device__ void store(void* p, Half4& v) { +#if CUDA_VERSION >= 9000 + asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), + "r"(__HALF2_TO_UI(v.a)), "r"(__HALF2_TO_UI(v.b))); +#else + asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), "r"(v.a.x), "r"(v.b.x)); +#endif + } +}; + +template <> +struct LoadStore { + static inline __device__ Half8 load(void* p) { + Half8 out; +#if CUDA_VERSION >= 9000 + asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" : + "=r"(__HALF2_TO_UI(out.a.a)), "=r"(__HALF2_TO_UI(out.a.b)), + "=r"(__HALF2_TO_UI(out.b.a)), "=r"(__HALF2_TO_UI(out.b.b)) : "l"(p)); +#else + asm("ld.global.v4.u32 {%0, %1, %2, %3}, [%4];" : + "=r"(out.a.a.x), "=r"(out.a.b.x), + "=r"(out.b.a.x), "=r"(out.b.b.x) : "l"(p)); +#endif + return out; + } + + static inline __device__ void store(void* p, Half8& v) { +#if CUDA_VERSION >= 9000 + asm("st.v4.u32 [%0], {%1, %2, %3, %4};" + : : "l"(p), "r"(__HALF2_TO_UI(v.a.a)), "r"(__HALF2_TO_UI(v.a.b)), + "r"(__HALF2_TO_UI(v.b.a)), "r"(__HALF2_TO_UI(v.b.b))); +#else + asm("st.v4.u32 [%0], {%1, %2, %3, %4};" + : : "l"(p), "r"(v.a.a.x), "r"(v.a.b.x), "r"(v.b.a.x), "r"(v.b.b.x)); +#endif + } +}; + +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh new file mode 100644 index 0000000000..7e9f25a2a0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh @@ -0,0 +1,561 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +// +// Templated wrappers to express math for different scalar and vector +// types, so kernels can have the same written form but can operate +// over half and float, and on vector types transparently +// + +namespace faiss { namespace gpu { + +template +struct Math { + typedef T ScalarType; + + static inline __device__ T add(T a, T b) { + return a + b; + } + + static inline __device__ T sub(T a, T b) { + return a - b; + } + + static inline __device__ T mul(T a, T b) { + return a * b; + } + + static inline __device__ T neg(T v) { + return -v; + } + + /// For a vector type, this is a horizontal add, returning sum(v_i) + static inline __device__ float reduceAdd(T v) { + return ConvertTo::to(v); + } + + static inline __device__ bool lt(T a, T b) { + return a < b; + } + + static inline __device__ bool gt(T a, T b) { + return a > b; + } + + static inline __device__ bool eq(T a, T b) { + return a == b; + } + + static inline __device__ T zero() { + return (T) 0; + } +}; + +template <> +struct Math { + typedef float ScalarType; + + static inline __device__ float2 add(float2 a, float2 b) { + float2 v; + v.x = a.x + b.x; + v.y = a.y + b.y; + return v; + } + + static inline __device__ float2 sub(float2 a, float2 b) { + float2 v; + v.x = a.x - b.x; + v.y = a.y - b.y; + return v; + } + + static inline __device__ float2 add(float2 a, float b) { + float2 v; + v.x = a.x + b; + v.y = a.y + b; + return v; + } + + static inline __device__ float2 sub(float2 a, float b) { + float2 v; + v.x = a.x - b; + v.y = a.y - b; + return v; + } + + static inline __device__ float2 mul(float2 a, float2 b) { + float2 v; + v.x = a.x * b.x; + v.y = a.y * b.y; + return v; + } + + static inline __device__ float2 mul(float2 a, float b) { + float2 v; + v.x = a.x * b; + v.y = a.y * b; + return v; + } + + static inline __device__ float2 neg(float2 v) { + v.x = -v.x; + v.y = -v.y; + return v; + } + + /// For a vector type, this is a horizontal add, returning sum(v_i) + static inline __device__ float reduceAdd(float2 v) { + return v.x + v.y; + } + + // not implemented for vector types + // static inline __device__ bool lt(float2 a, float2 b); + // static inline __device__ bool gt(float2 a, float2 b); + // static inline __device__ bool eq(float2 a, float2 b); + + static inline __device__ float2 zero() { + float2 v; + v.x = 0.0f; + v.y = 0.0f; + return v; + } +}; + +template <> +struct Math { + typedef float ScalarType; + + static inline __device__ float4 add(float4 a, float4 b) { + float4 v; + v.x = a.x + b.x; + v.y = a.y + b.y; + v.z = a.z + b.z; + v.w = a.w + b.w; + return v; + } + + static inline __device__ float4 sub(float4 a, float4 b) { + float4 v; + v.x = a.x - b.x; + v.y = a.y - b.y; + v.z = a.z - b.z; + v.w = a.w - b.w; + return v; + } + + static inline __device__ float4 add(float4 a, float b) { + float4 v; + v.x = a.x + b; + v.y = a.y + b; + v.z = a.z + b; + v.w = a.w + b; + return v; + } + + static inline __device__ float4 sub(float4 a, float b) { + float4 v; + v.x = a.x - b; + v.y = a.y - b; + v.z = a.z - b; + v.w = a.w - b; + return v; + } + + static inline __device__ float4 mul(float4 a, float4 b) { + float4 v; + v.x = a.x * b.x; + v.y = a.y * b.y; + v.z = a.z * b.z; + v.w = a.w * b.w; + return v; + } + + static inline __device__ float4 mul(float4 a, float b) { + float4 v; + v.x = a.x * b; + v.y = a.y * b; + v.z = a.z * b; + v.w = a.w * b; + return v; + } + + static inline __device__ float4 neg(float4 v) { + v.x = -v.x; + v.y = -v.y; + v.z = -v.z; + v.w = -v.w; + return v; + } + + /// For a vector type, this is a horizontal add, returning sum(v_i) + static inline __device__ float reduceAdd(float4 v) { + return v.x + v.y + v.z + v.w; + } + + // not implemented for vector types + // static inline __device__ bool lt(float4 a, float4 b); + // static inline __device__ bool gt(float4 a, float4 b); + // static inline __device__ bool eq(float4 a, float4 b); + + static inline __device__ float4 zero() { + float4 v; + v.x = 0.0f; + v.y = 0.0f; + v.z = 0.0f; + v.w = 0.0f; + return v; + } +}; + +#ifdef FAISS_USE_FLOAT16 +template <> +struct Math { + typedef half ScalarType; + + static inline __device__ half add(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hadd(a, b); +#else + return __float2half(__half2float(a) + __half2float(b)); +#endif + } + + static inline __device__ half sub(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hsub(a, b); +#else + return __float2half(__half2float(a) - __half2float(b)); +#endif + } + + static inline __device__ half mul(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hmul(a, b); +#else + return __float2half(__half2float(a) * __half2float(b)); +#endif + } + + static inline __device__ half neg(half v) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hneg(v); +#else + return __float2half(-__half2float(v)); +#endif + } + + static inline __device__ float reduceAdd(half v) { + return ConvertTo::to(v); + } + + static inline __device__ bool lt(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hlt(a, b); +#else + return __half2float(a) < __half2float(b); +#endif + } + + static inline __device__ bool gt(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hgt(a, b); +#else + return __half2float(a) > __half2float(b); +#endif + } + + static inline __device__ bool eq(half a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __heq(a, b); +#else + return __half2float(a) == __half2float(b); +#endif + } + + static inline __device__ half zero() { +#if CUDA_VERSION >= 9000 + return 0; +#else + half h; + h.x = 0; + return h; +#endif + } +}; + +template <> +struct Math { + typedef half ScalarType; + + static inline __device__ half2 add(half2 a, half2 b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hadd2(a, b); +#else + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + + af.x += bf.x; + af.y += bf.y; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 sub(half2 a, half2 b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hsub2(a, b); +#else + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + + af.x -= bf.x; + af.y -= bf.y; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 add(half2 a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + half2 b2 = __half2half2(b); + return __hadd2(a, b2); +#else + float2 af = __half22float2(a); + float bf = __half2float(b); + + af.x += bf; + af.y += bf; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 sub(half2 a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + half2 b2 = __half2half2(b); + return __hsub2(a, b2); +#else + float2 af = __half22float2(a); + float bf = __half2float(b); + + af.x -= bf; + af.y -= bf; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 mul(half2 a, half2 b) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hmul2(a, b); +#else + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + + af.x *= bf.x; + af.y *= bf.y; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 mul(half2 a, half b) { +#ifdef FAISS_USE_FULL_FLOAT16 + half2 b2 = __half2half2(b); + return __hmul2(a, b2); +#else + float2 af = __half22float2(a); + float bf = __half2float(b); + + af.x *= bf; + af.y *= bf; + + return __float22half2_rn(af); +#endif + } + + static inline __device__ half2 neg(half2 v) { +#ifdef FAISS_USE_FULL_FLOAT16 + return __hneg2(v); +#else + float2 vf = __half22float2(v); + vf.x = -vf.x; + vf.y = -vf.y; + + return __float22half2_rn(vf); +#endif + } + + static inline __device__ float reduceAdd(half2 v) { + float2 vf = __half22float2(v); + vf.x += vf.y; + + return vf.x; + } + + // not implemented for vector types + // static inline __device__ bool lt(half2 a, half2 b); + // static inline __device__ bool gt(half2 a, half2 b); + // static inline __device__ bool eq(half2 a, half2 b); + + static inline __device__ half2 zero() { + return __half2half2(Math::zero()); + } +}; + +template <> +struct Math { + typedef half ScalarType; + + static inline __device__ Half4 add(Half4 a, Half4 b) { + Half4 h; + h.a = Math::add(a.a, b.a); + h.b = Math::add(a.b, b.b); + return h; + } + + static inline __device__ Half4 sub(Half4 a, Half4 b) { + Half4 h; + h.a = Math::sub(a.a, b.a); + h.b = Math::sub(a.b, b.b); + return h; + } + + static inline __device__ Half4 add(Half4 a, half b) { + Half4 h; + h.a = Math::add(a.a, b); + h.b = Math::add(a.b, b); + return h; + } + + static inline __device__ Half4 sub(Half4 a, half b) { + Half4 h; + h.a = Math::sub(a.a, b); + h.b = Math::sub(a.b, b); + return h; + } + + static inline __device__ Half4 mul(Half4 a, Half4 b) { + Half4 h; + h.a = Math::mul(a.a, b.a); + h.b = Math::mul(a.b, b.b); + return h; + } + + static inline __device__ Half4 mul(Half4 a, half b) { + Half4 h; + h.a = Math::mul(a.a, b); + h.b = Math::mul(a.b, b); + return h; + } + + static inline __device__ Half4 neg(Half4 v) { + Half4 h; + h.a = Math::neg(v.a); + h.b = Math::neg(v.b); + return h; + } + + static inline __device__ float reduceAdd(Half4 v) { + float x = Math::reduceAdd(v.a); + float y = Math::reduceAdd(v.b); + return x + y; + } + + // not implemented for vector types + // static inline __device__ bool lt(Half4 a, Half4 b); + // static inline __device__ bool gt(Half4 a, Half4 b); + // static inline __device__ bool eq(Half4 a, Half4 b); + + static inline __device__ Half4 zero() { + Half4 h; + h.a = Math::zero(); + h.b = Math::zero(); + return h; + } +}; + +template <> +struct Math { + typedef half ScalarType; + + static inline __device__ Half8 add(Half8 a, Half8 b) { + Half8 h; + h.a = Math::add(a.a, b.a); + h.b = Math::add(a.b, b.b); + return h; + } + + static inline __device__ Half8 sub(Half8 a, Half8 b) { + Half8 h; + h.a = Math::sub(a.a, b.a); + h.b = Math::sub(a.b, b.b); + return h; + } + + static inline __device__ Half8 add(Half8 a, half b) { + Half8 h; + h.a = Math::add(a.a, b); + h.b = Math::add(a.b, b); + return h; + } + + static inline __device__ Half8 sub(Half8 a, half b) { + Half8 h; + h.a = Math::sub(a.a, b); + h.b = Math::sub(a.b, b); + return h; + } + + static inline __device__ Half8 mul(Half8 a, Half8 b) { + Half8 h; + h.a = Math::mul(a.a, b.a); + h.b = Math::mul(a.b, b.b); + return h; + } + + static inline __device__ Half8 mul(Half8 a, half b) { + Half8 h; + h.a = Math::mul(a.a, b); + h.b = Math::mul(a.b, b); + return h; + } + + static inline __device__ Half8 neg(Half8 v) { + Half8 h; + h.a = Math::neg(v.a); + h.b = Math::neg(v.b); + return h; + } + + static inline __device__ half reduceAdd(Half8 v) { + float x = Math::reduceAdd(v.a); + float y = Math::reduceAdd(v.b); + return x + y; + } + + // not implemented for vector types + // static inline __device__ bool lt(Half8 a, Half8 b); + // static inline __device__ bool gt(Half8 a, Half8 b); + // static inline __device__ bool eq(Half8 a, Half8 b); + + static inline __device__ Half8 zero() { + Half8 h; + h.a = Math::zero(); + h.b = Math::zero(); + return h; + } +}; +#endif // FAISS_USE_FLOAT16 + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult-inl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult-inl.cuh new file mode 100644 index 0000000000..ede225e035 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult-inl.cuh @@ -0,0 +1,160 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class DeviceMemory; + +template +struct GetCudaType; + +template <> +struct GetCudaType { + static constexpr cudaDataType_t Type = CUDA_R_32F; +}; + +template <> +struct GetCudaType { + static constexpr cudaDataType_t Type = CUDA_R_16F; +}; + +template +cublasStatus_t +rawGemm(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float fAlpha, + const AT *A, + int lda, + const BT *B, + int ldb, + const float fBeta, + float *C, + int ldc) { + auto cAT = GetCudaType::Type; + auto cBT = GetCudaType::Type; + + // Always accumulate in f32 + return cublasSgemmEx(handle, transa, transb, m, n, k, + &fAlpha, A, cAT, lda, + B, cBT, ldb, + &fBeta, + C, CUDA_R_32F, ldc); +} + +template +void +runMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + cublasHandle_t handle, + cudaStream_t stream) { + cublasSetStream(handle, stream); + + // Check that we have (m x k) * (k x n) = (m x n) + // using the input row-major layout + int aM = transA ? a.getSize(1) : a.getSize(0); + int aK = transA ? a.getSize(0) : a.getSize(1); + + int bK = transB ? b.getSize(1) : b.getSize(0); + int bN = transB ? b.getSize(0) : b.getSize(1); + + int cM = transC ? c.getSize(1) : c.getSize(0); + int cN = transC ? c.getSize(0) : c.getSize(1); + + FAISS_ASSERT(aM == cM); + FAISS_ASSERT(aK == bK); + FAISS_ASSERT(bN == cN); + + FAISS_ASSERT(a.getStride(1) == 1); + FAISS_ASSERT(b.getStride(1) == 1); + FAISS_ASSERT(c.getStride(1) == 1); + + // Now, we have to represent the matrix multiplication in + // column-major layout + float* pC = c.data(); + + int m = c.getSize(1); // stride 1 size + int n = c.getSize(0); // other size + int k = transA ? a.getSize(0) : a.getSize(1); + + int lda = transC ? a.getStride(0) : b.getStride(0); + int ldb = transC ? b.getStride(0) : a.getStride(0); + int ldc = c.getStride(0); + + auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + + if (transC) { + gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T; + gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T; + } + + cublasStatus_t err; + + if (transC) { + err = rawGemm(handle, + gemmTrA, gemmTrB, + m, n, k, alpha, + a.data(), lda, b.data(), ldb, beta, + pC, ldc); + } else { + err = rawGemm(handle, + gemmTrA, gemmTrB, + m, n, k, alpha, + b.data(), lda, a.data(), ldb, beta, + pC, ldc); + } + + FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS, + "cublas failed (%d): " + "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s", + (int) err, + a.getSize(0), a.getSize(1), transA ? "'" : "", + b.getSize(0), b.getSize(1), transB ? "'" : "", + c.getSize(0), c.getSize(1), transC ? "'" : ""); + CUDA_TEST_ERROR(); +} + +template +void runIteratedMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + cublasHandle_t handle, + cudaStream_t stream) { + FAISS_ASSERT(c.getSize(0) == a.getSize(0)); + FAISS_ASSERT(a.getSize(0) == b.getSize(0)); + + for (int i = 0; i < a.getSize(0); ++i) { + auto cView = c[i].view(); + auto aView = a[i].view(); + auto bView = b[i].view(); + + runMatrixMult(cView, transC, + aView, transA, + bView, transB, + alpha, beta, handle, stream); + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cu b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cu new file mode 100644 index 0000000000..2afb5017b2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cu @@ -0,0 +1,94 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +namespace faiss { namespace gpu { + +void +runBatchMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + DeviceMemory& mem, + cublasHandle_t handle, + cudaStream_t stream) { + FAISS_ASSERT(c.getSize(0) == a.getSize(0)); + FAISS_ASSERT(a.getSize(0) == b.getSize(0)); + cublasSetStream(handle, stream); + + // Check that we have (m x k) * (k x n) = (m x n) + // using the input row-major layout + int aM = transA ? a.getSize(2) : a.getSize(1); + int aK = transA ? a.getSize(1) : a.getSize(2); + + int bK = transB ? b.getSize(2) : b.getSize(1); + int bN = transB ? b.getSize(1) : b.getSize(2); + + int cM = transC ? c.getSize(2) : c.getSize(1); + int cN = transC ? c.getSize(1) : c.getSize(2); + + FAISS_ASSERT(aM == cM); + FAISS_ASSERT(aK == bK); + FAISS_ASSERT(bN == cN); + + // Now, we have to represent the matrix multiplication in + // column-major layout + float* pA = transC ? a.data() : b.data(); + float* pB = transC ? b.data() : a.data(); + float* pC = c.data(); + + int m = c.getSize(2); // stride 1 size + int n = c.getSize(1); // other size + int k = transA ? a.getSize(1) : a.getSize(2); + + int lda = transC ? a.getStride(1) : b.getStride(1); + int ldb = transC ? b.getStride(1) : a.getStride(1); + int ldc = c.getStride(1); + + auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + + if (transC) { + gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T; + gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T; + } + + HostTensor hostA({a.getSize(0)}); + HostTensor hostB({b.getSize(0)}); + HostTensor hostC({c.getSize(0)}); + + size_t aOffset = a.getStride(0); + size_t bOffset = b.getStride(0); + size_t cOffset = c.getStride(0); + + for (int i = 0; i < a.getSize(0); ++i) { + hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset; + hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset; + hostC[i] = c.data() + i * cOffset; + } + + DeviceTensor deviceA(mem, hostA, stream); + DeviceTensor deviceB(mem, hostB, stream); + DeviceTensor deviceC(mem, hostC, stream); + + auto err = + cublasSgemmBatched(handle, + gemmTrA, gemmTrB, + m, n, k, &alpha, + (const float**) deviceA.data(), lda, + (const float**) deviceB.data(), ldb, &beta, + deviceC.data(), ldc, a.getSize(0)); + FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS, + "cublasSgemmBatched failed (%d)", (int) err); + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cuh new file mode 100644 index 0000000000..eeb11ccc5c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MatrixMult.cuh @@ -0,0 +1,59 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +class DeviceMemory; + +/// C = alpha * A * B + beta * C +/// Expects row major layout, not fortran/blas column major! +template +void +runMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + cublasHandle_t handle, + cudaStream_t stream); + +/// C_i = alpha * A_i * B_i + beta * C_i +/// where `i` is the outermost dimension, via iterated gemm +/// Expects row major layout, not fortran/blas column major! +template +void runIteratedMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + cublasHandle_t handle, + cudaStream_t stream); + +/// C_i = alpha * A_i * B_i + beta * C_i +/// where `i` is the outermost dimension, via batched gemm +/// Expects row major layout, not fortran/blas column major! +void runBatchMatrixMult(Tensor& c, bool transC, + Tensor& a, bool transA, + Tensor& b, bool transB, + float alpha, + float beta, + DeviceMemory& mem, + cublasHandle_t handle, + cudaStream_t stream); + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.cpp b/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.cpp new file mode 100644 index 0000000000..282f835784 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.cpp @@ -0,0 +1,89 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Allocates CUDA memory for a given memory space +void allocMemorySpaceV(MemorySpace space, void** p, size_t size) { + switch (space) { + case MemorySpace::Device: + { + auto err = cudaMalloc(p, size); + + // Throw if we fail to allocate + FAISS_THROW_IF_NOT_FMT( + err == cudaSuccess, + "failed to cudaMalloc %zu bytes (error %d %s)", + size, (int) err, cudaGetErrorString(err)); + } + break; + case MemorySpace::Unified: + { +#ifdef FAISS_UNIFIED_MEM + auto err = cudaMallocManaged(p, size); + + // Throw if we fail to allocate + FAISS_THROW_IF_NOT_FMT( + err == cudaSuccess, + "failed to cudaMallocManaged %zu bytes (error %d %s)", + size, (int) err, cudaGetErrorString(err)); +#else + FAISS_THROW_MSG("Attempting to allocate via cudaMallocManaged " + "without CUDA 8+ support"); +#endif + } + break; + case MemorySpace::HostPinned: + { + auto err = cudaHostAlloc(p, size, cudaHostAllocDefault); + + // Throw if we fail to allocate + FAISS_THROW_IF_NOT_FMT( + err == cudaSuccess, + "failed to cudaHostAlloc %zu bytes (error %d %s)", + size, (int) err, cudaGetErrorString(err)); + } + break; + default: + FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int) space); + break; + } +} + +// We'll allow allocation to fail, but free should always succeed and be a +// fatal error if it doesn't free +void freeMemorySpace(MemorySpace space, void* p) { + switch (space) { + case MemorySpace::Device: + case MemorySpace::Unified: + { + auto err = cudaFree(p); + FAISS_ASSERT_FMT(err == cudaSuccess, + "Failed to cudaFree pointer %p (error %d %s)", + p, (int) err, cudaGetErrorString(err)); + } + break; + case MemorySpace::HostPinned: + { + auto err = cudaFreeHost(p); + FAISS_ASSERT_FMT(err == cudaSuccess, + "Failed to cudaFreeHost pointer %p (error %d %s)", + p, (int) err, cudaGetErrorString(err)); + } + break; + default: + FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int) space); + break; + } +} + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.h b/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.h new file mode 100644 index 0000000000..f269f06a39 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MemorySpace.h @@ -0,0 +1,44 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +#if CUDA_VERSION >= 8000 +// Whether or not we enable usage of CUDA Unified Memory +#define FAISS_UNIFIED_MEM 1 +#endif + +namespace faiss { namespace gpu { + +enum MemorySpace { + /// Managed using cudaMalloc/cudaFree + Device = 1, + /// Managed using cudaMallocManaged/cudaFree + Unified = 2, + /// Managed using cudaHostAlloc/cudaFreeHost + HostPinned = 3, +}; + +/// All memory allocations and de-allocations come through these functions + +/// Allocates CUDA memory for a given memory space (void pointer) +/// Throws a FaissException if we are unable to allocate the memory +void allocMemorySpaceV(MemorySpace space, void** p, size_t size); + +template +inline void allocMemorySpace(MemorySpace space, T** p, size_t size) { + allocMemorySpaceV(space, (void**)(void*) p, size); +} + +/// Frees CUDA memory for a given memory space +/// Asserts if we are unable to free the region +void freeMemorySpace(MemorySpace space, void* p); + +} } diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkBlock.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkBlock.cuh new file mode 100644 index 0000000000..2776258b57 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkBlock.cuh @@ -0,0 +1,289 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Merge pairs of lists smaller than blockDim.x (NumThreads) +template +inline __device__ void blockMergeSmall(K* listK, V* listV) { + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(utils::isPowerOf2(NumThreads), + "NumThreads must be a power-of-2"); + static_assert(L <= NumThreads, "merge list size must be <= NumThreads"); + + // Which pair of lists we are merging + int mergeId = threadIdx.x / L; + + // Which thread we are within the merge + int tid = threadIdx.x % L; + + // listK points to a region of size N * 2 * L + listK += 2 * L * mergeId; + listV += 2 * L * mergeId; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { + int pos = 2 * tid - (tid & (stride - 1)); + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +// Merge pairs of sorted lists larger than blockDim.x (NumThreads) +template +inline __device__ void blockMergeLarge(K* listK, V* listV) { + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(L >= kWarpSize, "merge list size must be >= 32"); + static_assert(utils::isPowerOf2(NumThreads), + "NumThreads must be a power-of-2"); + static_assert(L >= NumThreads, "merge list size must be >= NumThreads"); + + // For L > NumThreads, each thread has to perform more work + // per each stride. + constexpr int kLoopPerThread = L / NumThreads; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed +#pragma unroll + for (int loop = 0; loop < kLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + + constexpr int kSecondLoopPerThread = + FullMerge ? kLoopPerThread : kLoopPerThread / 2; + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { +#pragma unroll + for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = 2 * tid - (tid & (stride - 1)); + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +/// Class template to prevent static_assert from firing for +/// mixing smaller/larger than block cases +template +struct BlockMerge { +}; + +/// Merging lists smaller than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) { + constexpr int kNumParallelMerges = NumThreads / L; + constexpr int kNumIterations = N / kNumParallelMerges; + + static_assert(L <= NumThreads, "list must be <= NumThreads"); + static_assert((N < kNumParallelMerges) || + (kNumIterations * kNumParallelMerges == N), + "improper selection of N and L"); + + if (N < kNumParallelMerges) { + // We only need L threads per each list to perform the merge + blockMergeSmall( + listK, listV); + } else { + // All threads participate +#pragma unroll + for (int i = 0; i < kNumIterations; ++i) { + int start = i * kNumParallelMerges * 2 * L; + + blockMergeSmall( + listK + start, listV + start); + } + } + } +}; + +/// Merging lists larger than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) { + // Each pair of lists is merged sequentially +#pragma unroll + for (int i = 0; i < N; ++i) { + int start = i * 2 * L; + + blockMergeLarge( + listK + start, listV + start); + } + } +}; + +template +inline __device__ void blockMerge(K* listK, V* listV) { + constexpr bool kSmallerThanBlock = (L <= NumThreads); + + BlockMerge:: + merge(listK, listV); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkUtils.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkUtils.cuh new file mode 100644 index 0000000000..6810345226 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkUtils.cuh @@ -0,0 +1,24 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace faiss { namespace gpu { + +template +inline __device__ void swap(bool swap, T& x, T& y) { + T tmp = x; + x = swap ? y : x; + y = swap ? tmp : y; +} + +template +inline __device__ void assign(bool assign, T& x, T y) { + x = assign ? y : x; +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkWarp.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkWarp.cuh new file mode 100644 index 0000000000..4e486b025f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/MergeNetworkWarp.cuh @@ -0,0 +1,510 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// +// This file contains functions to: +// +// -perform bitonic merges on pairs of sorted lists, held in +// registers. Each list contains N * kWarpSize (multiple of 32) +// elements for some N. +// The bitonic merge is implemented for arbitrary sizes; +// sorted list A of size N1 * kWarpSize registers +// sorted list B of size N2 * kWarpSize registers => +// sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2 +// are >= 1 and don't have to be powers of 2. +// +// -perform bitonic sorts on a set of N * kWarpSize key/value pairs +// held in registers, by using the above bitonic merge as a +// primitive. +// N can be an arbitrary N >= 1; i.e., the bitonic sort here supports +// odd sizes and doesn't require the input to be a power of 2. +// +// The sort or merge network is completely statically instantiated via +// template specialization / expansion and constexpr, and it uses warp +// shuffles to exchange values between warp lanes. +// +// A note about comparsions: +// +// For a sorting network of keys only, we only need one +// comparison (a < b). However, what we really need to know is +// if one lane chooses to exchange a value, then the +// corresponding lane should also do the exchange. +// Thus, if one just uses the negation !(x < y) in the higher +// lane, this will also include the case where (x == y). Thus, one +// lane in fact performs an exchange and the other doesn't, but +// because the only value being exchanged is equivalent, nothing has +// changed. +// So, you can get away with just one comparison and its negation. +// +// If we're sorting keys and values, where equivalent keys can +// exist, then this is a problem, since we want to treat (x, v1) +// as not equivalent to (x, v2). +// +// To remedy this, you can either compare with a lexicographic +// ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since +// we're predicating all of the choices results in 3 comparisons +// being executed, or we can invert the selection so that there is no +// middle choice of equality; the other lane will likewise +// check that (b.k > a.k) (the higher lane has the values +// swapped). Then, the first lane swaps if and only if the +// second lane swaps; if both lanes have equivalent keys, no +// swap will be performed. This results in only two comparisons +// being executed. +// +// If you don't consider values as well, then this does not produce a +// consistent ordering among (k, v) pairs with equivalent keys but +// different values; for us, we don't really care about ordering or +// stability here. +// +// I have tried both re-arranging the order in the higher lane to get +// away with one comparison or adding the value to the check; both +// result in greater register consumption or lower speed than just +// perfoming both < and > comparisons with the variables, so I just +// stick with this. + +// This function merges kWarpSize / 2L lists in parallel using warp +// shuffles. +// It works on at most size-16 lists, as we need 32 threads for this +// shuffle merge. +// +// If IsBitonic is false, the first stage is reversed, so we don't +// need to sort directionally. It's still technically a bitonic sort. +template +inline __device__ void warpBitonicMergeLE16(K& k, V& v) { + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(L <= kWarpSize / 2, "merge list size must be <= 16"); + + int laneId = getLaneId(); + + if (!IsBitonic) { + // Reverse the first comparison stage. + // For example, merging a list of size 8 has the exchanges: + // 0 <-> 15, 1 <-> 14, ... + K otherK = shfl_xor(k, 2 * L - 1); + V otherV = shfl_xor(v, 2 * L - 1); + + // Whether we are the lesser thread in the exchange + bool small = !(laneId & L); + + if (Dir) { + // See the comment above how performing both of these + // comparisons in the warp seems to win out over the + // alternatives in practice + bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); + assign(s, k, otherK); + assign(s, v, otherV); + + } else { + bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); + assign(s, k, otherK); + assign(s, v, otherV); + } + } + +#pragma unroll + for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { + K otherK = shfl_xor(k, stride); + V otherV = shfl_xor(v, stride); + + // Whether we are the lesser thread in the exchange + bool small = !(laneId & stride); + + if (Dir) { + bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); + assign(s, k, otherK); + assign(s, v, otherV); + + } else { + bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); + assign(s, k, otherK); + assign(s, v, otherV); + } + } +} + +// Template for performing a bitonic merge of an arbitrary set of +// registers +template +struct BitonicMergeStep { +}; + +// +// Power-of-2 merge specialization +// + +// All merges eventually call this +template +struct BitonicMergeStep { + static inline __device__ void merge(K k[1], V v[1]) { + // Use warp shuffles + warpBitonicMergeLE16(k[0], v[0]); + } +}; + +template +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { + static_assert(utils::isPowerOf2(N), "must be power of 2"); + static_assert(N > 1, "must be N > 1"); + +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + K& ka = k[i]; + V& va = v[i]; + + K& kb = k[i + N / 2]; + V& vb = v[i + N / 2]; + + bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + swap(s, ka, kb); + swap(s, va, vb); + } + + { + K newK[N / 2]; + V newV[N / 2]; + +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + K newK[N / 2]; + V newV[N / 2]; + +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + newK[i] = k[i + N / 2]; + newV[i] = v[i + N / 2]; + } + + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + k[i + N / 2] = newK[i]; + v[i + N / 2] = newV[i]; + } + } + } +}; + +// +// Non-power-of-2 merge specialization +// + +// Low recursion +template +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { + static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); + static_assert(N >= 3, "must be N >= 3"); + + constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N); + +#pragma unroll + for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { + K& ka = k[i]; + V& va = v[i]; + + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; + + bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + swap(s, ka, kb); + swap(s, va, vb); + } + + constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; + constexpr int kHighSize = kNextHighestPowerOf2 / 2; + { + K newK[kLowSize]; + V newV[kLowSize]; + +#pragma unroll + for (int i = 0; i < kLowSize; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + constexpr bool kLowIsPowerOf2 = + utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); + // FIXME: compiler doesn't like this expression? compiler bug? +// constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < kLowSize; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + K newK[kHighSize]; + V newV[kHighSize]; + +#pragma unroll + for (int i = 0; i < kHighSize; ++i) { + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; + } + + constexpr bool kHighIsPowerOf2 = + utils::isPowerOf2(kNextHighestPowerOf2 / 2); + // FIXME: compiler doesn't like this expression? compiler bug? +// constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < kHighSize; ++i) { + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; + } + } + } +}; + +// High recursion +template +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { + static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); + static_assert(N >= 3, "must be N >= 3"); + + constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N); + +#pragma unroll + for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { + K& ka = k[i]; + V& va = v[i]; + + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; + + bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + swap(s, ka, kb); + swap(s, va, vb); + } + + constexpr int kLowSize = kNextHighestPowerOf2 / 2; + constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; + { + K newK[kLowSize]; + V newV[kLowSize]; + +#pragma unroll + for (int i = 0; i < kLowSize; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + constexpr bool kLowIsPowerOf2 = + utils::isPowerOf2(kNextHighestPowerOf2 / 2); + // FIXME: compiler doesn't like this expression? compiler bug? +// constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < kLowSize; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + K newK[kHighSize]; + V newV[kHighSize]; + +#pragma unroll + for (int i = 0; i < kHighSize; ++i) { + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; + } + + constexpr bool kHighIsPowerOf2 = + utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); + // FIXME: compiler doesn't like this expression? compiler bug? +// constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); + +#pragma unroll + for (int i = 0; i < kHighSize; ++i) { + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; + } + } + } +}; + +/// Merges two sets of registers across the warp of any size; +/// i.e., merges a sorted k/v list of size kWarpSize * N1 with a +/// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any +/// value >= 1 +template +inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1], + K k2[N2], V v2[N2]) { + constexpr int kSmallestN = N1 < N2 ? N1 : N2; + +#pragma unroll + for (int i = 0; i < kSmallestN; ++i) { + K& ka = k1[N1 - 1 - i]; + V& va = v1[N1 - 1 - i]; + + K& kb = k2[i]; + V& vb = v2[i]; + + K otherKa; + V otherVa; + + if (FullMerge) { + // We need the other values + otherKa = shfl_xor(ka, kWarpSize - 1); + otherVa = shfl_xor(va, kWarpSize - 1); + } + + K otherKb = shfl_xor(kb, kWarpSize - 1); + V otherVb = shfl_xor(vb, kWarpSize - 1); + + // ka is always first in the list, so we needn't use our lane + // in this comparison + bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb); + assign(swapa, ka, otherKb); + assign(swapa, va, otherVb); + + // kb is always second in the list, so we needn't use our lane + // in this comparison + if (FullMerge) { + bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa); + assign(swapb, kb, otherKa); + assign(swapb, vb, otherVa); + + } else { + // We don't care about updating elements in the second list + } + } + + BitonicMergeStep::merge(k1, v1); + if (FullMerge) { + // Only if we care about N2 do we need to bother merging it fully + BitonicMergeStep::merge(k2, v2); + } +} + +// Recursive template that uses the above bitonic merge to perform a +// bitonic sort +template +struct BitonicSortStep { + static inline __device__ void sort(K k[N], V v[N]) { + static_assert(N > 1, "did not hit specialized case"); + + // Sort recursively + constexpr int kSizeA = N / 2; + constexpr int kSizeB = N - kSizeA; + + K aK[kSizeA]; + V aV[kSizeA]; + +#pragma unroll + for (int i = 0; i < kSizeA; ++i) { + aK[i] = k[i]; + aV[i] = v[i]; + } + + BitonicSortStep::sort(aK, aV); + + K bK[kSizeB]; + V bV[kSizeB]; + +#pragma unroll + for (int i = 0; i < kSizeB; ++i) { + bK[i] = k[i + kSizeA]; + bV[i] = v[i + kSizeA]; + } + + BitonicSortStep::sort(bK, bV); + + // Merge halves + warpMergeAnyRegisters(aK, aV, bK, bV); + +#pragma unroll + for (int i = 0; i < kSizeA; ++i) { + k[i] = aK[i]; + v[i] = aV[i]; + } + +#pragma unroll + for (int i = 0; i < kSizeB; ++i) { + k[i + kSizeA] = bK[i]; + v[i + kSizeA] = bV[i]; + } + } +}; + +// Single warp (N == 1) sorting specialization +template +struct BitonicSortStep { + static inline __device__ void sort(K k[1], V v[1]) { + // Update this code if this changes + // should go from 1 -> kWarpSize in multiples of 2 + static_assert(kWarpSize == 32, "unexpected warp size"); + + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + } +}; + +/// Sort a list of kWarpSize * N elements in registers, where N is an +/// arbitrary >= 1 +template +inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) { + BitonicSortStep::sort(k, v); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/NoTypeTensor.cuh b/core/src/index/thirdparty/faiss/gpu/utils/NoTypeTensor.cuh new file mode 100644 index 0000000000..fdbc879f35 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/NoTypeTensor.cuh @@ -0,0 +1,123 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +template +class NoTypeTensor { + public: + NoTypeTensor() + : mem_(nullptr), + typeSize_(0) { + } + + template + NoTypeTensor(Tensor& t) + : mem_(t.data()), + typeSize_(sizeof(T)) { + for (int i = 0; i < Dim; ++i) { + size_[i] = t.getSize(i); + stride_[i] = t.getStride(i); + } + } + + NoTypeTensor(void* mem, int typeSize, std::initializer_list sizes) + : mem_(mem), + typeSize_(typeSize) { + + int i = 0; + for (auto s : sizes) { + size_[i++] = s; + } + + stride_[Dim - 1] = (IndexT) 1; + for (int j = Dim - 2; j >= 0; --j) { + stride_[j] = stride_[j + 1] * size_[j + 1]; + } + } + + NoTypeTensor(void* mem, int typeSize, int sizes[Dim]) + : mem_(mem), + typeSize_(typeSize) { + for (int i = 0; i < Dim; ++i) { + size_[i] = sizes[i]; + } + + stride_[Dim - 1] = (IndexT) 1; + for (int i = Dim - 2; i >= 0; --i) { + stride_[i] = stride_[i + 1] * sizes[i + 1]; + } + } + + NoTypeTensor(void* mem, int typeSize, + IndexT sizes[Dim], IndexT strides[Dim]) + : mem_(mem), + typeSize_(typeSize) { + for (int i = 0; i < Dim; ++i) { + size_[i] = sizes[i]; + stride_[i] = strides[i]; + } + } + + int getTypeSize() const { + return typeSize_; + } + + IndexT getSize(int dim) const { + FAISS_ASSERT(dim < Dim); + return size_[dim]; + } + + IndexT getStride(int dim) const { + FAISS_ASSERT(dim < Dim); + return stride_[dim]; + } + + template + Tensor toTensor() { + FAISS_ASSERT(sizeof(T) == typeSize_); + + return Tensor((T*) mem_, size_, stride_); + } + + NoTypeTensor narrowOutermost(IndexT start, + IndexT size) { + char* newPtr = (char*) mem_; + + if (start > 0) { + newPtr += typeSize_ * start * stride_[0]; + } + + IndexT newSize[Dim]; + for (int i = 0; i < Dim; ++i) { + if (i == 0) { + assert(start + size <= size_[0]); + newSize[i] = size; + } else { + newSize[i] = size_[i]; + } + } + + return NoTypeTensor( + newPtr, typeSize_, newSize, stride_); + } + + private: + void* mem_; + int typeSize_; + IndexT size_[Dim]; + IndexT stride_[Dim]; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Pair.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Pair.cuh new file mode 100644 index 0000000000..0162c91a70 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Pair.cuh @@ -0,0 +1,69 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +/// A simple pair type for CUDA device usage +template +struct Pair { + constexpr __device__ inline Pair() { + } + + constexpr __device__ inline Pair(K key, V value) + : k(key), v(value) { + } + + __device__ inline bool + operator==(const Pair& rhs) const { + return Math::eq(k, rhs.k) && Math::eq(v, rhs.v); + } + + __device__ inline bool + operator!=(const Pair& rhs) const { + return !operator==(rhs); + } + + __device__ inline bool + operator<(const Pair& rhs) const { + return Math::lt(k, rhs.k) || + (Math::eq(k, rhs.k) && Math::lt(v, rhs.v)); + } + + __device__ inline bool + operator>(const Pair& rhs) const { + return Math::gt(k, rhs.k) || + (Math::eq(k, rhs.k) && Math::gt(v, rhs.v)); + } + + K k; + V v; +}; + +template +inline __device__ Pair shfl_up(const Pair& pair, + unsigned int delta, + int width = kWarpSize) { + return Pair(shfl_up(pair.k, delta, width), + shfl_up(pair.v, delta, width)); +} + +template +inline __device__ Pair shfl_xor(const Pair& pair, + int laneMask, + int width = kWarpSize) { + return Pair(shfl_xor(pair.k, laneMask, width), + shfl_xor(pair.v, laneMask, width)); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/PtxUtils.cuh b/core/src/index/thirdparty/faiss/gpu/utils/PtxUtils.cuh new file mode 100644 index 0000000000..d1fad3905f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/PtxUtils.cuh @@ -0,0 +1,76 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { + +__device__ __forceinline__ +unsigned int getBitfield(unsigned int val, int pos, int len) { + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; +} + +__device__ __forceinline__ +unsigned long getBitfield(unsigned long val, int pos, int len) { + unsigned long ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; +} + +__device__ __forceinline__ +unsigned int setBitfield(unsigned int val, + unsigned int toInsert, int pos, int len) { + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); + return ret; +} + +__device__ __forceinline__ int getLaneId() { + int laneId; + asm("mov.u32 %0, %laneid;" : "=r"(laneId) ); + return laneId; +} + +__device__ __forceinline__ unsigned getLaneMaskLt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +} + +__device__ __forceinline__ unsigned getLaneMaskLe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +} + +__device__ __forceinline__ unsigned getLaneMaskGt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); + return mask; +} + +__device__ __forceinline__ unsigned getLaneMaskGe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); + return mask; +} + +__device__ __forceinline__ void namedBarrierWait(int name, int numThreads) { + asm volatile("bar.sync %0, %1;" : : "r"(name), "r"(numThreads) : "memory"); +} + +__device__ __forceinline__ void namedBarrierArrived(int name, int numThreads) { + asm volatile("bar.arrive %0, %1;" : : "r"(name), "r"(numThreads) : "memory"); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/ReductionOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/ReductionOperators.cuh new file mode 100644 index 0000000000..b810fc66ea --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/ReductionOperators.cuh @@ -0,0 +1,73 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +struct Sum { + __device__ inline T operator()(T a, T b) const { + return Math::add(a, b); + } + + inline __device__ T identity() const { + return Math::zero(); + } +}; + +template +struct Min { + __device__ inline T operator()(T a, T b) const { + return Math::lt(a, b) ? a : b; + } + + inline __device__ T identity() const { + return Limits::getMax(); + } +}; + +template +struct Max { + __device__ inline T operator()(T a, T b) const { + return Math::gt(a, b) ? a : b; + } + + inline __device__ T identity() const { + return Limits::getMin(); + } +}; + +/// Used for producing segmented prefix scans; the value of the Pair +/// denotes the start of a new segment for the scan +template +struct SegmentedReduce { + inline __device__ SegmentedReduce(const ReduceOp& o) + : op(o) { + } + + __device__ + inline Pair + operator()(const Pair& a, const Pair& b) const { + return Pair(b.v ? b.k : op(a.k, b.k), + a.v || b.v); + } + + inline __device__ Pair identity() const { + return Pair(op.identity(), false); + } + + ReduceOp op; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Reductions.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Reductions.cuh new file mode 100644 index 0000000000..e99b518630 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Reductions.cuh @@ -0,0 +1,142 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +__device__ inline T warpReduceAll(T val, Op op) { +#pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val = op(val, shfl_xor(val, mask)); + } + + return val; +} + +/// Sums a register value across all warp threads +template +__device__ inline T warpReduceAllSum(T val) { + return warpReduceAll, ReduceWidth>(val, Sum()); +} + +/// Performs a block-wide reduction +template +__device__ inline T blockReduceAll(T val, Op op, T* smem) { + int laneId = getLaneId(); + int warpId = threadIdx.x / kWarpSize; + + val = warpReduceAll(val, op); + if (laneId == 0) { + smem[warpId] = val; + } + __syncthreads(); + + if (warpId == 0) { + val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] : + op.identity(); + val = warpReduceAll(val, op); + + if (BroadcastAll) { + __threadfence_block(); + + if (laneId == 0) { + smem[0] = val; + } + } + } + + if (BroadcastAll) { + __syncthreads(); + val = smem[0]; + } + + if (KillWARDependency) { + __syncthreads(); + } + + return val; +} + +/// Performs a block-wide reduction of multiple values simultaneously +template +__device__ inline void blockReduceAll(T val[Num], Op op, T* smem) { + int laneId = getLaneId(); + int warpId = threadIdx.x / kWarpSize; + +#pragma unroll + for (int i = 0; i < Num; ++i) { + val[i] = warpReduceAll(val[i], op); + } + + if (laneId == 0) { +#pragma unroll + for (int i = 0; i < Num; ++i) { + smem[warpId * Num + i] = val[i]; + } + } + + __syncthreads(); + + if (warpId == 0) { +#pragma unroll + for (int i = 0; i < Num; ++i) { + val[i] = + laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] : + op.identity(); + val[i] = warpReduceAll(val[i], op); + } + + if (BroadcastAll) { + __threadfence_block(); + + if (laneId == 0) { +#pragma unroll + for (int i = 0; i < Num; ++i) { + smem[i] = val[i]; + } + } + } + } + + if (BroadcastAll) { + __syncthreads(); +#pragma unroll + for (int i = 0; i < Num; ++i) { + val[i] = smem[i]; + } + } + + if (KillWARDependency) { + __syncthreads(); + } +} + + +/// Sums a register value across the entire block +template +__device__ inline T blockReduceAllSum(T val, T* smem) { + return blockReduceAll, BroadcastAll, KillWARDependency>( + val, Sum(), smem); +} + +template +__device__ inline void blockReduceAllSum(T vals[Num], T* smem) { + return blockReduceAll, BroadcastAll, KillWARDependency>( + vals, Sum(), smem); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Select.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Select.cuh new file mode 100644 index 0000000000..0dad487140 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Select.cuh @@ -0,0 +1,563 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +// Specialization for block-wide monotonic merges producing a merge sort +// since what we really want is a constexpr loop expansion +template +struct FinalBlockMerge { +}; + +template +struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) { + // no merge required; single warp + } +}; + +template +struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) { + // Final merge doesn't need to fully merge the second list + blockMerge(sharedK, sharedV); + } +}; + +template +struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) { + blockMerge(sharedK, sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge(sharedK, sharedV); + } +}; + +template +struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) { + blockMerge(sharedK, sharedV); + blockMerge(sharedK, sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge(sharedK, sharedV); + } +}; + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + static constexpr int kTotalWarpSortSize = NumWarpQ; + + __device__ inline BlockSelect(K initKVal, + V initVVal, + K* smemK, + V* smemV, + int k) : + initK(initKVal), + initV(initVVal), + numVals(0), + warpKTop(initKVal), + sharedK(smemK), + sharedV(smemV), + kMinus1(k - 1) { + static_assert(utils::isPowerOf2(ThreadsPerBlock), + "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), + "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + int laneId = getLaneId(); + int warpId = threadIdx.x / kWarpSize; + warpK = sharedK + warpId * kTotalWarpSortSize; + warpV = sharedV + warpId * kTotalWarpSortSize; + + // Fill warp queue (only the actual queue space is fine, not where + // we write the per-thread queues for merging) + for (int i = laneId; i < NumWarpQ; i += kWarpSize) { + warpK[i] = initK; + warpV[i] = initV; + } + + warpFence(); + } + + __device__ inline void addThreadQ(K k, V v) { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + threadK[numVals] = k; + threadV[numVals ++] = v; + } + } + + __device__ inline void checkThreadQ() { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + // This has a trailing warpFence + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = warpK[kMinus1]; + + warpFence(); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() { + int laneId = getLaneId(); + + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize; + K warpKRegisters[kNumWarpQRegisters]; + V warpVRegisters[kNumWarpQRegisters]; + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpKRegisters[i] = warpK[i * kWarpSize + laneId]; + warpVRegisters[i] = warpV[i * kWarpSize + laneId]; + } + + warpFence(); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpKRegisters, warpVRegisters, threadK, threadV); + + // Write back out the warp queue +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i * kWarpSize + laneId] = warpKRegisters[i]; + warpV[i * kWarpSize + laneId] = warpVRegisters[i]; + } + + warpFence(); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + + // block-wide dep; thus far, all warps have been completely + // independent + __syncthreads(); + + // All warp queues are contiguous in smem. + // Now, we have kNumWarps lists of NumWarpQ elements. + // This is a power of 2. + FinalBlockMerge:: + merge(sharedK, sharedV); + + // The block-wide merge has a trailing syncthreads + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // Queues for all warps + K* sharedK; + V* sharedV; + + // Our warp's queue (points into sharedK/sharedV) + // warpK[0] is highest (Dir) or lowest (!Dir) + K* warpK; + V* warpV; + + // This is a cached k-1 value + int kMinus1; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) : + sharedK(smemK), + sharedV(smemV), + threadK(initK), + threadV(initV) { + } + + __device__ inline void addThreadQ(K k, V v) { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { + addThreadQ(k, v); + } + + __device__ inline void reduce() { + // Reduce within the warp + Pair pair(threadK, threadV); + + if (Dir) { + pair = + warpReduceAll, Max>>(pair, Max>()); + } else { + pair = + warpReduceAll, Min>>(pair, Min>()); + } + + // Each warp writes out a single value + int laneId = getLaneId(); + int warpId = threadIdx.x / kWarpSize; + + if (laneId == 0) { + sharedK[warpId] = pair.k; + sharedV[warpId] = pair.v; + } + + __syncthreads(); + + // We typically use this for small blocks (<= 128), just having the first + // thread in the block perform the reduction across warps is + // faster + if (threadIdx.x == 0) { + threadK = sharedK[0]; + threadV = sharedV[0]; + +#pragma unroll + for (int i = 1; i < kNumWarps; ++i) { + K k = sharedK[i]; + V v = sharedV[i]; + + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + // Hopefully a thread's smem reads/writes are ordered wrt + // itself, so no barrier needed :) + sharedK[0] = threadK; + sharedV[0] = threadV; + } + + // In case other threads wish to read this value + __syncthreads(); + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; + + // Where we reduce in smem + K* sharedK; + V* sharedV; +}; + +// +// per-warp WarpSelect +// + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct WarpSelect { + static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize; + + __device__ inline WarpSelect(K initKVal, V initVVal, int k) : + initK(initKVal), + initV(initVVal), + numVals(0), + warpKTop(initKVal), + kLane((k - 1) % kWarpSize) { + static_assert(utils::isPowerOf2(ThreadsPerBlock), + "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), + "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // Fill the warp queue with the default value +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i] = initK; + warpV[i] = initV; + } + } + + __device__ inline void addThreadQ(K k, V v) { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + // Rotate right +#pragma unroll + for (int i = NumThreadQ - 1; i > 0; --i) { + threadK[i] = threadK[i - 1]; + threadV[i] = threadV[i - 1]; + } + + threadK[0] = k; + threadV[0] = v; + ++numVals; + } + } + + __device__ inline void checkThreadQ() { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() { + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpK, warpV, threadK, threadV); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) { + int laneId = getLaneId(); + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + int idx = i * kWarpSize + laneId; + + if (idx < k) { + outK[idx] = warpK[i]; + outV[idx] = warpV[i]; + } + } + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // warpK[0] is highest (Dir) or lowest (!Dir) + K warpK[kNumWarpQRegisters]; + V warpV[kNumWarpQRegisters]; + + // This is what lane we should load an approximation (>=k) to the + // kth element from the last register in the warp queue (i.e., + // warpK[kNumWarpQRegisters - 1]). + int kLane; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct WarpSelect { + static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + __device__ inline WarpSelect(K initK, V initV, int k) : + threadK(initK), + threadV(initV) { + } + + __device__ inline void addThreadQ(K k, V v) { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { + addThreadQ(k, v); + } + + __device__ inline void reduce() { + // Reduce within the warp + Pair pair(threadK, threadV); + + if (Dir) { + pair = + warpReduceAll, Max>>(pair, Max>()); + } else { + pair = + warpReduceAll, Min>>(pair, Min>()); + } + + threadK = pair.k; + threadV = pair.v; + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) { + if (getLaneId() == 0) { + *outK = threadK; + *outV = threadV; + } + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.cpp b/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.cpp new file mode 100644 index 0000000000..18b8e04cff --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.cpp @@ -0,0 +1,239 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +StackDeviceMemory::Stack::Stack(int d, size_t sz) + : device_(d), + isOwner_(true), + start_(nullptr), + end_(nullptr), + size_(sz), + head_(nullptr), + mallocCurrent_(0), + highWaterMemoryUsed_(0), + highWaterMalloc_(0), + cudaMallocWarning_(true) { + DeviceScope s(device_); + + allocMemorySpace(MemorySpace::Device, &start_, size_); + + head_ = start_; + end_ = start_ + size_; +} + +StackDeviceMemory::Stack::Stack(int d, void* p, size_t sz, bool isOwner) + : device_(d), + isOwner_(isOwner), + start_((char*) p), + end_(((char*) p) + sz), + size_(sz), + head_((char*) p), + mallocCurrent_(0), + highWaterMemoryUsed_(0), + highWaterMalloc_(0), + cudaMallocWarning_(true) { +} + +StackDeviceMemory::Stack::~Stack() { + if (isOwner_) { + DeviceScope s(device_); + + freeMemorySpace(MemorySpace::Device, start_); + } +} + +size_t +StackDeviceMemory::Stack::getSizeAvailable() const { + return (end_ - head_); +} + +char* +StackDeviceMemory::Stack::getAlloc(size_t size, + cudaStream_t stream) { + if (size > (end_ - head_)) { + // Too large for our stack + DeviceScope s(device_); + + if (cudaMallocWarning_) { + // Print our requested size before we attempt the allocation + fprintf(stderr, "WARN: increase temp memory to avoid cudaMalloc, " + "or decrease query/add size (alloc %zu B, highwater %zu B)\n", + size, highWaterMalloc_); + } + + char* p = nullptr; + allocMemorySpace(MemorySpace::Device, &p, size); + + mallocCurrent_ += size; + highWaterMalloc_ = std::max(highWaterMalloc_, mallocCurrent_); + + return p; + } else { + // We can make the allocation out of our stack + // Find all the ranges that we overlap that may have been + // previously allocated; our allocation will be [head, endAlloc) + char* startAlloc = head_; + char* endAlloc = head_ + size; + + while (lastUsers_.size() > 0) { + auto& prevUser = lastUsers_.back(); + + // Because there is a previous user, we must overlap it + FAISS_ASSERT(prevUser.start_ <= endAlloc && prevUser.end_ >= startAlloc); + + if (stream != prevUser.stream_) { + // Synchronization required + // FIXME + FAISS_ASSERT(false); + } + + if (endAlloc < prevUser.end_) { + // Update the previous user info + prevUser.start_ = endAlloc; + + break; + } + + // If we're the exact size of the previous request, then we + // don't need to continue + bool done = (prevUser.end_ == endAlloc); + + lastUsers_.pop_back(); + + if (done) { + break; + } + } + + head_ = endAlloc; + FAISS_ASSERT(head_ <= end_); + + highWaterMemoryUsed_ = std::max(highWaterMemoryUsed_, + (size_t) (head_ - start_)); + return startAlloc; + } +} + +void +StackDeviceMemory::Stack::returnAlloc(char* p, + size_t size, + cudaStream_t stream) { + if (p < start_ || p >= end_) { + // This is not on our stack; it was a one-off allocation + DeviceScope s(device_); + + freeMemorySpace(MemorySpace::Device, p); + + FAISS_ASSERT(mallocCurrent_ >= size); + mallocCurrent_ -= size; + } else { + // This is on our stack + // Allocations should be freed in the reverse order they are made + FAISS_ASSERT(p + size == head_); + + head_ = p; + lastUsers_.push_back(Range(p, p + size, stream)); + } +} + +std::string +StackDeviceMemory::Stack::toString() const { + std::stringstream s; + + s << "SDM device " << device_ << ": Total memory " << size_ << " [" + << (void*) start_ << ", " << (void*) end_ << ")\n"; + s << " Available memory " << (size_t) (end_ - head_) + << " [" << (void*) head_ << ", " << (void*) end_ << ")\n"; + s << " High water temp alloc " << highWaterMemoryUsed_ << "\n"; + s << " High water cudaMalloc " << highWaterMalloc_ << "\n"; + + int i = lastUsers_.size(); + for (auto it = lastUsers_.rbegin(); it != lastUsers_.rend(); ++it) { + s << i-- << ": size " << (size_t) (it->end_ - it->start_) + << " stream " << it->stream_ + << " [" << (void*) it->start_ << ", " << (void*) it->end_ << ")\n"; + } + + return s.str(); +} + +size_t +StackDeviceMemory::Stack::getHighWaterCudaMalloc() const { + return highWaterMalloc_; +} + +StackDeviceMemory::StackDeviceMemory(int device, size_t allocPerDevice) + : device_(device), + stack_(device, allocPerDevice) { +} + +StackDeviceMemory::StackDeviceMemory(int device, + void* p, size_t size, bool isOwner) + : device_(device), + stack_(device, p, size, isOwner) { +} + +StackDeviceMemory::~StackDeviceMemory() { +} + +void +StackDeviceMemory::setCudaMallocWarning(bool b) { + stack_.cudaMallocWarning_ = b; +} + +int +StackDeviceMemory::getDevice() const { + return device_; +} + +DeviceMemoryReservation +StackDeviceMemory::getMemory(cudaStream_t stream, size_t size) { + // We guarantee 16 byte alignment for allocations, so bump up `size` + // to the next highest multiple of 16 + size = utils::roundUp(size, (size_t) 16); + + return DeviceMemoryReservation(this, + device_, + stack_.getAlloc(size, stream), + size, + stream); +} + +size_t +StackDeviceMemory::getSizeAvailable() const { + return stack_.getSizeAvailable(); +} + +std::string +StackDeviceMemory::toString() const { + return stack_.toString(); +} + +size_t +StackDeviceMemory::getHighWaterCudaMalloc() const { + return stack_.getHighWaterCudaMalloc(); +} + +void +StackDeviceMemory::returnAllocation(DeviceMemoryReservation& m) { + FAISS_ASSERT(m.get()); + FAISS_ASSERT(device_ == m.device()); + + stack_.returnAlloc((char*) m.get(), m.size(), m.stream()); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.h b/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.h new file mode 100644 index 0000000000..f7c3ea14e4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/StackDeviceMemory.h @@ -0,0 +1,129 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Device memory manager that provides temporary memory allocations +/// out of a region of memory +class StackDeviceMemory : public DeviceMemory { + public: + /// Allocate a new region of memory that we manage + explicit StackDeviceMemory(int device, size_t allocPerDevice); + + /// Manage a region of memory for a particular device, with or + /// without ownership + StackDeviceMemory(int device, void* p, size_t size, bool isOwner); + + ~StackDeviceMemory() override; + + /// Enable or disable the warning about not having enough temporary memory + /// when cudaMalloc gets called + void setCudaMallocWarning(bool b); + + int getDevice() const override; + + DeviceMemoryReservation getMemory(cudaStream_t stream, + size_t size) override; + + size_t getSizeAvailable() const override; + std::string toString() const override; + size_t getHighWaterCudaMalloc() const override; + + protected: + void returnAllocation(DeviceMemoryReservation& m) override; + + protected: + /// Previous allocation ranges and the streams for which + /// synchronization is required + struct Range { + inline Range(char* s, char* e, cudaStream_t str) : + start_(s), end_(e), stream_(str) { + } + + // References a memory range [start, end) + char* start_; + char* end_; + cudaStream_t stream_; + }; + + struct Stack { + /// Constructor that allocates memory via cudaMalloc + Stack(int device, size_t size); + + /// Constructor that references a pre-allocated region of memory + Stack(int device, void* p, size_t size, bool isOwner); + ~Stack(); + + /// Returns how much size is available for an allocation without + /// calling cudaMalloc + size_t getSizeAvailable() const; + + /// Obtains an allocation; all allocations are guaranteed to be 16 + /// byte aligned + char* getAlloc(size_t size, cudaStream_t stream); + + /// Returns an allocation + void returnAlloc(char* p, size_t size, cudaStream_t stream); + + /// Returns the stack state + std::string toString() const; + + /// Returns the high-water mark of cudaMalloc activity + size_t getHighWaterCudaMalloc() const; + + /// Device this allocation is on + int device_; + + /// Do we own our region of memory? + bool isOwner_; + + /// Where our allocation begins and ends + /// [start_, end_) is valid + char* start_; + char* end_; + + /// Total size end_ - start_ + size_t size_; + + /// Stack head within [start, end) + char* head_; + + /// List of previous last users of allocations on our stack, for + /// possible synchronization purposes + std::list lastUsers_; + + /// How much cudaMalloc memory is currently outstanding? + size_t mallocCurrent_; + + /// What's the high water mark in terms of memory used from the + /// temporary buffer? + size_t highWaterMemoryUsed_; + + /// What's the high water mark in terms of memory allocated via + /// cudaMalloc? + size_t highWaterMalloc_; + + /// Whether or not a warning upon cudaMalloc is generated + bool cudaMallocWarning_; + }; + + /// Our device + int device_; + + /// Memory stack + Stack stack_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/StaticUtils.h b/core/src/index/thirdparty/faiss/gpu/utils/StaticUtils.h new file mode 100644 index 0000000000..f6e5505afb --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/StaticUtils.h @@ -0,0 +1,83 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include + +namespace faiss { namespace gpu { namespace utils { + +template +constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + return (a / b); +} + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; +} + +template +constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + return divDown(a, b) * b; +} + +template +constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { + return divUp(a, b) * b; +} + +template +constexpr __host__ __device__ T pow(T n, T power) { + return (power > 0 ? n * pow(n, power - 1) : 1); +} + +template +constexpr __host__ __device__ T pow2(T n) { + return pow(2, (T) n); +} + +static_assert(pow2(8) == 256, "pow2"); + +template +constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); +} + +static_assert(log2(2) == 1, "log2"); +static_assert(log2(3) == 1, "log2"); +static_assert(log2(4) == 2, "log2"); + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + return (isPowerOf2(v) ? (T) 2 * v : ((T) 1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2((size_t) 2147483648ULL) == + (size_t) 4294967296ULL, "nextHighestPowerOf2"); + +} } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Tensor-inl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Tensor-inl.cuh new file mode 100644 index 0000000000..964fbfb940 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Tensor-inl.cuh @@ -0,0 +1,746 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +namespace faiss { namespace gpu { + +template class PtrTraits> +__host__ __device__ +Tensor::Tensor() + : data_(nullptr) { + static_assert(Dim > 0, "must have > 0 dimensions"); + + for (int i = 0; i < Dim; ++i) { + size_[i] = 0; + stride_[i] = (IndexT) 1; + } +} + +template class PtrTraits> +__host__ __device__ +Tensor::Tensor( + Tensor& t) { + this->operator=(t); +} + +template class PtrTraits> +__host__ __device__ +Tensor::Tensor( + Tensor&& t) { + this->operator=(std::move(t)); +} + +template class PtrTraits> +__host__ __device__ +Tensor& +Tensor::operator=( + Tensor& t) { + data_ = t.data_; + for (int i = 0; i < Dim; ++i) { + size_[i] = t.size_[i]; + stride_[i] = t.stride_[i]; + } + + return *this; +} + +template class PtrTraits> +__host__ __device__ +Tensor& +Tensor::operator=( + Tensor&& t) { + data_ = t.data_; t.data_ = nullptr; + for (int i = 0; i < Dim; ++i) { + stride_[i] = t.stride_[i]; t.stride_[i] = 0; + size_[i] = t.size_[i]; t.size_[i] = 0; + } + + return *this; +} + +template class PtrTraits> +__host__ __device__ +Tensor:: +Tensor(DataPtrType data, const IndexT sizes[Dim]) + : data_(data) { + static_assert(Dim > 0, "must have > 0 dimensions"); + + for (int i = 0; i < Dim; ++i) { + size_[i] = sizes[i]; + } + + stride_[Dim - 1] = (IndexT) 1; + for (int i = Dim - 2; i >= 0; --i) { + stride_[i] = stride_[i + 1] * sizes[i + 1]; + } +} + +template class PtrTraits> +__host__ __device__ +Tensor:: +Tensor(DataPtrType data, std::initializer_list sizes) + : data_(data) { + GPU_FAISS_ASSERT(sizes.size() == Dim); + static_assert(Dim > 0, "must have > 0 dimensions"); + + int i = 0; + for (auto s : sizes) { + size_[i++] = s; + } + + stride_[Dim - 1] = (IndexT) 1; + for (int j = Dim - 2; j >= 0; --j) { + stride_[j] = stride_[j + 1] * size_[j + 1]; + } +} + + +template class PtrTraits> +__host__ __device__ +Tensor::Tensor( + DataPtrType data, const IndexT sizes[Dim], const IndexT strides[Dim]) + : data_(data) { + static_assert(Dim > 0, "must have > 0 dimensions"); + + for (int i = 0; i < Dim; ++i) { + size_[i] = sizes[i]; + stride_[i] = strides[i]; + } +} + +template class PtrTraits> +__host__ void +Tensor::copyFrom( + Tensor& t, + cudaStream_t stream) { + // The tensor must be fully contiguous + GPU_FAISS_ASSERT(this->isContiguous()); + + // Size must be the same (since dimensions are checked and + // continuity is assumed, we need only check total number of + // elements + GPU_FAISS_ASSERT(this->numElements() == t.numElements()); + + if (t.numElements() > 0) { + GPU_FAISS_ASSERT(this->data_); + GPU_FAISS_ASSERT(t.data()); + + int ourDev = getDeviceForAddress(this->data_); + int tDev = getDeviceForAddress(t.data()); + + if (tDev == -1) { + CUDA_VERIFY(cudaMemcpyAsync(this->data_, + t.data(), + this->getSizeInBytes(), + ourDev == -1 ? cudaMemcpyHostToHost : + cudaMemcpyHostToDevice, + stream)); + } else { + CUDA_VERIFY(cudaMemcpyAsync(this->data_, + t.data(), + this->getSizeInBytes(), + ourDev == -1 ? cudaMemcpyDeviceToHost : + cudaMemcpyDeviceToDevice, + stream)); + } + } +} + +template class PtrTraits> +__host__ void +Tensor::copyTo( + Tensor& t, + cudaStream_t stream) { + // The tensor must be fully contiguous + GPU_FAISS_ASSERT(this->isContiguous()); + + // Size must be the same (since dimensions are checked and + // continuity is assumed, we need only check total number of + // elements + GPU_FAISS_ASSERT(this->numElements() == t.numElements()); + + if (t.numElements() > 0) { + GPU_FAISS_ASSERT(this->data_); + GPU_FAISS_ASSERT(t.data()); + + int ourDev = getDeviceForAddress(this->data_); + int tDev = getDeviceForAddress(t.data()); + + if (tDev == -1) { + CUDA_VERIFY(cudaMemcpyAsync(t.data(), + this->data_, + this->getSizeInBytes(), + ourDev == -1 ? cudaMemcpyHostToHost : + cudaMemcpyDeviceToHost, + stream)); + } else { + CUDA_VERIFY(cudaMemcpyAsync(t.data(), + this->data_, + this->getSizeInBytes(), + ourDev == -1 ? cudaMemcpyHostToDevice : + cudaMemcpyDeviceToDevice, + stream)); + } + } +} + +template class PtrTraits> +template +__host__ __device__ bool +Tensor::isSame( + const Tensor& rhs) const { + if (Dim != OtherDim) { + return false; + } + + for (int i = 0; i < Dim; ++i) { + if (this->getSize(i) != rhs.getSize(i)) { + return false; + } + + if (this->getStride(i) != rhs.getStride(i)) { + return false; + } + } + + return true; +} + +template class PtrTraits> +template +__host__ __device__ bool +Tensor::isSameSize( + const Tensor& rhs) const { + if (Dim != OtherDim) { + return false; + } + + for (int i = 0; i < Dim; ++i) { + if (this->getSize(i) != rhs.getSize(i)) { + return false; + } + } + + return true; +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::cast() { + static_assert(sizeof(U) == sizeof(T), "cast must be to same size object"); + + return Tensor( + reinterpret_cast(data_), size_, stride_); +} + +template class PtrTraits> +template +__host__ __device__ const Tensor +Tensor::cast() const { + static_assert(sizeof(U) == sizeof(T), "cast must be to same size object"); + + return Tensor( + reinterpret_cast(data_), size_, stride_); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::castResize() { + static_assert(sizeof(U) >= sizeof(T), "only handles greater sizes"); + constexpr int kMultiple = sizeof(U) / sizeof(T); + + GPU_FAISS_ASSERT(canCastResize()); + + IndexT newSize[Dim]; + IndexT newStride[Dim]; + + for (int i = 0; i < Dim - 1; ++i) { + newSize[i] = size_[i]; + newStride[i] = stride_[i] / kMultiple; + } + + newStride[Dim - 1] = 1; // this is the same as the old stride + newSize[Dim - 1] = size_[Dim - 1] / kMultiple; + + return Tensor( + reinterpret_cast(data_), newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ const Tensor +Tensor::castResize() const { + return const_cast*>(this)-> + castResize(); +} + +template class PtrTraits> +template +__host__ __device__ bool +Tensor::canCastResize() const { + static_assert(sizeof(U) >= sizeof(T), "only handles greater sizes"); + constexpr int kMultiple = sizeof(U) / sizeof(T); + + // Ensure that the base pointer is sizeof(U) aligned + if (((uintptr_t) data_) % sizeof(U) != 0) { + return false; + } + + // Check all outer strides + for (int i = 0; i < Dim - 1; ++i) { + if (stride_[i] % kMultiple != 0) { + return false; + } + } + + // Check inner size + if (size_[Dim - 1] % kMultiple != 0) { + return false; + } + + if (stride_[Dim - 1] != 1) { + return false; + } + + return true; +} + +template class PtrTraits> +template +__host__ Tensor +Tensor::castIndexType() const { + if (sizeof(NewIndexT) < sizeof(IndexT)) { + GPU_FAISS_ASSERT(this->canUseIndexType()); + } + + NewIndexT newSize[Dim]; + NewIndexT newStride[Dim]; + for (int i = 0; i < Dim; ++i) { + newSize[i] = (NewIndexT) size_[i]; + newStride[i] = (NewIndexT) stride_[i]; + } + + return Tensor( + data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ bool +Tensor::canUseIndexType() const { + static_assert(sizeof(size_t) >= sizeof(IndexT), + "index size too large"); + static_assert(sizeof(size_t) >= sizeof(NewIndexT), + "new index size too large"); + + // Find maximum offset that can be calculated + // FIXME: maybe also consider offset in bytes? multiply by sizeof(T)? + size_t maxOffset = 0; + + for (int i = 0; i < Dim; ++i) { + size_t curMaxOffset = (size_t) size_[i] * (size_t) stride_[i]; + if (curMaxOffset > maxOffset) { + maxOffset = curMaxOffset; + } + } + + if (maxOffset > (size_t) std::numeric_limits::max()) { + return false; + } + + return true; +} + +template class PtrTraits> +__host__ __device__ size_t +Tensor::numElements() const { + size_t size = (size_t) getSize(0); + + for (int i = 1; i < Dim; ++i) { + size *= (size_t) getSize(i); + } + + return size; +} + +template class PtrTraits> +__host__ __device__ bool +Tensor::isContiguous() const { + long prevSize = 1; + + for (int i = Dim - 1; i >= 0; --i) { + if (getSize(i) != (IndexT) 1) { + if (getStride(i) == prevSize) { + prevSize *= getSize(i); + } else { + return false; + } + } + } + + return true; +} + +template class PtrTraits> +__host__ __device__ bool +Tensor::isConsistentlySized(int i) const { + if (i == 0 && getStride(i) > 0 && getSize(i) > 0) { + return true; + } else if ((i > 0) && (i < Dim) && (getStride(i) > 0) && + ((getStride(i - 1) / getStride(i)) >= getSize(i))) { + return true; + } + + return false; +} + +template class PtrTraits> +__host__ __device__ bool +Tensor::isConsistentlySized() const { + for (int i = 0; i < Dim; ++i) { + if (!isConsistentlySized(i)) { + return false; + } + } + + return true; +} + +template class PtrTraits> +__host__ __device__ bool +Tensor::isContiguousDim(int i) const { + return (i == Dim - 1) || // just in case + ((i < Dim - 1) && + ((getStride(i) / getStride(i + 1)) == getSize(i + 1))); +} + +template class PtrTraits> +__host__ __device__ Tensor +Tensor::transpose(int dim1, + int dim2) const { + GPU_FAISS_ASSERT(dim1 >= 0 && dim1 < Dim); + GPU_FAISS_ASSERT(dim1 >= 0 && dim2 < Dim); + + // If a tensor is innermost contiguous, one cannot transpose the innermost + // dimension + if (InnerContig) { + GPU_FAISS_ASSERT(dim1 != Dim - 1 && dim2 != Dim - 1); + } + + IndexT newSize[Dim]; + IndexT newStride[Dim]; + + for (int i = 0; i < Dim; ++i) { + newSize[i] = size_[i]; + newStride[i] = stride_[i]; + } + + IndexT tmp = newSize[dim1]; + newSize[dim1] = newSize[dim2]; + newSize[dim2] = tmp; + + tmp = newStride[dim1]; + newStride[dim1] = newStride[dim2]; + newStride[dim2] = tmp; + + return Tensor(data_, newSize, newStride); +} + +template class PtrTraits> +__host__ __device__ Tensor +Tensor::transposeInnermost( + int dim1) const { + GPU_FAISS_ASSERT(dim1 >= 0 && dim1 < Dim); + + // We are exchanging with the innermost dimension + int dim2 = 1; + + IndexT newSize[Dim]; + IndexT newStride[Dim]; + + for (int i = 0; i < Dim; ++i) { + newSize[i] = size_[i]; + newStride[i] = stride_[i]; + } + + IndexT tmp = newSize[dim1]; + newSize[dim1] = newSize[dim2]; + newSize[dim2] = tmp; + + tmp = newStride[dim1]; + newStride[dim1] = newStride[dim2]; + newStride[dim2] = tmp; + + return Tensor(data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::upcastOuter() { + // Can only create tensors of greater dimension + static_assert(NewDim > Dim, "Can only upcast to greater dim"); + + IndexT newSize[NewDim]; + IndexT newStride[NewDim]; + + int shift = NewDim - Dim; + + for (int i = 0; i < NewDim; ++i) { + if (i < shift) { + // These are the extended dimensions + newSize[i] = (IndexT) 1; + newStride[i] = size_[0] * stride_[0]; + } else { + // Shift the remaining dimensions + newSize[i] = size_[i - shift]; + newStride[i] = stride_[i - shift]; + } + } + + return Tensor( + data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::upcastInner() { + // Can only create tensors of greater dimension + static_assert(NewDim > Dim, "Can only upcast to greater dim"); + + IndexT newSize[NewDim]; + IndexT newStride[NewDim]; + + for (int i = 0; i < NewDim; ++i) { + if (i < Dim) { + // Existing dimensions get copied over + newSize[i] = size_[i]; + newStride[i] = stride_[i]; + } else { + // Extended dimensions + newSize[i] = (IndexT) 1; + newStride[i] = (IndexT) 1; + } + } + + return Tensor( + data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::downcastOuter() { + // Can only create tensors of lesser dimension + static_assert(NewDim < Dim, "Can only downcast to lesser dim"); + + // We can't downcast non-contiguous tensors, since it leaves + // garbage data in the tensor. The tensor needs to be contiguous + // in all of the dimensions we are collapsing (no padding in + // them). + for (int i = 0; i < Dim - NewDim; ++i) { + bool cont = isContiguousDim(i); + GPU_FAISS_ASSERT(cont); + } + + IndexT newSize[NewDim]; + IndexT newStride[NewDim]; + + int ignoredDims = Dim - NewDim; + IndexT collapsedSize = 1; + + for (int i = 0; i < Dim; ++i) { + if (i < ignoredDims) { + // Collapse these dimensions + collapsedSize *= getSize(i); + } else { + // Non-collapsed dimensions + if (i == ignoredDims) { + // This is the first non-collapsed dimension + newSize[i - ignoredDims] = collapsedSize * getSize(i); + } else { + // Subsequent non-collapsed dimensions + newSize[i - ignoredDims] = getSize(i); + } + + newStride[i - ignoredDims] = getStride(i); + } + } + + return Tensor( + data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::downcastInner() { + // Can only create tensors of lesser dimension + static_assert(NewDim < Dim, "Can only downcast to lesser dim"); + + // We can't downcast non-contiguous tensors, since it leaves + // garbage data in the tensor. The tensor needs to be contiguous + // in all of the dimensions we are collapsing (no padding in + // them). + for (int i = NewDim; i < Dim; ++i) { + GPU_FAISS_ASSERT(isContiguousDim(i)); + } + + IndexT newSize[NewDim]; + IndexT newStride[NewDim]; + + IndexT collapsedSize = 1; + + for (int i = Dim - 1; i >= 0; --i) { + if (i >= NewDim) { + // Collapse these dimensions + collapsedSize *= getSize(i); + } else { + // Non-collapsed dimensions + if (i == NewDim - 1) { + // This is the first non-collapsed dimension + newSize[i] = collapsedSize * getSize(i); + newStride[i] = getStride(Dim - 1); + } else { + // Subsequent non-collapsed dimensions + newSize[i] = getSize(i); + newStride[i] = getStride(i); + } + } + } + + return Tensor( + data_, newSize, newStride); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::view(DataPtrType at) { + static_assert(SubDim >= 1 && SubDim < Dim, + "can only create view of lesser dim"); + + IndexT viewSizes[SubDim]; + IndexT viewStrides[SubDim]; + + for (int i = 0; i < SubDim; ++i) { + viewSizes[i] = size_[Dim - SubDim + i]; + viewStrides[i] = stride_[Dim - SubDim + i]; + } + + return Tensor( + at, viewSizes, viewStrides); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::view() { + return view(data_); +} + +template class PtrTraits> +__host__ __device__ Tensor +Tensor::narrowOutermost(IndexT start, + IndexT size) { + return this->narrow(0, start, size); +} + +template class PtrTraits> +__host__ __device__ Tensor +Tensor::narrow(int dim, + IndexT start, + IndexT size) { + DataPtrType newData = data_; + + GPU_FAISS_ASSERT(start >= 0 && + start < size_[dim] && + (start + size) <= size_[dim]); + + if (start > 0) { + newData += (size_t) start * stride_[dim]; + } + + IndexT newSize[Dim]; + for (int i = 0; i < Dim; ++i) { + if (i == dim) { + GPU_FAISS_ASSERT(start + size <= size_[dim]); + newSize[i] = size; + } else { + newSize[i] = size_[i]; + } + } + + // If we were innermost contiguous before, we are still innermost contiguous + return Tensor(newData, newSize, stride_); +} + +template class PtrTraits> +template +__host__ __device__ Tensor +Tensor::view( + std::initializer_list sizes) { + GPU_FAISS_ASSERT(this->isContiguous()); + + GPU_FAISS_ASSERT(sizes.size() == NewDim); + + // The total size of the new view must be the same as the total size + // of the old view + size_t curSize = numElements(); + size_t newSize = 1; + + for (auto s : sizes) { + newSize *= s; + } + + GPU_FAISS_ASSERT(curSize == newSize); + return Tensor(data(), sizes); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Tensor.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Tensor.cuh new file mode 100644 index 0000000000..bb3d956d6b --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Tensor.cuh @@ -0,0 +1,656 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +/// Multi-dimensional array class for CUDA device and host usage. +/// Originally from Facebook's fbcunn, since added to the Torch GPU +/// library cutorch as well. + +namespace faiss { namespace gpu { + +/// Our tensor type +template class PtrTraits> +class Tensor; + +/// Type of a subspace of a tensor +namespace detail { +template class PtrTraits> +class SubTensor; +} + +namespace traits { + +template +struct RestrictPtrTraits { + typedef T* __restrict__ PtrType; +}; + +template +struct DefaultPtrTraits { + typedef T* PtrType; +}; + +} + +/** + Templated multi-dimensional array that supports strided access of + elements. Main access is through `operator[]`; e.g., + `tensor[x][y][z]`. + + - `T` is the contained type (e.g., `float`) + - `Dim` is the tensor rank + - If `InnerContig` is true, then the tensor is assumed to be innermost + - contiguous, and only operations that make sense on contiguous + - arrays are allowed (e.g., no transpose). Strides are still + - calculated, but innermost stride is assumed to be 1. + - `IndexT` is the integer type used for size/stride arrays, and for + - all indexing math. Default is `int`, but for large tensors, `long` + - can be used instead. + - `PtrTraits` are traits applied to our data pointer (T*). By default, + - this is just T*, but RestrictPtrTraits can be used to apply T* + - __restrict__ for alias-free analysis. +*/ +template class PtrTraits = traits::DefaultPtrTraits> +class Tensor { + public: + enum { NumDim = Dim }; + typedef T DataType; + typedef IndexT IndexType; + enum { IsInnerContig = InnerContig }; + typedef typename PtrTraits::PtrType DataPtrType; + typedef Tensor TensorType; + + /// Default constructor + __host__ __device__ Tensor(); + + /// Copy constructor + __host__ __device__ Tensor(Tensor& t); + + /// Move constructor + __host__ __device__ Tensor(Tensor&& t); + + /// Assignment + __host__ __device__ Tensor& + operator=(Tensor& t); + + /// Move assignment + __host__ __device__ Tensor& + operator=(Tensor&& t); + + /// Constructor that calculates strides with no padding + __host__ __device__ Tensor(DataPtrType data, + const IndexT sizes[Dim]); + __host__ __device__ Tensor(DataPtrType data, + std::initializer_list sizes); + + /// Constructor that takes arbitrary size/stride arrays. + /// Errors if you attempt to pass non-contiguous strides to a + /// contiguous tensor. + __host__ __device__ Tensor(DataPtrType data, + const IndexT sizes[Dim], + const IndexT strides[Dim]); + + /// Copies a tensor into ourselves; sizes must match + __host__ void copyFrom(Tensor& t, + cudaStream_t stream); + + /// Copies ourselves into a tensor; sizes must match + __host__ void copyTo(Tensor& t, + cudaStream_t stream); + + /// Returns true if the two tensors are of the same dimensionality, + /// size and stride. + template + __host__ __device__ bool + isSame(const Tensor& rhs) const; + + /// Returns true if the two tensors are of the same dimensionality and size + template + __host__ __device__ bool + isSameSize(const Tensor& rhs) const; + + /// Cast to a tensor of a different type of the same size and + /// stride. U and our type T must be of the same size + template + __host__ __device__ Tensor cast(); + + /// Const version of `cast` + template + __host__ __device__ + const Tensor cast() const; + + /// Cast to a tensor of a different type which is potentially a + /// different size than our type T. Tensor must be aligned and the + /// innermost dimension must be a size that is a multiple of + /// sizeof(U) / sizeof(T), and the stride of the innermost dimension + /// must be contiguous. The stride of all outer dimensions must be a + /// multiple of sizeof(U) / sizeof(T) as well. + template + __host__ __device__ Tensor castResize(); + + /// Const version of `castResize` + template + __host__ __device__ const Tensor + castResize() const; + + /// Returns true if we can castResize() this tensor to the new type + template + __host__ __device__ bool canCastResize() const; + + /// Attempts to cast this tensor to a tensor of a different IndexT. + /// Fails if size or stride entries are not representable in the new + /// IndexT. + template + __host__ Tensor + castIndexType() const; + + /// Returns true if we can use this indexing type to access all elements + /// index type + template + __host__ bool canUseIndexType() const; + + /// Returns a raw pointer to the start of our data. + __host__ __device__ inline DataPtrType data() { + return data_; + } + + /// Returns a raw pointer to the end of our data, assuming + /// continuity + __host__ __device__ inline DataPtrType end() { + return data() + numElements(); + } + + /// Returns a raw pointer to the start of our data (const). + __host__ __device__ inline + const DataPtrType data() const { + return data_; + } + + /// Returns a raw pointer to the end of our data, assuming + /// continuity (const) + __host__ __device__ inline DataPtrType end() const { + return data() + numElements(); + } + + /// Cast to a different datatype + template + __host__ __device__ inline + typename PtrTraits::PtrType dataAs() { + return reinterpret_cast::PtrType>(data_); + } + + /// Cast to a different datatype + template + __host__ __device__ inline + const typename PtrTraits::PtrType dataAs() const { + return reinterpret_cast::PtrType>(data_); + } + + /// Returns a read/write view of a portion of our tensor. + __host__ __device__ inline + detail::SubTensor + operator[](IndexT); + + /// Returns a read/write view of a portion of our tensor (const). + __host__ __device__ inline + const detail::SubTensor + operator[](IndexT) const; + + /// Returns the size of a given dimension, `[0, Dim - 1]`. No bounds + /// checking. + __host__ __device__ inline IndexT getSize(int i) const { + return size_[i]; + } + + /// Returns the stride of a given dimension, `[0, Dim - 1]`. No bounds + /// checking. + __host__ __device__ inline IndexT getStride(int i) const { + return stride_[i]; + } + + /// Returns the total number of elements contained within our data + /// (product of `getSize(i)`) + __host__ __device__ size_t numElements() const; + + /// If we are contiguous, returns the total size in bytes of our + /// data + __host__ __device__ size_t getSizeInBytes() const { + return numElements() * sizeof(T); + } + + /// Returns the size array. + __host__ __device__ inline const IndexT* sizes() const { + return size_; + } + + /// Returns the stride array. + __host__ __device__ inline const IndexT* strides() const { + return stride_; + } + + /// Returns true if there is no padding within the tensor and no + /// re-ordering of the dimensions. + /// ~~~ + /// (stride(i) == size(i + 1) * stride(i + 1)) && stride(dim - 1) == 0 + /// ~~~ + __host__ __device__ bool isContiguous() const; + + /// Returns whether a given dimension has only increasing stride + /// from the previous dimension. A tensor that was permuted by + /// exchanging size and stride only will fail this check. + /// If `i == 0` just check `size > 0`. Returns `false` if `stride` is `<= 0`. + __host__ __device__ bool isConsistentlySized(int i) const; + + // Returns whether at each dimension `stride <= size`. + // If this is not the case then iterating once over the size space will + // touch the same memory locations multiple times. + __host__ __device__ bool isConsistentlySized() const; + + /// Returns true if the given dimension index has no padding + __host__ __device__ bool isContiguousDim(int i) const; + + /// Returns a tensor of the same dimension after transposing the two + /// dimensions given. Does not actually move elements; transposition + /// is made by permuting the size/stride arrays. + /// If the dimensions are not valid, asserts. + __host__ __device__ Tensor + transpose(int dim1, int dim2) const; + + /// Transpose a tensor, exchanging a non-innermost dimension with the + /// innermost dimension, returning a no longer innermost contiguous tensor + __host__ __device__ Tensor + transposeInnermost(int dim1) const; + + /// Upcast a tensor of dimension `D` to some tensor of dimension + /// D' > D by padding the leading dimensions by 1 + /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[1][1][2][3]` + template + __host__ __device__ Tensor + upcastOuter(); + + /// Upcast a tensor of dimension `D` to some tensor of dimension + /// D' > D by padding the lowest/most varying dimensions by 1 + /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[2][3][1][1]` + template + __host__ __device__ Tensor + upcastInner(); + + /// Downcast a tensor of dimension `D` to some tensor of dimension + /// D' < D by collapsing the leading dimensions. asserts if there is + /// padding on the leading dimensions. + template + __host__ __device__ + Tensor downcastOuter(); + + /// Downcast a tensor of dimension `D` to some tensor of dimension + /// D' < D by collapsing the leading dimensions. asserts if there is + /// padding on the leading dimensions. + template + __host__ __device__ + Tensor downcastInner(); + + /// Returns a tensor that is a view of the `SubDim`-dimensional slice + /// of this tensor, starting at `at`. + template + __host__ __device__ Tensor + view(DataPtrType at); + + /// Returns a tensor that is a view of the `SubDim`-dimensional slice + /// of this tensor, starting where our data begins + template + __host__ __device__ Tensor + view(); + + /// Returns a tensor of the same dimension that is a view of the + /// original tensor with the specified dimension restricted to the + /// elements in the range [start, start + size) + __host__ __device__ Tensor + narrowOutermost(IndexT start, IndexT size); + + /// Returns a tensor of the same dimension that is a view of the + /// original tensor with the specified dimension restricted to the + /// elements in the range [start, start + size). + /// Can occur in an arbitrary dimension + __host__ __device__ Tensor + narrow(int dim, IndexT start, IndexT size); + + /// Returns a view of the given tensor expressed as a tensor of a + /// different number of dimensions. + /// Only works if we are contiguous. + template + __host__ __device__ Tensor + view(std::initializer_list sizes); + + protected: + /// Raw pointer to where the tensor data begins + DataPtrType data_; + + /// Array of strides (in sizeof(T) terms) per each dimension + IndexT stride_[Dim]; + + /// Size per each dimension + IndexT size_[Dim]; +}; + +// Utilities for checking a collection of tensors +namespace detail { + +template +bool canUseIndexType() { + return true; +} + +template +bool canUseIndexType(const T& arg, const U&... args) { + return arg.template canUseIndexType() && + canUseIndexType(args...); +} + +} // namespace detail + +template +bool canUseIndexType(const T&... args) { + return detail::canUseIndexType(args...); +} + +namespace detail { + +/// Specialization for a view of a single value (0-dimensional) +template class PtrTraits> +class SubTensor { + public: + __host__ __device__ SubTensor + operator=(typename TensorType::DataType val) { + *data_ = val; + return *this; + } + + // operator T& + __host__ __device__ operator typename TensorType::DataType&() { + return *data_; + } + + // const operator T& returning const T& + __host__ __device__ operator const typename TensorType::DataType&() const { + return *data_; + } + + // operator& returning T* + __host__ __device__ typename TensorType::DataType* operator&() { + return data_; + } + + // const operator& returning const T* + __host__ __device__ const typename TensorType::DataType* operator&() const { + return data_; + } + + /// Returns a raw accessor to our slice. + __host__ __device__ inline typename TensorType::DataPtrType data() { + return data_; + } + + /// Returns a raw accessor to our slice (const). + __host__ __device__ inline + const typename TensorType::DataPtrType data() const { + return data_; + } + + /// Cast to a different datatype. + template + __host__ __device__ T& as() { + return *dataAs(); + } + + /// Cast to a different datatype (const). + template + __host__ __device__ const T& as() const { + return *dataAs(); + } + + /// Cast to a different datatype + template + __host__ __device__ inline + typename PtrTraits::PtrType dataAs() { + return reinterpret_cast::PtrType>(data_); + } + + /// Cast to a different datatype (const) + template + __host__ __device__ inline + typename PtrTraits::PtrType dataAs() const { + return reinterpret_cast::PtrType>(data_); + } + + /// Use the texture cache for reads + __device__ inline typename TensorType::DataType ldg() const { +#if __CUDA_ARCH__ >= 350 + return __ldg(data_); +#else + return *data_; +#endif + } + + /// Use the texture cache for reads; cast as a particular type + template + __device__ inline T ldgAs() const { +#if __CUDA_ARCH__ >= 350 + return __ldg(dataAs()); +#else + return as(); +#endif + } + + protected: + /// One dimension greater can create us + friend class SubTensor; + + /// Our parent tensor can create us + friend class Tensor; + + __host__ __device__ inline SubTensor( + TensorType& t, + typename TensorType::DataPtrType data) + : tensor_(t), + data_(data) { + } + + /// The tensor we're referencing + TensorType& tensor_; + + /// Where our value is located + typename TensorType::DataPtrType const data_; +}; + +/// A `SubDim`-rank slice of a parent Tensor +template class PtrTraits> +class SubTensor { + public: + /// Returns a view of the data located at our offset (the dimension + /// `SubDim` - 1 tensor). + __host__ __device__ inline + SubTensor + operator[](typename TensorType::IndexType index) { + if (TensorType::IsInnerContig && SubDim == 1) { + // Innermost dimension is stride 1 for contiguous arrays + return SubTensor( + tensor_, data_ + index); + } else { + return SubTensor( + tensor_, + data_ + index * tensor_.getStride(TensorType::NumDim - SubDim)); + } + } + + /// Returns a view of the data located at our offset (the dimension + /// `SubDim` - 1 tensor) (const). + __host__ __device__ inline + const SubTensor + operator[](typename TensorType::IndexType index) const { + if (TensorType::IsInnerContig && SubDim == 1) { + // Innermost dimension is stride 1 for contiguous arrays + return SubTensor( + tensor_, data_ + index); + } else { + return SubTensor( + tensor_, + data_ + index * tensor_.getStride(TensorType::NumDim - SubDim)); + } + } + + // operator& returning T* + __host__ __device__ typename TensorType::DataType* operator&() { + return data_; + } + + // const operator& returning const T* + __host__ __device__ const typename TensorType::DataType* operator&() const { + return data_; + } + + /// Returns a raw accessor to our slice. + __host__ __device__ inline typename TensorType::DataPtrType data() { + return data_; + } + + /// Returns a raw accessor to our slice (const). + __host__ __device__ inline + const typename TensorType::DataPtrType data() const { + return data_; + } + + /// Cast to a different datatype. + template + __host__ __device__ T& as() { + return *dataAs(); + } + + /// Cast to a different datatype (const). + template + __host__ __device__ const T& as() const { + return *dataAs(); + } + + /// Cast to a different datatype + template + __host__ __device__ inline + typename PtrTraits::PtrType dataAs() { + return reinterpret_cast::PtrType>(data_); + } + + /// Cast to a different datatype (const) + template + __host__ __device__ inline + typename PtrTraits::PtrType dataAs() const { + return reinterpret_cast::PtrType>(data_); + } + + /// Use the texture cache for reads + __device__ inline typename TensorType::DataType ldg() const { +#if __CUDA_ARCH__ >= 350 + return __ldg(data_); +#else + return *data_; +#endif + } + + /// Use the texture cache for reads; cast as a particular type + template + __device__ inline T ldgAs() const { +#if __CUDA_ARCH__ >= 350 + return __ldg(dataAs()); +#else + return as(); +#endif + } + + /// Returns a tensor that is a view of the SubDim-dimensional slice + /// of this tensor, starting where our data begins + Tensor view() { + return tensor_.template view(data_); + } + + protected: + /// One dimension greater can create us + friend class SubTensor; + + /// Our parent tensor can create us + friend class + Tensor; + + __host__ __device__ inline SubTensor( + TensorType& t, + typename TensorType::DataPtrType data) + : tensor_(t), + data_(data) { + } + + /// The tensor we're referencing + TensorType& tensor_; + + /// The start of our sub-region + typename TensorType::DataPtrType const data_; +}; + +} // namespace detail + +template class PtrTraits> +__host__ __device__ inline +detail::SubTensor, + Dim - 1, PtrTraits> + Tensor::operator[](IndexT index) { + return detail::SubTensor( + detail::SubTensor( + *this, data_)[index]); +} + +template class PtrTraits> +__host__ __device__ inline +const detail::SubTensor, + Dim - 1, PtrTraits> + Tensor::operator[](IndexT index) const { + return detail::SubTensor( + detail::SubTensor( + const_cast(*this), data_)[index]); +} + +} } // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/gpu/utils/ThrustAllocator.cuh b/core/src/index/thirdparty/faiss/gpu/utils/ThrustAllocator.cuh new file mode 100644 index 0000000000..4ca0415bfa --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/ThrustAllocator.cuh @@ -0,0 +1,69 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include + +namespace faiss { namespace gpu { + +/// Allocator for Thrust that comes out of a specified memory space +class GpuResourcesThrustAllocator { + public: + typedef char value_type; + + GpuResourcesThrustAllocator(void* mem, size_t size) + : start_((char*) mem), + cur_((char*) mem), + end_((char*) mem + size) { + } + + ~GpuResourcesThrustAllocator() { + // In the case of an exception being thrown, we may not have called + // deallocate on all of our sub-allocations. Free them here + for (auto p : mallocAllocs_) { + freeMemorySpace(MemorySpace::Device, p); + } + } + + char* allocate(std::ptrdiff_t size) { + if (size <= (end_ - cur_)) { + char* p = cur_; + cur_ += size; + FAISS_ASSERT(cur_ <= end_); + + return p; + } else { + char* p = nullptr; + allocMemorySpace(MemorySpace::Device, &p, size); + mallocAllocs_.insert(p); + return p; + } + } + + void deallocate(char* p, size_t size) { + // Allocations could be returned out-of-order; ignore those we + // didn't cudaMalloc + auto it = mallocAllocs_.find(p); + if (it != mallocAllocs_.end()) { + freeMemorySpace(MemorySpace::Device, p); + mallocAllocs_.erase(it); + } + } + + private: + char* start_; + char* cur_; + char* end_; + std::unordered_set mallocAllocs_; +}; + + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Timer.cpp b/core/src/index/thirdparty/faiss/gpu/utils/Timer.cpp new file mode 100644 index 0000000000..1764fec10a --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Timer.cpp @@ -0,0 +1,60 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +namespace faiss { namespace gpu { + +KernelTimer::KernelTimer(cudaStream_t stream) + : startEvent_(0), + stopEvent_(0), + stream_(stream), + valid_(true) { + CUDA_VERIFY(cudaEventCreate(&startEvent_)); + CUDA_VERIFY(cudaEventCreate(&stopEvent_)); + + CUDA_VERIFY(cudaEventRecord(startEvent_, stream_)); +} + +KernelTimer::~KernelTimer() { + CUDA_VERIFY(cudaEventDestroy(startEvent_)); + CUDA_VERIFY(cudaEventDestroy(stopEvent_)); +} + +float +KernelTimer::elapsedMilliseconds() { + FAISS_ASSERT(valid_); + + CUDA_VERIFY(cudaEventRecord(stopEvent_, stream_)); + CUDA_VERIFY(cudaEventSynchronize(stopEvent_)); + + auto time = 0.0f; + CUDA_VERIFY(cudaEventElapsedTime(&time, startEvent_, stopEvent_)); + valid_ = false; + + return time; +} + +CpuTimer::CpuTimer() { + clock_gettime(CLOCK_REALTIME, &start_); +} + +float +CpuTimer::elapsedMilliseconds() { + struct timespec end; + clock_gettime(CLOCK_REALTIME, &end); + + auto diffS = end.tv_sec - start_.tv_sec; + auto diffNs = end.tv_nsec - start_.tv_nsec; + + return 1000.0f * (float) diffS + ((float) diffNs) / 1000000.0f; +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Timer.h b/core/src/index/thirdparty/faiss/gpu/utils/Timer.h new file mode 100644 index 0000000000..ef2a161a32 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Timer.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +/// Utility class for timing execution of a kernel +class KernelTimer { + public: + /// Constructor starts the timer and adds an event into the current + /// device stream + KernelTimer(cudaStream_t stream = 0); + + /// Destructor releases event resources + ~KernelTimer(); + + /// Adds a stop event then synchronizes on the stop event to get the + /// actual GPU-side kernel timings for any kernels launched in the + /// current stream. Returns the number of milliseconds elapsed. + /// Can only be called once. + float elapsedMilliseconds(); + + private: + cudaEvent_t startEvent_; + cudaEvent_t stopEvent_; + cudaStream_t stream_; + bool valid_; +}; + +/// CPU wallclock elapsed timer +class CpuTimer { + public: + /// Creates and starts a new timer + CpuTimer(); + + /// Returns elapsed time in milliseconds + float elapsedMilliseconds(); + + private: + struct timespec start_; +}; + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Transpose.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Transpose.cuh new file mode 100644 index 0000000000..c6137d9f0d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/Transpose.cuh @@ -0,0 +1,154 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { namespace gpu { + +template +struct TensorInfo { + static constexpr int kMaxDims = 8; + + T* data; + IndexT sizes[kMaxDims]; + IndexT strides[kMaxDims]; + int dims; +}; + +template +struct TensorInfoOffset { + __device__ inline static unsigned int get(const TensorInfo& info, + IndexT linearId) { + IndexT offset = 0; + +#pragma unroll + for (int i = Dim - 1; i >= 0; --i) { + IndexT curDimIndex = linearId % info.sizes[i]; + IndexT curDimOffset = curDimIndex * info.strides[i]; + + offset += curDimOffset; + + if (i > 0) { + linearId /= info.sizes[i]; + } + } + + return offset; + } +}; + +template +struct TensorInfoOffset { + __device__ inline static unsigned int get(const TensorInfo& info, + IndexT linearId) { + return linearId; + } +}; + +template +TensorInfo getTensorInfo(const Tensor& t) { + TensorInfo info; + + for (int i = 0; i < Dim; ++i) { + info.sizes[i] = (IndexT) t.getSize(i); + info.strides[i] = (IndexT) t.getStride(i); + } + + info.data = t.data(); + info.dims = Dim; + + return info; +} + +template +__global__ void transposeAny(TensorInfo input, + TensorInfo output, + IndexT totalSize) { + for (IndexT i = blockIdx.x * blockDim.x + threadIdx.x; + i < totalSize; + i += gridDim.x + blockDim.x) { + auto inputOffset = TensorInfoOffset::get(input, i); + auto outputOffset = TensorInfoOffset::get(output, i); + +#if __CUDA_ARCH__ >= 350 + output.data[outputOffset] = __ldg(&input.data[inputOffset]); +#else + output.data[outputOffset] = input.data[inputOffset]; +#endif + } +} + +/// Performs an out-of-place transposition between any two dimensions. +/// Best performance is if the transposed dimensions are not +/// innermost, since the reads and writes will be coalesced. +/// Could include a shared memory transposition if the dimensions +/// being transposed are innermost, but would require support for +/// arbitrary rectangular matrices. +/// This linearized implementation seems to perform well enough, +/// especially for cases that we care about (outer dimension +/// transpositions). +template +void runTransposeAny(Tensor& in, + int dim1, int dim2, + Tensor& out, + cudaStream_t stream) { + static_assert(Dim <= TensorInfo::kMaxDims, + "too many dimensions"); + + FAISS_ASSERT(dim1 != dim2); + FAISS_ASSERT(dim1 < Dim && dim2 < Dim); + + int outSize[Dim]; + + for (int i = 0; i < Dim; ++i) { + outSize[i] = in.getSize(i); + } + + std::swap(outSize[dim1], outSize[dim2]); + + for (int i = 0; i < Dim; ++i) { + FAISS_ASSERT(out.getSize(i) == outSize[i]); + } + + size_t totalSize = in.numElements(); + size_t block = std::min((size_t) getMaxThreadsCurrentDevice(), totalSize); + + if (totalSize <= (size_t) std::numeric_limits::max()) { + // div/mod seems faster with unsigned types + auto inInfo = getTensorInfo(in); + auto outInfo = getTensorInfo(out); + + std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]); + std::swap(inInfo.strides[dim1], inInfo.strides[dim2]); + + auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096); + + transposeAny + <<>>(inInfo, outInfo, totalSize); + } else { + auto inInfo = getTensorInfo(in); + auto outInfo = getTensorInfo(out); + + std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]); + std::swap(inInfo.strides[dim1], inInfo.strides[dim2]); + + auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096); + + transposeAny + <<>>(inInfo, outInfo, totalSize); + } + CUDA_TEST_ERROR(); +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectFloat.cu b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectFloat.cu new file mode 100644 index 0000000000..4a03ab1311 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectFloat.cu @@ -0,0 +1,94 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +// warp Q to thread Q: +// 1, 1 +// 32, 2 +// 64, 3 +// 128, 3 +// 256, 4 +// 512, 8 +// 1024, 8 +// 2048, 8 + +WARP_SELECT_DECL(float, true, 1); +WARP_SELECT_DECL(float, true, 32); +WARP_SELECT_DECL(float, true, 64); +WARP_SELECT_DECL(float, true, 128); +WARP_SELECT_DECL(float, true, 256); +WARP_SELECT_DECL(float, true, 512); +WARP_SELECT_DECL(float, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_DECL(float, true, 2048); +#endif + +WARP_SELECT_DECL(float, false, 1); +WARP_SELECT_DECL(float, false, 32); +WARP_SELECT_DECL(float, false, 64); +WARP_SELECT_DECL(float, false, 128); +WARP_SELECT_DECL(float, false, 256); +WARP_SELECT_DECL(float, false, 512); +WARP_SELECT_DECL(float, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_DECL(float, false, 2048); +#endif + +void runWarpSelect(Tensor& in, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= 2048); + + if (dir) { + if (k == 1) { + WARP_SELECT_CALL(float, true, 1); + } else if (k <= 32) { + WARP_SELECT_CALL(float, true, 32); + } else if (k <= 64) { + WARP_SELECT_CALL(float, true, 64); + } else if (k <= 128) { + WARP_SELECT_CALL(float, true, 128); + } else if (k <= 256) { + WARP_SELECT_CALL(float, true, 256); + } else if (k <= 512) { + WARP_SELECT_CALL(float, true, 512); + } else if (k <= 1024) { + WARP_SELECT_CALL(float, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + WARP_SELECT_CALL(float, true, 2048); +#endif + } + } else { + if (k == 1) { + WARP_SELECT_CALL(float, false, 1); + } else if (k <= 32) { + WARP_SELECT_CALL(float, false, 32); + } else if (k <= 64) { + WARP_SELECT_CALL(float, false, 64); + } else if (k <= 128) { + WARP_SELECT_CALL(float, false, 128); + } else if (k <= 256) { + WARP_SELECT_CALL(float, false, 256); + } else if (k <= 512) { + WARP_SELECT_CALL(float, false, 512); + } else if (k <= 1024) { + WARP_SELECT_CALL(float, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + WARP_SELECT_CALL(float, false, 2048); +#endif + } + } +} + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu new file mode 100644 index 0000000000..d700ecaee7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 + +// warp Q to thread Q: +// 1, 1 +// 32, 2 +// 64, 3 +// 128, 3 +// 256, 4 +// 512, 8 +// 1024, 8 +// 2048, 8 + +WARP_SELECT_DECL(half, true, 1); +WARP_SELECT_DECL(half, true, 32); +WARP_SELECT_DECL(half, true, 64); +WARP_SELECT_DECL(half, true, 128); +WARP_SELECT_DECL(half, true, 256); +WARP_SELECT_DECL(half, true, 512); +WARP_SELECT_DECL(half, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_DECL(half, true, 2048); +#endif + +WARP_SELECT_DECL(half, false, 1); +WARP_SELECT_DECL(half, false, 32); +WARP_SELECT_DECL(half, false, 64); +WARP_SELECT_DECL(half, false, 128); +WARP_SELECT_DECL(half, false, 256); +WARP_SELECT_DECL(half, false, 512); +WARP_SELECT_DECL(half, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_DECL(half, false, 2048); +#endif + +void runWarpSelect(Tensor& in, + Tensor& outK, + Tensor& outV, + bool dir, int k, cudaStream_t stream) { + FAISS_ASSERT(k <= 1024); + + if (dir) { + if (k == 1) { + WARP_SELECT_CALL(half, true, 1); + } else if (k <= 32) { + WARP_SELECT_CALL(half, true, 32); + } else if (k <= 64) { + WARP_SELECT_CALL(half, true, 64); + } else if (k <= 128) { + WARP_SELECT_CALL(half, true, 128); + } else if (k <= 256) { + WARP_SELECT_CALL(half, true, 256); + } else if (k <= 512) { + WARP_SELECT_CALL(half, true, 512); + } else if (k <= 1024) { + WARP_SELECT_CALL(half, true, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + WARP_SELECT_CALL(half, true, 2048); +#endif + } + } else { + if (k == 1) { + WARP_SELECT_CALL(half, false, 1); + } else if (k <= 32) { + WARP_SELECT_CALL(half, false, 32); + } else if (k <= 64) { + WARP_SELECT_CALL(half, false, 64); + } else if (k <= 128) { + WARP_SELECT_CALL(half, false, 128); + } else if (k <= 256) { + WARP_SELECT_CALL(half, false, 256); + } else if (k <= 512) { + WARP_SELECT_CALL(half, false, 512); + } else if (k <= 1024) { + WARP_SELECT_CALL(half, false, 1024); +#if GPU_MAX_SELECTION_K >= 2048 + } else if (k <= 2048) { + WARP_SELECT_CALL(half, false, 2048); +#endif + } + } +} + +#endif // FAISS_USE_FLOAT16 + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh new file mode 100644 index 0000000000..1b690b0306 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh @@ -0,0 +1,72 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { namespace gpu { + +template +__global__ void warpSelect(Tensor in, + Tensor outK, + Tensor outV, + K initK, + IndexType initV, + int k) { + constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + + WarpSelect, + NumWarpQ, NumThreadQ, ThreadsPerBlock> + heap(initK, initV, k); + + int warpId = threadIdx.x / kWarpSize; + int row = blockIdx.x * kNumWarps + warpId; + + if (row >= in.getSize(0)) { + return; + } + + int i = getLaneId(); + K* inStart = in[row][i].data(); + + // Whole warps must participate in the selection + int limit = utils::roundDown(in.getSize(1), kWarpSize); + + for (; i < limit; i += kWarpSize) { + heap.add(*inStart, (IndexType) i); + inStart += kWarpSize; + } + + // Handle non-warp multiple remainder + if (i < in.getSize(1)) { + heap.addThreadQ(*inStart, (IndexType) i); + } + + heap.reduce(); + heap.writeOut(outK[row].data(), + outV[row].data(), k); +} + +void runWarpSelect(Tensor& in, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); + +#ifdef FAISS_USE_FLOAT16 +void runWarpSelect(Tensor& in, + Tensor& outKeys, + Tensor& outIndices, + bool dir, int k, cudaStream_t stream); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh b/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh new file mode 100644 index 0000000000..ec2e5b618c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh @@ -0,0 +1,119 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include + +namespace faiss { namespace gpu { + +template +inline __device__ T shfl(const T val, + int srcLane, int width = kWarpSize) { +#if CUDA_VERSION >= 9000 + return __shfl_sync(0xffffffff, val, srcLane, width); +#else + return __shfl(val, srcLane, width); +#endif +} + +// CUDA SDK does not provide specializations for T* +template +inline __device__ T* shfl(T* const val, + int srcLane, int width = kWarpSize) { + static_assert(sizeof(T*) == sizeof(long long), "pointer size"); + long long v = (long long) val; + + return (T*) shfl(v, srcLane, width); +} + +template +inline __device__ T shfl_up(const T val, + unsigned int delta, int width = kWarpSize) { +#if CUDA_VERSION >= 9000 + return __shfl_up_sync(0xffffffff, val, delta, width); +#else + return __shfl_up(val, delta, width); +#endif +} + +// CUDA SDK does not provide specializations for T* +template +inline __device__ T* shfl_up(T* const val, + unsigned int delta, int width = kWarpSize) { + static_assert(sizeof(T*) == sizeof(long long), "pointer size"); + long long v = (long long) val; + + return (T*) shfl_up(v, delta, width); +} + +template +inline __device__ T shfl_down(const T val, + unsigned int delta, int width = kWarpSize) { +#if CUDA_VERSION >= 9000 + return __shfl_down_sync(0xffffffff, val, delta, width); +#else + return __shfl_down(val, delta, width); +#endif +} + +// CUDA SDK does not provide specializations for T* +template +inline __device__ T* shfl_down(T* const val, + unsigned int delta, int width = kWarpSize) { + static_assert(sizeof(T*) == sizeof(long long), "pointer size"); + long long v = (long long) val; + return (T*) shfl_down(v, delta, width); +} + +template +inline __device__ T shfl_xor(const T val, + int laneMask, int width = kWarpSize) { +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(0xffffffff, val, laneMask, width); +#else + return __shfl_xor(val, laneMask, width); +#endif +} + +// CUDA SDK does not provide specializations for T* +template +inline __device__ T* shfl_xor(T* const val, + int laneMask, int width = kWarpSize) { + static_assert(sizeof(T*) == sizeof(long long), "pointer size"); + long long v = (long long) val; + return (T*) shfl_xor(v, laneMask, width); +} + +#ifdef FAISS_USE_FLOAT16 +// CUDA 9.0+ has half shuffle +#if CUDA_VERSION < 9000 +inline __device__ half shfl(half v, + int srcLane, int width = kWarpSize) { + unsigned int vu = v.x; + vu = __shfl(vu, srcLane, width); + + half h; + h.x = (unsigned short) vu; + return h; +} + +inline __device__ half shfl_xor(half v, + int laneMask, int width = kWarpSize) { + unsigned int vu = v.x; + vu = __shfl_xor(vu, laneMask, width); + + half h; + h.x = (unsigned short) vu; + return h; +} +#endif // CUDA_VERSION +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat1.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat1.cu new file mode 100644 index 0000000000..d53f4dc2aa --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat1.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 1, 1); +BLOCK_SELECT_IMPL(float, false, 1, 1); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat128.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat128.cu new file mode 100644 index 0000000000..2010034a18 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat128.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 128, 3); +BLOCK_SELECT_IMPL(float, false, 128, 3); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat256.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat256.cu new file mode 100644 index 0000000000..bcd93f3038 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat256.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 256, 4); +BLOCK_SELECT_IMPL(float, false, 256, 4); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat32.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat32.cu new file mode 100644 index 0000000000..35073dcfcd --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat32.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 32, 2); +BLOCK_SELECT_IMPL(float, false, 32, 2); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat64.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat64.cu new file mode 100644 index 0000000000..c2671068ee --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloat64.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 64, 3); +BLOCK_SELECT_IMPL(float, false, 64, 3); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF1024.cu new file mode 100644 index 0000000000..4c9c5188cb --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF1024.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, false, 1024, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF2048.cu new file mode 100644 index 0000000000..7828c2045d --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF2048.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_IMPL(float, false, 2048, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF512.cu new file mode 100644 index 0000000000..f24ee0bfa6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatF512.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, false, 512, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT1024.cu new file mode 100644 index 0000000000..1f84b371e3 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT1024.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 1024, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT2048.cu new file mode 100644 index 0000000000..48037838a9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT2048.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +BLOCK_SELECT_IMPL(float, true, 2048, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT512.cu new file mode 100644 index 0000000000..3c93edfc09 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectFloatT512.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +BLOCK_SELECT_IMPL(float, true, 512, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu new file mode 100644 index 0000000000..d2525935c2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 1, 1); +BLOCK_SELECT_IMPL(half, false, 1, 1); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu new file mode 100644 index 0000000000..3759af9342 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 128, 3); +BLOCK_SELECT_IMPL(half, false, 128, 3); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu new file mode 100644 index 0000000000..a8a5cf13e9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 256, 4); +BLOCK_SELECT_IMPL(half, false, 256, 4); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu new file mode 100644 index 0000000000..18907c5119 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 32, 2); +BLOCK_SELECT_IMPL(half, false, 32, 2); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu new file mode 100644 index 0000000000..81a9a84a9f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 64, 3); +BLOCK_SELECT_IMPL(half, false, 64, 3); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu new file mode 100644 index 0000000000..e83b615193 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, false, 1024, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu new file mode 100644 index 0000000000..e06c334481 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu @@ -0,0 +1,19 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, false, 2048, 8); +#endif +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu new file mode 100644 index 0000000000..c1b67bd3de --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, false, 512, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu new file mode 100644 index 0000000000..2fd0dffa37 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 1024, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu new file mode 100644 index 0000000000..f91b6787e2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu @@ -0,0 +1,19 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 2048, 8); +#endif +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu new file mode 100644 index 0000000000..a2877db6ed --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +BLOCK_SELECT_IMPL(half, true, 512, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectImpl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectImpl.cuh new file mode 100644 index 0000000000..e7a5a03c22 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectImpl.cuh @@ -0,0 +1,106 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \ + extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream); \ + \ + extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& inK, \ + Tensor& inV, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream); + +#define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \ + void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) { \ + FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \ + FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \ + FAISS_ASSERT(outK.getSize(1) == k); \ + FAISS_ASSERT(outV.getSize(1) == k); \ + \ + auto grid = dim3(in.getSize(0)); \ + \ + constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \ + auto block = dim3(kBlockSelectNumThreads); \ + \ + FAISS_ASSERT(k <= WARP_Q); \ + FAISS_ASSERT(dir == DIR); \ + \ + auto kInit = dir ? Limits::getMin() : Limits::getMax(); \ + auto vInit = -1; \ + \ + if (bitset.getSize(0) == 0) \ + blockSelect \ + <<>>(in, outK, outV, kInit, vInit, k); \ + else \ + blockSelect \ + <<>>(in, bitset, outK, outV, kInit, vInit, k); \ + CUDA_TEST_ERROR(); \ + } \ + \ + void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& inK, \ + Tensor& inV, \ + Tensor& bitset, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) { \ + FAISS_ASSERT(inK.isSameSize(inV)); \ + FAISS_ASSERT(outK.isSameSize(outV)); \ + \ + auto grid = dim3(inK.getSize(0)); \ + \ + constexpr int kBlockSelectNumThreads = (WARP_Q <= 1024) ? 128 : 64; \ + auto block = dim3(kBlockSelectNumThreads); \ + \ + FAISS_ASSERT(k <= WARP_Q); \ + FAISS_ASSERT(dir == DIR); \ + \ + auto kInit = dir ? Limits::getMin() : Limits::getMax(); \ + auto vInit = -1; \ + \ + if (bitset.getSize(0) == 0) \ + blockSelectPair \ + <<>>(inK, inV, outK, outV, kInit, vInit, k); \ + else \ + blockSelectPair \ + <<>>(inK, inV, bitset, outK, outV, kInit, vInit, k); \ + CUDA_TEST_ERROR(); \ + } + + +#define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \ + runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + in, bitset, outK, outV, dir, k, stream) + +#define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \ + runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + inK, inV, bitset, outK, outV, dir, k, stream) diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat1.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat1.cu new file mode 100644 index 0000000000..c641e50fdd --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat1.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 1, 1); +WARP_SELECT_IMPL(float, false, 1, 1); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat128.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat128.cu new file mode 100644 index 0000000000..76d98d1f20 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat128.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 128, 3); +WARP_SELECT_IMPL(float, false, 128, 3); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat256.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat256.cu new file mode 100644 index 0000000000..a0dd47feb1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat256.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 256, 4); +WARP_SELECT_IMPL(float, false, 256, 4); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat32.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat32.cu new file mode 100644 index 0000000000..2461c94857 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat32.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 32, 2); +WARP_SELECT_IMPL(float, false, 32, 2); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat64.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat64.cu new file mode 100644 index 0000000000..a16c3830ca --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloat64.cu @@ -0,0 +1,15 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 64, 3); +WARP_SELECT_IMPL(float, false, 64, 3); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF1024.cu new file mode 100644 index 0000000000..9effd9ee75 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF1024.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, false, 1024, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF2048.cu new file mode 100644 index 0000000000..3abc7e61f8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF2048.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_IMPL(float, false, 2048, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF512.cu new file mode 100644 index 0000000000..0d92dc0361 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatF512.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, false, 512, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT1024.cu new file mode 100644 index 0000000000..caae455f26 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT1024.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 1024, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT2048.cu new file mode 100644 index 0000000000..b7cb048461 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT2048.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +WARP_SELECT_IMPL(float, true, 2048, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT512.cu new file mode 100644 index 0000000000..c8de86a237 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectFloatT512.cu @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +WARP_SELECT_IMPL(float, true, 512, 8); + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu new file mode 100644 index 0000000000..da3206d454 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 1, 1); +WARP_SELECT_IMPL(half, false, 1, 1); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu new file mode 100644 index 0000000000..8705e593c5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 128, 3); +WARP_SELECT_IMPL(half, false, 128, 3); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu new file mode 100644 index 0000000000..a7af219582 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 256, 4); +WARP_SELECT_IMPL(half, false, 256, 4); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu new file mode 100644 index 0000000000..d7ed389aec --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 32, 2); +WARP_SELECT_IMPL(half, false, 32, 2); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu new file mode 100644 index 0000000000..fea6c40b9c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu @@ -0,0 +1,17 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 64, 3); +WARP_SELECT_IMPL(half, false, 64, 3); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu new file mode 100644 index 0000000000..d99eea9c7c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, false, 1024, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu new file mode 100644 index 0000000000..030d28e17f --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu @@ -0,0 +1,19 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, false, 2048, 8); +#endif +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu new file mode 100644 index 0000000000..651d727580 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, false, 512, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu new file mode 100644 index 0000000000..5a576d7c48 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 1024, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu new file mode 100644 index 0000000000..b5bd1f9e53 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu @@ -0,0 +1,19 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace faiss { namespace gpu { + +#if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 2048, 8); +#endif +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu new file mode 100644 index 0000000000..21b8660273 --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu @@ -0,0 +1,16 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { namespace gpu { + +#ifdef FAISS_USE_FLOAT16 +WARP_SELECT_IMPL(half, true, 512, 8); +#endif + +} } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh new file mode 100644 index 0000000000..eee8ef0d5c --- /dev/null +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectImpl.cuh @@ -0,0 +1,47 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#define WARP_SELECT_DECL(TYPE, DIR, WARP_Q) \ + extern void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) + +#define WARP_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \ + void runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + Tensor& in, \ + Tensor& outK, \ + Tensor& outV, \ + bool dir, \ + int k, \ + cudaStream_t stream) { \ + \ + constexpr int kWarpSelectNumThreads = 128; \ + auto grid = dim3(utils::divUp(in.getSize(0), \ + (kWarpSelectNumThreads / kWarpSize))); \ + auto block = dim3(kWarpSelectNumThreads); \ + \ + FAISS_ASSERT(k <= WARP_Q); \ + FAISS_ASSERT(dir == DIR); \ + \ + auto kInit = dir ? Limits::getMin() : Limits::getMax(); \ + auto vInit = -1; \ + \ + warpSelect \ + <<>>(in, outK, outV, kInit, vInit, k); \ + CUDA_TEST_ERROR(); \ + } + +#define WARP_SELECT_CALL(TYPE, DIR, WARP_Q) \ + runWarpSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \ + in, outK, outV, dir, k, stream) diff --git a/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.cpp b/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.cpp new file mode 100644 index 0000000000..7482fb7b3b --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.cpp @@ -0,0 +1,322 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + + +namespace faiss { + + +/*********************************************************************** + * RangeSearchResult + ***********************************************************************/ + +RangeSearchResult::RangeSearchResult (idx_t nq, bool alloc_lims): nq (nq) { + if (alloc_lims) { + lims = new size_t [nq + 1]; + memset (lims, 0, sizeof(*lims) * (nq + 1)); + } else { + lims = nullptr; + } + labels = nullptr; + distances = nullptr; + buffer_size = 1024 * 256; +} + +/// called when lims contains the nb of elements result entries +/// for each query +void RangeSearchResult::do_allocation () { + size_t ofs = 0; + for (int i = 0; i < nq; i++) { + size_t n = lims[i]; + lims [i] = ofs; + ofs += n; + } + lims [nq] = ofs; + labels = new idx_t [ofs]; + distances = new float [ofs]; +} + +RangeSearchResult::~RangeSearchResult () { + delete [] labels; + delete [] distances; + delete [] lims; +} + + + + + +/*********************************************************************** + * BufferList + ***********************************************************************/ + + +BufferList::BufferList (size_t buffer_size): + buffer_size (buffer_size) +{ + wp = buffer_size; +} + +BufferList::~BufferList () +{ + for (int i = 0; i < buffers.size(); i++) { + delete [] buffers[i].ids; + delete [] buffers[i].dis; + } +} + +void BufferList::add (idx_t id, float dis) { + if (wp == buffer_size) { // need new buffer + append_buffer(); + } + Buffer & buf = buffers.back(); + buf.ids [wp] = id; + buf.dis [wp] = dis; + wp++; +} + + +void BufferList::append_buffer () +{ + Buffer buf = {new idx_t [buffer_size], new float [buffer_size]}; + buffers.push_back (buf); + wp = 0; +} + +/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to +/// tables dest_ids, dest_dis +void BufferList::copy_range (size_t ofs, size_t n, + idx_t * dest_ids, float *dest_dis) +{ + size_t bno = ofs / buffer_size; + ofs -= bno * buffer_size; + while (n > 0) { + size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs; + Buffer buf = buffers [bno]; + memcpy (dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids)); + memcpy (dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis)); + dest_ids += ncopy; + dest_dis += ncopy; + ofs = 0; + bno ++; + n -= ncopy; + } +} + + +/*********************************************************************** + * RangeSearchPartialResult + ***********************************************************************/ + +void RangeQueryResult::add (float dis, idx_t id) { + nres++; + pres->add (id, dis); +} + + + +RangeSearchPartialResult::RangeSearchPartialResult (RangeSearchResult * res_in): + BufferList(res_in->buffer_size), + res(res_in) +{} + + +/// begin a new result +RangeQueryResult & + RangeSearchPartialResult::new_result (idx_t qno) +{ + RangeQueryResult qres = {qno, 0, this}; + queries.push_back (qres); + return queries.back(); +} + + +void RangeSearchPartialResult::finalize () +{ + set_lims (); +#pragma omp barrier + +#pragma omp single + res->do_allocation (); + +#pragma omp barrier + copy_result (); +} + + +/// called by range_search before do_allocation +void RangeSearchPartialResult::set_lims () +{ + for (int i = 0; i < queries.size(); i++) { + RangeQueryResult & qres = queries[i]; + res->lims[qres.qno] = qres.nres; + } +} + +/// called by range_search after do_allocation +void RangeSearchPartialResult::copy_result (bool incremental) +{ + size_t ofs = 0; + for (int i = 0; i < queries.size(); i++) { + RangeQueryResult & qres = queries[i]; + + copy_range (ofs, qres.nres, + res->labels + res->lims[qres.qno], + res->distances + res->lims[qres.qno]); + if (incremental) { + res->lims[qres.qno] += qres.nres; + } + ofs += qres.nres; + } +} + +void RangeSearchPartialResult::merge (std::vector & + partial_results, bool do_delete) +{ + + int npres = partial_results.size(); + if (npres == 0) return; + RangeSearchResult *result = partial_results[0]->res; + size_t nx = result->nq; + + // count + for (const RangeSearchPartialResult * pres : partial_results) { + if (!pres) continue; + for (const RangeQueryResult &qres : pres->queries) { + result->lims[qres.qno] += qres.nres; + } + } + result->do_allocation (); + for (int j = 0; j < npres; j++) { + if (!partial_results[j]) continue; + partial_results[j]->copy_result (true); + if (do_delete) { + delete partial_results[j]; + partial_results[j] = nullptr; + } + } + + // reset the limits + for (size_t i = nx; i > 0; i--) { + result->lims [i] = result->lims [i - 1]; + } + result->lims [0] = 0; +} + +/*********************************************************************** + * IDSelectorRange + ***********************************************************************/ + +IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax): + imin (imin), imax (imax) +{ +} + +bool IDSelectorRange::is_member (idx_t id) const +{ + return id >= imin && id < imax; +} + +/*********************************************************************** + * IDSelectorArray + ***********************************************************************/ + +IDSelectorArray::IDSelectorArray (size_t n, const idx_t *ids): + n (n), ids(ids) +{ +} + +bool IDSelectorArray::is_member (idx_t id) const +{ + for (idx_t i = 0; i < n; i++) { + if (ids[i] == id) return true; + } + return false; +} + + +/*********************************************************************** + * IDSelectorBatch + ***********************************************************************/ + +IDSelectorBatch::IDSelectorBatch (size_t n, const idx_t *indices) +{ + nbits = 0; + while (n > (1L << nbits)) nbits++; + nbits += 5; + // for n = 1M, nbits = 25 is optimal, see P56659518 + + mask = (1L << nbits) - 1; + bloom.resize (1UL << (nbits - 3), 0); + for (long i = 0; i < n; i++) { + Index::idx_t id = indices[i]; + set.insert(id); + id &= mask; + bloom[id >> 3] |= 1 << (id & 7); + } +} + +bool IDSelectorBatch::is_member (idx_t i) const +{ + long im = i & mask; + if(!(bloom[im>>3] & (1 << (im & 7)))) { + return 0; + } + return set.count(i); +} + + +/*********************************************************** + * Interrupt callback + ***********************************************************/ + + +std::unique_ptr InterruptCallback::instance; + +std::mutex InterruptCallback::lock; + +void InterruptCallback::clear_instance () { + delete instance.release (); +} + +void InterruptCallback::check () { + if (!instance.get()) { + return; + } + if (instance->want_interrupt ()) { + FAISS_THROW_MSG ("computation interrupted"); + } +} + +bool InterruptCallback::is_interrupted () { + if (!instance.get()) { + return false; + } + std::lock_guard guard(lock); + return instance->want_interrupt(); +} + + +size_t InterruptCallback::get_period_hint (size_t flops) { + if (!instance.get()) { + return 1L << 30; // never check + } + // for 10M flops, it is reasonable to check once every 10 iterations + return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1); +} + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.h b/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.h new file mode 100644 index 0000000000..c82b9ed560 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/AuxIndexStructures.h @@ -0,0 +1,257 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +// Auxiliary index structures, that are used in indexes but that can +// be forward-declared + +#ifndef FAISS_AUX_INDEX_STRUCTURES_H +#define FAISS_AUX_INDEX_STRUCTURES_H + +#include + +#include +#include +#include +#include + +#include + +namespace faiss { + +/** The objective is to have a simple result structure while + * minimizing the number of mem copies in the result. The method + * do_allocation can be overloaded to allocate the result tables in + * the matrix type of a scripting language like Lua or Python. */ +struct RangeSearchResult { + size_t nq; ///< nb of queries + size_t *lims; ///< size (nq + 1) + + typedef Index::idx_t idx_t; + + idx_t *labels; ///< result for query i is labels[lims[i]:lims[i+1]] + float *distances; ///< corresponding distances (not sorted) + + size_t buffer_size; ///< size of the result buffers used + + /// lims must be allocated on input to range_search. + explicit RangeSearchResult (idx_t nq, bool alloc_lims=true); + + /// called when lims contains the nb of elements result entries + /// for each query + + virtual void do_allocation (); + + virtual ~RangeSearchResult (); +}; + + +/** Encapsulates a set of ids to remove. */ +struct IDSelector { + typedef Index::idx_t idx_t; + virtual bool is_member (idx_t id) const = 0; + virtual ~IDSelector() {} +}; + + + +/** remove ids between [imni, imax) */ +struct IDSelectorRange: IDSelector { + idx_t imin, imax; + + IDSelectorRange (idx_t imin, idx_t imax); + bool is_member(idx_t id) const override; + ~IDSelectorRange() override {} +}; + +/** simple list of elements to remove + * + * this is inefficient in most cases, except for IndexIVF with + * maintain_direct_map + */ +struct IDSelectorArray: IDSelector { + size_t n; + const idx_t *ids; + + IDSelectorArray (size_t n, const idx_t *ids); + bool is_member(idx_t id) const override; + ~IDSelectorArray() override {} +}; + +/** Remove ids from a set. Repetitions of ids in the indices set + * passed to the constructor does not hurt performance. The hash + * function used for the bloom filter and GCC's implementation of + * unordered_set are just the least significant bits of the id. This + * works fine for random ids or ids in sequences but will produce many + * hash collisions if lsb's are always the same */ +struct IDSelectorBatch: IDSelector { + + std::unordered_set set; + + typedef unsigned char uint8_t; + std::vector bloom; // assumes low bits of id are a good hash value + int nbits; + idx_t mask; + + IDSelectorBatch (size_t n, const idx_t *indices); + bool is_member(idx_t id) const override; + ~IDSelectorBatch() override {} +}; + +/**************************************************************** + * Result structures for range search. + * + * The main constraint here is that we want to support parallel + * queries from different threads in various ways: 1 thread per query, + * several threads per query. We store the actual results in blocks of + * fixed size rather than exponentially increasing memory. At the end, + * we copy the block content to a linear result array. + *****************************************************************/ + +/** List of temporary buffers used to store results before they are + * copied to the RangeSearchResult object. */ +struct BufferList { + typedef Index::idx_t idx_t; + + // buffer sizes in # entries + size_t buffer_size; + + struct Buffer { + idx_t *ids; + float *dis; + }; + + std::vector buffers; + size_t wp; ///< write pointer in the last buffer. + + explicit BufferList (size_t buffer_size); + + ~BufferList (); + + /// create a new buffer + void append_buffer (); + + /// add one result, possibly appending a new buffer if needed + void add (idx_t id, float dis); + + /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to + /// tables dest_ids, dest_dis + void copy_range (size_t ofs, size_t n, + idx_t * dest_ids, float *dest_dis); + +}; + +struct RangeSearchPartialResult; + +/// result structure for a single query +struct RangeQueryResult { + using idx_t = Index::idx_t; + idx_t qno; //< id of the query + size_t nres; //< nb of results for this query + RangeSearchPartialResult * pres; + + /// called by search function to report a new result + void add (float dis, idx_t id); +}; + +/// the entries in the buffers are split per query +struct RangeSearchPartialResult: BufferList { + RangeSearchResult * res; + + /// eventually the result will be stored in res_in + explicit RangeSearchPartialResult (RangeSearchResult * res_in); + + /// query ids + nb of results per query. + std::vector queries; + + /// begin a new result + RangeQueryResult & new_result (idx_t qno); + + /***************************************** + * functions used at the end of the search to merge the result + * lists */ + void finalize (); + + /// called by range_search before do_allocation + void set_lims (); + + /// called by range_search after do_allocation + void copy_result (bool incremental = false); + + /// merge a set of PartialResult's into one RangeSearchResult + /// on ouptut the partialresults are empty! + static void merge (std::vector & + partial_results, bool do_delete=true); + +}; + + +/*********************************************************** + * The distance computer maintains a current query and computes + * distances to elements in an index that supports random access. + * + * The DistanceComputer is not intended to be thread-safe (eg. because + * it maintains counters) so the distance functions are not const, + * instanciate one from each thread if needed. + ***********************************************************/ +struct DistanceComputer { + using idx_t = Index::idx_t; + + /// called before computing distances + virtual void set_query(const float *x) = 0; + + /// compute distance of vector i to current query + virtual float operator () (idx_t i) = 0; + + /// compute distance between two stored vectors + virtual float symmetric_dis (idx_t i, idx_t j) = 0; + + virtual ~DistanceComputer() {} +}; + +/*********************************************************** + * Interrupt callback + ***********************************************************/ + +struct InterruptCallback { + virtual bool want_interrupt () = 0; + virtual ~InterruptCallback() {} + + // lock that protects concurrent calls to is_interrupted + static std::mutex lock; + + static std::unique_ptr instance; + + static void clear_instance (); + + /** check if: + * - an interrupt callback is set + * - the callback retuns true + * if this is the case, then throw an exception. Should not be called + * from multiple threds. + */ + static void check (); + + /// same as check() but return true if is interrupted instead of + /// throwing. Can be called from multiple threads. + static bool is_interrupted (); + + /** assuming each iteration takes a certain number of flops, what + * is a reasonable interval to check for interrupts? + */ + static size_t get_period_hint (size_t flops); + +}; + + + +}; // namespace faiss + + + +#endif diff --git a/core/src/index/thirdparty/faiss/impl/FaissAssert.h b/core/src/index/thirdparty/faiss/impl/FaissAssert.h new file mode 100644 index 0000000000..f906589d46 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/FaissAssert.h @@ -0,0 +1,95 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_ASSERT_INCLUDED +#define FAISS_ASSERT_INCLUDED + +#include +#include +#include +#include + +/// +/// Assertions +/// + +#define FAISS_ASSERT(X) \ + do { \ + if (! (X)) { \ + fprintf(stderr, "Faiss assertion '%s' failed in %s " \ + "at %s:%d\n", \ + #X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) + +#define FAISS_ASSERT_MSG(X, MSG) \ + do { \ + if (! (X)) { \ + fprintf(stderr, "Faiss assertion '%s' failed in %s " \ + "at %s:%d; details: " MSG "\n", \ + #X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + abort(); \ + } \ + } while (false) + +#define FAISS_ASSERT_FMT(X, FMT, ...) \ + do { \ + if (! (X)) { \ + fprintf(stderr, "Faiss assertion '%s' failed in %s " \ + "at %s:%d; details: " FMT "\n", \ + #X, __PRETTY_FUNCTION__, __FILE__, __LINE__, __VA_ARGS__); \ + abort(); \ + } \ + } while (false) + +/// +/// Exceptions for returning user errors +/// + +#define FAISS_THROW_MSG(MSG) \ + do { \ + throw faiss::FaissException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) + +#define FAISS_THROW_FMT(FMT, ...) \ + do { \ + std::string __s; \ + int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \ + __s.resize(__size + 1); \ + snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \ + throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \ + } while (false) + +/// +/// Exceptions thrown upon a conditional failure +/// + +#define FAISS_THROW_IF_NOT(X) \ + do { \ + if (!(X)) { \ + FAISS_THROW_FMT("Error: '%s' failed", #X); \ + } \ + } while (false) + +#define FAISS_THROW_IF_NOT_MSG(X, MSG) \ + do { \ + if (!(X)) { \ + FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \ + } \ + } while (false) + +#define FAISS_THROW_IF_NOT_FMT(X, FMT, ...) \ + do { \ + if (!(X)) { \ + FAISS_THROW_FMT("Error: '%s' failed: " FMT, #X, __VA_ARGS__); \ + } \ + } while (false) + +#endif diff --git a/core/src/index/thirdparty/faiss/impl/FaissException.cpp b/core/src/index/thirdparty/faiss/impl/FaissException.cpp new file mode 100644 index 0000000000..c79930e55e --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/FaissException.cpp @@ -0,0 +1,66 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +namespace faiss { + +FaissException::FaissException(const std::string& m) + : msg(m) { +} + +FaissException::FaissException(const std::string& m, + const char* funcName, + const char* file, + int line) { + int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s", + funcName, file, line, m.c_str()); + msg.resize(size + 1); + snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s", + funcName, file, line, m.c_str()); +} + +const char* +FaissException::what() const noexcept { + return msg.c_str(); +} + +void handleExceptions( + std::vector>& exceptions) { + if (exceptions.size() == 1) { + // throw the single received exception directly + std::rethrow_exception(exceptions.front().second); + + } else if (exceptions.size() > 1) { + // multiple exceptions; aggregate them and return a single exception + std::stringstream ss; + + for (auto& p : exceptions) { + try { + std::rethrow_exception(p.second); + } catch (std::exception& ex) { + if (ex.what()) { + // exception message available + ss << "Exception thrown from index " << p.first << ": " + << ex.what() << "\n"; + } else { + // No message available + ss << "Unknown exception thrown from index " << p.first << "\n"; + } + } catch (...) { + ss << "Unknown exception thrown from index " << p.first << "\n"; + } + } + + throw FaissException(ss.str()); + } +} + +} diff --git a/core/src/index/thirdparty/faiss/impl/FaissException.h b/core/src/index/thirdparty/faiss/impl/FaissException.h new file mode 100644 index 0000000000..9d54edbad5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/FaissException.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_EXCEPTION_INCLUDED +#define FAISS_EXCEPTION_INCLUDED + +#include +#include +#include +#include + +namespace faiss { + +/// Base class for Faiss exceptions +class FaissException : public std::exception { + public: + explicit FaissException(const std::string& msg); + + FaissException(const std::string& msg, + const char* funcName, + const char* file, + int line); + + /// from std::exception + const char* what() const noexcept override; + + std::string msg; +}; + +/// Handle multiple exceptions from worker threads, throwing an appropriate +/// exception that aggregates the information +/// The pair int is the thread that generated the exception +void +handleExceptions(std::vector>& exceptions); + +/** bare-bones unique_ptr + * this one deletes with delete [] */ +template +struct ScopeDeleter { + const T * ptr; + explicit ScopeDeleter (const T* ptr = nullptr): ptr (ptr) {} + void release () {ptr = nullptr; } + void set (const T * ptr_in) { ptr = ptr_in; } + void swap (ScopeDeleter &other) {std::swap (ptr, other.ptr); } + ~ScopeDeleter () { + delete [] ptr; + } +}; + +/** same but deletes with the simple delete (least common case) */ +template +struct ScopeDeleter1 { + const T * ptr; + explicit ScopeDeleter1 (const T* ptr = nullptr): ptr (ptr) {} + void release () {ptr = nullptr; } + void set (const T * ptr_in) { ptr = ptr_in; } + void swap (ScopeDeleter1 &other) {std::swap (ptr, other.ptr); } + ~ScopeDeleter1 () { + delete ptr; + } +}; + +} + +#endif diff --git a/core/src/index/thirdparty/faiss/impl/HNSW.cpp b/core/src/index/thirdparty/faiss/impl/HNSW.cpp new file mode 100644 index 0000000000..740ab0d136 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/HNSW.cpp @@ -0,0 +1,817 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + +namespace faiss { + + +/************************************************************** + * HNSW structure implementation + **************************************************************/ + +int HNSW::nb_neighbors(int layer_no) const +{ + return cum_nneighbor_per_level[layer_no + 1] - + cum_nneighbor_per_level[layer_no]; +} + +void HNSW::set_nb_neighbors(int level_no, int n) +{ + FAISS_THROW_IF_NOT(levels.size() == 0); + int cur_n = nb_neighbors(level_no); + for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) { + cum_nneighbor_per_level[i] += n - cur_n; + } +} + +int HNSW::cum_nb_neighbors(int layer_no) const +{ + return cum_nneighbor_per_level[layer_no]; +} + +void HNSW::neighbor_range(idx_t no, int layer_no, + size_t * begin, size_t * end) const +{ + size_t o = offsets[no]; + *begin = o + cum_nb_neighbors(layer_no); + *end = o + cum_nb_neighbors(layer_no + 1); +} + + + +HNSW::HNSW(int M) : rng(12345) { + set_default_probas(M, 1.0 / log(M)); + max_level = -1; + entry_point = -1; + efSearch = 16; + efConstruction = 40; + upper_beam = 1; + offsets.push_back(0); +} + + +int HNSW::random_level() +{ + double f = rng.rand_float(); + // could be a bit faster with bissection + for (int level = 0; level < assign_probas.size(); level++) { + if (f < assign_probas[level]) { + return level; + } + f -= assign_probas[level]; + } + // happens with exponentially low probability + return assign_probas.size() - 1; +} + +void HNSW::set_default_probas(int M, float levelMult) +{ + int nn = 0; + cum_nneighbor_per_level.push_back (0); + for (int level = 0; ;level++) { + float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult)); + if (proba < 1e-9) break; + assign_probas.push_back(proba); + nn += level == 0 ? M * 2 : M; + cum_nneighbor_per_level.push_back (nn); + } +} + +void HNSW::clear_neighbor_tables(int level) +{ + for (int i = 0; i < levels.size(); i++) { + size_t begin, end; + neighbor_range(i, level, &begin, &end); + for (size_t j = begin; j < end; j++) { + neighbors[j] = -1; + } + } +} + + +void HNSW::reset() { + max_level = -1; + entry_point = -1; + offsets.clear(); + offsets.push_back(0); + levels.clear(); + neighbors.clear(); +} + + + +void HNSW::print_neighbor_stats(int level) const +{ + FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size()); + printf("stats on level %d, max %d neighbors per vertex:\n", + level, nb_neighbors(level)); + size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; +#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ + reduction(+: tot_reciprocal) reduction(+: n_node) + for (int i = 0; i < levels.size(); i++) { + if (levels[i] > level) { + n_node++; + size_t begin, end; + neighbor_range(i, level, &begin, &end); + std::unordered_set neighset; + for (size_t j = begin; j < end; j++) { + if (neighbors [j] < 0) break; + neighset.insert(neighbors[j]); + } + int n_neigh = neighset.size(); + int n_common = 0; + int n_reciprocal = 0; + for (size_t j = begin; j < end; j++) { + storage_idx_t i2 = neighbors[j]; + if (i2 < 0) break; + FAISS_ASSERT(i2 != i); + size_t begin2, end2; + neighbor_range(i2, level, &begin2, &end2); + for (size_t j2 = begin2; j2 < end2; j2++) { + storage_idx_t i3 = neighbors[j2]; + if (i3 < 0) break; + if (i3 == i) { + n_reciprocal++; + continue; + } + if (neighset.count(i3)) { + neighset.erase(i3); + n_common++; + } + } + } + tot_neigh += n_neigh; + tot_common += n_common; + tot_reciprocal += n_reciprocal; + } + } + float normalizer = n_node; + printf(" nb of nodes at that level %ld\n", n_node); + printf(" neighbors per node: %.2f (%ld)\n", + tot_neigh / normalizer, tot_neigh); + printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer); + printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%ld)\n", + tot_common / normalizer, tot_common); + + + +} + + +void HNSW::fill_with_random_links(size_t n) +{ + int max_level = prepare_level_tab(n); + RandomGenerator rng2(456); + + for (int level = max_level - 1; level >= 0; --level) { + std::vector elts; + for (int i = 0; i < n; i++) { + if (levels[i] > level) { + elts.push_back(i); + } + } + printf ("linking %ld elements in level %d\n", + elts.size(), level); + + if (elts.size() == 1) continue; + + for (int ii = 0; ii < elts.size(); ii++) { + int i = elts[ii]; + size_t begin, end; + neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + int other = 0; + do { + other = elts[rng2.rand_int(elts.size())]; + } while(other == i); + + neighbors[j] = other; + } + } + } +} + + +int HNSW::prepare_level_tab(size_t n, bool preset_levels) +{ + size_t n0 = offsets.size() - 1; + + if (preset_levels) { + FAISS_ASSERT (n0 + n == levels.size()); + } else { + FAISS_ASSERT (n0 == levels.size()); + for (int i = 0; i < n; i++) { + int pt_level = random_level(); + levels.push_back(pt_level + 1); + } + } + + int max_level = 0; + for (int i = 0; i < n; i++) { + int pt_level = levels[i + n0] - 1; + if (pt_level > max_level) max_level = pt_level; + offsets.push_back(offsets.back() + + cum_nb_neighbors(pt_level + 1)); + neighbors.resize(offsets.back(), -1); + } + + return max_level; +} + + +/** Enumerate vertices from farthest to nearest from query, keep a + * neighbor only if there is no previous neighbor that is closer to + * that vertex than the query. + */ +void HNSW::shrink_neighbor_list( + DistanceComputer& qdis, + std::priority_queue& input, + std::vector& output, + int max_size) +{ + while (input.size() > 0) { + NodeDistFarther v1 = input.top(); + input.pop(); + float dist_v1_q = v1.d; + + bool good = true; + for (NodeDistFarther v2 : output) { + float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id); + + if (dist_v1_v2 < dist_v1_q) { + good = false; + break; + } + } + + if (good) { + output.push_back(v1); + if (output.size() >= max_size) { + return; + } + } + } +} + + +namespace { + + +using storage_idx_t = HNSW::storage_idx_t; +using NodeDistCloser = HNSW::NodeDistCloser; +using NodeDistFarther = HNSW::NodeDistFarther; + + +/************************************************************** + * Addition subroutines + **************************************************************/ + + +/// remove neighbors from the list to make it smaller than max_size +void shrink_neighbor_list( + DistanceComputer& qdis, + std::priority_queue& resultSet1, + int max_size) +{ + if (resultSet1.size() < max_size) { + return; + } + std::priority_queue resultSet; + std::vector returnlist; + + while (resultSet1.size() > 0) { + resultSet.emplace(resultSet1.top().d, resultSet1.top().id); + resultSet1.pop(); + } + + HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size); + + for (NodeDistFarther curen2 : returnlist) { + resultSet1.emplace(curen2.d, curen2.id); + } + +} + + +/// add a link between two elements, possibly shrinking the list +/// of links to make room for it. +void add_link(HNSW& hnsw, + DistanceComputer& qdis, + storage_idx_t src, storage_idx_t dest, + int level) +{ + size_t begin, end; + hnsw.neighbor_range(src, level, &begin, &end); + if (hnsw.neighbors[end - 1] == -1) { + // there is enough room, find a slot to add it + size_t i = end; + while(i > begin) { + if (hnsw.neighbors[i - 1] != -1) break; + i--; + } + hnsw.neighbors[i] = dest; + return; + } + + // otherwise we let them fight out which to keep + + // copy to resultSet... + std::priority_queue resultSet; + resultSet.emplace(qdis.symmetric_dis(src, dest), dest); + for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG + storage_idx_t neigh = hnsw.neighbors[i]; + resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh); + } + + shrink_neighbor_list(qdis, resultSet, end - begin); + + // ...and back + size_t i = begin; + while (resultSet.size()) { + hnsw.neighbors[i++] = resultSet.top().id; + resultSet.pop(); + } + // they may have shrunk more than just by 1 element + while(i < end) { + hnsw.neighbors[i++] = -1; + } +} + +/// search neighbors on a single level, starting from an entry point +void search_neighbors_to_add( + HNSW& hnsw, + DistanceComputer& qdis, + std::priority_queue& results, + int entry_point, + float d_entry_point, + int level, + VisitedTable &vt) +{ + // top is nearest candidate + std::priority_queue candidates; + + NodeDistFarther ev(d_entry_point, entry_point); + candidates.push(ev); + results.emplace(d_entry_point, entry_point); + vt.set(entry_point); + + while (!candidates.empty()) { + // get nearest + const NodeDistFarther &currEv = candidates.top(); + + if (currEv.d > results.top().d) { + break; + } + int currNode = currEv.id; + candidates.pop(); + + // loop over neighbors + size_t begin, end; + hnsw.neighbor_range(currNode, level, &begin, &end); + for(size_t i = begin; i < end; i++) { + storage_idx_t nodeId = hnsw.neighbors[i]; + if (nodeId < 0) break; + if (vt.get(nodeId)) continue; + vt.set(nodeId); + + float dis = qdis(nodeId); + NodeDistFarther evE1(dis, nodeId); + + if (results.size() < hnsw.efConstruction || + results.top().d > dis) { + + results.emplace(dis, nodeId); + candidates.emplace(dis, nodeId); + if (results.size() > hnsw.efConstruction) { + results.pop(); + } + } + } + } + vt.advance(); +} + + +/************************************************************** + * Searching subroutines + **************************************************************/ + +/// greedily update a nearest vector at a given level +void greedy_update_nearest(const HNSW& hnsw, + DistanceComputer& qdis, + int level, + storage_idx_t& nearest, + float& d_nearest) +{ + for(;;) { + storage_idx_t prev_nearest = nearest; + + size_t begin, end; + hnsw.neighbor_range(nearest, level, &begin, &end); + for(size_t i = begin; i < end; i++) { + storage_idx_t v = hnsw.neighbors[i]; + if (v < 0) break; + float dis = qdis(v); + if (dis < d_nearest) { + nearest = v; + d_nearest = dis; + } + } + if (nearest == prev_nearest) { + return; + } + } +} + + +} // namespace + + +/// Finds neighbors and builds links with them, starting from an entry +/// point. The own neighbor list is assumed to be locked. +void HNSW::add_links_starting_from(DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + float d_nearest, + int level, + omp_lock_t *locks, + VisitedTable &vt) +{ + std::priority_queue link_targets; + + search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest, + level, vt); + + // but we can afford only this many neighbors + int M = nb_neighbors(level); + + ::faiss::shrink_neighbor_list(ptdis, link_targets, M); + + while (!link_targets.empty()) { + int other_id = link_targets.top().id; + + omp_set_lock(&locks[other_id]); + add_link(*this, ptdis, other_id, pt_id, level); + omp_unset_lock(&locks[other_id]); + + add_link(*this, ptdis, pt_id, other_id, level); + + link_targets.pop(); + } +} + + +/************************************************************** + * Building, parallel + **************************************************************/ + +void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id, + std::vector& locks, + VisitedTable& vt) +{ + // greedy search on upper levels + + storage_idx_t nearest; +#pragma omp critical + { + nearest = entry_point; + + if (nearest == -1) { + max_level = pt_level; + entry_point = pt_id; + } + } + + if (nearest < 0) { + return; + } + + omp_set_lock(&locks[pt_id]); + + int level = max_level; // level at which we start adding neighbors + float d_nearest = ptdis(nearest); + + for(; level > pt_level; level--) { + greedy_update_nearest(*this, ptdis, level, nearest, d_nearest); + } + + for(; level >= 0; level--) { + add_links_starting_from(ptdis, pt_id, nearest, d_nearest, + level, locks.data(), vt); + } + + omp_unset_lock(&locks[pt_id]); + + if (pt_level > max_level) { + max_level = pt_level; + entry_point = pt_id; + } +} + + +/** Do a BFS on the candidates list */ + +int HNSW::search_from_candidates( + DistanceComputer& qdis, int k, + idx_t *I, float *D, + MinimaxHeap& candidates, + VisitedTable& vt, + int level, int nres_in) const +{ + int nres = nres_in; + int ndis = 0; + for (int i = 0; i < candidates.size(); i++) { + idx_t v1 = candidates.ids[i]; + float d = candidates.dis[i]; + FAISS_ASSERT(v1 >= 0); + if (nres < k) { + faiss::maxheap_push(++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_pop(nres--, D, I); + faiss::maxheap_push(++nres, D, I, d, v1); + } + vt.set(v1); + } + + bool do_dis_check = check_relative_distance; + int nstep = 0; + + while (candidates.size() > 0) { + float d0 = 0; + int v0 = candidates.pop_min(&d0); + + if (do_dis_check) { + // tricky stopping condition: there are more that ef + // distances that are processed already that are smaller + // than d0 + + int n_dis_below = candidates.count_below(d0); + if(n_dis_below >= efSearch) { + break; + } + } + + size_t begin, end; + neighbor_range(v0, level, &begin, &end); + + for (size_t j = begin; j < end; j++) { + int v1 = neighbors[j]; + if (v1 < 0) break; + if (vt.get(v1)) { + continue; + } + vt.set(v1); + ndis++; + float d = qdis(v1); + if (nres < k) { + faiss::maxheap_push(++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_pop(nres--, D, I); + faiss::maxheap_push(++nres, D, I, d, v1); + } + candidates.push(v1, d); + } + + nstep++; + if (!do_dis_check && nstep > efSearch) { + break; + } + } + + if (level == 0) { +#pragma omp critical + { + hnsw_stats.n1 ++; + if (candidates.size() == 0) { + hnsw_stats.n2 ++; + } + hnsw_stats.n3 += ndis; + } + } + + return nres; +} + + +/************************************************************** + * Searching + **************************************************************/ + +std::priority_queue HNSW::search_from_candidate_unbounded( + const Node& node, + DistanceComputer& qdis, + int ef, + VisitedTable *vt) const +{ + int ndis = 0; + std::priority_queue top_candidates; + std::priority_queue, std::greater> candidates; + + top_candidates.push(node); + candidates.push(node); + + vt->set(node.second); + + while (!candidates.empty()) { + float d0; + storage_idx_t v0; + std::tie(d0, v0) = candidates.top(); + + if (d0 > top_candidates.top().first) { + break; + } + + candidates.pop(); + + size_t begin, end; + neighbor_range(v0, 0, &begin, &end); + + for (size_t j = begin; j < end; ++j) { + int v1 = neighbors[j]; + + if (v1 < 0) { + break; + } + if (vt->get(v1)) { + continue; + } + + vt->set(v1); + + float d1 = qdis(v1); + ++ndis; + + if (top_candidates.top().first > d1 || top_candidates.size() < ef) { + candidates.emplace(d1, v1); + top_candidates.emplace(d1, v1); + + if (top_candidates.size() > ef) { + top_candidates.pop(); + } + } + } + } + +#pragma omp critical + { + ++hnsw_stats.n1; + if (candidates.size() == 0) { + ++hnsw_stats.n2; + } + hnsw_stats.n3 += ndis; + } + + return top_candidates; +} + +void HNSW::search(DistanceComputer& qdis, int k, + idx_t *I, float *D, + VisitedTable& vt) const +{ + if (upper_beam == 1) { + + // greedy search on upper levels + storage_idx_t nearest = entry_point; + float d_nearest = qdis(nearest); + + for(int level = max_level; level >= 1; level--) { + greedy_update_nearest(*this, qdis, level, nearest, d_nearest); + } + + int ef = std::max(efSearch, k); + if (search_bounded_queue) { + MinimaxHeap candidates(ef); + + candidates.push(nearest, d_nearest); + + search_from_candidates(qdis, k, I, D, candidates, vt, 0); + } else { + std::priority_queue top_candidates = + search_from_candidate_unbounded(Node(d_nearest, nearest), + qdis, ef, &vt); + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + + int nres = 0; + while (!top_candidates.empty()) { + float d; + storage_idx_t label; + std::tie(d, label) = top_candidates.top(); + faiss::maxheap_push(++nres, D, I, d, label); + top_candidates.pop(); + } + } + + vt.advance(); + + } else { + int candidates_size = upper_beam; + MinimaxHeap candidates(candidates_size); + + std::vector I_to_next(candidates_size); + std::vector D_to_next(candidates_size); + + int nres = 1; + I_to_next[0] = entry_point; + D_to_next[0] = qdis(entry_point); + + for(int level = max_level; level >= 0; level--) { + + // copy I, D -> candidates + + candidates.clear(); + + for (int i = 0; i < nres; i++) { + candidates.push(I_to_next[i], D_to_next[i]); + } + + if (level == 0) { + nres = search_from_candidates(qdis, k, I, D, candidates, vt, 0); + } else { + nres = search_from_candidates( + qdis, candidates_size, + I_to_next.data(), D_to_next.data(), + candidates, vt, level + ); + } + vt.advance(); + } + } +} + + +void HNSW::MinimaxHeap::push(storage_idx_t i, float v) { + if (k == n) { + if (v >= dis[0]) return; + faiss::heap_pop (k--, dis.data(), ids.data()); + --nvalid; + } + faiss::heap_push (++k, dis.data(), ids.data(), v, i); + ++nvalid; +} + +float HNSW::MinimaxHeap::max() const { + return dis[0]; +} + +int HNSW::MinimaxHeap::size() const { + return nvalid; +} + +void HNSW::MinimaxHeap::clear() { + nvalid = k = 0; +} + +int HNSW::MinimaxHeap::pop_min(float *vmin_out) { + assert(k > 0); + // returns min. This is an O(n) operation + int i = k - 1; + while (i >= 0) { + if (ids[i] != -1) break; + i--; + } + if (i == -1) return -1; + int imin = i; + float vmin = dis[i]; + i--; + while(i >= 0) { + if (ids[i] != -1 && dis[i] < vmin) { + vmin = dis[i]; + imin = i; + } + i--; + } + if (vmin_out) *vmin_out = vmin; + int ret = ids[imin]; + ids[imin] = -1; + --nvalid; + + return ret; +} + +int HNSW::MinimaxHeap::count_below(float thresh) { + int n_below = 0; + for(int i = 0; i < k; i++) { + if (dis[i] < thresh) { + n_below++; + } + } + + return n_below; +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/HNSW.h b/core/src/index/thirdparty/faiss/impl/HNSW.h new file mode 100644 index 0000000000..cde99c1c29 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/HNSW.h @@ -0,0 +1,275 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include + +#include + +#include +#include +#include +#include + + +namespace faiss { + + +/** Implementation of the Hierarchical Navigable Small World + * datastructure. + * + * Efficient and robust approximate nearest neighbor search using + * Hierarchical Navigable Small World graphs + * + * Yu. A. Malkov, D. A. Yashunin, arXiv 2017 + * + * This implmentation is heavily influenced by the NMSlib + * implementation by Yury Malkov and Leonid Boystov + * (https://github.com/searchivarius/nmslib) + * + * The HNSW object stores only the neighbor link structure, see + * IndexHNSW.h for the full index object. + */ + + +struct VisitedTable; +struct DistanceComputer; // from AuxIndexStructures + +struct HNSW { + /// internal storage of vectors (32 bits: this is expensive) + typedef int storage_idx_t; + + /// Faiss results are 64-bit + typedef Index::idx_t idx_t; + + typedef std::pair Node; + + /** Heap structure that allows fast + */ + struct MinimaxHeap { + int n; + int k; + int nvalid; + + std::vector ids; + std::vector dis; + typedef faiss::CMax HC; + + explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {} + + void push(storage_idx_t i, float v); + + float max() const; + + int size() const; + + void clear(); + + int pop_min(float *vmin_out = nullptr); + + int count_below(float thresh); + }; + + + /// to sort pairs of (id, distance) from nearest to fathest or the reverse + struct NodeDistCloser { + float d; + int id; + NodeDistCloser(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; } + }; + + struct NodeDistFarther { + float d; + int id; + NodeDistFarther(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; } + }; + + + /// assignment probability to each layer (sum=1) + std::vector assign_probas; + + /// number of neighbors stored per layer (cumulative), should not + /// be changed after first add + std::vector cum_nneighbor_per_level; + + /// level of each vector (base level = 1), size = ntotal + std::vector levels; + + /// offsets[i] is the offset in the neighbors array where vector i is stored + /// size ntotal + 1 + std::vector offsets; + + /// neighbors[offsets[i]:offsets[i+1]] is the list of neighbors of vector i + /// for all levels. this is where all storage goes. + std::vector neighbors; + + /// entry point in the search structure (one of the points with maximum level + storage_idx_t entry_point; + + faiss::RandomGenerator rng; + + /// maximum level + int max_level; + + /// expansion factor at construction time + int efConstruction; + + /// expansion factor at search time + int efSearch; + + /// during search: do we check whether the next best distance is good enough? + bool check_relative_distance = true; + + /// number of entry points in levels > 0. + int upper_beam; + + /// use bounded queue during exploration + bool search_bounded_queue = true; + + // methods that initialize the tree sizes + + /// initialize the assign_probas and cum_nneighbor_per_level to + /// have 2*M links on level 0 and M links on levels > 0 + void set_default_probas(int M, float levelMult); + + /// set nb of neighbors for this level (before adding anything) + void set_nb_neighbors(int level_no, int n); + + // methods that access the tree sizes + + /// nb of neighbors for this level + int nb_neighbors(int layer_no) const; + + /// cumumlative nb up to (and excluding) this level + int cum_nb_neighbors(int layer_no) const; + + /// range of entries in the neighbors table of vertex no at layer_no + void neighbor_range(idx_t no, int layer_no, + size_t * begin, size_t * end) const; + + /// only mandatory parameter: nb of neighbors + explicit HNSW(int M = 32); + + /// pick a random level for a new point + int random_level(); + + /// add n random levels to table (for debugging...) + void fill_with_random_links(size_t n); + + void add_links_starting_from(DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + float d_nearest, + int level, + omp_lock_t *locks, + VisitedTable &vt); + + + /** add point pt_id on all levels <= pt_level and build the link + * structure for them. */ + void add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id, + std::vector& locks, + VisitedTable& vt); + + int search_from_candidates(DistanceComputer& qdis, int k, + idx_t *I, float *D, + MinimaxHeap& candidates, + VisitedTable &vt, + int level, int nres_in = 0) const; + + std::priority_queue search_from_candidate_unbounded( + const Node& node, + DistanceComputer& qdis, + int ef, + VisitedTable *vt + ) const; + + /// search interface + void search(DistanceComputer& qdis, int k, + idx_t *I, float *D, + VisitedTable& vt) const; + + void reset(); + + void clear_neighbor_tables(int level); + void print_neighbor_stats(int level) const; + + int prepare_level_tab(size_t n, bool preset_levels = false); + + static void shrink_neighbor_list( + DistanceComputer& qdis, + std::priority_queue& input, + std::vector& output, + int max_size); + +}; + + +/************************************************************** + * Auxiliary structures + **************************************************************/ + +/// set implementation optimized for fast access. +struct VisitedTable { + std::vector visited; + int visno; + + explicit VisitedTable(int size) + : visited(size), visno(1) {} + + /// set flog #no to true + void set(int no) { + visited[no] = visno; + } + + /// get flag #no + bool get(int no) const { + return visited[no] == visno; + } + + /// reset all flags to false + void advance() { + visno++; + if (visno == 250) { + // 250 rather than 255 because sometimes we use visno and visno+1 + memset(visited.data(), 0, sizeof(visited[0]) * visited.size()); + visno = 1; + } + } +}; + + +struct HNSWStats { + size_t n1, n2, n3; + size_t ndis; + size_t nreorder; + bool view; + + HNSWStats() { + reset(); + } + + void reset() { + n1 = n2 = n3 = 0; + ndis = 0; + nreorder = 0; + view = false; + } +}; + +// global var that collects them all +extern HNSWStats hnsw_stats; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/PolysemousTraining.cpp b/core/src/index/thirdparty/faiss/impl/PolysemousTraining.cpp new file mode 100644 index 0000000000..32166dce6a --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/PolysemousTraining.cpp @@ -0,0 +1,954 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include + +/***************************************** + * Mixed PQ / Hamming + ******************************************/ + +namespace faiss { + + +/**************************************************** + * Optimization code + ****************************************************/ + +SimulatedAnnealingParameters::SimulatedAnnealingParameters () +{ + // set some reasonable defaults for the optimization + init_temperature = 0.7; + temperature_decay = pow (0.9, 1/500.); + // reduce by a factor 0.9 every 500 it + n_iter = 500000; + n_redo = 2; + seed = 123; + verbose = 0; + only_bit_flips = false; + init_random = false; +} + +// what would the cost update be if iw and jw were swapped? +// default implementation just computes both and computes the difference +double PermutationObjective::cost_update ( + const int *perm, int iw, int jw) const +{ + double orig_cost = compute_cost (perm); + + std::vector perm2 (n); + for (int i = 0; i < n; i++) + perm2[i] = perm[i]; + perm2[iw] = perm[jw]; + perm2[jw] = perm[iw]; + + double new_cost = compute_cost (perm2.data()); + return new_cost - orig_cost; +} + + + + +SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer ( + PermutationObjective *obj, + const SimulatedAnnealingParameters &p): + SimulatedAnnealingParameters (p), + obj (obj), + n(obj->n), + logfile (nullptr) +{ + rnd = new RandomGenerator (p.seed); + FAISS_THROW_IF_NOT (n < 100000 && n >=0 ); +} + +SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer () +{ + delete rnd; +} + +// run the optimization and return the best result in best_perm +double SimulatedAnnealingOptimizer::run_optimization (int * best_perm) +{ + double min_cost = 1e30; + + // just do a few runs of the annealing and keep the lowest output cost + for (int it = 0; it < n_redo; it++) { + std::vector perm(n); + for (int i = 0; i < n; i++) + perm[i] = i; + if (init_random) { + for (int i = 0; i < n; i++) { + int j = i + rnd->rand_int (n - i); + std::swap (perm[i], perm[j]); + } + } + float cost = optimize (perm.data()); + if (logfile) fprintf (logfile, "\n"); + if(verbose > 1) { + printf (" optimization run %d: cost=%g %s\n", + it, cost, cost < min_cost ? "keep" : ""); + } + if (cost < min_cost) { + memcpy (best_perm, perm.data(), sizeof(perm[0]) * n); + min_cost = cost; + } + } + return min_cost; +} + +// perform the optimization loop, starting from and modifying +// permutation in-place +double SimulatedAnnealingOptimizer::optimize (int *perm) +{ + double cost = init_cost = obj->compute_cost (perm); + int log2n = 0; + while (!(n <= (1 << log2n))) log2n++; + double temperature = init_temperature; + int n_swap = 0, n_hot = 0; + for (int it = 0; it < n_iter; it++) { + temperature = temperature * temperature_decay; + int iw, jw; + if (only_bit_flips) { + iw = rnd->rand_int (n); + jw = iw ^ (1 << rnd->rand_int (log2n)); + } else { + iw = rnd->rand_int (n); + jw = rnd->rand_int (n - 1); + if (jw == iw) jw++; + } + double delta_cost = obj->cost_update (perm, iw, jw); + if (delta_cost < 0 || rnd->rand_float () < temperature) { + std::swap (perm[iw], perm[jw]); + cost += delta_cost; + n_swap++; + if (delta_cost >= 0) n_hot++; + } + if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) { + printf (" iteration %d cost %g temp %g n_swap %d " + "(%d hot) \r", + it, cost, temperature, n_swap, n_hot); + fflush(stdout); + } + if (logfile) { + fprintf (logfile, "%d %g %g %d %d\n", + it, cost, temperature, n_swap, n_hot); + } + } + if (verbose > 1) printf("\n"); + return cost; +} + + + + + +/**************************************************** + * Cost functions: ReproduceDistanceTable + ****************************************************/ + + + + + + +static inline int hamming_dis (uint64_t a, uint64_t b) +{ + return __builtin_popcountl (a ^ b); +} + +namespace { + +/// optimize permutation to reproduce a distance table with Hamming distances +struct ReproduceWithHammingObjective : PermutationObjective { + int nbits; + double dis_weight_factor; + + static double sqr (double x) { return x * x; } + + + // weihgting of distances: it is more important to reproduce small + // distances well + double dis_weight (double x) const + { + return exp (-dis_weight_factor * x); + } + + std::vector target_dis; // wanted distances (size n^2) + std::vector weights; // weights for each distance (size n^2) + + // cost = quadratic difference between actual distance and Hamming distance + double compute_cost(const int* perm) const override { + double cost = 0; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + double wanted = target_dis[i * n + j]; + double w = weights[i * n + j]; + double actual = hamming_dis(perm[i], perm[j]); + cost += w * sqr(wanted - actual); + } + } + return cost; + } + + + // what would the cost update be if iw and jw were swapped? + // computed in O(n) instead of O(n^2) for the full re-computation + double cost_update(const int* perm, int iw, int jw) const override { + double delta_cost = 0; + + for (int i = 0; i < n; i++) { + if (i == iw) { + for (int j = 0; j < n; j++) { + double wanted = target_dis[i * n + j], w = weights[i * n + j]; + double actual = hamming_dis(perm[i], perm[j]); + delta_cost -= w * sqr(wanted - actual); + double new_actual = + hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]); + delta_cost += w * sqr(wanted - new_actual); + } + } else if (i == jw) { + for (int j = 0; j < n; j++) { + double wanted = target_dis[i * n + j], w = weights[i * n + j]; + double actual = hamming_dis(perm[i], perm[j]); + delta_cost -= w * sqr(wanted - actual); + double new_actual = + hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]); + delta_cost += w * sqr(wanted - new_actual); + } + } else { + int j = iw; + { + double wanted = target_dis[i * n + j], w = weights[i * n + j]; + double actual = hamming_dis(perm[i], perm[j]); + delta_cost -= w * sqr(wanted - actual); + double new_actual = hamming_dis(perm[i], perm[jw]); + delta_cost += w * sqr(wanted - new_actual); + } + j = jw; + { + double wanted = target_dis[i * n + j], w = weights[i * n + j]; + double actual = hamming_dis(perm[i], perm[j]); + delta_cost -= w * sqr(wanted - actual); + double new_actual = hamming_dis(perm[i], perm[iw]); + delta_cost += w * sqr(wanted - new_actual); + } + } + } + + return delta_cost; + } + + + + ReproduceWithHammingObjective ( + int nbits, + const std::vector & dis_table, + double dis_weight_factor): + nbits (nbits), dis_weight_factor (dis_weight_factor) + { + n = 1 << nbits; + FAISS_THROW_IF_NOT (dis_table.size() == n * n); + set_affine_target_dis (dis_table); + } + + void set_affine_target_dis (const std::vector & dis_table) + { + double sum = 0, sum2 = 0; + int n2 = n * n; + for (int i = 0; i < n2; i++) { + sum += dis_table [i]; + sum2 += dis_table [i] * dis_table [i]; + } + double mean = sum / n2; + double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2)); + + target_dis.resize (n2); + + for (int i = 0; i < n2; i++) { + // the mapping function + double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) + + nbits / 2; + target_dis[i] = td; + // compute a weight + weights.push_back (dis_weight (td)); + } + + } + + ~ReproduceWithHammingObjective() override {} +}; + +} // anonymous namespace + +// weihgting of distances: it is more important to reproduce small +// distances well +double ReproduceDistancesObjective::dis_weight (double x) const +{ + return exp (-dis_weight_factor * x); +} + + +double ReproduceDistancesObjective::get_source_dis (int i, int j) const +{ + return source_dis [i * n + j]; +} + +// cost = quadratic difference between actual distance and Hamming distance +double ReproduceDistancesObjective::compute_cost (const int *perm) const +{ + double cost = 0; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + double wanted = target_dis [i * n + j]; + double w = weights [i * n + j]; + double actual = get_source_dis (perm[i], perm[j]); + cost += w * sqr (wanted - actual); + } + } + return cost; +} + +// what would the cost update be if iw and jw were swapped? +// computed in O(n) instead of O(n^2) for the full re-computation +double ReproduceDistancesObjective::cost_update( + const int *perm, int iw, int jw) const +{ + double delta_cost = 0; + for (int i = 0; i < n; i++) { + if (i == iw) { + for (int j = 0; j < n; j++) { + double wanted = target_dis [i * n + j], + w = weights [i * n + j]; + double actual = get_source_dis (perm[i], perm[j]); + delta_cost -= w * sqr (wanted - actual); + double new_actual = get_source_dis ( + perm[jw], + perm[j == iw ? jw : j == jw ? iw : j]); + delta_cost += w * sqr (wanted - new_actual); + } + } else if (i == jw) { + for (int j = 0; j < n; j++) { + double wanted = target_dis [i * n + j], + w = weights [i * n + j]; + double actual = get_source_dis (perm[i], perm[j]); + delta_cost -= w * sqr (wanted - actual); + double new_actual = get_source_dis ( + perm[iw], + perm[j == iw ? jw : j == jw ? iw : j]); + delta_cost += w * sqr (wanted - new_actual); + } + } else { + int j = iw; + { + double wanted = target_dis [i * n + j], + w = weights [i * n + j]; + double actual = get_source_dis (perm[i], perm[j]); + delta_cost -= w * sqr (wanted - actual); + double new_actual = get_source_dis (perm[i], perm[jw]); + delta_cost += w * sqr (wanted - new_actual); + } + j = jw; + { + double wanted = target_dis [i * n + j], + w = weights [i * n + j]; + double actual = get_source_dis (perm[i], perm[j]); + delta_cost -= w * sqr (wanted - actual); + double new_actual = get_source_dis (perm[i], perm[iw]); + delta_cost += w * sqr (wanted - new_actual); + } + } + } + return delta_cost; +} + + + +ReproduceDistancesObjective::ReproduceDistancesObjective ( + int n, + const double *source_dis_in, + const double *target_dis_in, + double dis_weight_factor): + dis_weight_factor (dis_weight_factor), + target_dis (target_dis_in) +{ + this->n = n; + set_affine_target_dis (source_dis_in); +} + +void ReproduceDistancesObjective::compute_mean_stdev ( + const double *tab, size_t n2, + double *mean_out, double *stddev_out) +{ + double sum = 0, sum2 = 0; + for (int i = 0; i < n2; i++) { + sum += tab [i]; + sum2 += tab [i] * tab [i]; + } + double mean = sum / n2; + double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2)); + *mean_out = mean; + *stddev_out = stddev; +} + +void ReproduceDistancesObjective::set_affine_target_dis ( + const double *source_dis_in) +{ + int n2 = n * n; + + double mean_src, stddev_src; + compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src); + + double mean_target, stddev_target; + compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target); + + printf ("map mean %g std %g -> mean %g std %g\n", + mean_src, stddev_src, mean_target, stddev_target); + + source_dis.resize (n2); + weights.resize (n2); + + for (int i = 0; i < n2; i++) { + // the mapping function + source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src + * stddev_target + mean_target; + + // compute a weight + weights [i] = dis_weight (target_dis[i]); + } + +} + +/**************************************************** + * Cost functions: RankingScore + ****************************************************/ + +/// Maintains a 3D table of elementary costs. +/// Accumulates elements based on Hamming distance comparisons +template +struct Score3Computer: PermutationObjective { + + int nc; + + // cost matrix of size nc * nc *nc + // n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+) + // where x has PQ code i, y- PQ code j and y+ PQ code k + std::vector n_gt; + + + /// the cost is a triple loop on the nc * nc * nc matrix of entries. + /// + Taccu compute (const int * perm) const + { + Taccu accu = 0; + const Ttab *p = n_gt.data(); + for (int i = 0; i < nc; i++) { + int ip = perm [i]; + for (int j = 0; j < nc; j++) { + int jp = perm [j]; + for (int k = 0; k < nc; k++) { + int kp = perm [k]; + if (hamming_dis (ip, jp) < + hamming_dis (ip, kp)) { + accu += *p; // n_gt [ ( i * nc + j) * nc + k]; + } + p++; + } + } + } + return accu; + } + + + /** cost update if entries iw and jw of the permutation would be + * swapped. + * + * The computation is optimized by avoiding elements in the + * nc*nc*nc cube that are known not to change. For nc=256, this + * reduces the nb of cells to visit to about 6/256 th of the + * cells. Practical speedup is about 8x, and the code is quite + * complex :-/ + */ + Taccu compute_update (const int *perm, int iw, int jw) const + { + assert (iw != jw); + if (iw > jw) std::swap (iw, jw); + + Taccu accu = 0; + const Ttab * n_gt_i = n_gt.data(); + for (int i = 0; i < nc; i++) { + int ip0 = perm [i]; + int ip = perm [i == iw ? jw : i == jw ? iw : i]; + + //accu += update_i (perm, iw, jw, ip0, ip, n_gt_i); + + accu += update_i_cross (perm, iw, jw, + ip0, ip, n_gt_i); + + if (ip != ip0) + accu += update_i_plane (perm, iw, jw, + ip0, ip, n_gt_i); + + n_gt_i += nc * nc; + } + + return accu; + } + + + Taccu update_i (const int *perm, int iw, int jw, + int ip0, int ip, const Ttab * n_gt_i) const + { + Taccu accu = 0; + const Ttab *n_gt_ij = n_gt_i; + for (int j = 0; j < nc; j++) { + int jp0 = perm[j]; + int jp = perm [j == iw ? jw : j == jw ? iw : j]; + for (int k = 0; k < nc; k++) { + int kp0 = perm [k]; + int kp = perm [k == iw ? jw : k == jw ? iw : k]; + int ng = n_gt_ij [k]; + if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) { + accu += ng; + } + if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) { + accu -= ng; + } + } + n_gt_ij += nc; + } + return accu; + } + + // 2 inner loops for the case ip0 != ip + Taccu update_i_plane (const int *perm, int iw, int jw, + int ip0, int ip, const Ttab * n_gt_i) const + { + Taccu accu = 0; + const Ttab *n_gt_ij = n_gt_i; + + for (int j = 0; j < nc; j++) { + if (j != iw && j != jw) { + int jp = perm[j]; + for (int k = 0; k < nc; k++) { + if (k != iw && k != jw) { + int kp = perm [k]; + Ttab ng = n_gt_ij [k]; + if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) { + accu += ng; + } + if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) { + accu -= ng; + } + } + } + } + n_gt_ij += nc; + } + return accu; + } + + /// used for the 8 cells were the 3 indices are swapped + inline Taccu update_k (const int *perm, int iw, int jw, + int ip0, int ip, int jp0, int jp, + int k, + const Ttab * n_gt_ij) const + { + Taccu accu = 0; + int kp0 = perm [k]; + int kp = perm [k == iw ? jw : k == jw ? iw : k]; + Ttab ng = n_gt_ij [k]; + if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) { + accu += ng; + } + if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) { + accu -= ng; + } + return accu; + } + + /// compute update on a line of k's, where i and j are swapped + Taccu update_j_line (const int *perm, int iw, int jw, + int ip0, int ip, int jp0, int jp, + const Ttab * n_gt_ij) const + { + Taccu accu = 0; + for (int k = 0; k < nc; k++) { + if (k == iw || k == jw) continue; + int kp = perm [k]; + Ttab ng = n_gt_ij [k]; + if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) { + accu += ng; + } + if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) { + accu -= ng; + } + } + return accu; + } + + + /// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw + Taccu update_i_cross (const int *perm, int iw, int jw, + int ip0, int ip, const Ttab * n_gt_i) const + { + Taccu accu = 0; + const Ttab *n_gt_ij = n_gt_i; + + for (int j = 0; j < nc; j++) { + int jp0 = perm[j]; + int jp = perm [j == iw ? jw : j == jw ? iw : j]; + + accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij); + accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij); + + if (jp != jp0) + accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij); + + n_gt_ij += nc; + } + return accu; + } + + + /// PermutationObjective implementeation (just negates the scores + /// for minimization) + + double compute_cost(const int* perm) const override { + return -compute(perm); + } + + double cost_update(const int* perm, int iw, int jw) const override { + double ret = -compute_update(perm, iw, jw); + return ret; + } + + ~Score3Computer() override {} +}; + + + + + +struct IndirectSort { + const float *tab; + bool operator () (int a, int b) {return tab[a] < tab[b]; } +}; + + + +struct RankingScore2: Score3Computer { + int nbits; + int nq, nb; + const uint32_t *qcodes, *bcodes; + const float *gt_distances; + + RankingScore2 (int nbits, int nq, int nb, + const uint32_t *qcodes, const uint32_t *bcodes, + const float *gt_distances): + nbits(nbits), nq(nq), nb(nb), qcodes(qcodes), + bcodes(bcodes), gt_distances(gt_distances) + { + n = nc = 1 << nbits; + n_gt.resize (nc * nc * nc); + init_n_gt (); + } + + + double rank_weight (int r) + { + return 1.0 / (r + 1); + } + + /// count nb of i, j in a x b st. i < j + /// a and b should be sorted on input + /// they are the ranks of j and k respectively. + /// specific version for diff-of-rank weighting, cannot optimized + /// with a cumulative table + double accum_gt_weight_diff (const std::vector & a, + const std::vector & b) + { + int nb = b.size(), na = a.size(); + + double accu = 0; + int j = 0; + for (int i = 0; i < na; i++) { + int ai = a[i]; + while (j < nb && ai >= b[j]) j++; + + double accu_i = 0; + for (int k = j; k < b.size(); k++) + accu_i += rank_weight (b[k] - ai); + + accu += rank_weight (ai) * accu_i; + + } + return accu; + } + + void init_n_gt () + { + for (int q = 0; q < nq; q++) { + const float *gtd = gt_distances + q * nb; + const uint32_t *cb = bcodes;// all same codes + float * n_gt_q = & n_gt [qcodes[q] * nc * nc]; + + printf("init gt for q=%d/%d \r", q, nq); fflush(stdout); + + std::vector rankv (nb); + int * ranks = rankv.data(); + + // elements in each code bin, ordered by rank within each bin + std::vector > tab (nc); + + { // build rank table + IndirectSort s = {gtd}; + for (int j = 0; j < nb; j++) ranks[j] = j; + std::sort (ranks, ranks + nb, s); + } + + for (int rank = 0; rank < nb; rank++) { + int i = ranks [rank]; + tab [cb[i]].push_back (rank); + } + + + // this is very expensive. Any suggestion for improvement + // welcome. + for (int i = 0; i < nc; i++) { + std::vector & di = tab[i]; + for (int j = 0; j < nc; j++) { + std::vector & dj = tab[j]; + n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj); + + } + } + + } + + } + +}; + + +/***************************************** + * PolysemousTraining + ******************************************/ + + + +PolysemousTraining::PolysemousTraining () +{ + optimization_type = OT_ReproduceDistances_affine; + ntrain_permutation = 0; + dis_weight_factor = log(2); +} + + + +void PolysemousTraining::optimize_reproduce_distances ( + ProductQuantizer &pq) const +{ + + int dsub = pq.dsub; + + int n = pq.ksub; + int nbits = pq.nbits; + +#pragma omp parallel for + for (int m = 0; m < pq.M; m++) { + std::vector dis_table; + + // printf ("Optimizing quantizer %d\n", m); + + float * centroids = pq.get_centroids (m, 0); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + dis_table.push_back (fvec_L2sqr (centroids + i * dsub, + centroids + j * dsub, + dsub)); + } + } + + std::vector perm (n); + ReproduceWithHammingObjective obj ( + nbits, dis_table, + dis_weight_factor); + + + SimulatedAnnealingOptimizer optim (&obj, *this); + + if (log_pattern.size()) { + char fname[256]; + snprintf (fname, 256, log_pattern.c_str(), m); + printf ("opening log file %s\n", fname); + optim.logfile = fopen (fname, "w"); + FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile"); + } + double final_cost = optim.run_optimization (perm.data()); + + if (verbose > 0) { + printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n", + m, optim.init_cost, final_cost); + } + + if (log_pattern.size()) fclose (optim.logfile); + + std::vector centroids_copy; + for (int i = 0; i < dsub * n; i++) + centroids_copy.push_back (centroids[i]); + + for (int i = 0; i < n; i++) + memcpy (centroids + perm[i] * dsub, + centroids_copy.data() + i * dsub, + dsub * sizeof(centroids[0])); + + } + +} + + +void PolysemousTraining::optimize_ranking ( + ProductQuantizer &pq, size_t n, const float *x) const +{ + + int dsub = pq.dsub; + + int nbits = pq.nbits; + + std::vector all_codes (pq.code_size * n); + + pq.compute_codes (x, all_codes.data(), n); + + FAISS_THROW_IF_NOT (pq.nbits == 8); + + if (n == 0) + pq.compute_sdc_table (); + +#pragma omp parallel for + for (int m = 0; m < pq.M; m++) { + size_t nq, nb; + std::vector codes; // query codes, then db codes + std::vector gt_distances; // nq * nb matrix of distances + + if (n > 0) { + std::vector xtrain (n * dsub); + for (int i = 0; i < n; i++) + memcpy (xtrain.data() + i * dsub, + x + i * pq.d + m * dsub, + sizeof(float) * dsub); + + codes.resize (n); + for (int i = 0; i < n; i++) + codes [i] = all_codes [i * pq.code_size + m]; + + nq = n / 4; nb = n - nq; + const float *xq = xtrain.data(); + const float *xb = xq + nq * dsub; + + gt_distances.resize (nq * nb); + + pairwise_L2sqr (dsub, + nq, xq, + nb, xb, + gt_distances.data()); + } else { + nq = nb = pq.ksub; + codes.resize (2 * nq); + for (int i = 0; i < nq; i++) + codes[i] = codes [i + nq] = i; + + gt_distances.resize (nq * nb); + + memcpy (gt_distances.data (), + pq.sdc_table.data () + m * nq * nb, + sizeof (float) * nq * nb); + } + + double t0 = getmillisecs (); + + PermutationObjective *obj = new RankingScore2 ( + nbits, nq, nb, + codes.data(), codes.data() + nq, + gt_distances.data ()); + ScopeDeleter1 del (obj); + + if (verbose > 0) { + printf(" m=%d, nq=%ld, nb=%ld, intialize RankingScore " + "in %.3f ms\n", + m, nq, nb, getmillisecs () - t0); + } + + SimulatedAnnealingOptimizer optim (obj, *this); + + if (log_pattern.size()) { + char fname[256]; + snprintf (fname, 256, log_pattern.c_str(), m); + printf ("opening log file %s\n", fname); + optim.logfile = fopen (fname, "w"); + FAISS_THROW_IF_NOT_FMT (optim.logfile, + "could not open logfile %s", fname); + } + + std::vector perm (pq.ksub); + + double final_cost = optim.run_optimization (perm.data()); + printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n", + m, optim.init_cost, final_cost); + + if (log_pattern.size()) fclose (optim.logfile); + + float * centroids = pq.get_centroids (m, 0); + + std::vector centroids_copy; + for (int i = 0; i < dsub * pq.ksub; i++) + centroids_copy.push_back (centroids[i]); + + for (int i = 0; i < pq.ksub; i++) + memcpy (centroids + perm[i] * dsub, + centroids_copy.data() + i * dsub, + dsub * sizeof(centroids[0])); + + } + +} + + + +void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq, + size_t n, const float *x) const +{ + if (optimization_type == OT_None) { + + } else if (optimization_type == OT_ReproduceDistances_affine) { + optimize_reproduce_distances (pq); + } else { + optimize_ranking (pq, n, x); + } + + pq.compute_sdc_table (); + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/PolysemousTraining.h b/core/src/index/thirdparty/faiss/impl/PolysemousTraining.h new file mode 100644 index 0000000000..c27a48c999 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/PolysemousTraining.h @@ -0,0 +1,158 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED +#define FAISS_POLYSEMOUS_TRAINING_INCLUDED + + +#include + + +namespace faiss { + + +/// parameters used for the simulated annealing method +struct SimulatedAnnealingParameters { + + // optimization parameters + double init_temperature; // init probaility of accepting a bad swap + double temperature_decay; // at each iteration the temp is multiplied by this + int n_iter; // nb of iterations + int n_redo; // nb of runs of the simulation + int seed; // random seed + int verbose; + bool only_bit_flips; // restrict permutation changes to bit flips + bool init_random; // intialize with a random permutation (not identity) + + // set reasonable defaults + SimulatedAnnealingParameters (); + +}; + + +/// abstract class for the loss function +struct PermutationObjective { + + int n; + + virtual double compute_cost (const int *perm) const = 0; + + // what would the cost update be if iw and jw were swapped? + // default implementation just computes both and computes the difference + virtual double cost_update (const int *perm, int iw, int jw) const; + + virtual ~PermutationObjective () {} +}; + + +struct ReproduceDistancesObjective : PermutationObjective { + + double dis_weight_factor; + + static double sqr (double x) { return x * x; } + + // weihgting of distances: it is more important to reproduce small + // distances well + double dis_weight (double x) const; + + std::vector source_dis; ///< "real" corrected distances (size n^2) + const double * target_dis; ///< wanted distances (size n^2) + std::vector weights; ///< weights for each distance (size n^2) + + double get_source_dis (int i, int j) const; + + // cost = quadratic difference between actual distance and Hamming distance + double compute_cost(const int* perm) const override; + + // what would the cost update be if iw and jw were swapped? + // computed in O(n) instead of O(n^2) for the full re-computation + double cost_update(const int* perm, int iw, int jw) const override; + + ReproduceDistancesObjective ( + int n, + const double *source_dis_in, + const double *target_dis_in, + double dis_weight_factor); + + static void compute_mean_stdev (const double *tab, size_t n2, + double *mean_out, double *stddev_out); + + void set_affine_target_dis (const double *source_dis_in); + + ~ReproduceDistancesObjective() override {} +}; + +struct RandomGenerator; + +/// Simulated annealing optimization algorithm for permutations. + struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters { + + PermutationObjective *obj; + int n; ///< size of the permutation + FILE *logfile; /// logs values of the cost function + + SimulatedAnnealingOptimizer (PermutationObjective *obj, + const SimulatedAnnealingParameters &p); + RandomGenerator *rnd; + + /// remember intial cost of optimization + double init_cost; + + // main entry point. Perform the optimization loop, starting from + // and modifying permutation in-place + double optimize (int *perm); + + // run the optimization and return the best result in best_perm + double run_optimization (int * best_perm); + + virtual ~SimulatedAnnealingOptimizer (); +}; + + + + +/// optimizes the order of indices in a ProductQuantizer +struct PolysemousTraining: SimulatedAnnealingParameters { + + enum Optimization_type_t { + OT_None, + OT_ReproduceDistances_affine, ///< default + OT_Ranking_weighted_diff ///< same as _2, but use rank of y+ - rank of y- + }; + Optimization_type_t optimization_type; + + /** use 1/4 of the training points for the optimization, with + * max. ntrain_permutation. If ntrain_permutation == 0: train on + * centroids */ + int ntrain_permutation; + double dis_weight_factor; ///< decay of exp that weights distance loss + + // filename pattern for the logging of iterations + std::string log_pattern; + + // sets default values + PolysemousTraining (); + + /// reorder the centroids so that the Hamming distace becomes a + /// good approximation of the SDC distance (called by train) + void optimize_pq_for_hamming (ProductQuantizer & pq, + size_t n, const float *x) const; + + /// called by optimize_pq_for_hamming + void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const; + /// called by optimize_pq_for_hamming + void optimize_reproduce_distances (ProductQuantizer &pq) const; + +}; + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/impl/ProductQuantizer-inl.h b/core/src/index/thirdparty/faiss/impl/ProductQuantizer-inl.h new file mode 100644 index 0000000000..01937dca9f --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ProductQuantizer-inl.h @@ -0,0 +1,138 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +namespace faiss { + +inline +PQEncoderGeneric::PQEncoderGeneric(uint8_t *code, int nbits, + uint8_t offset) + : code(code), offset(offset), nbits(nbits), reg(0) +{ + assert(nbits <= 64); + if (offset > 0) { + reg = (*code & ((1 << offset) - 1)); + } +} + +inline +void PQEncoderGeneric::encode(uint64_t x) +{ + reg |= (uint8_t)(x << offset); + x >>= (8 - offset); + if (offset + nbits >= 8) { + *code++ = reg; + + for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) { + *code++ = (uint8_t)x; + x >>= 8; + } + + offset += nbits; + offset &= 7; + reg = (uint8_t)x; + } else { + offset += nbits; + } +} + +inline +PQEncoderGeneric::~PQEncoderGeneric() +{ + if (offset > 0) { + *code = reg; + } +} + + +inline +PQEncoder8::PQEncoder8(uint8_t *code, int nbits) + : code(code) { + assert(8 == nbits); +} + +inline +void PQEncoder8::encode(uint64_t x) { + *code++ = (uint8_t)x; +} + +inline +PQEncoder16::PQEncoder16(uint8_t *code, int nbits) + : code((uint16_t *)code) { + assert(16 == nbits); +} + +inline +void PQEncoder16::encode(uint64_t x) { + *code++ = (uint16_t)x; +} + + +inline +PQDecoderGeneric::PQDecoderGeneric(const uint8_t *code, + int nbits) + : code(code), + offset(0), + nbits(nbits), + mask((1ull << nbits) - 1), + reg(0) { + assert(nbits <= 64); +} + +inline +uint64_t PQDecoderGeneric::decode() { + if (offset == 0) { + reg = *code; + } + uint64_t c = (reg >> offset); + + if (offset + nbits >= 8) { + uint64_t e = 8 - offset; + ++code; + for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) { + c |= ((uint64_t)(*code++) << e); + e += 8; + } + + offset += nbits; + offset &= 7; + if (offset > 0) { + reg = *code; + c |= ((uint64_t)reg << e); + } + } else { + offset += nbits; + } + + return c & mask; +} + + +inline +PQDecoder8::PQDecoder8(const uint8_t *code, int nbits) + : code(code) { + assert(8 == nbits); +} + +inline +uint64_t PQDecoder8::decode() { + return (uint64_t)(*code++); +} + + +inline +PQDecoder16::PQDecoder16(const uint8_t *code, int nbits) + : code((uint16_t *)code) { + assert(16 == nbits); +} + +inline +uint64_t PQDecoder16::decode() { + return (uint64_t)(*code++); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp new file mode 100644 index 0000000000..a9658af46a --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.cpp @@ -0,0 +1,759 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +} + + +namespace faiss { + + +/* compute an estimator using look-up tables for typical values of M */ +template +void pq_estimators_from_tables_Mmul4 (int M, const CT * codes, + size_t ncodes, + const float * __restrict dis_table, + size_t ksub, + size_t k, + float * heap_dis, + int64_t * heap_ids) +{ + + for (size_t j = 0; j < ncodes; j++) { + float dis = 0; + const float *dt = dis_table; + + for (size_t m = 0; m < M; m+=4) { + float dism = 0; + dism = dt[*codes++]; dt += ksub; + dism += dt[*codes++]; dt += ksub; + dism += dt[*codes++]; dt += ksub; + dism += dt[*codes++]; dt += ksub; + dis += dism; + } + + if (C::cmp (heap_dis[0], dis)) { + heap_swap_top (k, heap_dis, heap_ids, dis, j); + } + } +} + + +template +void pq_estimators_from_tables_M4 (const CT * codes, + size_t ncodes, + const float * __restrict dis_table, + size_t ksub, + size_t k, + float * heap_dis, + int64_t * heap_ids) +{ + + for (size_t j = 0; j < ncodes; j++) { + float dis = 0; + const float *dt = dis_table; + dis = dt[*codes++]; dt += ksub; + dis += dt[*codes++]; dt += ksub; + dis += dt[*codes++]; dt += ksub; + dis += dt[*codes++]; + + if (C::cmp (heap_dis[0], dis)) { + heap_swap_top (k, heap_dis, heap_ids, dis, j); + } + } +} + + +template +static inline void pq_estimators_from_tables (const ProductQuantizer& pq, + const CT * codes, + size_t ncodes, + const float * dis_table, + size_t k, + float * heap_dis, + int64_t * heap_ids) +{ + + if (pq.M == 4) { + + pq_estimators_from_tables_M4 (codes, ncodes, + dis_table, pq.ksub, k, + heap_dis, heap_ids); + return; + } + + if (pq.M % 4 == 0) { + pq_estimators_from_tables_Mmul4 (pq.M, codes, ncodes, + dis_table, pq.ksub, k, + heap_dis, heap_ids); + return; + } + + /* Default is relatively slow */ + const size_t M = pq.M; + const size_t ksub = pq.ksub; + for (size_t j = 0; j < ncodes; j++) { + float dis = 0; + const float * __restrict dt = dis_table; + for (int m = 0; m < M; m++) { + dis += dt[*codes++]; + dt += ksub; + } + if (C::cmp (heap_dis[0], dis)) { + heap_swap_top (k, heap_dis, heap_ids, dis, j); + } + } +} + +template +static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq, + size_t nbits, + const uint8_t *codes, + size_t ncodes, + const float *dis_table, + size_t k, + float *heap_dis, + int64_t *heap_ids) +{ + const size_t M = pq.M; + const size_t ksub = pq.ksub; + for (size_t j = 0; j < ncodes; ++j) { + PQDecoderGeneric decoder( + codes + j * pq.code_size, nbits + ); + float dis = 0; + const float * __restrict dt = dis_table; + for (size_t m = 0; m < M; m++) { + uint64_t c = decoder.decode(); + dis += dt[c]; + dt += ksub; + } + + if (C::cmp(heap_dis[0], dis)) { + heap_swap_top(k, heap_dis, heap_ids, dis, j); + } + } +} + +/********************************************* + * PQ implementation + *********************************************/ + + + +ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits): + d(d), M(M), nbits(nbits), assign_index(nullptr) +{ + set_derived_values (); +} + +ProductQuantizer::ProductQuantizer () + : ProductQuantizer(0, 1, 0) {} + +void ProductQuantizer::set_derived_values () { + // quite a few derived values + FAISS_THROW_IF_NOT (d % M == 0); + dsub = d / M; + code_size = (nbits * M + 7) / 8; + ksub = 1 << nbits; + centroids.resize (d * ksub); + verbose = false; + train_type = Train_default; +} + +void ProductQuantizer::set_params (const float * centroids_, int m) +{ + memcpy (get_centroids(m, 0), centroids_, + ksub * dsub * sizeof (centroids_[0])); +} + + +static void init_hypercube (int d, int nbits, + int n, const float * x, + float *centroids) +{ + + std::vector mean (d); + for (int i = 0; i < n; i++) + for (int j = 0; j < d; j++) + mean [j] += x[i * d + j]; + + float maxm = 0; + for (int j = 0; j < d; j++) { + mean [j] /= n; + if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]); + } + + for (int i = 0; i < (1 << nbits); i++) { + float * cent = centroids + i * d; + for (int j = 0; j < nbits; j++) + cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm; + for (int j = nbits; j < d; j++) + cent[j] = mean [j]; + } + + +} + +static void init_hypercube_pca (int d, int nbits, + int n, const float * x, + float *centroids) +{ + PCAMatrix pca (d, nbits); + pca.train (n, x); + + + for (int i = 0; i < (1 << nbits); i++) { + float * cent = centroids + i * d; + for (int j = 0; j < d; j++) { + cent[j] = pca.mean[j]; + float f = 1.0; + for (int k = 0; k < nbits; k++) + cent[j] += f * + sqrt (pca.eigenvalues [k]) * + (((i >> k) & 1) ? 1 : -1) * + pca.PCAMat [j + k * d]; + } + } + +} + +void ProductQuantizer::train (int n, const float * x) +{ + if (train_type != Train_shared) { + train_type_t final_train_type; + final_train_type = train_type; + if (train_type == Train_hypercube || + train_type == Train_hypercube_pca) { + if (dsub < nbits) { + final_train_type = Train_default; + printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n", + nbits, dsub); + } + } + + float * xslice = new float[n * dsub]; + ScopeDeleter del (xslice); + for (int m = 0; m < M; m++) { + for (int j = 0; j < n; j++) + memcpy (xslice + j * dsub, + x + j * d + m * dsub, + dsub * sizeof(float)); + + Clustering clus (dsub, ksub, cp); + + // we have some initialization for the centroids + if (final_train_type != Train_default) { + clus.centroids.resize (dsub * ksub); + } + + switch (final_train_type) { + case Train_hypercube: + init_hypercube (dsub, nbits, n, xslice, + clus.centroids.data ()); + break; + case Train_hypercube_pca: + init_hypercube_pca (dsub, nbits, n, xslice, + clus.centroids.data ()); + break; + case Train_hot_start: + memcpy (clus.centroids.data(), + get_centroids (m, 0), + dsub * ksub * sizeof (float)); + break; + default: ; + } + + if(verbose) { + clus.verbose = true; + printf ("Training PQ slice %d/%zd\n", m, M); + } + IndexFlatL2 index (dsub); + clus.train (n, xslice, assign_index ? *assign_index : index); + set_params (clus.centroids.data(), m); + } + + + } else { + + Clustering clus (dsub, ksub, cp); + + if(verbose) { + clus.verbose = true; + printf ("Training all PQ slices at once\n"); + } + + IndexFlatL2 index (dsub); + + clus.train (n * M, x, assign_index ? *assign_index : index); + for (int m = 0; m < M; m++) { + set_params (clus.centroids.data(), m); + } + + } +} + +template +void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) { + float distances [pq.ksub]; + PQEncoder encoder(code, pq.nbits); + for (size_t m = 0; m < pq.M; m++) { + float mindis = 1e20; + uint64_t idxm = 0; + const float * xsub = x + m * pq.dsub; + + fvec_L2sqr_ny(distances, xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub); + + /* Find best centroid */ + for (size_t i = 0; i < pq.ksub; i++) { + float dis = distances[i]; + if (dis < mindis) { + mindis = dis; + idxm = i; + } + } + + encoder.encode(idxm); + } +} + +void ProductQuantizer::compute_code(const float * x, uint8_t * code) const { + switch (nbits) { + case 8: + faiss::compute_code(*this, x, code); + break; + + case 16: + faiss::compute_code(*this, x, code); + break; + + default: + faiss::compute_code(*this, x, code); + break; + } +} + +template +void decode(const ProductQuantizer& pq, const uint8_t *code, float *x) +{ + PQDecoder decoder(code, pq.nbits); + for (size_t m = 0; m < pq.M; m++) { + uint64_t c = decoder.decode(); + memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub); + } +} + +void ProductQuantizer::decode (const uint8_t *code, float *x) const +{ + switch (nbits) { + case 8: + faiss::decode(*this, code, x); + break; + + case 16: + faiss::decode(*this, code, x); + break; + + default: + faiss::decode(*this, code, x); + break; + } +} + + +void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const +{ + for (size_t i = 0; i < n; i++) { + this->decode (code + code_size * i, x + d * i); + } +} + + +void ProductQuantizer::compute_code_from_distance_table (const float *tab, + uint8_t *code) const +{ + PQEncoderGeneric encoder(code, nbits); + for (size_t m = 0; m < M; m++) { + float mindis = 1e20; + uint64_t idxm = 0; + + /* Find best centroid */ + for (size_t j = 0; j < ksub; j++) { + float dis = *tab++; + if (dis < mindis) { + mindis = dis; + idxm = j; + } + } + + encoder.encode(idxm); + } +} + +void ProductQuantizer::compute_codes_with_assign_index ( + const float * x, + uint8_t * codes, + size_t n) +{ + FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub); + + for (size_t m = 0; m < M; m++) { + assign_index->reset (); + assign_index->add (ksub, get_centroids (m, 0)); + size_t bs = 65536; + float * xslice = new float[bs * dsub]; + ScopeDeleter del (xslice); + idx_t *assign = new idx_t[bs]; + ScopeDeleter del2 (assign); + + for (size_t i0 = 0; i0 < n; i0 += bs) { + size_t i1 = std::min(i0 + bs, n); + + for (size_t i = i0; i < i1; i++) { + memcpy (xslice + (i - i0) * dsub, + x + i * d + m * dsub, + dsub * sizeof(float)); + } + + assign_index->assign (i1 - i0, xslice, assign); + + if (nbits == 8) { + uint8_t *c = codes + code_size * i0 + m; + for (size_t i = i0; i < i1; i++) { + *c = assign[i - i0]; + c += M; + } + } else if (nbits == 16) { + uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2); + for (size_t i = i0; i < i1; i++) { + *c = assign[i - i0]; + c += M; + } + } else { + for (size_t i = i0; i < i1; ++i) { + uint8_t *c = codes + code_size * i + ((m * nbits) / 8); + uint8_t offset = (m * nbits) % 8; + uint64_t ass = assign[i - i0]; + + PQEncoderGeneric encoder(c, nbits, offset); + encoder.encode(ass); + } + } + + } + } + +} + +void ProductQuantizer::compute_codes (const float * x, + uint8_t * codes, + size_t n) const +{ + // process by blocks to avoid using too much RAM + size_t bs = 256 * 1024; + if (n > bs) { + for (size_t i0 = 0; i0 < n; i0 += bs) { + size_t i1 = std::min(i0 + bs, n); + compute_codes (x + d * i0, codes + code_size * i0, i1 - i0); + } + return; + } + + if (dsub < 16) { // simple direct computation + +#pragma omp parallel for + for (size_t i = 0; i < n; i++) + compute_code (x + i * d, codes + i * code_size); + + } else { // worthwile to use BLAS + float *dis_tables = new float [n * ksub * M]; + ScopeDeleter del (dis_tables); + compute_distance_tables (n, x, dis_tables); + +#pragma omp parallel for + for (size_t i = 0; i < n; i++) { + uint8_t * code = codes + i * code_size; + const float * tab = dis_tables + i * ksub * M; + compute_code_from_distance_table (tab, code); + } + } +} + + +void ProductQuantizer::compute_distance_table (const float * x, + float * dis_table) const +{ + size_t m; + + for (m = 0; m < M; m++) { + fvec_L2sqr_ny (dis_table + m * ksub, + x + m * dsub, + get_centroids(m, 0), + dsub, + ksub); + } +} + +void ProductQuantizer::compute_inner_prod_table (const float * x, + float * dis_table) const +{ + size_t m; + + for (m = 0; m < M; m++) { + fvec_inner_products_ny (dis_table + m * ksub, + x + m * dsub, + get_centroids(m, 0), + dsub, + ksub); + } +} + + +void ProductQuantizer::compute_distance_tables ( + size_t nx, + const float * x, + float * dis_tables) const +{ + + if (dsub < 16) { + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + compute_distance_table (x + i * d, dis_tables + i * ksub * M); + } + + } else { // use BLAS + + for (int m = 0; m < M; m++) { + pairwise_L2sqr (dsub, + nx, x + dsub * m, + ksub, centroids.data() + m * dsub * ksub, + dis_tables + ksub * m, + d, dsub, ksub * M); + } + } +} + +void ProductQuantizer::compute_inner_prod_tables ( + size_t nx, + const float * x, + float * dis_tables) const +{ + + if (dsub < 16) { + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M); + } + + } else { // use BLAS + + // compute distance tables + for (int m = 0; m < M; m++) { + FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, + dsubi = dsub, di = d; + float one = 1.0, zero = 0; + + sgemm_ ("Transposed", "Not transposed", + &ksubi, &nxi, &dsubi, + &one, ¢roids [m * dsub * ksub], &dsubi, + x + dsub * m, &di, + &zero, dis_tables + ksub * m, &ldc); + } + + } +} + +template +static void pq_knn_search_with_tables ( + const ProductQuantizer& pq, + size_t nbits, + const float *dis_tables, + const uint8_t * codes, + const size_t ncodes, + HeapArray * res, + bool init_finalize_heap) +{ + size_t k = res->k, nx = res->nh; + size_t ksub = pq.ksub, M = pq.M; + + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + /* query preparation for asymmetric search: compute look-up tables */ + const float* dis_table = dis_tables + i * ksub * M; + + /* Compute distances and keep smallest values */ + int64_t * __restrict heap_ids = res->ids + i * k; + float * __restrict heap_dis = res->val + i * k; + + if (init_finalize_heap) { + heap_heapify (k, heap_dis, heap_ids); + } + + switch (nbits) { + case 8: + pq_estimators_from_tables (pq, + codes, ncodes, + dis_table, + k, heap_dis, heap_ids); + break; + + case 16: + pq_estimators_from_tables (pq, + (uint16_t*)codes, ncodes, + dis_table, + k, heap_dis, heap_ids); + break; + + default: + pq_estimators_from_tables_generic (pq, + nbits, + codes, ncodes, + dis_table, + k, heap_dis, heap_ids); + break; + } + + if (init_finalize_heap) { + heap_reorder (k, heap_dis, heap_ids); + } + } +} + +void ProductQuantizer::search (const float * __restrict x, + size_t nx, + const uint8_t * codes, + const size_t ncodes, + float_maxheap_array_t * res, + bool init_finalize_heap) const +{ + FAISS_THROW_IF_NOT (nx == res->nh); + std::unique_ptr dis_tables(new float [nx * ksub * M]); + compute_distance_tables (nx, x, dis_tables.get()); + + pq_knn_search_with_tables> ( + *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap); +} + +void ProductQuantizer::search_ip (const float * __restrict x, + size_t nx, + const uint8_t * codes, + const size_t ncodes, + float_minheap_array_t * res, + bool init_finalize_heap) const +{ + FAISS_THROW_IF_NOT (nx == res->nh); + std::unique_ptr dis_tables(new float [nx * ksub * M]); + compute_inner_prod_tables (nx, x, dis_tables.get()); + + pq_knn_search_with_tables > ( + *this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap); +} + + + +static float sqr (float x) { + return x * x; +} + +void ProductQuantizer::compute_sdc_table () +{ + sdc_table.resize (M * ksub * ksub); + + for (int m = 0; m < M; m++) { + + const float *cents = centroids.data() + m * ksub * dsub; + float * dis_tab = sdc_table.data() + m * ksub * ksub; + + // TODO optimize with BLAS + for (int i = 0; i < ksub; i++) { + const float *centi = cents + i * dsub; + for (int j = 0; j < ksub; j++) { + float accu = 0; + const float *centj = cents + j * dsub; + for (int k = 0; k < dsub; k++) + accu += sqr (centi[k] - centj[k]); + dis_tab [i + j * ksub] = accu; + } + } + } +} + +void ProductQuantizer::search_sdc (const uint8_t * qcodes, + size_t nq, + const uint8_t * bcodes, + const size_t nb, + float_maxheap_array_t * res, + bool init_finalize_heap) const +{ + FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub); + FAISS_THROW_IF_NOT (nbits == 8); + size_t k = res->k; + + +#pragma omp parallel for + for (size_t i = 0; i < nq; i++) { + + /* Compute distances and keep smallest values */ + idx_t * heap_ids = res->ids + i * k; + float * heap_dis = res->val + i * k; + const uint8_t * qcode = qcodes + i * code_size; + + if (init_finalize_heap) + maxheap_heapify (k, heap_dis, heap_ids); + + const uint8_t * bcode = bcodes; + for (size_t j = 0; j < nb; j++) { + float dis = 0; + const float * tab = sdc_table.data(); + for (int m = 0; m < M; m++) { + dis += tab[bcode[m] + qcode[m] * ksub]; + tab += ksub * ksub; + } + if (dis < heap_dis[0]) { + maxheap_swap_top (k, heap_dis, heap_ids, dis, j); + } + bcode += code_size; + } + + if (init_finalize_heap) + maxheap_reorder (k, heap_dis, heap_ids); + } + +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h new file mode 100644 index 0000000000..6364be4eae --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ProductQuantizer.h @@ -0,0 +1,238 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_PRODUCT_QUANTIZER_H +#define FAISS_PRODUCT_QUANTIZER_H + +#include + +#include + +#include +#include + +namespace faiss { + +/** Product Quantizer. Implemented only for METRIC_L2 */ +struct ProductQuantizer { + + using idx_t = Index::idx_t; + + size_t d; ///< size of the input vectors + size_t M; ///< number of subquantizers + size_t nbits; ///< number of bits per quantization index + + // values derived from the above + size_t dsub; ///< dimensionality of each subvector + size_t code_size; ///< bytes per indexed vector + size_t ksub; ///< number of centroids for each subquantizer + bool verbose; ///< verbose during training? + + /// initialization + enum train_type_t { + Train_default, + Train_hot_start, ///< the centroids are already initialized + Train_shared, ///< share dictionary accross PQ segments + Train_hypercube, ///< intialize centroids with nbits-D hypercube + Train_hypercube_pca, ///< intialize centroids with nbits-D hypercube + }; + train_type_t train_type; + + ClusteringParameters cp; ///< parameters used during clustering + + /// if non-NULL, use this index for assignment (should be of size + /// d / M) + Index *assign_index; + + /// Centroid table, size M * ksub * dsub + std::vector centroids; + + /// return the centroids associated with subvector m + float * get_centroids (size_t m, size_t i) { + return ¢roids [(m * ksub + i) * dsub]; + } + const float * get_centroids (size_t m, size_t i) const { + return ¢roids [(m * ksub + i) * dsub]; + } + + // Train the product quantizer on a set of points. A clustering + // can be set on input to define non-default clustering parameters + void train (int n, const float *x); + + ProductQuantizer(size_t d, /* dimensionality of the input vectors */ + size_t M, /* number of subquantizers */ + size_t nbits); /* number of bit per subvector index */ + + ProductQuantizer (); + + /// compute derived values when d, M and nbits have been set + void set_derived_values (); + + /// Define the centroids for subquantizer m + void set_params (const float * centroids, int m); + + /// Quantize one vector with the product quantizer + void compute_code (const float * x, uint8_t * code) const ; + + /// same as compute_code for several vectors + void compute_codes (const float * x, + uint8_t * codes, + size_t n) const ; + + /// speed up code assignment using assign_index + /// (non-const because the index is changed) + void compute_codes_with_assign_index ( + const float * x, + uint8_t * codes, + size_t n); + + /// decode a vector from a given code (or n vectors if third argument) + void decode (const uint8_t *code, float *x) const; + void decode (const uint8_t *code, float *x, size_t n) const; + + /// If we happen to have the distance tables precomputed, this is + /// more efficient to compute the codes. + void compute_code_from_distance_table (const float *tab, + uint8_t *code) const; + + + /** Compute distance table for one vector. + * + * The distance table for x = [x_0 x_1 .. x_(M-1)] is a M * ksub + * matrix that contains + * + * dis_table (m, j) = || x_m - c_(m, j)||^2 + * for m = 0..M-1 and j = 0 .. ksub - 1 + * + * where c_(m, j) is the centroid no j of sub-quantizer m. + * + * @param x input vector size d + * @param dis_table output table, size M * ksub + */ + void compute_distance_table (const float * x, + float * dis_table) const; + + void compute_inner_prod_table (const float * x, + float * dis_table) const; + + + /** compute distance table for several vectors + * @param nx nb of input vectors + * @param x input vector size nx * d + * @param dis_table output table, size nx * M * ksub + */ + void compute_distance_tables (size_t nx, + const float * x, + float * dis_tables) const; + + void compute_inner_prod_tables (size_t nx, + const float * x, + float * dis_tables) const; + + + /** perform a search (L2 distance) + * @param x query vectors, size nx * d + * @param nx nb of queries + * @param codes database codes, size ncodes * code_size + * @param ncodes nb of nb vectors + * @param res heap array to store results (nh == nx) + * @param init_finalize_heap initialize heap (input) and sort (output)? + */ + void search (const float * x, + size_t nx, + const uint8_t * codes, + const size_t ncodes, + float_maxheap_array_t *res, + bool init_finalize_heap = true) const; + + /** same as search, but with inner product similarity */ + void search_ip (const float * x, + size_t nx, + const uint8_t * codes, + const size_t ncodes, + float_minheap_array_t *res, + bool init_finalize_heap = true) const; + + + /// Symmetric Distance Table + std::vector sdc_table; + + // intitialize the SDC table from the centroids + void compute_sdc_table (); + + void search_sdc (const uint8_t * qcodes, + size_t nq, + const uint8_t * bcodes, + const size_t ncodes, + float_maxheap_array_t * res, + bool init_finalize_heap = true) const; + + size_t cal_size() { return sizeof(*this) + centroids.size() * sizeof(float); } +}; + + +/************************************************* + * Objects to encode / decode strings of bits + *************************************************/ + +struct PQEncoderGeneric { + uint8_t *code; ///< code for this vector + uint8_t offset; + const int nbits; ///< number of bits per subquantizer index + + uint8_t reg; + + PQEncoderGeneric(uint8_t *code, int nbits, uint8_t offset = 0); + + void encode(uint64_t x); + + ~PQEncoderGeneric(); +}; + + +struct PQEncoder8 { + uint8_t *code; + PQEncoder8(uint8_t *code, int nbits); + void encode(uint64_t x); +}; + +struct PQEncoder16 { + uint16_t *code; + PQEncoder16(uint8_t *code, int nbits); + void encode(uint64_t x); +}; + + +struct PQDecoderGeneric { + const uint8_t *code; + uint8_t offset; + const int nbits; + const uint64_t mask; + uint8_t reg; + PQDecoderGeneric(const uint8_t *code, int nbits); + uint64_t decode(); +}; + +struct PQDecoder8 { + const uint8_t *code; + PQDecoder8(const uint8_t *code, int nbits); + uint64_t decode(); +}; + +struct PQDecoder16 { + const uint16_t *code; + PQDecoder16(const uint8_t *code, int nbits); + uint64_t decode(); +}; + +} // namespace faiss + +#include + +#endif diff --git a/core/src/index/thirdparty/faiss/impl/RHNSW.cpp b/core/src/index/thirdparty/faiss/impl/RHNSW.cpp new file mode 100644 index 0000000000..4928910baa --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/RHNSW.cpp @@ -0,0 +1,441 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include + +#include + +namespace faiss { + + +/************************************************************** + * hnsw structure implementation + **************************************************************/ + +RHNSW::RHNSW(int M) : M(M), rng(12345) { + level_generator.seed(100); + max_level = -1; + entry_point = -1; + efSearch = 16; + efConstruction = 40; + upper_beam = 1; + level0_link_size = sizeof(int) * ((M << 1) | 1); + link_size = sizeof(int) * (M + 1); + level0_links = nullptr; + linkLists = nullptr; + level_constant = 1 / log(1.0 * M); + visited_list_pool = nullptr; +} + +void RHNSW::init(int ntotal) { + level_generator.seed(100); + if (visited_list_pool) delete visited_list_pool; + visited_list_pool = new VisitedListPool(1, ntotal); + std::vector(ntotal).swap(link_list_locks); +} + +RHNSW::~RHNSW() { + free(level0_links); + for (auto i = 0; i < levels.size(); ++ i) { + if (levels[i]) + free(linkLists[i]); + } + free(linkLists); + delete visited_list_pool; +} + +void RHNSW::reset() { + max_level = -1; + entry_point = -1; + levels.clear(); + free(level0_links); + for (auto i = 0; i < levels.size(); ++ i) { + if (levels[i]) + free(linkLists[i]); + } + free(linkLists); + level0_links = nullptr; + linkLists = nullptr; + level_constant = 1 / log(1.0 * M); +} + +int RHNSW::prepare_level_tab(size_t n, bool preset_levels) +{ + size_t n0 = levels.size(); + + std::vector level_stats(n); + if (preset_levels) { + FAISS_ASSERT (n0 + n == levels.size()); + } else { + FAISS_ASSERT (n0 == levels.size()); + for (int i = 0; i < n; i++) { + int pt_level = random_level(level_constant); + levels.push_back(pt_level); + } + } + + char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size); + if (level0_links_new == nullptr) { + throw std::runtime_error("No enough memory 4 level0_links!"); + } + memset(level0_links_new, 0, (n0 + n) * level0_link_size); + if (level0_links) { + memcpy(level0_links_new, level0_links, n0 * level0_link_size); + free(level0_links); + } + level0_links = level0_links_new; + + char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n)); + if (linkLists_new == nullptr) { + throw std::runtime_error("No enough memory 4 level0_links_new!"); + } + if (linkLists) { + memcpy(linkLists_new, linkLists, n0 * sizeof(void*)); + free(linkLists); + } + linkLists = linkLists_new; + + int max_level = 0; + int debug_space = 0; + for (int i = 0; i < n; i++) { + int pt_level = levels[i + n0]; + if (pt_level > max_level) max_level = pt_level; + if (pt_level) { + linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1); + if (linkLists[n0 + i] == nullptr) { + throw std::runtime_error("No enough memory 4 linkLists!"); + } + memset(linkLists[n0 + i], 0, link_size * pt_level + 1); + } + if (max_level >= level_stats.size()) { + level_stats.resize(max_level + 1); + } + level_stats[pt_level] ++; + } + +// printf("level stats:\n"); +// for (int i = 0; i <= max_level; ++ i) +// printf("level %d: %d points\n", i, level_stats[i]); +// printf("\n"); + std::vector(n0 + n).swap(link_list_locks); + if (visited_list_pool) delete visited_list_pool; + visited_list_pool = new VisitedListPool(1, n0 + n); + + return max_level; +} + + +/************************************************************** + * new implementation of hnsw ispired by hnswlib + * by cmli@zilliz July 30, 2020 + **************************************************************/ +using Node = faiss::RHNSW::Node; +using CompareByFirst = faiss::RHNSW::CompareByFirst; +void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) { + + std::unique_lock lock_el(link_list_locks[pt_id]); + std::unique_lock temp_lock(global); + int maxlevel_copy = max_level; + if (pt_level <= maxlevel_copy) + temp_lock.unlock(); + int currObj = entry_point; + int ep_copy = entry_point; + + if (currObj != -1) { + if (pt_level < maxlevel_copy) { + float curdist = ptdis(currObj); + for (int lev = maxlevel_copy; lev > pt_level; lev --) { + bool changed = true; + while (changed) { + changed = false; + std::unique_lock lk(link_list_locks[currObj]); + int *curObj_link = get_neighbor_link(currObj, lev); + auto curObj_nei_num = get_neighbors_num(curObj_link); + for (auto i = 1; i <= curObj_nei_num; ++ i) { + int cand = curObj_link[i]; + if (cand < 0 || cand > levels.size()) + throw std::runtime_error("cand error when addPoint"); + float d = ptdis(cand); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + for (int lev = std::min(pt_level, maxlevel_copy); lev >= 0; -- lev) { + if (lev > maxlevel_copy || lev < 0) + throw std::runtime_error("Level error"); + + std::priority_queue, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev); + currObj = top_candidates.top().second; + make_connection(ptdis, pt_id, top_candidates, lev); + } + } else { + entry_point = 0; + max_level = pt_level; + } + + if (pt_level > maxlevel_copy) { + entry_point = pt_id; + max_level = pt_level; + } + +} + +std::priority_queue, CompareByFirst> +RHNSW::search_layer(DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + int level) { + VisitedList *vl = visited_list_pool->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, CompareByFirst> top_candidates; + std::priority_queue, CompareByFirst> candidate_set; + + float d_nearest = ptdis(nearest); + float lb = d_nearest; + top_candidates.emplace(d_nearest, nearest); + candidate_set.emplace(-d_nearest, nearest); + visited_array[nearest] = visited_array_tag; + + while (!candidate_set.empty()) { + Node currNode = candidate_set.top(); + if ((-currNode.first) > lb) + break; + candidate_set.pop(); + int cur_id = currNode.second; + std::unique_lock lk(link_list_locks[cur_id]); + int *cur_link = get_neighbor_link(cur_id, level); + auto cur_neighbor_num = get_neighbors_num(cur_link); + + for (auto i = 1; i <= cur_neighbor_num; ++ i) { + int candidate_id = cur_link[i]; + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + float dcand = ptdis(candidate_id); + if (top_candidates.size() < efConstruction || lb > dcand) { + candidate_set.emplace(-dcand, candidate_id); + top_candidates.emplace(dcand, candidate_id); + if (top_candidates.size() > efConstruction) + top_candidates.pop(); + if (!top_candidates.empty()) + lb = top_candidates.top().first; + } + } + } + visited_list_pool->releaseVisitedList(vl); + return top_candidates; +} + +std::priority_queue, CompareByFirst> +RHNSW::search_base_layer(DistanceComputer& ptdis, + storage_idx_t nearest, + storage_idx_t ef, + float d_nearest, + ConcurrentBitsetPtr bitset) const { + VisitedList *vl = visited_list_pool->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, CompareByFirst> top_candidates; + std::priority_queue, CompareByFirst> candidate_set; + + float lb; + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(nearest))) { + lb = d_nearest; + top_candidates.emplace(d_nearest, nearest); + candidate_set.emplace(-d_nearest, nearest); + } else { + lb = std::numeric_limits::max(); + candidate_set.emplace(-lb, nearest); + } + visited_array[nearest] = visited_array_tag; + + while (!candidate_set.empty()) { + Node currNode = candidate_set.top(); + if ((-currNode.first) > lb) + break; + candidate_set.pop(); + int cur_id = currNode.second; + int *cur_link = get_neighbor_link(cur_id, 0); + auto cur_neighbor_num = get_neighbors_num(cur_link); + for (auto i = 1; i <= cur_neighbor_num; ++ i) { + int candidate_id = cur_link[i]; + if (visited_array[candidate_id] != visited_array_tag) { + visited_array[candidate_id] = visited_array_tag; + float dcand = ptdis(candidate_id); + if (top_candidates.size() < ef || lb > dcand) { + candidate_set.emplace(-dcand, candidate_id); + if (bitset == nullptr || !bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id))) + top_candidates.emplace(dcand, candidate_id); + if (top_candidates.size() > ef) + top_candidates.pop(); + if (!top_candidates.empty()) + lb = top_candidates.top().first; + } + } + } + } + visited_list_pool->releaseVisitedList(vl); + return top_candidates; +} + +void +RHNSW::make_connection(DistanceComputer& ptdis, + storage_idx_t pt_id, + std::priority_queue, CompareByFirst> &cand, + int level) { + int maxM = level ? M : M << 1; + int *selectedNeighbors = (int*)malloc(sizeof(int) * maxM); + int selectedNeighborsNum = 0; + prune_neighbors(ptdis, cand, maxM, selectedNeighbors, selectedNeighborsNum); + if (selectedNeighborsNum > maxM) + throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!"); + + int *cur_link = get_neighbor_link(pt_id, level); + if (*cur_link) + throw std::runtime_error("The newly inserted element should have blank link"); + + set_neighbors_num(cur_link, selectedNeighborsNum); + for (auto i = 1; i <= selectedNeighborsNum; ++ i) { + if (cur_link[i]) + throw std::runtime_error("Possible memory corruption."); + if (level > levels[selectedNeighbors[i - 1]]) + throw std::runtime_error("Trying to make a link on a non-exisitent level."); + cur_link[i] = selectedNeighbors[i - 1]; + } + + for (auto i = 0; i < selectedNeighborsNum; ++ i) { + std::unique_lock lk(link_list_locks[selectedNeighbors[i]]); + + int *selected_link = get_neighbor_link(selectedNeighbors[i], level); + auto selected_neighbor_num = get_neighbors_num(selected_link); + if (selected_neighbor_num > maxM) + throw std::runtime_error("Bad value of selected_neighbor_num."); + if (selectedNeighbors[i] == pt_id) + throw std::runtime_error("Trying to connect an element to itself."); + if (level > levels[selectedNeighbors[i]]) + throw std::runtime_error("Trying to make a link on a non-exisitent level."); + if (selected_neighbor_num < maxM) { + selected_link[selected_neighbor_num + 1] = pt_id; + set_neighbors_num(selected_link, selected_neighbor_num + 1); + } else { + double d_max = ptdis(selectedNeighbors[i]); + std::priority_queue, CompareByFirst> candi; + candi.emplace(d_max, pt_id); + for (auto j = 1; j <= selected_neighbor_num; ++ j) + candi.emplace(ptdis.symmetric_dis(selectedNeighbors[i], selected_link[j]), selected_link[j]); + int indx = 0; + prune_neighbors(ptdis, candi, maxM, selected_link + 1, indx); + set_neighbors_num(selected_link, indx); + } + } + + free(selectedNeighbors); +} + +void RHNSW::prune_neighbors(DistanceComputer& ptdis, + std::priority_queue, CompareByFirst> &cand, + const int maxM, int *ret, int &ret_len) { + if (cand.size() < maxM) { + while (!cand.empty()) { + ret[ret_len ++] = cand.top().second; + cand.pop(); + } + return; + } + std::priority_queue closest; + + while (!cand.empty()) { + closest.emplace(-cand.top().first, cand.top().second); + cand.pop(); + } + + while (closest.size()) { + if (ret_len >= maxM) + break; + Node curr = closest.top(); + float dist_to_query = -curr.first; + closest.pop(); + bool good = true; + for (auto i = 0; i < ret_len; ++ i) { + float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]); + if (cur_dist < dist_to_query) { + good = false; + break; + } + } + if (good) { + ret[ret_len ++] = curr.second; + } + } +} + +void RHNSW::searchKnn(DistanceComputer& qdis, int k, + idx_t *I, float *D, + ConcurrentBitsetPtr bitset) const { + if (levels.size() == 0) + return; + int ep = entry_point; + float dist = qdis(ep); + + for (auto i = max_level; i > 0; -- i) { + bool good = true; + while (good) { + good = false; + int *ep_link = get_neighbor_link(ep, i); + auto ep_neighbors_cnt = get_neighbors_num(ep_link); + for (auto j = 1; j <= ep_neighbors_cnt; ++ j) { + int cand = ep_link[j]; + if (cand < 0 || cand > levels.size()) + throw std::runtime_error("cand error"); + float d = qdis(cand); + if (d < dist) { + dist = d; + ep = cand; + good = true; + } + } + } + } + std::priority_queue, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset); + while (top_candidates.size() > k) + top_candidates.pop(); + int i = 0; + while (!top_candidates.empty()) { + I[i] = top_candidates.top().second; + D[i] = top_candidates.top().first; + i ++; + top_candidates.pop(); + } +} + +size_t RHNSW::cal_size() { + size_t ret = 0; + ret += sizeof(*this); + ret += visited_list_pool->GetSize(); + ret += link_list_locks.size() * sizeof(std::mutex); + ret += levels.size() * sizeof(int); + ret += levels.size() * level0_link_size; + ret += levels.size() * sizeof(void*); + for (auto i = 0; i < levels.size(); ++ i) { + ret += levels[i] ? link_size * levels[i] : 0; + } + return ret; +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/RHNSW.h b/core/src/index/thirdparty/faiss/impl/RHNSW.h new file mode 100644 index 0000000000..40ac9d68ef --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/RHNSW.h @@ -0,0 +1,367 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + + +namespace faiss { + + +/** Implementation of the Hierarchical Navigable Small World + * datastructure. + * + * Efficient and robust approximate nearest neighbor search using + * Hierarchical Navigable Small World graphs + * + * Yu. A. Malkov, D. A. Yashunin, arXiv 2017 + * + * This implmentation is heavily influenced by the hnswlib + * implementation by Yury Malkov and Leonid Boystov + * (https://github.com/searchivarius/nmslib/hnswlib) + * + * The HNSW object stores only the neighbor link structure, see + * IndexHNSW.h for the full index object. + */ + + +struct DistanceComputer; // from AuxIndexStructures +class VisitedListPool; + +struct RHNSW { + /// internal storage of vectors (32 bits: this is expensive) + typedef int storage_idx_t; + + /// Faiss results are 64-bit + typedef Index::idx_t idx_t; + + typedef std::pair Node; + + /** Heap structure that allows fast + */ + struct MinimaxHeap { + int n; + int k; + int nvalid; + + std::vector ids; + std::vector dis; + typedef faiss::CMax HC; + + explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {} + + void push(storage_idx_t i, float v) { + if (k == n) { + if (v >= dis[0]) return; + faiss::heap_pop (k--, dis.data(), ids.data()); + --nvalid; + } + faiss::heap_push (++k, dis.data(), ids.data(), v, i); + ++nvalid; + } + + float max() const { + return dis[0]; + } + + int size() const { + return nvalid; + } + + void clear() { + nvalid = k = 0; + } + + int pop_min(float *vmin_out = nullptr) { + assert(k > 0); + // returns min. This is an O(n) operation + int i = k - 1; + while (i >= 0) { + if (ids[i] != -1) break; + i--; + } + if (i == -1) return -1; + int imin = i; + float vmin = dis[i]; + i--; + while(i >= 0) { + if (ids[i] != -1 && dis[i] < vmin) { + vmin = dis[i]; + imin = i; + } + i--; + } + if (vmin_out) *vmin_out = vmin; + int ret = ids[imin]; + ids[imin] = -1; + --nvalid; + + return ret; + } + + int count_below(float thresh) { + int n_below = 0; + for(int i = 0; i < k; i++) { + if (dis[i] < thresh) { + n_below++; + } + } + + return n_below; + } + }; + + /// to sort pairs of (id, distance) from nearest to fathest or the reverse + struct NodeDistCloser { + float d; + int id; + NodeDistCloser(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; } + }; + + struct NodeDistFarther { + float d; + int id; + NodeDistFarther(float d, int id): d(d), id(id) {} + bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; } + }; + + struct CompareByFirst { + constexpr bool operator()(Node const &a, + Node const &b) const noexcept { + return a.first < b.first; + } + }; + + + /// level of each vector (base level = 1), size = ntotal + std::vector levels; + + /// number of entry points in levels > 0. + int upper_beam; + + /// entry point in the search structure (one of the points with maximum level + storage_idx_t entry_point; + + faiss::RandomGenerator rng; + std::default_random_engine level_generator; + + /// maximum level + int max_level; + int M; + char *level0_links; + char **linkLists; + size_t level0_link_size; + size_t link_size; + double level_constant; + VisitedListPool *visited_list_pool; + std::vector link_list_locks; + std::mutex global; + + /// expansion factor at construction time + int efConstruction; + + /// expansion factor at search time + int efSearch; + + /// range of entries in the neighbors table of vertex no at layer_no + storage_idx_t* get_neighbor_link(idx_t no, int layer_no) const { + return layer_no == 0 ? (int*)(level0_links + no * level0_link_size) : (int*)(linkLists[no] + (layer_no - 1) * link_size); + } + unsigned short int get_neighbors_num(int *p) const { + return *((unsigned short int*)p); + } + void set_neighbors_num(int *p, unsigned short int num) const { + *((unsigned short int*)(p)) = *((unsigned short int *)(&num)); + } + + /// only mandatory parameter: nb of neighbors + explicit RHNSW(int M = 32); + ~RHNSW(); + + void init(int ntotal); + /// pick a random level for a new point, arg = 1/log(M) + int random_level(double arg) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator)) * arg; + return (int)r; + } + + void reset(); + + int prepare_level_tab(size_t n, bool preset_levels = false); + + // re-implementations inspired by hnswlib + /** add point pt_id on all levels <= pt_level and build the link + * structure for them. inspired by implementation of hnswlib */ + void addPoint(DistanceComputer& ptdis, int pt_level, int pt_id); + + std::priority_queue, CompareByFirst> + search_layer (DistanceComputer& ptdis, + storage_idx_t pt_id, + storage_idx_t nearest, + int level); + + std::priority_queue, CompareByFirst> + search_base_layer (DistanceComputer& ptdis, + storage_idx_t nearest, + storage_idx_t ef, + float d_nearest, + ConcurrentBitsetPtr bitset = nullptr) const; + + void make_connection(DistanceComputer& ptdis, + storage_idx_t pt_id, + std::priority_queue, CompareByFirst> &cand, + int level); + + void prune_neighbors(DistanceComputer& ptdis, + std::priority_queue, CompareByFirst> &cand, + const int maxM, int *ret, int &ret_len); + + /// search interface inspired by hnswlib + void searchKnn(DistanceComputer& qdis, int k, + idx_t *I, float *D, + ConcurrentBitsetPtr bitset = nullptr) const; + + size_t cal_size(); + +}; + + +/************************************************************** + * Auxiliary structures + **************************************************************/ + +typedef unsigned short int vl_type; + +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + }; + + // keep compatibae with original version VisitedTable + /// set flog #no to true + void set(int no) { + mass[no] = curV; + } + + /// get flag #no + bool get(int no) const { + return mass[no] == curV; + } + + void advance() { + reset(); + } + + ~VisitedList() { delete[] mass; } +}; + +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + }; + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + }; + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + }; + + int64_t GetSize() { + auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type); + auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size); + return pool_size + sizeof(*this); + } +}; + +struct RHNSWStats { + size_t n1, n2, n3; + size_t ndis; + size_t nreorder; + bool view; + + RHNSWStats() { + reset(); + } + + void reset() { + n1 = n2 = n3 = 0; + ndis = 0; + nreorder = 0; + view = false; + } +}; + +// global var that collects them all +extern RHNSWStats rhnsw_stats; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp new file mode 100644 index 0000000000..35bc71ed73 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.cpp @@ -0,0 +1,182 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include + +#include +#include +#include + +namespace faiss { + +/******************************************************************* + * ScalarQuantizer implementation + * + * The main source of complexity is to support combinations of 4 + * variants without incurring runtime tests or virtual function calls: + * + * - 4 / 8 bits per code component + * - uniform / non-uniform + * - IP / L2 distance search + * - scalar / AVX distance computation + * + * The appropriate Quantizer object is returned via select_quantizer + * that hides the template mess. + ********************************************************************/ + + + +/******************************************************************* + * ScalarQuantizer implementation + ********************************************************************/ + +ScalarQuantizer::ScalarQuantizer + (size_t d, QuantizerType qtype): + qtype (qtype), rangestat(RangeStat::RS_minmax), rangestat_arg(0), d (d) +{ + switch (qtype) { + case QuantizerType::QT_8bit: + case QuantizerType::QT_8bit_uniform: + case QuantizerType::QT_8bit_direct: + code_size = d; + break; + case QuantizerType::QT_4bit: + case QuantizerType::QT_4bit_uniform: + code_size = (d + 1) / 2; + break; + case QuantizerType::QT_6bit: + code_size = (d * 6 + 7) / 8; + break; + case QuantizerType::QT_fp16: + code_size = d * 2; + break; + } +} + +ScalarQuantizer::ScalarQuantizer (): + qtype(QuantizerType::QT_8bit), + rangestat(RangeStat::RS_minmax), rangestat_arg(0), d (0), code_size(0) +{} + +void ScalarQuantizer::train (size_t n, const float *x) +{ + int bit_per_dim = + qtype == QuantizerType::QT_4bit_uniform ? 4 : + qtype == QuantizerType::QT_4bit ? 4 : + qtype == QuantizerType::QT_6bit ? 6 : + qtype == QuantizerType::QT_8bit_uniform ? 8 : + qtype == QuantizerType::QT_8bit ? 8 : -1; + + switch (qtype) { + case QuantizerType::QT_4bit_uniform: + case QuantizerType::QT_8bit_uniform: + train_Uniform (rangestat, rangestat_arg, + n * d, 1 << bit_per_dim, x, trained); + break; + case QuantizerType::QT_4bit: + case QuantizerType::QT_8bit: + case QuantizerType::QT_6bit: + train_NonUniform (rangestat, rangestat_arg, + n, d, 1 << bit_per_dim, x, trained); + break; + case QuantizerType::QT_fp16: + case QuantizerType::QT_8bit_direct: + // no training necessary + break; + } +} + +void ScalarQuantizer::train_residual(size_t n, + const float *x, + Index *quantizer, + bool by_residual, + bool verbose) +{ + const float * x_in = x; + + // 100k points more than enough + x = fvecs_maybe_subsample ( + d, (size_t*)&n, 100000, + x, verbose, 1234); + + ScopeDeleter del_x (x_in == x ? nullptr : x); + + if (by_residual) { + std::vector idx(n); + quantizer->assign (n, x, idx.data()); + + std::vector residuals(n * d); + quantizer->compute_residual_n (n, x, residuals.data(), idx.data()); + + train (n, residuals.data()); + } else { + train (n, x); + } +} + + +Quantizer *ScalarQuantizer::select_quantizer () const +{ + /* use hook to decide use AVX512 or not */ + return sq_sel_quantizer(qtype, d, trained); +} + + +void ScalarQuantizer::compute_codes (const float * x, + uint8_t * codes, + size_t n) const +{ + std::unique_ptr squant(select_quantizer ()); + + memset (codes, 0, code_size * n); +#pragma omp parallel for + for (size_t i = 0; i < n; i++) + squant->encode_vector (x + i * d, codes + i * code_size); +} + +void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const +{ + std::unique_ptr squant(select_quantizer ()); + +#pragma omp parallel for + for (size_t i = 0; i < n; i++) + squant->decode_vector (codes + i * code_size, x + i * d); +} + + +SQDistanceComputer * +ScalarQuantizer::get_distance_computer (MetricType metric) const +{ + FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); + /* use hook to decide use AVX512 or not */ + return sq_get_distance_computer(metric, qtype, d, trained); +} + + +/******************************************************************* + * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object + * + * It is an InvertedListScanner, but is designed to work with + * IndexScalarQuantizer as well. + ********************************************************************/ + +InvertedListScanner* ScalarQuantizer::select_InvertedListScanner + (MetricType mt, const Index *quantizer, + bool store_pairs, bool by_residual) const +{ + /* use hook to decide use AVX512 or not */ + return sq_sel_inv_list_scanner(mt, this, quantizer, d, store_pairs, by_residual); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h new file mode 100644 index 0000000000..3dfa72333d --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizer.h @@ -0,0 +1,243 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include + +namespace faiss { + +/** + * The uniform quantizer has a range [vmin, vmax]. The range can be + * the same for all dimensions (uniform) or specific per dimension + * (default). + */ + +struct ScalarQuantizer { + + QuantizerType qtype; + + /** The uniform encoder can estimate the range of representable + * values of the unform encoder using different statistics. Here + * rs = rangestat_arg */ + + RangeStat rangestat; + float rangestat_arg; + + /// dimension of input vectors + size_t d; + + /// bytes per vector + size_t code_size; + + /// trained values (including the range) + std::vector trained; + + ScalarQuantizer (size_t d, QuantizerType qtype); + ScalarQuantizer (); + + void train (size_t n, const float *x); + + /// Used by an IVF index to train based on the residuals + void train_residual (size_t n, + const float *x, + Index *quantizer, + bool by_residual, + bool verbose); + + /// same as compute_code for several vectors + void compute_codes (const float * x, + uint8_t * codes, + size_t n) const ; + + /// decode a vector from a given code (or n vectors if third argument) + void decode (const uint8_t *code, float *x, size_t n) const; + + + /***************************************************** + * Objects that provide methods for encoding/decoding, distance + * computation and inverted list scanning + *****************************************************/ + + Quantizer * select_quantizer() const; + + SQDistanceComputer *get_distance_computer (MetricType metric = METRIC_L2) + const; + + InvertedListScanner *select_InvertedListScanner + (MetricType mt, const Index *quantizer, bool store_pairs, + bool by_residual=false) const; + + size_t cal_size() { return sizeof(*this) + trained.size() * sizeof(float); } +}; + +template +struct IVFSQScannerIP: InvertedListScanner { + DCClass dc; + bool store_pairs, by_residual; + + size_t code_size; + + idx_t list_no; /// current list (set to 0 for Flat index + float accu0; /// added to all distances + + IVFSQScannerIP(int d, const std::vector & trained, + size_t code_size, bool store_pairs, + bool by_residual): + dc(d, trained), store_pairs(store_pairs), + by_residual(by_residual), + code_size(code_size), list_no(0), accu0(0) + {} + + + void set_query (const float *query) override { + dc.set_query (query); + } + + void set_list (idx_t list_no, float coarse_dis) override { + this->list_no = list_no; + accu0 = by_residual ? coarse_dis : 0; + } + + float distance_to_code (const uint8_t *code) const final { + return accu0 + dc.query_to_code (code); + } + + size_t scan_codes (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset) const override + { + size_t nup = 0; + + for (size_t j = 0; j < list_size; j++) { + if(!bitset || !bitset->test(ids[j])){ + float accu = accu0 + dc.query_to_code (codes); + + if (accu > simi [0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + minheap_swap_top (k, simi, idxi, accu, id); + nup++; + } + } + codes += code_size; + } + return nup; + } + + void scan_codes_range (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult & res, + ConcurrentBitsetPtr bitset = nullptr) const override + { + for (size_t j = 0; j < list_size; j++) { + float accu = accu0 + dc.query_to_code (codes); + if (accu > radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add (accu, id); + } + codes += code_size; + } + } +}; + + +template +struct IVFSQScannerL2: InvertedListScanner { + DCClass dc; + + bool store_pairs, by_residual; + size_t code_size; + const Index *quantizer; + idx_t list_no; /// current inverted list + const float *x; /// current query + + std::vector tmp; + + IVFSQScannerL2(int d, const std::vector & trained, + size_t code_size, const Index *quantizer, + bool store_pairs, bool by_residual): + dc(d, trained), store_pairs(store_pairs), by_residual(by_residual), + code_size(code_size), quantizer(quantizer), + list_no (0), x (nullptr), tmp (d) + { + } + + + void set_query (const float *query) override { + x = query; + if (!quantizer) { + dc.set_query (query); + } + } + + + void set_list (idx_t list_no, float /*coarse_dis*/) override { + if (by_residual) { + this->list_no = list_no; + // shift of x_in wrt centroid + quantizer->Index::compute_residual (x, tmp.data(), list_no); + dc.set_query (tmp.data ()); + } else { + dc.set_query (x); + } + } + + float distance_to_code (const uint8_t *code) const final { + return dc.query_to_code (code); + } + + size_t scan_codes (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float *simi, idx_t *idxi, + size_t k, + ConcurrentBitsetPtr bitset) const override + { + size_t nup = 0; + for (size_t j = 0; j < list_size; j++) { + if(!bitset || !bitset->test(ids[j])){ + float dis = dc.query_to_code (codes); + + if (dis < simi [0]) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + maxheap_swap_top (k, simi, idxi, dis, id); + nup++; + } + } + codes += code_size; + } + return nup; + } + + void scan_codes_range (size_t list_size, + const uint8_t *codes, + const idx_t *ids, + float radius, + RangeQueryResult & res, + ConcurrentBitsetPtr bitset = nullptr) const override + { + for (size_t j = 0; j < list_size; j++) { + float dis = dc.query_to_code (codes); + if (dis < radius) { + int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; + res.add (dis, id); + } + codes += code_size; + } + } +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec.h new file mode 100644 index 0000000000..38abdc7e74 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec.h @@ -0,0 +1,603 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace faiss { + + +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ + +struct Codec8bit { + static void encode_component (float x, uint8_t *code, int i) { + code[i] = (int)(255 * x); + } + + static float decode_component (const uint8_t *code, int i) { + return (code[i] + 0.5f) / 255.0f; + } +}; + + +struct Codec4bit { + static void encode_component (float x, uint8_t *code, int i) { + code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2); + } + + static float decode_component (const uint8_t *code, int i) { + return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; + } +}; + +struct Codec6bit { + static void encode_component (float x, uint8_t *code, int i) { + int bits = (int)(x * 63.0); + code += (i >> 2) * 3; + switch(i & 3) { + case 0: + code[0] |= bits; + break; + case 1: + code[0] |= bits << 6; + code[1] |= bits >> 2; + break; + case 2: + code[1] |= bits << 4; + code[2] |= bits >> 4; + break; + case 3: + code[2] |= bits << 2; + break; + } + } + + static float decode_component (const uint8_t *code, int i) { + uint8_t bits; + code += (i >> 2) * 3; + switch(i & 3) { + case 0: + bits = code[0] & 0x3f; + break; + case 1: + bits = code[0] >> 6; + bits |= (code[1] & 0xf) << 2; + break; + case 2: + bits = code[1] >> 4; + bits |= (code[2] & 3) << 4; + break; + case 3: + bits = code[2] >> 2; + break; + } + return (bits + 0.5f) / 63.0f; + } +}; + + + +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ + + +template +struct QuantizerTemplate {}; + + +template +struct QuantizerTemplate: Quantizer { + const size_t d; + const float vmin, vdiff; + + QuantizerTemplate(size_t d, const std::vector &trained): + d(d), vmin(trained[0]), vdiff(trained[1]) + { + } + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = (x[i] - vmin) / vdiff; + if (xi < 0) { + xi = 0; + } + if (xi > 1.0) { + xi = 1.0; + } + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin + xi * vdiff; + } + } + + float reconstruct_component (const uint8_t * code, int i) const + { + float xi = Codec::decode_component (code, i); + return vmin + xi * vdiff; + } +}; + + +template +struct QuantizerTemplate: Quantizer { + const size_t d; + const float *vmin, *vdiff; + + QuantizerTemplate (size_t d, const std::vector &trained): + d(d), vmin(trained.data()), vdiff(trained.data() + d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + float xi = (x[i] - vmin[i]) / vdiff[i]; + if (xi < 0) + xi = 0; + if (xi > 1.0) + xi = 1.0; + Codec::encode_component(xi, code, i); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + float xi = Codec::decode_component(code, i); + x[i] = vmin[i] + xi * vdiff[i]; + } + } + + float reconstruct_component (const uint8_t * code, int i) const + { + float xi = Codec::decode_component (code, i); + return vmin[i] + xi * vdiff[i]; + } +}; + + +/******************************************************************* + * FP16 quantizer + *******************************************************************/ + +template +struct QuantizerFP16 {}; + +template<> +struct QuantizerFP16<1>: Quantizer { + const size_t d; + + QuantizerFP16(size_t d, const std::vector & /* unused */): + d(d) {} + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + ((uint16_t*)code)[i] = encode_fp16(x[i]); + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = decode_fp16(((uint16_t*)code)[i]); + } + } + + float reconstruct_component (const uint8_t * code, int i) const + { + return decode_fp16(((uint16_t*)code)[i]); + } +}; + + +/******************************************************************* + * 8bit_direct quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirect {}; + +template<> +struct Quantizer8bitDirect<1>: Quantizer { + const size_t d; + + Quantizer8bitDirect(size_t d, const std::vector & /* unused */): + d(d) {} + + + void encode_vector(const float* x, uint8_t* code) const final { + for (size_t i = 0; i < d; i++) { + code[i] = (uint8_t)x[i]; + } + } + + void decode_vector(const uint8_t* code, float* x) const final { + for (size_t i = 0; i < d; i++) { + x[i] = code[i]; + } + } + + float reconstruct_component (const uint8_t * code, int i) const + { + return code[i]; + } +}; + + +template +Quantizer *select_quantizer_1 ( + QuantizerType qtype, + size_t d, const std::vector & trained) +{ + switch(qtype) { + case QuantizerType::QT_8bit: + return new QuantizerTemplate(d, trained); + case QuantizerType::QT_6bit: + return new QuantizerTemplate(d, trained); + case QuantizerType::QT_4bit: + return new QuantizerTemplate(d, trained); + case QuantizerType::QT_8bit_uniform: + return new QuantizerTemplate(d, trained); + case QuantizerType::QT_4bit_uniform: + return new QuantizerTemplate(d, trained); + case QuantizerType::QT_fp16: + return new QuantizerFP16 (d, trained); + case QuantizerType::QT_8bit_direct: + return new Quantizer8bitDirect (d, trained); + } + FAISS_THROW_MSG ("unknown qtype"); +} + + + +/******************************************************************* + * Similarity: gets vector components and computes a similarity wrt. a + * query vector stored in the object. The data fields just encapsulate + * an accumulator. + */ + +template +struct SimilarityL2 {}; + +template<> +struct SimilarityL2<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2 (const float * y): y(y) {} + + /******* scalar accumulator *******/ + + float accu; + + void begin () { + accu = 0; + yi = y; + } + + void add_component (float x) { + float tmp = *yi++ - x; + accu += tmp * tmp; + } + + void add_component_2 (float x1, float x2) { + float tmp = x1 - x2; + accu += tmp * tmp; + } + + float result () { + return accu; + } +}; + + +template +struct SimilarityIP {}; + +template<> +struct SimilarityIP<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + const float *y, *yi; + + float accu; + + explicit SimilarityIP (const float * y): + y (y) {} + + void begin () { + accu = 0; + yi = y; + } + + void add_component (float x) { + accu += *yi++ * x; + } + + void add_component_2 (float x1, float x2) { + accu += x1 * x2; + } + + float result () { + return accu; + } +}; + + +/******************************************************************* + * DistanceComputer: combines a similarity and a quantizer to do + * code-to-vector or code-to-code comparisons + *******************************************************************/ + +template +struct DCTemplate : SQDistanceComputer {}; + +template +struct DCTemplate : SQDistanceComputer +{ + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector &trained): + quant(d, trained) + {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float xi = quant.reconstruct_component(code, i); + sim.add_component(xi); + } + return sim.result(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin(); + for (size_t i = 0; i < quant.d; i++) { + float x1 = quant.reconstruct_component(code1, i); + float x2 = quant.reconstruct_component(code2, i); + sim.add_component_2(x1, x2); + } + return sim.result(); + } + + void set_query (const float *x) final { + q = x; + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_distance (q, code); + } +}; + + +/******************************************************************* + * DistanceComputerByte: computes distances in the integer domain + *******************************************************************/ + +template +struct DistanceComputerByte : SQDistanceComputer {}; + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector &): d(d), tmp(d) { + } + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + int accu = 0; + for (int i = 0; i < d; i++) { + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu += int(code1[i]) * code2[i]; + } else { + int diff = int(code1[i]) - code2[i]; + accu += diff * diff; + } + } + return accu; + } + + void set_query (const float *x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_code_distance (tmp.data(), code); + } +}; + + +/******************************************************************* + * select_distance_computer: runtime selection of template + * specialization + *******************************************************************/ + +template +SQDistanceComputer *select_distance_computer ( + QuantizerType qtype, + size_t d, const std::vector & trained) +{ + constexpr int SIMDWIDTH = Sim::simdwidth; + switch(qtype) { + case QuantizerType::QT_8bit_uniform: + return new DCTemplate, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit_uniform: + return new DCTemplate, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit: + return new DCTemplate, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_6bit: + return new DCTemplate, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit: + return new DCTemplate, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_fp16: + return new DCTemplate + , Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit_direct: + if (d % 16 == 0) { + return new DistanceComputerByte(d, trained); + } else { + return new DCTemplate + , Sim, SIMDWIDTH>(d, trained); + } + } + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + +template +InvertedListScanner* sel2_InvertedListScanner ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + if (DCClass::Sim::metric_type == METRIC_L2) { + return new IVFSQScannerL2(sq->d, sq->trained, sq->code_size, + quantizer, store_pairs, r); + } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { + return new IVFSQScannerIP(sq->d, sq->trained, sq->code_size, + store_pairs, r); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +template +InvertedListScanner* sel12_InvertedListScanner ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + using QuantizerClass = QuantizerTemplate; + using DCClass = DCTemplate; + return sel2_InvertedListScanner (sq, quantizer, store_pairs, r); +} + + +template +InvertedListScanner* sel1_InvertedListScanner ( + const ScalarQuantizer *sq, const Index *quantizer, + bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + switch(sq->qtype) { + case QuantizerType::QT_8bit_uniform: + return sel12_InvertedListScanner + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit_uniform: + return sel12_InvertedListScanner + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit: + return sel12_InvertedListScanner + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit: + return sel12_InvertedListScanner + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_6bit: + return sel12_InvertedListScanner + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_fp16: + return sel2_InvertedListScanner + , Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit_direct: + if (sq->d % 16 == 0) { + return sel2_InvertedListScanner + > + (sq, quantizer, store_pairs, r); + } else { + return sel2_InvertedListScanner + , + Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + } + } + + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + +template +InvertedListScanner* sel0_InvertedListScanner ( + MetricType mt, const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool by_residual) +{ + if (mt == METRIC_L2) { + return sel1_InvertedListScanner > + (sq, quantizer, store_pairs, by_residual); + } else if (mt == METRIC_INNER_PRODUCT) { + return sel1_InvertedListScanner > + (sq, quantizer, store_pairs, by_residual); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx.h new file mode 100644 index 0000000000..e361376b5f --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx.h @@ -0,0 +1,576 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace faiss { + +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ + +struct Codec8bit_avx : public Codec8bit { + static __m256 decode_8_components (const uint8_t *code, int i) { + uint64_t c8 = *(uint64_t*)(code + i); + __m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8)); + __m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32)); + // __m256i i8 = _mm256_set_m128i(c4lo, c4hi); + __m256i i8 = _mm256_castsi128_si256 (c4lo); + i8 = _mm256_insertf128_si256 (i8, c4hi, 1); + __m256 f8 = _mm256_cvtepi32_ps (i8); + __m256 half = _mm256_set1_ps (0.5f); + f8 += half; + __m256 one_255 = _mm256_set1_ps (1.f / 255.f); + return f8 * one_255; + } +}; + +struct Codec4bit_avx : public Codec4bit { + static __m256 decode_8_components (const uint8_t *code, int i) { + uint32_t c4 = *(uint32_t*)(code + (i >> 1)); + uint32_t mask = 0x0f0f0f0f; + uint32_t c4ev = c4 & mask; + uint32_t c4od = (c4 >> 4) & mask; + + // the 8 lower bytes of c8 contain the values + __m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev), + _mm_set1_epi32(c4od)); + __m128i c4lo = _mm_cvtepu8_epi32 (c8); + __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4)); + __m256i i8 = _mm256_castsi128_si256 (c4lo); + i8 = _mm256_insertf128_si256 (i8, c4hi, 1); + __m256 f8 = _mm256_cvtepi32_ps (i8); + __m256 half = _mm256_set1_ps (0.5f); + f8 += half; + __m256 one_255 = _mm256_set1_ps (1.f / 15.f); + return f8 * one_255; + } +}; + +struct Codec6bit_avx : public Codec6bit { + static __m256 decode_8_components (const uint8_t *code, int i) { + return _mm256_set_ps + (decode_component(code, i + 7), + decode_component(code, i + 6), + decode_component(code, i + 5), + decode_component(code, i + 4), + decode_component(code, i + 3), + decode_component(code, i + 2), + decode_component(code, i + 1), + decode_component(code, i + 0)); + } +}; + + +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ + +template +struct QuantizerTemplate_avx {}; + +template +struct QuantizerTemplate_avx : public QuantizerTemplate { + QuantizerTemplate_avx(size_t d, const std::vector &trained) : + QuantizerTemplate (d, trained) {} +}; + +template +struct QuantizerTemplate_avx : public QuantizerTemplate { + QuantizerTemplate_avx (size_t d, const std::vector &trained) : + QuantizerTemplate (d, trained) {} + + __m256 reconstruct_8_components (const uint8_t * code, int i) const { + __m256 xi = Codec::decode_8_components (code, i); + return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff); + } +}; + +template +struct QuantizerTemplate_avx : public QuantizerTemplate { + QuantizerTemplate_avx (size_t d, const std::vector &trained) : + QuantizerTemplate (d, trained) {} +}; + +template +struct QuantizerTemplate_avx: public QuantizerTemplate { + QuantizerTemplate_avx (size_t d, const std::vector &trained) : + QuantizerTemplate (d, trained) {} + + __m256 reconstruct_8_components (const uint8_t * code, int i) const { + __m256 xi = Codec::decode_8_components (code, i); + return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i); + } +}; + + +/******************************************************************* + * FP16 quantizer + *******************************************************************/ + +template +struct QuantizerFP16_avx {}; + +template<> +struct QuantizerFP16_avx<1> : public QuantizerFP16<1> { + QuantizerFP16_avx (size_t d, const std::vector &unused) : + QuantizerFP16<1> (d, unused) {} +}; + +template<> +struct QuantizerFP16_avx<8>: public QuantizerFP16<1> { + QuantizerFP16_avx (size_t d, const std::vector &trained): + QuantizerFP16<1> (d, trained) {} + + __m256 reconstruct_8_components (const uint8_t * code, int i) const { + __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i)); + return _mm256_cvtph_ps (codei); + } +}; + + +/******************************************************************* + * 8bit_direct quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirect_avx {}; + +template<> +struct Quantizer8bitDirect_avx<1> : public Quantizer8bitDirect<1> { + Quantizer8bitDirect_avx (size_t d, const std::vector &unused) : + Quantizer8bitDirect(d, unused) {} +}; + +template<> +struct Quantizer8bitDirect_avx<8>: public Quantizer8bitDirect<1> { + Quantizer8bitDirect_avx (size_t d, const std::vector &trained) : + Quantizer8bitDirect<1> (d, trained) {} + + __m256 reconstruct_8_components (const uint8_t * code, int i) const { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32 + return _mm256_cvtepi32_ps (y8); // 8 * float32 + } +}; + +template +Quantizer *select_quantizer_1_avx (QuantizerType qtype, size_t d, + const std::vector & trained) { + switch(qtype) { + case QuantizerType::QT_8bit: + return new QuantizerTemplate_avx(d, trained); + case QuantizerType::QT_6bit: + return new QuantizerTemplate_avx(d, trained); + case QuantizerType::QT_4bit: + return new QuantizerTemplate_avx(d, trained); + case QuantizerType::QT_8bit_uniform: + return new QuantizerTemplate_avx(d, trained); + case QuantizerType::QT_4bit_uniform: + return new QuantizerTemplate_avx(d, trained); + case QuantizerType::QT_fp16: + return new QuantizerFP16_avx(d, trained); + case QuantizerType::QT_8bit_direct: + return new Quantizer8bitDirect_avx(d, trained); + } + FAISS_THROW_MSG ("unknown qtype"); +} + + +/******************************************************************* + * Similarity: gets vector components and computes a similarity wrt. a + * query vector stored in the object. The data fields just encapsulate + * an accumulator. + */ + +template +struct SimilarityL2_avx {}; + +template<> +struct SimilarityL2_avx<1> : public SimilarityL2<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; + + explicit SimilarityL2_avx (const float * y) : SimilarityL2<1>(y) {} +}; + +template<> +struct SimilarityL2_avx<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2_avx (const float * y): y(y) {} + __m256 accu8; + + void begin_8 () { + accu8 = _mm256_setzero_ps(); + yi = y; + } + + void add_8_components (__m256 x) { + __m256 yiv = _mm256_loadu_ps (yi); + yi += 8; + __m256 tmp = yiv - x; + accu8 += tmp * tmp; + } + + void add_8_components_2 (__m256 x, __m256 y) { + __m256 tmp = y - x; + accu8 += tmp * tmp; + } + + float result_8 () { + __m256 sum = _mm256_hadd_ps(accu8, accu8); + __m256 sum2 = _mm256_hadd_ps(sum, sum); + // now add the 0th and 4th component + return + _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) + + _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1)); + } +}; + + +template +struct SimilarityIP_avx {}; + +template<> +struct SimilarityIP_avx<1> : public SimilarityIP<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + explicit SimilarityIP_avx (const float * y) : SimilarityIP<1>(y) {} +}; + +template<> +struct SimilarityIP_avx<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP_avx (const float * y): y (y) {} + + __m256 accu8; + + void begin_8 () { + accu8 = _mm256_setzero_ps(); + yi = y; + } + + void add_8_components (__m256 x) { + __m256 yiv = _mm256_loadu_ps (yi); + yi += 8; + accu8 += yiv * x; + } + + void add_8_components_2 (__m256 x1, __m256 x2) { + accu8 += x1 * x2; + } + + float result_8 () { + __m256 sum = _mm256_hadd_ps(accu8, accu8); + __m256 sum2 = _mm256_hadd_ps(sum, sum); + // now add the 0th and 4th component + return + _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) + + _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1)); + } +}; + + +/******************************************************************* + * DistanceComputer: combines a similarity and a quantizer to do + * code-to-vector or code-to-code comparisons + *******************************************************************/ + +template +struct DCTemplate_avx : SQDistanceComputer {}; + +template +struct DCTemplate_avx : public DCTemplate { + DCTemplate_avx(size_t d, const std::vector &trained) : + DCTemplate(d, trained) {} +}; + +template +struct DCTemplate_avx : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate_avx(size_t d, const std::vector &trained): + quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + __m256 xi = quant.reconstruct_8_components(code, i); + sim.add_8_components(xi); + } + return sim.result_8(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_8(); + for (size_t i = 0; i < quant.d; i += 8) { + __m256 x1 = quant.reconstruct_8_components(code1, i); + __m256 x2 = quant.reconstruct_8_components(code2, i); + sim.add_8_components_2(x1, x2); + } + return sim.result_8(); + } + + void set_query (const float *x) final { + q = x; + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_distance (q, code); + } +}; + + +/******************************************************************* + * DistanceComputerByte: computes distances in the integer domain + *******************************************************************/ + +template +struct DistanceComputerByte_avx : SQDistanceComputer {}; + +template +struct DistanceComputerByte_avx : public DistanceComputerByte { + DistanceComputerByte_avx(int d, const std::vector &unused) : + DistanceComputerByte(d, unused) {} +}; + +template +struct DistanceComputerByte_avx : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte_avx(int d, const std::vector &): d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { + // __m256i accu = _mm256_setzero_ps (); + __m256i accu = _mm256_setzero_si256 (); + for (int i = 0; i < d; i += 16) { + // load 16 bytes, convert to 16 uint16_t + __m256i c1 = _mm256_cvtepu8_epi16 + (_mm_loadu_si128((__m128i*)(code1 + i))); + __m256i c2 = _mm256_cvtepu8_epi16 + (_mm_loadu_si128((__m128i*)(code2 + i))); + __m256i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm256_madd_epi16(c1, c2); + } else { + __m256i diff = _mm256_sub_epi16(c1, c2); + prod32 = _mm256_madd_epi16(diff, diff); + } + accu = _mm256_add_epi32 (accu, prod32); + } + __m128i sum = _mm256_extractf128_si256(accu, 0); + sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1)); + sum = _mm_hadd_epi32 (sum, sum); + sum = _mm_hadd_epi32 (sum, sum); + return _mm_cvtsi128_si32 (sum); + } + + void set_query (const float *x) final { + /* + for (int i = 0; i < d; i += 8) { + __m256 xi = _mm256_loadu_ps (x + i); + __m256i ci = _mm256_cvtps_epi32(xi); + */ + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_code_distance (tmp.data(), code); + } +}; + + +/******************************************************************* + * select_distance_computer: runtime selection of template + * specialization + *******************************************************************/ + +template +SQDistanceComputer *select_distance_computer_avx ( + QuantizerType qtype, + size_t d, const std::vector & trained) +{ + constexpr int SIMDWIDTH = Sim::simdwidth; + switch(qtype) { + case QuantizerType::QT_8bit_uniform: + return new DCTemplate_avx, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit_uniform: + return new DCTemplate_avx, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit: + return new DCTemplate_avx, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_6bit: + return new DCTemplate_avx, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit: + return new DCTemplate_avx, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_fp16: + return new DCTemplate_avx + , Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit_direct: + if (d % 16 == 0) { + return new DistanceComputerByte_avx(d, trained); + } else { + return new DCTemplate_avx + , Sim, SIMDWIDTH>(d, trained); + } + } + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + +template +InvertedListScanner* sel2_InvertedListScanner_avx ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + return sel2_InvertedListScanner (sq, quantizer, store_pairs, r); +} + +template +InvertedListScanner* sel12_InvertedListScanner_avx ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + using QuantizerClass = QuantizerTemplate_avx; + using DCClass = DCTemplate_avx; + return sel2_InvertedListScanner_avx (sq, quantizer, store_pairs, r); +} + + +template +InvertedListScanner* sel1_InvertedListScanner_avx ( + const ScalarQuantizer *sq, const Index *quantizer, + bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + switch(sq->qtype) { + case QuantizerType::QT_8bit_uniform: + return sel12_InvertedListScanner_avx + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit_uniform: + return sel12_InvertedListScanner_avx + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit: + return sel12_InvertedListScanner_avx + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit: + return sel12_InvertedListScanner_avx + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_6bit: + return sel12_InvertedListScanner_avx + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_fp16: + return sel2_InvertedListScanner_avx + , Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit_direct: + if (sq->d % 16 == 0) { + return sel2_InvertedListScanner_avx + > + (sq, quantizer, store_pairs, r); + } else { + return sel2_InvertedListScanner_avx + , + Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + } + } + + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + +template +InvertedListScanner* sel0_InvertedListScanner_avx ( + MetricType mt, const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool by_residual) +{ + if (mt == METRIC_L2) { + return sel1_InvertedListScanner_avx > + (sq, quantizer, store_pairs, by_residual); + } else if (mt == METRIC_INNER_PRODUCT) { + return sel1_InvertedListScanner_avx > + (sq, quantizer, store_pairs, by_residual); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx512.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx512.h new file mode 100644 index 0000000000..3892f9dfdd --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerCodec_avx512.h @@ -0,0 +1,661 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace faiss { + +/******************************************************************* + * ScalarQuantizer implementation + * + * The main source of complexity is to support combinations of 4 + * variants without incurring runtime tests or virtual function calls: + * + * - 4 / 8 bits per code component + * - uniform / non-uniform + * - IP / L2 distance search + * - scalar / AVX distance computation + * + * The appropriate Quantizer object is returned via select_quantizer + * that hides the template mess. + ********************************************************************/ + + +/******************************************************************* + * Codec: converts between values in [0, 1] and an index in a code + * array. The "i" parameter is the vector component index (not byte + * index). + */ + +struct Codec8bit_avx512 : public Codec8bit_avx { + static __m512 decode_16_components (const uint8_t *code, int i) { + uint64_t c8 = *(uint64_t*)(code + i); + __m256i c8lo = _mm256_cvtepu8_epi32 (_mm_set1_epi64x(c8)); + c8 = *(uint64_t*)(code + i + 8); + __m256i c8hi = _mm256_cvtepu8_epi32 (_mm_set1_epi64x(c8)); + // __m256i i8 = _mm256_set_m128i(c4lo, c4hi); + __m512i i16 = _mm512_castsi256_si512 (c8lo); + i16 = _mm512_inserti32x8 (i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps (i16); + __m512 half = _mm512_set1_ps (0.5f); + f16 += half; + __m512 one_255 = _mm512_set1_ps (1.f / 255.f); + return f16 * one_255; + } +}; + +struct Codec4bit_avx512 : public Codec4bit_avx { + static __m512 decode_16_components (const uint8_t *code, int i) { + uint64_t c8 = *(uint64_t*)(code + (i >> 1)); + uint64_t mask = 0x0f0f0f0f0f0f0f0f; + uint64_t c8ev = c8 & mask; + uint64_t c8od = (c8 >> 4) & mask; + + // the 8 lower bytes of c8 contain the values + __m128i c16 = _mm_unpacklo_epi8 (_mm_set1_epi64x(c8ev), + _mm_set1_epi64x(c8od)); + __m256i c8lo = _mm256_cvtepu8_epi32 (c16); + __m256i c8hi = _mm256_cvtepu8_epi32 (_mm_srli_si128(c16, 4)); + __m512i i16 = _mm512_castsi256_si512 (c8lo); + i16 = _mm512_inserti32x8 (i16, c8hi, 1); + __m512 f16 = _mm512_cvtepi32_ps (i16); + __m512 half = _mm512_set1_ps (0.5f); + f16 += half; + __m512 one_255 = _mm512_set1_ps (1.f / 15.f); + return f16 * one_255; + } +}; + +struct Codec6bit_avx512 : public Codec6bit_avx { + static __m512 decode_16_components (const uint8_t *code, int i) { + return _mm512_set_ps + (decode_component(code, i + 15), + decode_component(code, i + 14), + decode_component(code, i + 13), + decode_component(code, i + 12), + decode_component(code, i + 11), + decode_component(code, i + 10), + decode_component(code, i + 9), + decode_component(code, i + 8), + decode_component(code, i + 7), + decode_component(code, i + 6), + decode_component(code, i + 5), + decode_component(code, i + 4), + decode_component(code, i + 3), + decode_component(code, i + 2), + decode_component(code, i + 1), + decode_component(code, i + 0)); + } +}; + + +/******************************************************************* + * Quantizer: normalizes scalar vector components, then passes them + * through a codec + *******************************************************************/ + +template +struct QuantizerTemplate_avx512 {}; + +template +struct QuantizerTemplate_avx512 : public QuantizerTemplate_avx { + QuantizerTemplate_avx512(size_t d, const std::vector &trained) : + QuantizerTemplate_avx (d, trained) {} +}; + +template +struct QuantizerTemplate_avx512 : public QuantizerTemplate_avx { + QuantizerTemplate_avx512 (size_t d, const std::vector &trained) : + QuantizerTemplate_avx (d, trained) {} +}; + +template +struct QuantizerTemplate_avx512 : public QuantizerTemplate_avx { + QuantizerTemplate_avx512 (size_t d, const std::vector &trained) : + QuantizerTemplate_avx (d, trained) {} + + __m512 reconstruct_16_components (const uint8_t * code, int i) const { + __m512 xi = Codec::decode_16_components (code, i); + return _mm512_set1_ps(this->vmin) + xi * _mm512_set1_ps (this->vdiff); + } +}; + + +template +struct QuantizerTemplate_avx512 : public QuantizerTemplate_avx { + QuantizerTemplate_avx512 (size_t d, const std::vector &trained) : + QuantizerTemplate_avx (d, trained) {} +}; + +template +struct QuantizerTemplate_avx512 : public QuantizerTemplate_avx { + QuantizerTemplate_avx512 (size_t d, const std::vector &trained): + QuantizerTemplate_avx (d, trained) {} +}; + +template +struct QuantizerTemplate_avx512: public QuantizerTemplate_avx { + QuantizerTemplate_avx512 (size_t d, const std::vector &trained): + QuantizerTemplate_avx (d, trained) {} + + __m512 reconstruct_16_components (const uint8_t * code, int i) const { + __m512 xi = Codec::decode_16_components (code, i); + return _mm512_loadu_ps (this->vmin + i) + xi * _mm512_loadu_ps (this->vdiff + i); + } +}; + +/******************************************************************* + * FP16 quantizer + *******************************************************************/ + +template +struct QuantizerFP16_avx512 {}; + +template<> +struct QuantizerFP16_avx512<1> : public QuantizerFP16_avx<1> { + QuantizerFP16_avx512(size_t d, const std::vector &unused) : + QuantizerFP16_avx<1> (d, unused) {} +}; + +template<> +struct QuantizerFP16_avx512<8> : public QuantizerFP16_avx<8> { + QuantizerFP16_avx512 (size_t d, const std::vector &trained) : + QuantizerFP16_avx<8> (d, trained) {} +}; + +template<> +struct QuantizerFP16_avx512<16>: public QuantizerFP16_avx<8> { + QuantizerFP16_avx512 (size_t d, const std::vector &trained): + QuantizerFP16_avx<8> (d, trained) {} + + __m512 reconstruct_16_components (const uint8_t * code, int i) const { + __m256i codei = _mm256_loadu_si256 ((const __m256i*)(code + 2 * i)); + return _mm512_cvtph_ps (codei); + } +}; + +/******************************************************************* + * 8bit_direct quantizer + *******************************************************************/ + +template +struct Quantizer8bitDirect_avx512 {}; + +template<> +struct Quantizer8bitDirect_avx512<1> : public Quantizer8bitDirect_avx<1> { + Quantizer8bitDirect_avx512(size_t d, const std::vector &unused) : + Quantizer8bitDirect_avx<1> (d, unused) {} +}; + +template<> +struct Quantizer8bitDirect_avx512<8> : public Quantizer8bitDirect_avx<8> { + Quantizer8bitDirect_avx512 (size_t d, const std::vector &trained): + Quantizer8bitDirect_avx<8> (d, trained) {} +}; + +template<> +struct Quantizer8bitDirect_avx512<16> : public Quantizer8bitDirect_avx<8> { + Quantizer8bitDirect_avx512 (size_t d, const std::vector &trained): + Quantizer8bitDirect_avx<8> (d, trained) {} + + __m512 reconstruct_16_components (const uint8_t * code, int i) const { + __m128i x8 = _mm_load_si128((__m128i*)(code + i)); // 16 * int8 + __m512i y8 = _mm512_cvtepu8_epi32 (x8); // 16 * int32 + return _mm512_cvtepi32_ps (y8); // 16 * float32 + } +}; + + +template +Quantizer *select_quantizer_1_avx512 (QuantizerType qtype, size_t d, + const std::vector & trained) +{ + switch(qtype) { + case QuantizerType::QT_8bit: + return new QuantizerTemplate_avx512(d, trained); + case QuantizerType::QT_6bit: + return new QuantizerTemplate_avx512(d, trained); + case QuantizerType::QT_4bit: + return new QuantizerTemplate_avx512(d, trained); + case QuantizerType::QT_8bit_uniform: + return new QuantizerTemplate_avx512(d, trained); + case QuantizerType::QT_4bit_uniform: + return new QuantizerTemplate_avx512(d, trained); + case QuantizerType::QT_fp16: + return new QuantizerFP16_avx512(d, trained); + case QuantizerType::QT_8bit_direct: + return new Quantizer8bitDirect_avx512(d, trained); + } + FAISS_THROW_MSG ("unknown qtype"); +} + + +/******************************************************************* + * Similarity: gets vector components and computes a similarity wrt. a + * query vector stored in the object. The data fields just encapsulate + * an accumulator. + */ + +template +struct SimilarityL2_avx512 {}; + + +template<> +struct SimilarityL2_avx512<1> : public SimilarityL2_avx<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_L2; + + explicit SimilarityL2_avx512 (const float * y) : SimilarityL2_avx<1> (y) {} +}; + +template<> +struct SimilarityL2_avx512<8> : public SimilarityL2_avx<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_L2; + + explicit SimilarityL2_avx512 (const float * y) : SimilarityL2_avx<8> (y) {} +}; + +template<> +struct SimilarityL2_avx512<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2_avx512 (const float * y): y(y) {} + __m512 accu16; + + void begin_16 () { + accu16 = _mm512_setzero_ps(); + yi = y; + } + + void add_16_components (__m512 x) { + __m512 yiv = _mm512_loadu_ps (yi); + yi += 16; + __m512 tmp = yiv - x; + accu16 += tmp * tmp; + } + + void add_16_components_2 (__m512 x, __m512 y) { + __m512 tmp = y - x; + accu16 += tmp * tmp; + } + + float result_16 () { + __m256 sum0 = _mm512_extractf32x8_ps(accu16, 1) + _mm512_extractf32x8_ps(accu16, 0); + __m256 sum1 = _mm256_hadd_ps(sum0, sum0); + __m256 sum2 = _mm256_hadd_ps(sum1, sum1); + // now add the 0th and 4th component + return + _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) + + _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1)); + } +}; + + +template +struct SimilarityIP_avx512 {}; + + +template<> +struct SimilarityIP_avx512<1> : public SimilarityIP_avx<1> { + static constexpr int simdwidth = 1; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + explicit SimilarityIP_avx512 (const float * y) : SimilarityIP_avx<1> (y) {} +}; + +template<> +struct SimilarityIP_avx512<8> : public SimilarityIP_avx<8> { + static constexpr int simdwidth = 8; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + explicit SimilarityIP_avx512 (const float * y) : SimilarityIP_avx<8> (y) {} +}; + +template<> +struct SimilarityIP_avx512<16> { + static constexpr int simdwidth = 16; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + float accu; + + explicit SimilarityIP_avx512 (const float * y) : y (y) {} + + __m512 accu16; + + void begin_16 () { + accu16 = _mm512_setzero_ps(); + yi = y; + } + + void add_16_components (__m512 x) { + __m512 yiv = _mm512_loadu_ps (yi); + yi += 16; + accu16 += yiv * x; + } + + void add_16_components_2 (__m512 x1, __m512 x2) { + accu16 += x1 * x2; + } + + float result_16 () { + __m256 sum0 = _mm512_extractf32x8_ps(accu16, 1) + _mm512_extractf32x8_ps(accu16, 0); + __m256 sum1 = _mm256_hadd_ps(sum0, sum0); + __m256 sum2 = _mm256_hadd_ps(sum1, sum1); + // now add the 0th and 4th component + return + _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) + + _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1)); + } +}; + + +/******************************************************************* + * DistanceComputer: combines a similarity and a quantizer to do + * code-to-vector or code-to-code comparisons + *******************************************************************/ + +template +struct DCTemplate_avx512 : SQDistanceComputer {}; + +template +struct DCTemplate_avx512 : public DCTemplate_avx { + DCTemplate_avx512(size_t d, const std::vector &trained) : + DCTemplate_avx (d, trained) {} +}; + +template +struct DCTemplate_avx512 : public DCTemplate_avx { + DCTemplate_avx512(size_t d, const std::vector &trained) : + DCTemplate_avx (d, trained) {} +}; + +template +struct DCTemplate_avx512 : SQDistanceComputer { + using Sim = Similarity; + + Quantizer quant; + + DCTemplate_avx512(size_t d, const std::vector &trained): + quant(d, trained) + {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { + Similarity sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 x1 = quant.reconstruct_16_components(code1, i); + __m512 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query (const float *x) final { + q = x; + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_distance (q, code); + } +}; + + +/******************************************************************* + * DistanceComputerByte: computes distances in the integer domain + *******************************************************************/ + +template +struct DistanceComputerByte_avx512 : SQDistanceComputer {}; + +template +struct DistanceComputerByte_avx512 : public DistanceComputerByte_avx { + DistanceComputerByte_avx512(int d, const std::vector &unused) : + DistanceComputerByte_avx (d, unused) {} +}; + +template +struct DistanceComputerByte_avx512 : public DistanceComputerByte_avx { + DistanceComputerByte_avx512(int d, const std::vector &unused) : + DistanceComputerByte_avx (d, unused) {} +}; + +template +struct DistanceComputerByte_avx512 : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte_avx512(int d, const std::vector &): d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { + // __m256i accu = _mm256_setzero_ps (); + __m512i accu = _mm512_setzero_si512 (); + for (int i = 0; i < d; i += 32) { + // load 32 bytes, convert to 16 uint16_t + __m512i c1 = _mm512_cvtepu8_epi16 + (_mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16 + (_mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32 (accu, prod32); + } + __m128i sum = _mm512_extracti32x4_epi32(accu, 0); + sum = _mm_add_epi32 (sum, _mm512_extracti32x4_epi32(accu, 1)); + sum = _mm_add_epi32 (sum, _mm512_extracti32x4_epi32(accu, 2)); + sum = _mm_add_epi32 (sum, _mm512_extracti32x4_epi32(accu, 3)); + sum = _mm_hadd_epi32 (sum, sum); + sum = _mm_hadd_epi32 (sum, sum); + return _mm_cvtsi128_si32 (sum); + } + + void set_query (const float *x) final { + /* + for (int i = 0; i < d; i += 8) { + __m256 xi = _mm256_loadu_ps (x + i); + __m256i ci = _mm256_cvtps_epi32(xi); + */ + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + /// compute distance of vector i to current query + float operator () (idx_t i) final { + return compute_distance (q, codes + i * code_size); + } + + float symmetric_dis (idx_t i, idx_t j) override { + return compute_code_distance (codes + i * code_size, + codes + j * code_size); + } + + float query_to_code (const uint8_t * code) const { + return compute_code_distance (tmp.data(), code); + } +}; + + +/******************************************************************* + * select_distance_computer: runtime selection of template + * specialization + *******************************************************************/ + +template +SQDistanceComputer *select_distance_computer_avx512 ( + QuantizerType qtype, + size_t d, const std::vector & trained) +{ + constexpr int SIMDWIDTH = Sim::simdwidth; + switch(qtype) { + case QuantizerType::QT_8bit_uniform: + return new DCTemplate_avx512, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit_uniform: + return new DCTemplate_avx512, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit: + return new DCTemplate_avx512, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_6bit: + return new DCTemplate_avx512, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_4bit: + return new DCTemplate_avx512, + Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_fp16: + return new DCTemplate_avx512 + , Sim, SIMDWIDTH>(d, trained); + + case QuantizerType::QT_8bit_direct: + if (d % 16 == 0) { + return new DistanceComputerByte_avx512(d, trained); + } else { + return new DCTemplate_avx512 + , Sim, SIMDWIDTH>(d, trained); + } + } + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + + +template +InvertedListScanner* sel2_InvertedListScanner_avx512 ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + return sel2_InvertedListScanner (sq, quantizer, store_pairs, r); +} + +template +InvertedListScanner* sel12_InvertedListScanner_avx512 ( + const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + using QuantizerClass = QuantizerTemplate_avx512; + using DCClass = DCTemplate_avx512; + return sel2_InvertedListScanner_avx512 (sq, quantizer, store_pairs, r); +} + + +template +InvertedListScanner* sel1_InvertedListScanner_avx512 ( + const ScalarQuantizer *sq, const Index *quantizer, + bool store_pairs, bool r) +{ + constexpr int SIMDWIDTH = Similarity::simdwidth; + switch(sq->qtype) { + case QuantizerType::QT_8bit_uniform: + return sel12_InvertedListScanner_avx512 + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit_uniform: + return sel12_InvertedListScanner_avx512 + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit: + return sel12_InvertedListScanner_avx512 + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_4bit: + return sel12_InvertedListScanner_avx512 + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_6bit: + return sel12_InvertedListScanner_avx512 + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_fp16: + return sel2_InvertedListScanner_avx512 + , Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + case QuantizerType::QT_8bit_direct: + if (sq->d % 16 == 0) { + return sel2_InvertedListScanner_avx512 + > + (sq, quantizer, store_pairs, r); + } else { + return sel2_InvertedListScanner_avx512 + , + Similarity, SIMDWIDTH> > + (sq, quantizer, store_pairs, r); + } + } + + FAISS_THROW_MSG ("unknown qtype"); + return nullptr; +} + +template +InvertedListScanner* sel0_InvertedListScanner_avx512 ( + MetricType mt, const ScalarQuantizer *sq, + const Index *quantizer, bool store_pairs, bool by_residual) +{ + if (mt == METRIC_L2) { + return sel1_InvertedListScanner_avx512 > + (sq, quantizer, store_pairs, by_residual); + } else if (mt == METRIC_INNER_PRODUCT) { + return sel1_InvertedListScanner_avx512 > + (sq, quantizer, store_pairs, by_residual); + } else { + FAISS_THROW_MSG("unsupported metric type"); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.cpp new file mode 100644 index 0000000000..71fc4807b9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.cpp @@ -0,0 +1,39 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +namespace faiss { + +/******************************************************************* + * ScalarQuantizer Distance Computer + ********************************************************************/ + +/* SSE */ +SQDistanceComputer * +sq_get_distance_computer_ref (MetricType metric, QuantizerType qtype, size_t dim, const std::vector& trained) { + if (metric == METRIC_L2) { + return select_distance_computer>(qtype, dim, trained); + } else { + return select_distance_computer>(qtype, dim, trained); + } +} + +Quantizer * +sq_select_quantizer_ref (QuantizerType qtype, size_t dim, const std::vector& trained) { + return select_quantizer_1<1> (qtype, dim, trained); +} + +InvertedListScanner* +sq_select_inverted_list_scanner_ref (MetricType mt, const ScalarQuantizer *sq, const Index *quantizer, size_t dim, bool store_pairs, bool by_residual) { + return sel0_InvertedListScanner<1> (mt, sq, quantizer, store_pairs, by_residual); +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.h new file mode 100644 index 0000000000..d088d54bc9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +namespace faiss { + +SQDistanceComputer * +sq_get_distance_computer_ref( + MetricType metric, + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +Quantizer * +sq_select_quantizer_ref( + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +InvertedListScanner* +sq_select_inverted_list_scanner_ref( + MetricType mt, + const ScalarQuantizer *sq, + const Index *quantizer, + size_t dim, + bool store_pairs, + bool by_residual); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.cpp new file mode 100644 index 0000000000..2da2af6f60 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.cpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +namespace faiss { + +/******************************************************************* + * ScalarQuantizer Distance Computer + ********************************************************************/ + +SQDistanceComputer * +sq_get_distance_computer_avx (MetricType metric, QuantizerType qtype, size_t dim, const std::vector& trained) { + if (metric == METRIC_L2) { + if (dim % 8 == 0) { + return select_distance_computer_avx>(qtype, dim, trained); + } else { + return select_distance_computer_avx>(qtype, dim, trained); + } + } else { + if (dim % 8 == 0) { + return select_distance_computer_avx>(qtype, dim, trained); + } else { + return select_distance_computer_avx>(qtype, dim, trained); + } + } +} + +Quantizer * +sq_select_quantizer_avx (QuantizerType qtype, size_t dim, const std::vector& trained) { + if (dim % 8 == 0) { + return select_quantizer_1_avx<8>(qtype, dim, trained); + } else { + return select_quantizer_1_avx<1> (qtype, dim, trained); + } +} + +InvertedListScanner* +sq_select_inverted_list_scanner_avx (MetricType mt, const ScalarQuantizer *sq, const Index *quantizer, size_t dim, bool store_pairs, bool by_residual) { + if (dim % 8 == 0) { + return sel0_InvertedListScanner_avx<8> (mt, sq, quantizer, store_pairs, by_residual); + } else { + return sel0_InvertedListScanner_avx<1> (mt, sq, quantizer, store_pairs, by_residual); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.h new file mode 100644 index 0000000000..3b04aa4d2e --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +namespace faiss { + +SQDistanceComputer * +sq_get_distance_computer_avx( + MetricType metric, + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +Quantizer * +sq_select_quantizer_avx( + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +InvertedListScanner* +sq_select_inverted_list_scanner_avx( + MetricType mt, + const ScalarQuantizer *sq, + const Index *quantizer, + size_t dim, + bool store_pairs, + bool by_residual); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.cpp new file mode 100644 index 0000000000..6a62847c1d --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.cpp @@ -0,0 +1,62 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +namespace faiss { + +/******************************************************************* + * ScalarQuantizer Distance Computer + ********************************************************************/ + +SQDistanceComputer * +sq_get_distance_computer_avx512 (MetricType metric, QuantizerType qtype, size_t dim, const std::vector& trained) { + if (metric == METRIC_L2) { + if (dim % 16 == 0) { + return select_distance_computer_avx512>(qtype, dim, trained); + } else if (dim % 8 == 0) { + return select_distance_computer_avx512>(qtype, dim, trained); + } else { + return select_distance_computer_avx512>(qtype, dim, trained); + } + } else { + if (dim % 16 == 0) { + return select_distance_computer_avx512>(qtype, dim, trained); + } else if (dim % 8 == 0) { + return select_distance_computer_avx512>(qtype, dim, trained); + } else { + return select_distance_computer_avx512>(qtype, dim, trained); + } + } +} + +Quantizer * +sq_select_quantizer_avx512 (QuantizerType qtype, size_t dim, const std::vector& trained) { + if (dim % 16 == 0) { + return select_quantizer_1_avx512<16> (qtype, dim, trained); + } else if (dim % 8 == 0) { + return select_quantizer_1_avx512<8> (qtype, dim, trained); + } else { + return select_quantizer_1_avx512<1> (qtype, dim, trained); + } +} + +InvertedListScanner* +sq_select_inverted_list_scanner_avx512 (MetricType mt, const ScalarQuantizer *sq, const Index *quantizer, size_t dim, bool store_pairs, bool by_residual) { + if (dim % 16 == 0) { + return sel0_InvertedListScanner_avx512<16> (mt, sq, quantizer, store_pairs, by_residual); + } else if (dim % 8 == 0) { + return sel0_InvertedListScanner_avx512<8> (mt, sq, quantizer, store_pairs, by_residual); + } else { + return sel0_InvertedListScanner_avx512<1> (mt, sq, quantizer, store_pairs, by_residual); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.h new file mode 100644 index 0000000000..f1b03027a9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerDC_avx512.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include + +namespace faiss { + +SQDistanceComputer * +sq_get_distance_computer_avx512( + MetricType metric, + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +Quantizer * +sq_select_quantizer_avx512( + QuantizerType qtype, + size_t dim, + const std::vector& trained); + +InvertedListScanner* +sq_select_inverted_list_scanner_avx512( + MetricType mt, + const ScalarQuantizer *sq, + const Index *quantizer, + size_t dim, + bool store_pairs, + bool by_residual); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.cpp b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.cpp new file mode 100644 index 0000000000..0b0f9aa92f --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.cpp @@ -0,0 +1,295 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include + +#ifdef __SSE__ +#include +#endif + +#include +#include +#include + +namespace faiss { + +#ifdef __AVX__ +#define USE_AVX +#endif + + +#ifdef USE_AVX + +uint16_t encode_fp16 (float x) { + __m128 xf = _mm_set1_ps (x); + __m128i xi = _mm_cvtps_ph ( + xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC); + return _mm_cvtsi128_si32 (xi) & 0xffff; +} + + +float decode_fp16 (uint16_t x) { + __m128i xi = _mm_set1_epi16 (x); + __m128 xf = _mm_cvtph_ps (xi); + return _mm_cvtss_f32 (xf); +} + +#else + +// non-intrinsic FP16 <-> FP32 code adapted from +// https://github.com/ispc/ispc/blob/master/stdlib.ispc + +float floatbits (uint32_t x) { + void *xptr = &x; + return *(float*)xptr; +} + +uint32_t intbits (float f) { + void *fptr = &f; + return *(uint32_t*)fptr; +} + + +uint16_t encode_fp16 (float f) { + // via Fabian "ryg" Giesen. + // https://gist.github.com/2156668 + uint32_t sign_mask = 0x80000000u; + int32_t o; + + uint32_t fint = intbits(f); + uint32_t sign = fint & sign_mask; + fint ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code (since + // there's no unsigned PCMPGTD). + + // Inf or NaN (all exponent bits set) + // NaN->qNaN and Inf->Inf + // unconditional assignment here, will override with right value for + // the regular case below. + uint32_t f32infty = 255u << 23; + o = (fint > f32infty) ? 0x7e00u : 0x7c00u; + + // (De)normalized number or zero + // update fint unconditionally to save the blending; we don't need it + // anymore for the Inf/NaN case anyway. + + const uint32_t round_mask = ~0xfffu; + const uint32_t magic = 15u << 23; + + // Shift exponent down, denormalize if necessary. + // NOTE This represents half-float denormals using single + // precision denormals. The main reason to do this is that + // there's no shift with per-lane variable shifts in SSE*, which + // we'd otherwise need. It has some funky side effects though: + // - This conversion will actually respect the FTZ (Flush To Zero) + // flag in MXCSR - if it's set, no half-float denormals will be + // generated. I'm honestly not sure whether this is good or + // bad. It's definitely interesting. + // - If the underlying HW doesn't support denormals (not an issue + // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs), + // you will always get flush-to-zero behavior. This is bad, + // unless you're on a CPU where you don't care. + // - Denormals tend to be slow. FP32 denormals are rare in + // practice outside of things like recursive filters in DSP - + // not a typical half-float application. Whether FP16 denormals + // are rare in practice, I don't know. Whatever slow path your + // HW may or may not have for denormals, this may well hit it. + float fscale = floatbits(fint & round_mask) * floatbits(magic); + fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u)); + int32_t fint2 = intbits(fscale) - round_mask; + + if (fint < f32infty) + o = fint2 >> 13; // Take the bits! + + return (o | (sign >> 16)); +} + +float decode_fp16 (uint16_t h) { + // https://gist.github.com/2144712 + // Fabian "ryg" Giesen. + + const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift + + int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits + int32_t exp = shifted_exp & o; // just the exponent + o += (int32_t)(127 - 15) << 23; // exponent adjust + + int32_t infnan_val = o + ((int32_t)(128 - 16) << 23); + int32_t zerodenorm_val = intbits( + floatbits(o + (1u<<23)) - floatbits(113u << 23)); + int32_t reg_val = (exp == 0) ? zerodenorm_val : o; + + int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16; + return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit); +} + +#endif + + +/******************************************************************* + * Quantizer range training + */ + +static float sqr (float x) { + return x * x; +} + + +void train_Uniform(RangeStat rs, float rs_arg, + idx_t n, int k, const float *x, + std::vector & trained) +{ + trained.resize (2); + float & vmin = trained[0]; + float & vmax = trained[1]; + + if (rs == RangeStat::RS_minmax) { + vmin = HUGE_VAL; vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) vmin = x[i]; + if (x[i] > vmax) vmax = x[i]; + } + float vexp = (vmax - vmin) * rs_arg; + vmin -= vexp; + vmax += vexp; + } else if (rs == RangeStat::RS_meanstd) { + double sum = 0, sum2 = 0; + for (size_t i = 0; i < n; i++) { + sum += x[i]; + sum2 += x[i] * x[i]; + } + float mean = sum / n; + float var = sum2 / n - mean * mean; + float std = var <= 0 ? 1.0 : sqrt(var); + + vmin = mean - std * rs_arg ; + vmax = mean + std * rs_arg ; + } else if (rs == RangeStat::RS_quantiles) { + std::vector x_copy(n); + memcpy(x_copy.data(), x, n * sizeof(*x)); + // TODO just do a qucikselect + std::sort(x_copy.begin(), x_copy.end()); + int o = int(rs_arg * n); + if (o < 0) o = 0; + if (o > n - o) o = n / 2; + vmin = x_copy[o]; + vmax = x_copy[n - 1 - o]; + + } else if (rs == RangeStat::RS_optim) { + float a, b; + float sx = 0; + { + vmin = HUGE_VAL, vmax = -HUGE_VAL; + for (size_t i = 0; i < n; i++) { + if (x[i] < vmin) vmin = x[i]; + if (x[i] > vmax) vmax = x[i]; + sx += x[i]; + } + b = vmin; + a = (vmax - vmin) / (k - 1); + } + int verbose = false; + int niter = 2000; + float last_err = -1; + int iter_last_err = 0; + for (int it = 0; it < niter; it++) { + float sn = 0, sn2 = 0, sxn = 0, err1 = 0; + + for (idx_t i = 0; i < n; i++) { + float xi = x[i]; + float ni = floor ((xi - b) / a + 0.5); + if (ni < 0) ni = 0; + if (ni >= k) ni = k - 1; + err1 += sqr (xi - (ni * a + b)); + sn += ni; + sn2 += ni * ni; + sxn += ni * xi; + } + + if (err1 == last_err) { + iter_last_err ++; + if (iter_last_err == 16) break; + } else { + last_err = err1; + iter_last_err = 0; + } + + float det = sqr (sn) - sn2 * n; + + b = (sn * sxn - sn2 * sx) / det; + a = (sn * sx - n * sxn) / det; + if (verbose) { + printf ("it %d, err1=%g \r", it, err1); + fflush(stdout); + } + } + if (verbose) printf("\n"); + + vmin = b; + vmax = b + a * (k - 1); + + } else { + FAISS_THROW_MSG ("Invalid qtype"); + } + vmax -= vmin; +} + +void train_NonUniform(RangeStat rs, float rs_arg, + idx_t n, int d, int k, const float *x, + std::vector & trained) +{ + trained.resize (2 * d); + float * vmin = trained.data(); + float * vmax = trained.data() + d; + if (rs == RangeStat::RS_minmax) { + memcpy (vmin, x, sizeof(*x) * d); + memcpy (vmax, x, sizeof(*x) * d); + for (size_t i = 1; i < n; i++) { + const float *xi = x + i * d; + for (size_t j = 0; j < d; j++) { + if (xi[j] < vmin[j]) vmin[j] = xi[j]; + if (xi[j] > vmax[j]) vmax[j] = xi[j]; + } + } + float *vdiff = vmax; + for (size_t j = 0; j < d; j++) { + float vexp = (vmax[j] - vmin[j]) * rs_arg; + vmin[j] -= vexp; + vmax[j] += vexp; + vdiff [j] = vmax[j] - vmin[j]; + } + } else { + // transpose + std::vector xt(n * d); + for (size_t i = 1; i < n; i++) { + const float *xi = x + i * d; + for (size_t j = 0; j < d; j++) { + xt[j * n + i] = xi[j]; + } + } + std::vector trained_d(2); +#pragma omp parallel for + for (size_t j = 0; j < d; j++) { + train_Uniform(rs, rs_arg, + n, k, xt.data() + j * n, + trained_d); + vmin[j] = trained_d[0]; + vmax[j] = trained_d[1]; + } + } +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.h b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.h new file mode 100644 index 0000000000..c272d29bc4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ScalarQuantizerOp.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include + +#include + +#include +#include +#include + +namespace faiss { + +typedef Index::idx_t idx_t; + +enum class QuantizerType { + QT_8bit = 0, ///< 8 bits per component + QT_4bit, ///< 4 bits per component + QT_8bit_uniform, ///< same, shared range for all dimensions + QT_4bit_uniform, + QT_fp16, + QT_8bit_direct, /// fast indexing of uint8s + QT_6bit, ///< 6 bits per component +}; + +// rangestat_arg. +enum class RangeStat { + RS_minmax = 0, ///< [min - rs*(max-min), max + rs*(max-min)] + RS_meanstd, ///< [mean - std * rs, mean + std * rs] + RS_quantiles, ///< [Q(rs), Q(1-rs)] + RS_optim, ///< alternate optimization of reconstruction error +}; + +struct Quantizer { + // encodes one vector. Assumes code is filled with 0s on input! + virtual void encode_vector(const float *x, uint8_t *code) const = 0; + virtual void decode_vector(const uint8_t *code, float *x) const = 0; + + virtual ~Quantizer() {} +}; + +struct SQDistanceComputer: DistanceComputer { + const float *q; + const uint8_t *codes; + size_t code_size; + + SQDistanceComputer (): q(nullptr), codes (nullptr), code_size (0) + {} +}; + +extern uint16_t encode_fp16 (float x); +extern float decode_fp16 (uint16_t x); + +extern void train_Uniform(RangeStat rs, float rs_arg, + idx_t n, int k, const float *x, + std::vector & trained); +extern void train_NonUniform(RangeStat rs, float rs_arg, + idx_t n, int d, int k, const float *x, + std::vector & trained); + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/ThreadedIndex-inl.h b/core/src/index/thirdparty/faiss/impl/ThreadedIndex-inl.h new file mode 100644 index 0000000000..de549a0288 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ThreadedIndex-inl.h @@ -0,0 +1,192 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace faiss { + +template +ThreadedIndex::ThreadedIndex(bool threaded) + // 0 is default dimension + : ThreadedIndex(0, threaded) { +} + +template +ThreadedIndex::ThreadedIndex(int d, bool threaded) + : IndexT(d), + own_fields(false), + isThreaded_(threaded) { + } + +template +ThreadedIndex::~ThreadedIndex() { + for (auto& p : indices_) { + if (isThreaded_) { + // should have worker thread + FAISS_ASSERT((bool) p.second); + + // This will also flush all pending work + p.second->stop(); + p.second->waitForThreadExit(); + } else { + // should not have worker thread + FAISS_ASSERT(!(bool) p.second); + } + + if (own_fields) { + delete p.first; + } + } +} + +template +void ThreadedIndex::addIndex(IndexT* index) { + // We inherit the dimension from the first index added to us if we don't have + // a set dimension + if (indices_.empty() && this->d == 0) { + this->d = index->d; + } + + // The new index must match our set dimension + FAISS_THROW_IF_NOT_FMT(this->d == index->d, + "addIndex: dimension mismatch for " + "newly added index; expecting dim %d, " + "new index has dim %d", + this->d, index->d); + + if (!indices_.empty()) { + auto& existing = indices_.front().first; + + FAISS_THROW_IF_NOT_MSG(index->metric_type == existing->metric_type, + "addIndex: newly added index is " + "of different metric type than old index"); + + // Make sure this index is not duplicated + for (auto& p : indices_) { + FAISS_THROW_IF_NOT_MSG(p.first != index, + "addIndex: attempting to add index " + "that is already in the collection"); + } + } + + indices_.emplace_back( + std::make_pair( + index, + std::unique_ptr(isThreaded_ ? + new WorkerThread : nullptr))); + + onAfterAddIndex(index); +} + +template +void ThreadedIndex::removeIndex(IndexT* index) { + for (auto it = indices_.begin(); it != indices_.end(); ++it) { + if (it->first == index) { + // This is our index; stop the worker thread before removing it, + // to ensure that it has finished before function exit + if (isThreaded_) { + // should have worker thread + FAISS_ASSERT((bool) it->second); + it->second->stop(); + it->second->waitForThreadExit(); + } else { + // should not have worker thread + FAISS_ASSERT(!(bool) it->second); + } + + indices_.erase(it); + onAfterRemoveIndex(index); + + if (own_fields) { + delete index; + } + + return; + } + } + + // could not find our index + FAISS_THROW_MSG("IndexReplicas::removeIndex: index not found"); +} + +template +void ThreadedIndex::runOnIndex(std::function f) { + if (isThreaded_) { + std::vector> v; + + for (int i = 0; i < this->indices_.size(); ++i) { + auto& p = this->indices_[i]; + auto indexPtr = p.first; + v.emplace_back(p.second->add([f, i, indexPtr](){ f(i, indexPtr); })); + } + + waitAndHandleFutures(v); + } else { + // Multiple exceptions may be thrown; gather them as we encounter them, + // while letting everything else run to completion + std::vector> exceptions; + + for (int i = 0; i < this->indices_.size(); ++i) { + auto& p = this->indices_[i]; + try { + f(i, p.first); + } catch (...) { + exceptions.emplace_back(std::make_pair(i, std::current_exception())); + } + } + + handleExceptions(exceptions); + } +} + +template +void ThreadedIndex::runOnIndex( + std::function f) const { + const_cast*>(this)->runOnIndex( + [f](int i, IndexT* idx){ f(i, idx); }); +} + +template +void ThreadedIndex::reset() { + runOnIndex([](int, IndexT* index){ index->reset(); }); + this->ntotal = 0; + this->is_trained = false; +} + +template +void +ThreadedIndex::onAfterAddIndex(IndexT* index) { +} + +template +void +ThreadedIndex::onAfterRemoveIndex(IndexT* index) { +} + +template +void +ThreadedIndex::waitAndHandleFutures(std::vector>& v) { + // Blocking wait for completion for all of the indices, capturing any + // exceptions that are generated + std::vector> exceptions; + + for (int i = 0; i < v.size(); ++i) { + auto& fut = v[i]; + + try { + fut.get(); + } catch (...) { + exceptions.emplace_back(std::make_pair(i, std::current_exception())); + } + } + + handleExceptions(exceptions); +} + +} // namespace diff --git a/core/src/index/thirdparty/faiss/impl/ThreadedIndex.h b/core/src/index/thirdparty/faiss/impl/ThreadedIndex.h new file mode 100644 index 0000000000..89f21486a6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/ThreadedIndex.h @@ -0,0 +1,80 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { + +/// A holder of indices in a collection of threads +/// The interface to this class itself is not thread safe +template +class ThreadedIndex : public IndexT { + public: + explicit ThreadedIndex(bool threaded); + explicit ThreadedIndex(int d, bool threaded); + + ~ThreadedIndex() override; + + /// override an index that is managed by ourselves. + /// WARNING: once an index is added, it becomes unsafe to touch it from any + /// other thread than that on which is managing it, until we are shut + /// down. Use runOnIndex to perform work on it instead. + void addIndex(IndexT* index); + + /// Remove an index that is managed by ourselves. + /// This will flush all pending work on that index, and then shut + /// down its managing thread, and will remove the index. + void removeIndex(IndexT* index); + + /// Run a function on all indices, in the thread that the index is + /// managed in. + /// Function arguments are (index in collection, index pointer) + void runOnIndex(std::function f); + void runOnIndex(std::function f) const; + + /// faiss::Index API + /// All indices receive the same call + void reset() override; + + /// Returns the number of sub-indices + int count() const { return indices_.size(); } + + /// Returns the i-th sub-index + IndexT* at(int i) { return indices_[i].first; } + + /// Returns the i-th sub-index (const version) + const IndexT* at(int i) const { return indices_[i].first; } + + /// Whether or not we are responsible for deleting our contained indices + bool own_fields; + + protected: + /// Called just after an index is added + virtual void onAfterAddIndex(IndexT* index); + + /// Called just after an index is removed + virtual void onAfterRemoveIndex(IndexT* index); + +protected: + static void waitAndHandleFutures(std::vector>& v); + + /// Collection of Index instances, with their managing worker thread if any + std::vector>> indices_; + + /// Is this index multi-threaded? + bool isThreaded_; +}; + +} // namespace + +#include diff --git a/core/src/index/thirdparty/faiss/impl/index_read.cpp b/core/src/index/thirdparty/faiss/impl/index_read.cpp new file mode 100644 index 0000000000..9a3936a715 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/index_read.cpp @@ -0,0 +1,1087 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + + +namespace faiss { + +/************************************************************* + * I/O macros + * + * we use macros so that we have a line number to report in abort + * (). This makes debugging a lot easier. The IOReader or IOWriter is + * always called f and thus is not passed in as a macro parameter. + **************************************************************/ + + +#define READANDCHECK(ptr, n) { \ + size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \ + FAISS_THROW_IF_NOT_FMT(ret == (n), \ + "read error in %s: %ld != %ld (%s)", \ + f->name.c_str(), ret, size_t(n), strerror(errno)); \ + } + +#define READ1(x) READANDCHECK(&(x), 1) + +// will fail if we write 256G of data at once... +#define READVECTOR(vec) { \ + long size; \ + READANDCHECK (&size, 1); \ + FAISS_THROW_IF_NOT (size >= 0 && size < (1L << 40)); \ + (vec).resize (size); \ + READANDCHECK ((vec).data (), size); \ + } + + + +/************************************************************* + * Read + **************************************************************/ + +static void read_index_header (Index *idx, IOReader *f) { + READ1 (idx->d); + READ1 (idx->ntotal); + Index::idx_t dummy; + READ1 (dummy); + READ1 (dummy); + READ1 (idx->is_trained); + READ1 (idx->metric_type); + if (idx->metric_type > 1) { + READ1 (idx->metric_arg); + } + idx->verbose = false; +} + +VectorTransform* read_VectorTransform (IOReader *f) { + uint32_t h; + READ1 (h); + VectorTransform *vt = nullptr; + + if (h == fourcc ("rrot") || h == fourcc ("PCAm") || + h == fourcc ("LTra") || h == fourcc ("PcAm") || + h == fourcc ("Viqm")) { + LinearTransform *lt = nullptr; + if (h == fourcc ("rrot")) { + lt = new RandomRotationMatrix (); + } else if (h == fourcc ("PCAm") || + h == fourcc ("PcAm")) { + PCAMatrix * pca = new PCAMatrix (); + READ1 (pca->eigen_power); + READ1 (pca->random_rotation); + if (h == fourcc ("PcAm")) + READ1 (pca->balanced_bins); + READVECTOR (pca->mean); + READVECTOR (pca->eigenvalues); + READVECTOR (pca->PCAMat); + lt = pca; + } else if (h == fourcc ("Viqm")) { + ITQMatrix *itqm = new ITQMatrix (); + READ1 (itqm->max_iter); + READ1 (itqm->seed); + lt = itqm; + } else if (h == fourcc ("LTra")) { + lt = new LinearTransform (); + } + READ1 (lt->have_bias); + READVECTOR (lt->A); + READVECTOR (lt->b); + FAISS_THROW_IF_NOT (lt->A.size() >= lt->d_in * lt->d_out); + FAISS_THROW_IF_NOT (!lt->have_bias || lt->b.size() >= lt->d_out); + lt->set_is_orthonormal(); + vt = lt; + } else if (h == fourcc ("RmDT")) { + RemapDimensionsTransform *rdt = new RemapDimensionsTransform (); + READVECTOR (rdt->map); + vt = rdt; + } else if (h == fourcc ("VNrm")) { + NormalizationTransform *nt = new NormalizationTransform (); + READ1 (nt->norm); + vt = nt; + } else if (h == fourcc ("VCnt")) { + CenteringTransform *ct = new CenteringTransform (); + READVECTOR (ct->mean); + vt = ct; + } else if (h == fourcc ("Viqt")) { + ITQTransform *itqt = new ITQTransform (); + + READVECTOR (itqt->mean); + READ1 (itqt->do_pca); + { + ITQMatrix *itqm = dynamic_cast + (read_VectorTransform (f)); + FAISS_THROW_IF_NOT(itqm); + itqt->itq = *itqm; + delete itqm; + } + { + LinearTransform *pi = dynamic_cast + (read_VectorTransform (f)); + FAISS_THROW_IF_NOT (pi); + itqt->pca_then_itq = *pi; + delete pi; + } + vt = itqt; + } else { + FAISS_THROW_MSG("fourcc not recognized"); + } + READ1 (vt->d_in); + READ1 (vt->d_out); + READ1 (vt->is_trained); + return vt; +} + + +static void read_ArrayInvertedLists_sizes ( + IOReader *f, std::vector & sizes) +{ + uint32_t list_type; + READ1(list_type); + if (list_type == fourcc("full")) { + size_t os = sizes.size(); + READVECTOR (sizes); + FAISS_THROW_IF_NOT (os == sizes.size()); + } else if (list_type == fourcc("sprs")) { + std::vector idsizes; + READVECTOR (idsizes); + for (size_t j = 0; j < idsizes.size(); j += 2) { + FAISS_THROW_IF_NOT (idsizes[j] < sizes.size()); + sizes[idsizes[j]] = idsizes[j + 1]; + } + } else { + FAISS_THROW_MSG ("invalid list_type"); + } +} + +InvertedLists *read_InvertedLists (IOReader *f, int io_flags) { + uint32_t h; + READ1 (h); + if (h == fourcc ("il00")) { + fprintf(stderr, "read_InvertedLists:" + " WARN! inverted lists not stored with IVF object\n"); + return nullptr; + } else if (h == fourcc ("iloa") && !(io_flags & IO_FLAG_MMAP)) { + size_t nlist; + size_t code_size; + std::vector list_length; + READ1(nlist); + READ1(code_size); + READVECTOR(list_length); + auto ails = new ReadOnlyArrayInvertedLists(nlist, code_size, list_length); + size_t n; + READ1(n); +#ifdef USE_CPU + ails->readonly_ids.resize(n); + ails->readonly_codes.resize(n*code_size); + READANDCHECK(ails->readonly_ids.data(), n); + READANDCHECK(ails->readonly_codes.data(), n * code_size); +#else + ails->pin_readonly_ids = std::make_shared(n * sizeof(InvertedLists::idx_t)); + ails->pin_readonly_codes = std::make_shared(n * code_size * sizeof(uint8_t)); + READANDCHECK((InvertedLists::idx_t *) ails->pin_readonly_ids->data, n); + READANDCHECK((uint8_t *) ails->pin_readonly_codes->data, n * code_size); +#endif + return ails; + } else if (h == fourcc ("ilar") && !(io_flags & IO_FLAG_MMAP)) { + auto ails = new ArrayInvertedLists (0, 0); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->ids.resize (ails->nlist); + ails->codes.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + for (size_t i = 0; i < ails->nlist; i++) { + ails->ids[i].resize (sizes[i]); + ails->codes[i].resize (sizes[i] * ails->code_size); + } + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + READANDCHECK (ails->codes[i].data(), n * ails->code_size); + READANDCHECK (ails->ids[i].data(), n); + } + } + return ails; + } else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) { + // then we load it as an OnDiskInvertedLists + + FileIOReader *reader = dynamic_cast(f); + FAISS_THROW_IF_NOT_MSG(reader, "mmap only supported for File objects"); + FILE *fdesc = reader->f; + + auto ails = new OnDiskInvertedLists (); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->read_only = true; + ails->lists.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + size_t o0 = ftell(fdesc), o = o0; + { // do the mmap + struct stat buf; + int ret = fstat (fileno(fdesc), &buf); + FAISS_THROW_IF_NOT_FMT (ret == 0, + "fstat failed: %s", strerror(errno)); + ails->totsize = buf.st_size; + ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize, + PROT_READ, MAP_SHARED, + fileno(fdesc), 0); + FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED, + "could not mmap: %s", + strerror(errno)); + } + + for (size_t i = 0; i < ails->nlist; i++) { + OnDiskInvertedLists::List & l = ails->lists[i]; + l.size = l.capacity = sizes[i]; + l.offset = o; + o += l.size * (sizeof(OnDiskInvertedLists::idx_t) + + ails->code_size); + } + FAISS_THROW_IF_NOT(o <= ails->totsize); + // resume normal reading of file + fseek (fdesc, o, SEEK_SET); + return ails; + } else if (h == fourcc ("ilod")) { + OnDiskInvertedLists *od = new OnDiskInvertedLists(); + od->read_only = io_flags & IO_FLAG_READ_ONLY; + READ1 (od->nlist); + READ1 (od->code_size); + // this is a POD object + READVECTOR (od->lists); + { + std::vector v; + READVECTOR(v); + od->slots.assign(v.begin(), v.end()); + } + { + std::vector x; + READVECTOR(x); + od->filename.assign(x.begin(), x.end()); + + if (io_flags & IO_FLAG_ONDISK_SAME_DIR) { + FileIOReader *reader = dynamic_cast(f); + FAISS_THROW_IF_NOT_MSG ( + reader, "IO_FLAG_ONDISK_SAME_DIR only supported " + "when reading from file"); + std::string indexname = reader->name; + std::string dirname = "./"; + size_t slash = indexname.find_last_of('/'); + if (slash != std::string::npos) { + dirname = indexname.substr(0, slash + 1); + } + std::string filename = od->filename; + slash = filename.find_last_of('/'); + if (slash != std::string::npos) { + filename = filename.substr(slash + 1); + } + filename = dirname + filename; + printf("IO_FLAG_ONDISK_SAME_DIR: " + "updating ondisk filename from %s to %s\n", + od->filename.c_str(), filename.c_str()); + od->filename = filename; + } + + } + READ1(od->totsize); + od->do_mmap(); + return od; + } else { + FAISS_THROW_MSG ("read_InvertedLists: unsupported invlist type"); + } +} + +static void read_InvertedLists ( + IndexIVF *ivf, IOReader *f, int io_flags) { + InvertedLists *ils = read_InvertedLists (f, io_flags); + FAISS_THROW_IF_NOT (!ils || (ils->nlist == ivf->nlist && + ils->code_size == ivf->code_size)); + ivf->invlists = ils; + ivf->own_invlists = true; +} + +InvertedLists *read_InvertedLists_nm (IOReader *f, int io_flags) { + uint32_t h; + READ1 (h); + if (h == fourcc ("il00")) { + fprintf(stderr, "read_InvertedLists:" + " WARN! inverted lists not stored with IVF object\n"); + return nullptr; + } else if (h == fourcc ("iloa") && !(io_flags & IO_FLAG_MMAP)) { + // not going to happen + return nullptr; + } else if (h == fourcc ("ilar") && !(io_flags & IO_FLAG_MMAP)) { + auto ails = new ArrayInvertedLists (0, 0); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->ids.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + for (size_t i = 0; i < ails->nlist; i++) { + ails->ids[i].resize (sizes[i]); + } + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + READANDCHECK (ails->ids[i].data(), n); + } + } + return ails; + } else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) { + // then we load it as an OnDiskInvertedLists + FileIOReader *reader = dynamic_cast(f); + FAISS_THROW_IF_NOT_MSG(reader, "mmap only supported for File objects"); + FILE *fdesc = reader->f; + + auto ails = new OnDiskInvertedLists (); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->read_only = true; + ails->lists.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + size_t o0 = ftell(fdesc), o = o0; + { // do the mmap + struct stat buf; + int ret = fstat (fileno(fdesc), &buf); + FAISS_THROW_IF_NOT_FMT (ret == 0, + "fstat failed: %s", strerror(errno)); + ails->totsize = buf.st_size; + ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize, + PROT_READ, MAP_SHARED, + fileno(fdesc), 0); + FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED, + "could not mmap: %s", + strerror(errno)); + } + + for (size_t i = 0; i < ails->nlist; i++) { + OnDiskInvertedLists::List & l = ails->lists[i]; + l.size = l.capacity = sizes[i]; + l.offset = o; + o += l.size * (sizeof(OnDiskInvertedLists::idx_t) + + ails->code_size); + } + FAISS_THROW_IF_NOT(o <= ails->totsize); + // resume normal reading of file + fseek (fdesc, o, SEEK_SET); + return ails; + } else if (h == fourcc ("ilod")) { + // not going to happen + return nullptr; + } else { + FAISS_THROW_MSG ("read_InvertedLists: unsupported invlist type"); + } +} + +static void read_InvertedLists_nm ( + IndexIVF *ivf, IOReader *f, int io_flags) { + InvertedLists *ils = read_InvertedLists_nm (f, io_flags); + FAISS_THROW_IF_NOT (!ils || (ils->nlist == ivf->nlist && + ils->code_size == ivf->code_size)); + ivf->invlists = ils; + ivf->own_invlists = true; +} + +static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) { + READ1 (pq->d); + READ1 (pq->M); + READ1 (pq->nbits); + pq->set_derived_values (); + READVECTOR (pq->centroids); +} + +static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) { + READ1 (ivsc->qtype); + READ1 (ivsc->rangestat); + READ1 (ivsc->rangestat_arg); + READ1 (ivsc->d); + READ1 (ivsc->code_size); + READVECTOR (ivsc->trained); +} + + +static void read_HNSW (HNSW *hnsw, IOReader *f) { + READVECTOR (hnsw->assign_probas); + READVECTOR (hnsw->cum_nneighbor_per_level); + READVECTOR (hnsw->levels); + READVECTOR (hnsw->offsets); + READVECTOR (hnsw->neighbors); + + READ1 (hnsw->entry_point); + READ1 (hnsw->max_level); + READ1 (hnsw->efConstruction); + READ1 (hnsw->efSearch); + READ1 (hnsw->upper_beam); +} + +static void read_RHNSW (RHNSW *rhnsw, IOReader *f) { + READ1 (rhnsw->entry_point); + READ1 (rhnsw->max_level); + READ1 (rhnsw->M); + READ1 (rhnsw->level0_link_size); + READ1 (rhnsw->link_size); + READ1 (rhnsw->level_constant); + READ1 (rhnsw->efConstruction); + READ1 (rhnsw->efSearch); + + READVECTOR (rhnsw->levels); + auto ntotal = rhnsw->levels.size(); + rhnsw->level0_links = (char*) malloc(ntotal * rhnsw->level0_link_size); + READANDCHECK( rhnsw->level0_links, ntotal * rhnsw->level0_link_size); + rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*)); + for (auto i = 0; i < ntotal; ++ i) { + if (rhnsw->levels[i]) { + rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1); + READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1); + } + } +} + +ProductQuantizer * read_ProductQuantizer (const char*fname) { + FileIOReader reader(fname); + return read_ProductQuantizer(&reader); +} + +ProductQuantizer * read_ProductQuantizer (IOReader *reader) { + ProductQuantizer *pq = new ProductQuantizer(); + ScopeDeleter1 del (pq); + + read_ProductQuantizer(pq, reader); + del.release (); + return pq; +} + +static void read_direct_map (DirectMap *dm, IOReader *f) { + char maintain_direct_map; + READ1 (maintain_direct_map); + dm->type = (DirectMap::Type)maintain_direct_map; + READVECTOR (dm->array); + if (dm->type == DirectMap::Hashtable) { + using idx_t = Index::idx_t; + std::vector> v; + READVECTOR (v); + std::unordered_map & map = dm->hashtable; + map.reserve (v.size()); + for (auto it: v) { + map [it.first] = it.second; + } + } + +} + + +static void read_ivf_header ( + IndexIVF *ivf, IOReader *f, + std::vector > *ids = nullptr) +{ + read_index_header (ivf, f); + READ1 (ivf->nlist); + READ1 (ivf->nprobe); + ivf->quantizer = read_index (f); + ivf->own_fields = true; + if (ids) { // used in legacy "Iv" formats + ids->resize (ivf->nlist); + for (size_t i = 0; i < ivf->nlist; i++) + READVECTOR ((*ids)[i]); + } + read_direct_map (&ivf->direct_map, f); +} + +// used for legacy formats +static ArrayInvertedLists *set_array_invlist( + IndexIVF *ivf, std::vector > &ids) +{ + ArrayInvertedLists *ail = new ArrayInvertedLists ( + ivf->nlist, ivf->code_size); + std::swap (ail->ids, ids); + ivf->invlists = ail; + ivf->own_invlists = true; + return ail; +} + +static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags) +{ + bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ"); + + IndexIVFPQR *ivfpqr = + h == fourcc ("IvQR") || h == fourcc ("IwQR") ? + new IndexIVFPQR () : nullptr; + IndexIVFPQ * ivpq = ivfpqr ? ivfpqr : new IndexIVFPQ (); + + std::vector > ids; + read_ivf_header (ivpq, f, legacy ? &ids : nullptr); + READ1 (ivpq->by_residual); + READ1 (ivpq->code_size); + read_ProductQuantizer (&ivpq->pq, f); + + if (legacy) { + ArrayInvertedLists *ail = set_array_invlist (ivpq, ids); + for (size_t i = 0; i < ail->nlist; i++) + READVECTOR (ail->codes[i]); + } else { + read_InvertedLists (ivpq, f, io_flags); + } + + if (ivpq->is_trained) { + // precomputed table not stored. It is cheaper to recompute it + ivpq->use_precomputed_table = 0; + if (ivpq->by_residual) + ivpq->precompute_table (); + if (ivfpqr) { + read_ProductQuantizer (&ivfpqr->refine_pq, f); + READVECTOR (ivfpqr->refine_codes); + READ1 (ivfpqr->k_factor); + } + } + return ivpq; +} + +int read_old_fmt_hack = 0; + +Index *read_index (IOReader *f, int io_flags) { + Index * idx = nullptr; + uint32_t h; + READ1 (h); + if (h == fourcc ("IxFI") || h == fourcc ("IxF2") || h == fourcc("IxFl")) { + IndexFlat *idxf; + if (h == fourcc ("IxFI")) { + idxf = new IndexFlatIP (); + } else if (h == fourcc("IxF2")) { + idxf = new IndexFlatL2 (); + } else { + idxf = new IndexFlat (); + } + read_index_header (idxf, f); + READVECTOR (idxf->xb); + FAISS_THROW_IF_NOT (idxf->xb.size() == idxf->ntotal * idxf->d); + // leak! + idx = idxf; + } else if (h == fourcc("IxHE") || h == fourcc("IxHe")) { + IndexLSH * idxl = new IndexLSH (); + read_index_header (idxl, f); + READ1 (idxl->nbits); + READ1 (idxl->rotate_data); + READ1 (idxl->train_thresholds); + READVECTOR (idxl->thresholds); + READ1 (idxl->bytes_per_vec); + if (h == fourcc("IxHE")) { + FAISS_THROW_IF_NOT_FMT (idxl->nbits % 64 == 0, + "can only read old format IndexLSH with " + "nbits multiple of 64 (got %d)", + (int) idxl->nbits); + // leak + idxl->bytes_per_vec *= 8; + } + { + RandomRotationMatrix *rrot = dynamic_cast + (read_VectorTransform (f)); + FAISS_THROW_IF_NOT_MSG(rrot, "expected a random rotation"); + idxl->rrot = *rrot; + delete rrot; + } + READVECTOR (idxl->codes); + FAISS_THROW_IF_NOT (idxl->rrot.d_in == idxl->d && + idxl->rrot.d_out == idxl->nbits); + FAISS_THROW_IF_NOT ( + idxl->codes.size() == idxl->ntotal * idxl->bytes_per_vec); + idx = idxl; + } else if (h == fourcc ("IxPQ") || h == fourcc ("IxPo") || + h == fourcc ("IxPq")) { + // IxPQ and IxPo were merged into the same IndexPQ object + IndexPQ * idxp =new IndexPQ (); + read_index_header (idxp, f); + read_ProductQuantizer (&idxp->pq, f); + READVECTOR (idxp->codes); + if (h == fourcc ("IxPo") || h == fourcc ("IxPq")) { + READ1 (idxp->search_type); + READ1 (idxp->encode_signs); + READ1 (idxp->polysemous_ht); + } + // Old versoins of PQ all had metric_type set to INNER_PRODUCT + // when they were in fact using L2. Therefore, we force metric type + // to L2 when the old format is detected + if (h == fourcc ("IxPQ") || h == fourcc ("IxPo")) { + idxp->metric_type = METRIC_L2; + } + if (h == fourcc("IxPq")) { + idxp->pq.compute_sdc_table (); + } + idx = idxp; + } else if (h == fourcc ("IvFl") || h == fourcc("IvFL")) { // legacy + IndexIVFFlat * ivfl = new IndexIVFFlat (); + std::vector > ids; + read_ivf_header (ivfl, f, &ids); + ivfl->code_size = ivfl->d * sizeof(float); + ArrayInvertedLists *ail = set_array_invlist (ivfl, ids); + + if (h == fourcc ("IvFL")) { + for (size_t i = 0; i < ivfl->nlist; i++) { + READVECTOR (ail->codes[i]); + } + } else { // old format + for (size_t i = 0; i < ivfl->nlist; i++) { + std::vector vec; + READVECTOR (vec); + ail->codes[i].resize(vec.size() * sizeof(float)); + memcpy(ail->codes[i].data(), vec.data(), + ail->codes[i].size()); + } + } + idx = ivfl; + } else if (h == fourcc ("IwFd")) { + IndexIVFFlatDedup * ivfl = new IndexIVFFlatDedup (); + read_ivf_header (ivfl, f); + ivfl->code_size = ivfl->d * sizeof(float); + { + std::vector tab; + READVECTOR (tab); + for (long i = 0; i < tab.size(); i += 2) { + std::pair + pair (tab[i], tab[i + 1]); + ivfl->instances.insert (pair); + } + } + read_InvertedLists (ivfl, f, io_flags); + idx = ivfl; + } else if (h == fourcc ("IwFl")) { + IndexIVFFlat * ivfl = new IndexIVFFlat (); + read_ivf_header (ivfl, f); + ivfl->code_size = ivfl->d * sizeof(float); + read_InvertedLists (ivfl, f, io_flags); + idx = ivfl; + } else if (h == fourcc ("IxSQ")) { + IndexScalarQuantizer * idxs = new IndexScalarQuantizer (); + read_index_header (idxs, f); + read_ScalarQuantizer (&idxs->sq, f); + READVECTOR (idxs->codes); + idxs->code_size = idxs->sq.code_size; + idx = idxs; + } else if (h == fourcc ("IxLa")) { + int d, nsq, scale_nbit, r2; + READ1 (d); + READ1 (nsq); + READ1 (scale_nbit); + READ1 (r2); + IndexLattice *idxl = new IndexLattice (d, nsq, scale_nbit, r2); + read_index_header (idxl, f); + READVECTOR (idxl->trained); + idx = idxl; + } else if(h == fourcc ("IvSQ")) { // legacy + IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer(); + std::vector > ids; + read_ivf_header (ivsc, f, &ids); + read_ScalarQuantizer (&ivsc->sq, f); + READ1 (ivsc->code_size); + ArrayInvertedLists *ail = set_array_invlist (ivsc, ids); + for(int i = 0; i < ivsc->nlist; i++) + READVECTOR (ail->codes[i]); + idx = ivsc; + } else if(h == fourcc ("IwSQ") || h == fourcc ("IwSq")) { + IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer(); + read_ivf_header (ivsc, f); + read_ScalarQuantizer (&ivsc->sq, f); + READ1 (ivsc->code_size); + if (h == fourcc ("IwSQ")) { + ivsc->by_residual = true; + } else { + READ1 (ivsc->by_residual); + } + read_InvertedLists (ivsc, f, io_flags); + idx = ivsc; + } else if (h == fourcc("ISqH")) { + IndexIVFSQHybrid *ivfsqhbyrid = new IndexIVFSQHybrid(); + read_ivf_header(ivfsqhbyrid, f); + read_ScalarQuantizer(&ivfsqhbyrid->sq, f); + READ1 (ivfsqhbyrid->code_size); + READ1 (ivfsqhbyrid->by_residual); + read_InvertedLists(ivfsqhbyrid, f, io_flags); + idx = ivfsqhbyrid; + } else if(h == fourcc ("IwSh")) { + IndexIVFSpectralHash *ivsp = new IndexIVFSpectralHash (); + read_ivf_header (ivsp, f); + ivsp->vt = read_VectorTransform (f); + ivsp->own_fields = true; + READ1 (ivsp->nbit); + // not stored by write_ivf_header + ivsp->code_size = (ivsp->nbit + 7) / 8; + READ1 (ivsp->period); + READ1 (ivsp->threshold_type); + READVECTOR (ivsp->trained); + read_InvertedLists (ivsp, f, io_flags); + idx = ivsp; + } else if(h == fourcc ("IvPQ") || h == fourcc ("IvQR") || + h == fourcc ("IwPQ") || h == fourcc ("IwQR")) { + + idx = read_ivfpq (f, h, io_flags); + + } else if(h == fourcc ("IxPT")) { + IndexPreTransform * ixpt = new IndexPreTransform(); + ixpt->own_fields = true; + read_index_header (ixpt, f); + int nt; + if (read_old_fmt_hack == 2) { + nt = 1; + } else { + READ1 (nt); + } + for (int i = 0; i < nt; i++) { + ixpt->chain.push_back (read_VectorTransform (f)); + } + ixpt->index = read_index (f, io_flags); + idx = ixpt; + } else if(h == fourcc ("Imiq")) { + MultiIndexQuantizer * imiq = new MultiIndexQuantizer (); + read_index_header (imiq, f); + read_ProductQuantizer (&imiq->pq, f); + idx = imiq; + } else if(h == fourcc ("IxRF")) { + IndexRefineFlat *idxrf = new IndexRefineFlat (); + read_index_header (idxrf, f); + idxrf->base_index = read_index(f, io_flags); + idxrf->own_fields = true; + IndexFlat *rf = dynamic_cast (read_index (f, io_flags)); + std::swap (*rf, idxrf->refine_index); + delete rf; + READ1 (idxrf->k_factor); + idx = idxrf; + } else if(h == fourcc ("IxMp") || h == fourcc ("IxM2")) { + bool is_map2 = h == fourcc ("IxM2"); + IndexIDMap * idxmap = is_map2 ? new IndexIDMap2 () : new IndexIDMap (); + read_index_header (idxmap, f); + idxmap->index = read_index (f, io_flags); + idxmap->own_fields = true; + READVECTOR (idxmap->id_map); + if (is_map2) { + static_cast(idxmap)->construct_rev_map (); + } + idx = idxmap; + } else if (h == fourcc ("Ix2L")) { + Index2Layer * idxp = new Index2Layer (); + read_index_header (idxp, f); + idxp->q1.quantizer = read_index (f, io_flags); + READ1 (idxp->q1.nlist); + READ1 (idxp->q1.quantizer_trains_alone); + read_ProductQuantizer (&idxp->pq, f); + READ1 (idxp->code_size_1); + READ1 (idxp->code_size_2); + READ1 (idxp->code_size); + READVECTOR (idxp->codes); + idx = idxp; + } else if(h == fourcc("IHNf") || h == fourcc("IHNp") || + h == fourcc("IHNs") || h == fourcc("IHN2")) { + IndexHNSW *idxhnsw = nullptr; + if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat (); + if (h == fourcc("IHNp")) idxhnsw = new IndexHNSWPQ (); + if (h == fourcc("IHNs")) idxhnsw = new IndexHNSWSQ (); + if (h == fourcc("IHN2")) idxhnsw = new IndexHNSW2Level (); + read_index_header (idxhnsw, f); + read_HNSW (&idxhnsw->hnsw, f); + idxhnsw->storage = read_index (f, io_flags); + idxhnsw->own_fields = true; + if (h == fourcc("IHNp")) { + dynamic_cast(idxhnsw->storage)->pq.compute_sdc_table (); + } + idx = idxhnsw; + } else if(h == fourcc("IRHf") || h == fourcc("IRHp") || + h == fourcc("IRHs") || h == fourcc("IRH2")) { + IndexRHNSW *idxrhnsw = nullptr; + if (h == fourcc("IRHf")) idxrhnsw = new IndexRHNSWFlat (); + if (h == fourcc("IRHp")) idxrhnsw = new IndexRHNSWPQ (); + if (h == fourcc("IRHs")) idxrhnsw = new IndexRHNSWSQ (); + if (h == fourcc("IRH2")) idxrhnsw = new IndexRHNSW2Level (); + read_index_header (idxrhnsw, f); + read_RHNSW (&idxrhnsw->hnsw, f); + idxrhnsw->own_fields = true; + idx = idxrhnsw; + } else { + FAISS_THROW_FMT("Index type 0x%08x not supported\n", h); + idx = nullptr; + } + return idx; +} + + +Index *read_index (FILE * f, int io_flags) { + FileIOReader reader(f); + return read_index(&reader, io_flags); +} + +Index *read_index (const char *fname, int io_flags) { + FileIOReader reader(fname); + Index *idx = read_index (&reader, io_flags); + return idx; +} + +// read offset-only index +Index *read_index_nm (IOReader *f, int io_flags) { + Index * idx = nullptr; + uint32_t h; + READ1 (h); + if (h == fourcc ("IwFl")) { + IndexIVFFlat * ivfl = new IndexIVFFlat (); + read_ivf_header (ivfl, f); + ivfl->code_size = ivfl->d * sizeof(float); + read_InvertedLists_nm (ivfl, f, io_flags); + idx = ivfl; + } else if(h == fourcc ("IwSq")) { + IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer(); + read_ivf_header (ivsc, f); + read_ScalarQuantizer (&ivsc->sq, f); + READ1 (ivsc->code_size); + READ1 (ivsc->by_residual); + read_InvertedLists_nm (ivsc, f, io_flags); + idx = ivsc; + } else if (h == fourcc("ISqH")) { + IndexIVFSQHybrid *ivfsqhbyrid = new IndexIVFSQHybrid(); + read_ivf_header(ivfsqhbyrid, f); + read_ScalarQuantizer(&ivfsqhbyrid->sq, f); + READ1 (ivfsqhbyrid->code_size); + READ1 (ivfsqhbyrid->by_residual); + read_InvertedLists_nm(ivfsqhbyrid, f, io_flags); + idx = ivfsqhbyrid; + } else { + FAISS_THROW_FMT("Index type 0x%08x not supported\n", h); + idx = nullptr; + } + return idx; +} + + +Index *read_index_nm (FILE * f, int io_flags) { + FileIOReader reader(f); + return read_index_nm(&reader, io_flags); +} + +Index *read_index_nm (const char *fname, int io_flags) { + FileIOReader reader(fname); + Index *idx = read_index_nm (&reader, io_flags); + return idx; +} + +VectorTransform *read_VectorTransform (const char *fname) { + FileIOReader reader(fname); + VectorTransform *vt = read_VectorTransform (&reader); + return vt; +} + + + +/************************************************************* + * Read binary indexes + **************************************************************/ + +static void read_InvertedLists ( + IndexBinaryIVF *ivf, IOReader *f, int io_flags) { + InvertedLists *ils = read_InvertedLists (f, io_flags); + FAISS_THROW_IF_NOT (!ils || (ils->nlist == ivf->nlist && + ils->code_size == ivf->code_size)); + ivf->invlists = ils; + ivf->own_invlists = true; +} + + + +static void read_index_binary_header (IndexBinary *idx, IOReader *f) { + READ1 (idx->d); + READ1 (idx->code_size); + READ1 (idx->ntotal); + READ1 (idx->is_trained); + READ1 (idx->metric_type); + idx->verbose = false; +} + +static void read_binary_ivf_header ( + IndexBinaryIVF *ivf, IOReader *f, + std::vector > *ids = nullptr) +{ + read_index_binary_header (ivf, f); + READ1 (ivf->nlist); + READ1 (ivf->nprobe); + ivf->quantizer = read_index_binary (f); + ivf->own_fields = true; + if (ids) { // used in legacy "Iv" formats + ids->resize (ivf->nlist); + for (size_t i = 0; i < ivf->nlist; i++) + READVECTOR ((*ids)[i]); + } + read_direct_map (&ivf->direct_map, f); +} + +static void read_binary_hash_invlists ( + IndexBinaryHash::InvertedListMap &invlists, + int b, IOReader *f) +{ + size_t sz; + READ1 (sz); + int il_nbit = 0; + READ1 (il_nbit); + // buffer for bitstrings + std::vector buf((b + il_nbit) * sz); + READVECTOR (buf); + BitstringReader rd (buf.data(), buf.size()); + invlists.reserve (sz); + for (size_t i = 0; i < sz; i++) { + uint64_t hash = rd.read(b); + uint64_t ilsz = rd.read(il_nbit); + auto & il = invlists[hash]; + READVECTOR (il.ids); + FAISS_THROW_IF_NOT (il.ids.size() == ilsz); + READVECTOR (il.vecs); + } +} + +static void read_binary_multi_hash_map( + IndexBinaryMultiHash::Map &map, + int b, size_t ntotal, + IOReader *f) +{ + int id_bits; + size_t sz; + READ1 (id_bits); + READ1 (sz); + std::vector buf; + READVECTOR (buf); + size_t nbit = (b + id_bits) * sz + ntotal * id_bits; + FAISS_THROW_IF_NOT (buf.size() == (nbit + 7) / 8); + BitstringReader rd (buf.data(), buf.size()); + map.reserve (sz); + for (size_t i = 0; i < sz; i++) { + uint64_t hash = rd.read(b); + uint64_t ilsz = rd.read(id_bits); + auto & il = map[hash]; + for (size_t j = 0; j < ilsz; j++) { + il.push_back (rd.read (id_bits)); + } + } +} + + + +IndexBinary *read_index_binary (IOReader *f, int io_flags) { + IndexBinary * idx = nullptr; + uint32_t h; + READ1 (h); + if (h == fourcc ("IBxF")) { + IndexBinaryFlat *idxf = new IndexBinaryFlat (); + read_index_binary_header (idxf, f); + READVECTOR (idxf->xb); + FAISS_THROW_IF_NOT (idxf->xb.size() == idxf->ntotal * idxf->code_size); + // leak! + idx = idxf; + } else if (h == fourcc ("IBwF")) { + IndexBinaryIVF *ivf = new IndexBinaryIVF (); + read_binary_ivf_header (ivf, f); + read_InvertedLists (ivf, f, io_flags); + idx = ivf; + } else if (h == fourcc ("IBFf")) { + IndexBinaryFromFloat *idxff = new IndexBinaryFromFloat (); + read_index_binary_header (idxff, f); + idxff->own_fields = true; + idxff->index = read_index (f, io_flags); + idx = idxff; + } else if (h == fourcc ("IBHf")) { + IndexBinaryHNSW *idxhnsw = new IndexBinaryHNSW (); + read_index_binary_header (idxhnsw, f); + read_HNSW (&idxhnsw->hnsw, f); + idxhnsw->storage = read_index_binary (f, io_flags); + idxhnsw->own_fields = true; + idx = idxhnsw; + } else if(h == fourcc ("IBMp") || h == fourcc ("IBM2")) { + bool is_map2 = h == fourcc ("IBM2"); + IndexBinaryIDMap * idxmap = is_map2 ? + new IndexBinaryIDMap2 () : new IndexBinaryIDMap (); + read_index_binary_header (idxmap, f); + idxmap->index = read_index_binary (f, io_flags); + idxmap->own_fields = true; + READVECTOR (idxmap->id_map); + if (is_map2) { + static_cast(idxmap)->construct_rev_map (); + } + idx = idxmap; + } else if(h == fourcc("IBHh")) { + IndexBinaryHash *idxh = new IndexBinaryHash (); + read_index_binary_header (idxh, f); + READ1 (idxh->b); + READ1 (idxh->nflip); + read_binary_hash_invlists(idxh->invlists, idxh->b, f); + idx = idxh; + } else if(h == fourcc("IBHm")) { + IndexBinaryMultiHash* idxmh = new IndexBinaryMultiHash (); + read_index_binary_header (idxmh, f); + idxmh->storage = dynamic_cast (read_index_binary (f)); + FAISS_THROW_IF_NOT(idxmh->storage && idxmh->storage->ntotal == idxmh->ntotal); + idxmh->own_fields = true; + READ1 (idxmh->b); + READ1 (idxmh->nhash); + READ1 (idxmh->nflip); + idxmh->maps.resize (idxmh->nhash); + for (int i = 0; i < idxmh->nhash; i++) { + read_binary_multi_hash_map( + idxmh->maps[i], idxmh->b, idxmh->ntotal, f); + } + idx = idxmh; + } else { + FAISS_THROW_FMT("Index type 0x%08x not supported\n", h); + idx = nullptr; + } + return idx; +} + +IndexBinary *read_index_binary (FILE * f, int io_flags) { + FileIOReader reader(f); + return read_index_binary(&reader, io_flags); +} + +IndexBinary *read_index_binary (const char *fname, int io_flags) { + FileIOReader reader(fname); + IndexBinary *idx = read_index_binary (&reader, io_flags); + return idx; +} + + +} // namespace faiss \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/impl/index_write.cpp b/core/src/index/thirdparty/faiss/impl/index_write.cpp new file mode 100644 index 0000000000..9bbaa4e8bc --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/index_write.cpp @@ -0,0 +1,811 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + + +/************************************************************* + * The I/O format is the content of the class. For objects that are + * inherited, like Index, a 4-character-code (fourcc) indicates which + * child class this is an instance of. + * + * In this case, the fields of the parent class are written first, + * then the ones for the child classes. Note that this requires + * classes to be serialized to have a constructor without parameters, + * so that the fields can be filled in later. The default constructor + * should set reasonable defaults for all fields. + * + * The fourccs are assigned arbitrarily. When the class changed (added + * or deprecated fields), the fourcc can be replaced. New code should + * be able to read the old fourcc and fill in new classes. + * + * TODO: serialization to strings for use in Python pickle or Torch + * serialization. + * + * TODO: in this file, the read functions that encouter errors may + * leak memory. + **************************************************************/ + + + +namespace faiss { + + +/************************************************************* + * I/O macros + * + * we use macros so that we have a line number to report in abort + * (). This makes debugging a lot easier. The IOReader or IOWriter is + * always called f and thus is not passed in as a macro parameter. + **************************************************************/ + + +#define WRITEANDCHECK(ptr, n) { \ + size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \ + FAISS_THROW_IF_NOT_FMT(ret == (n), \ + "write error in %s: %ld != %ld (%s)", \ + f->name.c_str(), ret, size_t(n), strerror(errno)); \ + } + +#define WRITE1(x) WRITEANDCHECK(&(x), 1) + +#define WRITEVECTOR(vec) { \ + size_t size = (vec).size (); \ + WRITEANDCHECK (&size, 1); \ + WRITEANDCHECK ((vec).data (), size); \ + } + + + +/************************************************************* + * Write + **************************************************************/ +static void write_index_header (const Index *idx, IOWriter *f) { + WRITE1 (idx->d); + WRITE1 (idx->ntotal); + Index::idx_t dummy = 1 << 20; + WRITE1 (dummy); + WRITE1 (dummy); + WRITE1 (idx->is_trained); + WRITE1 (idx->metric_type); + if (idx->metric_type > 1) { + WRITE1 (idx->metric_arg); + } +} + +void write_VectorTransform (const VectorTransform *vt, IOWriter *f) { + if (const LinearTransform * lt = + dynamic_cast < const LinearTransform *> (vt)) { + if (dynamic_cast(lt)) { + uint32_t h = fourcc ("rrot"); + WRITE1 (h); + } else if (const PCAMatrix * pca = + dynamic_cast(lt)) { + uint32_t h = fourcc ("PcAm"); + WRITE1 (h); + WRITE1 (pca->eigen_power); + WRITE1 (pca->random_rotation); + WRITE1 (pca->balanced_bins); + WRITEVECTOR (pca->mean); + WRITEVECTOR (pca->eigenvalues); + WRITEVECTOR (pca->PCAMat); + } else if (const ITQMatrix * itqm = + dynamic_cast(lt)) { + uint32_t h = fourcc ("Viqm"); + WRITE1 (h); + WRITE1 (itqm->max_iter); + WRITE1 (itqm->seed); + } else { + // generic LinearTransform (includes OPQ) + uint32_t h = fourcc ("LTra"); + WRITE1 (h); + } + WRITE1 (lt->have_bias); + WRITEVECTOR (lt->A); + WRITEVECTOR (lt->b); + } else if (const RemapDimensionsTransform *rdt = + dynamic_cast(vt)) { + uint32_t h = fourcc ("RmDT"); + WRITE1 (h); + WRITEVECTOR (rdt->map); + } else if (const NormalizationTransform *nt = + dynamic_cast(vt)) { + uint32_t h = fourcc ("VNrm"); + WRITE1 (h); + WRITE1 (nt->norm); + } else if (const CenteringTransform *ct = + dynamic_cast(vt)) { + uint32_t h = fourcc ("VCnt"); + WRITE1 (h); + WRITEVECTOR (ct->mean); + } else if (const ITQTransform *itqt = + dynamic_cast (vt)) { + uint32_t h = fourcc ("Viqt"); + WRITE1 (h); + WRITEVECTOR (itqt->mean); + WRITE1 (itqt->do_pca); + write_VectorTransform (&itqt->itq, f); + write_VectorTransform (&itqt->pca_then_itq, f); + } else { + FAISS_THROW_MSG ("cannot serialize this"); + } + // common fields + WRITE1 (vt->d_in); + WRITE1 (vt->d_out); + WRITE1 (vt->is_trained); +} + +void write_ProductQuantizer (const ProductQuantizer *pq, IOWriter *f) { + WRITE1 (pq->d); + WRITE1 (pq->M); + WRITE1 (pq->nbits); + WRITEVECTOR (pq->centroids); +} + +static void write_ScalarQuantizer ( + const ScalarQuantizer *ivsc, IOWriter *f) { + WRITE1 (ivsc->qtype); + WRITE1 (ivsc->rangestat); + WRITE1 (ivsc->rangestat_arg); + WRITE1 (ivsc->d); + WRITE1 (ivsc->code_size); + WRITEVECTOR (ivsc->trained); +} + +void write_InvertedLists (const InvertedLists *ils, IOWriter *f) { + if (ils == nullptr) { + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } else if (const auto & ails = + dynamic_cast(ils)) { + uint32_t h = fourcc ("ilar"); + WRITE1 (h); + WRITE1 (ails->nlist); + WRITE1 (ails->code_size); + // here we store either as a full or a sparse data buffer + size_t n_non0 = 0; + for (size_t i = 0; i < ails->nlist; i++) { + if (ails->ids[i].size() > 0) + n_non0++; + } + if (n_non0 > ails->nlist / 2) { + uint32_t list_type = fourcc("full"); + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + sizes.push_back (ails->ids[i].size()); + } + WRITEVECTOR (sizes); + } else { + int list_type = fourcc("sprs"); // sparse + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + sizes.push_back (i); + sizes.push_back (n); + } + } + WRITEVECTOR (sizes); + } + // make a single contiguous data buffer (useful for mmapping) + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + WRITEANDCHECK (ails->codes[i].data(), n * ails->code_size); + WRITEANDCHECK (ails->ids[i].data(), n); + } + } + } else if (const auto & oa = + dynamic_cast(ils)) { + uint32_t h = fourcc("iloa"); + WRITE1 (h); + WRITE1 (oa->nlist); + WRITE1 (oa->code_size); + WRITEVECTOR(oa->readonly_length); +#ifdef USE_CPU + size_t n = oa->readonly_ids.size(); + WRITE1(n); + WRITEANDCHECK(oa->readonly_ids.data(), n); + WRITEANDCHECK(oa->readonly_codes.data(), n * oa->code_size); +#else + size_t n = oa->pin_readonly_ids->size() / sizeof(InvertedLists::idx_t); + WRITE1(n); + WRITEANDCHECK((InvertedLists::idx_t *) oa->pin_readonly_ids->data, n); + WRITEANDCHECK((uint8_t *) oa->pin_readonly_codes->data, n * oa->code_size); +#endif + } else if (const auto & od = + dynamic_cast(ils)) { + uint32_t h = fourcc ("ilod"); + WRITE1 (h); + WRITE1 (ils->nlist); + WRITE1 (ils->code_size); + // this is a POD object + WRITEVECTOR (od->lists); + + { + std::vector v( + od->slots.begin(), od->slots.end()); + WRITEVECTOR(v); + } + { + std::vector x(od->filename.begin(), od->filename.end()); + WRITEVECTOR(x); + } + WRITE1(od->totsize); + + } else { + fprintf(stderr, "WARN! write_InvertedLists: unsupported invlist type, " + "saving null invlist\n"); + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } +} + +// write inverted lists for offset-only index +void write_InvertedLists_nm (const InvertedLists *ils, IOWriter *f) { + if (ils == nullptr) { + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } else if (const auto & ails = + dynamic_cast(ils)) { + uint32_t h = fourcc ("ilar"); + WRITE1 (h); + WRITE1 (ails->nlist); + WRITE1 (ails->code_size); + // here we store either as a full or a sparse data buffer + size_t n_non0 = 0; + for (size_t i = 0; i < ails->nlist; i++) { + if (ails->ids[i].size() > 0) + n_non0++; + } + if (n_non0 > ails->nlist / 2) { + uint32_t list_type = fourcc("full"); + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + sizes.push_back (ails->ids[i].size()); + } + WRITEVECTOR (sizes); + } else { + int list_type = fourcc("sprs"); // sparse + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + sizes.push_back (i); + sizes.push_back (n); + } + } + WRITEVECTOR (sizes); + } + // make a single contiguous data buffer (useful for mmapping) + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + // WRITEANDCHECK (ails->codes[i].data(), n * ails->code_size); + WRITEANDCHECK (ails->ids[i].data(), n); + } + } + } else if (const auto & oa = + dynamic_cast(ils)) { + // not going to happen + } else { + fprintf(stderr, "WARN! write_InvertedLists: unsupported invlist type, " + "saving null invlist\n"); + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } +} + + +void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) { + FileIOWriter writer(fname); + write_ProductQuantizer (pq, &writer); +} + +static void write_HNSW (const HNSW *hnsw, IOWriter *f) { + + WRITEVECTOR (hnsw->assign_probas); + WRITEVECTOR (hnsw->cum_nneighbor_per_level); + WRITEVECTOR (hnsw->levels); + WRITEVECTOR (hnsw->offsets); + WRITEVECTOR (hnsw->neighbors); + + WRITE1 (hnsw->entry_point); + WRITE1 (hnsw->max_level); + WRITE1 (hnsw->efConstruction); + WRITE1 (hnsw->efSearch); + WRITE1 (hnsw->upper_beam); +} + +static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) { + WRITE1 (rhnsw->entry_point); + WRITE1 (rhnsw->max_level); + WRITE1 (rhnsw->M); + WRITE1 (rhnsw->level0_link_size); + WRITE1 (rhnsw->link_size); + WRITE1 (rhnsw->level_constant); + WRITE1 (rhnsw->efConstruction); + WRITE1 (rhnsw->efSearch); + + WRITEVECTOR (rhnsw->levels); + WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size()); + for (auto i = 0; i < rhnsw->levels.size(); ++ i) { + if (rhnsw->levels[i]) + WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1); + } +} + +static void write_direct_map (const DirectMap *dm, IOWriter *f) { + char maintain_direct_map = (char)dm->type; // for backwards compatibility with bool + WRITE1 (maintain_direct_map); + WRITEVECTOR (dm->array); + if (dm->type == DirectMap::Hashtable) { + using idx_t = Index::idx_t; + std::vector> v; + const std::unordered_map & map = dm->hashtable; + v.resize (map.size()); + std::copy(map.begin(), map.end(), v.begin()); + WRITEVECTOR (v); + } +} + +static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) { + write_index_header (ivf, f); + WRITE1 (ivf->nlist); + WRITE1 (ivf->nprobe); + write_index (ivf->quantizer, f); + write_direct_map (&ivf->direct_map, f); +} + +void write_index (const Index *idx, IOWriter *f) { + if (const IndexFlat * idxf = dynamic_cast (idx)) { + uint32_t h = fourcc ( + idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" : + idxf->metric_type == METRIC_L2 ? "IxF2" : "IxFl"); + WRITE1 (h); + write_index_header (idx, f); + WRITEVECTOR (idxf->xb); + } else if(const IndexLSH * idxl = dynamic_cast (idx)) { + uint32_t h = fourcc ("IxHe"); + WRITE1 (h); + write_index_header (idx, f); + WRITE1 (idxl->nbits); + WRITE1 (idxl->rotate_data); + WRITE1 (idxl->train_thresholds); + WRITEVECTOR (idxl->thresholds); + WRITE1 (idxl->bytes_per_vec); + write_VectorTransform (&idxl->rrot, f); + WRITEVECTOR (idxl->codes); + } else if(const IndexPQ * idxp = dynamic_cast (idx)) { + uint32_t h = fourcc ("IxPq"); + WRITE1 (h); + write_index_header (idx, f); + write_ProductQuantizer (&idxp->pq, f); + WRITEVECTOR (idxp->codes); + // search params -- maybe not useful to store? + WRITE1 (idxp->search_type); + WRITE1 (idxp->encode_signs); + WRITE1 (idxp->polysemous_ht); + } else if(const Index2Layer * idxp = + dynamic_cast (idx)) { + uint32_t h = fourcc ("Ix2L"); + WRITE1 (h); + write_index_header (idx, f); + write_index (idxp->q1.quantizer, f); + WRITE1 (idxp->q1.nlist); + WRITE1 (idxp->q1.quantizer_trains_alone); + write_ProductQuantizer (&idxp->pq, f); + WRITE1 (idxp->code_size_1); + WRITE1 (idxp->code_size_2); + WRITE1 (idxp->code_size); + WRITEVECTOR (idxp->codes); + } else if(const IndexScalarQuantizer * idxs = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IxSQ"); + WRITE1 (h); + write_index_header (idx, f); + write_ScalarQuantizer (&idxs->sq, f); + WRITEVECTOR (idxs->codes); + } else if(const IndexLattice * idxl = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IxLa"); + WRITE1 (h); + WRITE1 (idxl->d); + WRITE1 (idxl->nsq); + WRITE1 (idxl->scale_nbit); + WRITE1 (idxl->zn_sphere_codec.r2); + write_index_header (idx, f); + WRITEVECTOR (idxl->trained); + } else if(const IndexIVFFlatDedup * ivfl = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwFd"); + WRITE1 (h); + write_ivf_header (ivfl, f); + { + std::vector tab (2 * ivfl->instances.size()); + long i = 0; + for (auto it = ivfl->instances.begin(); + it != ivfl->instances.end(); ++it) { + tab[i++] = it->first; + tab[i++] = it->second; + } + WRITEVECTOR (tab); + } + write_InvertedLists (ivfl->invlists, f); + } else if(const IndexIVFFlat * ivfl = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwFl"); + WRITE1 (h); + write_ivf_header (ivfl, f); + write_InvertedLists (ivfl->invlists, f); + } else if(const IndexIVFScalarQuantizer * ivsc = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwSq"); + WRITE1 (h); + write_ivf_header (ivsc, f); + write_ScalarQuantizer (&ivsc->sq, f); + WRITE1 (ivsc->code_size); + WRITE1 (ivsc->by_residual); + write_InvertedLists (ivsc->invlists, f); + } else if(const IndexIVFSQHybrid *ivfsqhbyrid = + dynamic_cast(idx)) { + uint32_t h = fourcc ("ISqH"); + WRITE1 (h); + write_ivf_header (ivfsqhbyrid, f); + write_ScalarQuantizer (&ivfsqhbyrid->sq, f); + WRITE1 (ivfsqhbyrid->code_size); + WRITE1 (ivfsqhbyrid->by_residual); + write_InvertedLists (ivfsqhbyrid->invlists, f); + } else if(const IndexIVFSpectralHash *ivsp = + dynamic_cast(idx)) { + uint32_t h = fourcc ("IwSh"); + WRITE1 (h); + write_ivf_header (ivsp, f); + write_VectorTransform (ivsp->vt, f); + WRITE1 (ivsp->nbit); + WRITE1 (ivsp->period); + WRITE1 (ivsp->threshold_type); + WRITEVECTOR (ivsp->trained); + write_InvertedLists (ivsp->invlists, f); + } else if(const IndexIVFPQ * ivpq = + dynamic_cast (idx)) { + const IndexIVFPQR * ivfpqr = dynamic_cast (idx); + + uint32_t h = fourcc (ivfpqr ? "IwQR" : "IwPQ"); + WRITE1 (h); + write_ivf_header (ivpq, f); + WRITE1 (ivpq->by_residual); + WRITE1 (ivpq->code_size); + write_ProductQuantizer (&ivpq->pq, f); + write_InvertedLists (ivpq->invlists, f); + if (ivfpqr) { + write_ProductQuantizer (&ivfpqr->refine_pq, f); + WRITEVECTOR (ivfpqr->refine_codes); + WRITE1 (ivfpqr->k_factor); + } + + } else if(const IndexPreTransform * ixpt = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IxPT"); + WRITE1 (h); + write_index_header (ixpt, f); + int nt = ixpt->chain.size(); + WRITE1 (nt); + for (int i = 0; i < nt; i++) + write_VectorTransform (ixpt->chain[i], f); + write_index (ixpt->index, f); + } else if(const MultiIndexQuantizer * imiq = + dynamic_cast (idx)) { + uint32_t h = fourcc ("Imiq"); + WRITE1 (h); + write_index_header (imiq, f); + write_ProductQuantizer (&imiq->pq, f); + } else if(const IndexRefineFlat * idxrf = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IxRF"); + WRITE1 (h); + write_index_header (idxrf, f); + write_index (idxrf->base_index, f); + write_index (&idxrf->refine_index, f); + WRITE1 (idxrf->k_factor); + } else if(const IndexIDMap * idxmap = + dynamic_cast (idx)) { + uint32_t h = + dynamic_cast (idx) ? fourcc ("IxM2") : + fourcc ("IxMp"); + // no need to store additional info for IndexIDMap2 + WRITE1 (h); + write_index_header (idxmap, f); + write_index (idxmap->index, f); + WRITEVECTOR (idxmap->id_map); + } else if(const IndexHNSW * idxhnsw = + dynamic_cast (idx)) { + uint32_t h = + dynamic_cast(idx) ? fourcc("IHNf") : + dynamic_cast(idx) ? fourcc("IHNp") : + dynamic_cast(idx) ? fourcc("IHNs") : + dynamic_cast(idx) ? fourcc("IHN2") : + 0; + FAISS_THROW_IF_NOT (h != 0); + WRITE1 (h); + write_index_header (idxhnsw, f); + write_HNSW (&idxhnsw->hnsw, f); + write_index (idxhnsw->storage, f); + } else if (const IndexRHNSW * idxrhnsw = + dynamic_cast(idx)) { + uint32_t h = + dynamic_cast(idx) ? fourcc("IRHf") : + dynamic_cast(idx) ? fourcc("IRHp") : + dynamic_cast(idx) ? fourcc("IRHs") : + dynamic_cast(idx) ? fourcc("IRH2") : + 0; + FAISS_THROW_IF_NOT (h != 0); + WRITE1 (h); + write_index_header (idxrhnsw, f); + write_RHNSW (&idxrhnsw->hnsw, f); + } else { + FAISS_THROW_MSG ("don't know how to serialize this type of index"); + } +} + +void write_index (const Index *idx, FILE *f) { + FileIOWriter writer(f); + write_index (idx, &writer); +} + +void write_index (const Index *idx, const char *fname) { + FileIOWriter writer(fname); + write_index (idx, &writer); +} + +// write index for offset-only index +void write_index_nm (const Index *idx, IOWriter *f) { + if(const IndexIVFFlat * ivfl = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwFl"); + WRITE1 (h); + write_ivf_header (ivfl, f); + write_InvertedLists_nm (ivfl->invlists, f); + } else if(const IndexIVFScalarQuantizer * ivsc = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwSq"); + WRITE1 (h); + write_ivf_header (ivsc, f); + write_ScalarQuantizer (&ivsc->sq, f); + WRITE1 (ivsc->code_size); + WRITE1 (ivsc->by_residual); + write_InvertedLists_nm (ivsc->invlists, f); + } else if(const IndexIVFSQHybrid *ivfsqhbyrid = + dynamic_cast(idx)) { + uint32_t h = fourcc ("ISqH"); + WRITE1 (h); + write_ivf_header (ivfsqhbyrid, f); + write_ScalarQuantizer (&ivfsqhbyrid->sq, f); + WRITE1 (ivfsqhbyrid->code_size); + WRITE1 (ivfsqhbyrid->by_residual); + write_InvertedLists_nm (ivfsqhbyrid->invlists, f); + } else { + FAISS_THROW_MSG ("don't know how to serialize this type of index"); + } +} + +void write_index_nm (const Index *idx, FILE *f) { + FileIOWriter writer(f); + write_index_nm (idx, &writer); +} + +void write_index_nm (const Index *idx, const char *fname) { + FileIOWriter writer(fname); + write_index_nm (idx, &writer); +} + +void write_VectorTransform (const VectorTransform *vt, const char *fname) { + FileIOWriter writer(fname); + write_VectorTransform (vt, &writer); +} + + +/************************************************************* + * Write binary indexes + **************************************************************/ + + +static void write_index_binary_header (const IndexBinary *idx, IOWriter *f) { + WRITE1 (idx->d); + WRITE1 (idx->code_size); + WRITE1 (idx->ntotal); + WRITE1 (idx->is_trained); + WRITE1 (idx->metric_type); +} + +static void write_binary_ivf_header (const IndexBinaryIVF *ivf, IOWriter *f) { + write_index_binary_header (ivf, f); + WRITE1 (ivf->nlist); + WRITE1 (ivf->nprobe); + write_index_binary (ivf->quantizer, f); + write_direct_map (&ivf->direct_map, f); +} + +static void write_binary_hash_invlists ( + const IndexBinaryHash::InvertedListMap &invlists, + int b, IOWriter *f) +{ + size_t sz = invlists.size(); + WRITE1 (sz); + size_t maxil = 0; + for (auto it = invlists.begin(); it != invlists.end(); ++it) { + if(it->second.ids.size() > maxil) { + maxil = it->second.ids.size(); + } + } + int il_nbit = 0; + while(maxil >= ((uint64_t)1 << il_nbit)) { + il_nbit++; + } + WRITE1(il_nbit); + + // first write sizes then data, may be useful if we want to + // memmap it at some point + + // buffer for bitstrings + std::vector buf (((b + il_nbit) * sz + 7) / 8); + BitstringWriter wr (buf.data(), buf.size()); + for (auto it = invlists.begin(); it != invlists.end(); ++it) { + wr.write (it->first, b); + wr.write (it->second.ids.size(), il_nbit); + } + WRITEVECTOR (buf); + + for (auto it = invlists.begin(); it != invlists.end(); ++it) { + WRITEVECTOR (it->second.ids); + WRITEVECTOR (it->second.vecs); + } +} + +static void write_binary_multi_hash_map( + const IndexBinaryMultiHash::Map &map, + int b, size_t ntotal, + IOWriter *f) +{ + int id_bits = 0; + while ((ntotal > ((Index::idx_t)1 << id_bits))) { + id_bits++; + } + WRITE1(id_bits); + size_t sz = map.size(); + WRITE1(sz); + size_t nbit = (b + id_bits) * sz + ntotal * id_bits; + std::vector buf((nbit + 7) / 8); + BitstringWriter wr (buf.data(), buf.size()); + for (auto it = map.begin(); it != map.end(); ++it) { + wr.write(it->first, b); + wr.write(it->second.size(), id_bits); + for (auto id : it->second) { + wr.write(id, id_bits); + } + } + WRITEVECTOR (buf); +} + +void write_index_binary (const IndexBinary *idx, IOWriter *f) { + if (const IndexBinaryFlat *idxf = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBxF"); + WRITE1 (h); + write_index_binary_header (idx, f); + WRITEVECTOR (idxf->xb); + } else if (const IndexBinaryIVF *ivf = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBwF"); + WRITE1 (h); + write_binary_ivf_header (ivf, f); + write_InvertedLists (ivf->invlists, f); + } else if(const IndexBinaryFromFloat * idxff = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBFf"); + WRITE1 (h); + write_index_binary_header (idxff, f); + write_index (idxff->index, f); + } else if (const IndexBinaryHNSW *idxhnsw = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBHf"); + WRITE1 (h); + write_index_binary_header (idxhnsw, f); + write_HNSW (&idxhnsw->hnsw, f); + write_index_binary (idxhnsw->storage, f); + } else if(const IndexBinaryIDMap * idxmap = + dynamic_cast (idx)) { + uint32_t h = + dynamic_cast (idx) ? fourcc ("IBM2") : + fourcc ("IBMp"); + // no need to store additional info for IndexIDMap2 + WRITE1 (h); + write_index_binary_header (idxmap, f); + write_index_binary (idxmap->index, f); + WRITEVECTOR (idxmap->id_map); + } else if (const IndexBinaryHash *idxh = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBHh"); + WRITE1 (h); + write_index_binary_header (idxh, f); + WRITE1 (idxh->b); + WRITE1 (idxh->nflip); + write_binary_hash_invlists(idxh->invlists, idxh->b, f); + } else if (const IndexBinaryMultiHash *idxmh = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IBHm"); + WRITE1 (h); + write_index_binary_header (idxmh, f); + write_index_binary (idxmh->storage, f); + WRITE1 (idxmh->b); + WRITE1 (idxmh->nhash); + WRITE1 (idxmh->nflip); + for (int i = 0; i < idxmh->nhash; i++) { + write_binary_multi_hash_map( + idxmh->maps[i], idxmh->b, idxmh->ntotal, f); + } + } else { + FAISS_THROW_MSG ("don't know how to serialize this type of index"); + } +} + +void write_index_binary (const IndexBinary *idx, FILE *f) { + FileIOWriter writer(f); + write_index_binary(idx, &writer); +} + +void write_index_binary (const IndexBinary *idx, const char *fname) { + FileIOWriter writer(fname); + write_index_binary (idx, &writer); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/io.cpp b/core/src/index/thirdparty/faiss/impl/io.cpp new file mode 100644 index 0000000000..0954f3c1fc --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/io.cpp @@ -0,0 +1,252 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include +#include + + +namespace faiss { + + +/*********************************************************************** + * IO functions + ***********************************************************************/ + + +int IOReader::fileno () +{ + FAISS_THROW_MSG ("IOReader does not support memory mapping"); +} + +int IOWriter::fileno () +{ + FAISS_THROW_MSG ("IOWriter does not support memory mapping"); +} + +/*********************************************************************** + * IO Vector + ***********************************************************************/ + + +size_t VectorIOWriter::operator()( + const void *ptr, size_t size, size_t nitems) +{ + size_t bytes = size * nitems; + if (bytes > 0) { + size_t o = data.size(); + data.resize(o + bytes); + memcpy (&data[o], ptr, size * nitems); + } + return nitems; +} + +size_t VectorIOReader::operator()( + void *ptr, size_t size, size_t nitems) +{ + if (rp >= data.size()) return 0; + size_t nremain = (data.size() - rp) / size; + if (nremain < nitems) nitems = nremain; + if (size * nitems > 0) { + memcpy (ptr, &data[rp], size * nitems); + rp += size * nitems; + } + return nitems; +} + + + + +/*********************************************************************** + * IO File + ***********************************************************************/ + + + +FileIOReader::FileIOReader(FILE *rf): f(rf) {} + +FileIOReader::FileIOReader(const char * fname) +{ + name = fname; + f = fopen(fname, "rb"); + FAISS_THROW_IF_NOT_FMT (f, "could not open %s for reading: %s", + fname, strerror(errno)); + need_close = true; +} + +FileIOReader::~FileIOReader() { + if (need_close) { + int ret = fclose(f); + if (ret != 0) {// we cannot raise and exception in the destructor + fprintf(stderr, "file %s close error: %s", + name.c_str(), strerror(errno)); + } + } +} + +size_t FileIOReader::operator()(void *ptr, size_t size, size_t nitems) { + return fread(ptr, size, nitems, f); +} + +int FileIOReader::fileno() { + return ::fileno (f); +} + + +FileIOWriter::FileIOWriter(FILE *wf): f(wf) {} + +FileIOWriter::FileIOWriter(const char * fname) +{ + name = fname; + f = fopen(fname, "wb"); + FAISS_THROW_IF_NOT_FMT (f, "could not open %s for writing: %s", + fname, strerror(errno)); + need_close = true; +} + +FileIOWriter::~FileIOWriter() { + if (need_close) { + int ret = fclose(f); + if (ret != 0) { + // we cannot raise and exception in the destructor + fprintf(stderr, "file %s close error: %s", + name.c_str(), strerror(errno)); + } + } +} + +size_t FileIOWriter::operator()(const void *ptr, size_t size, size_t nitems) { + return fwrite(ptr, size, nitems, f); +} + +int FileIOWriter::fileno() { + return ::fileno (f); +} + +/*********************************************************************** + * IO buffer + ***********************************************************************/ + +BufferedIOReader::BufferedIOReader(IOReader *reader, size_t bsz, size_t totsz): + reader(reader), bsz(bsz), totsz(totsz), ofs(0), b0(0), b1(0), buffer(bsz) +{ +} + + +size_t BufferedIOReader::operator()(void *ptr, size_t unitsize, size_t nitems) +{ + size_t size = unitsize * nitems; + if (size == 0) return 0; + char * dst = (char*)ptr; + size_t nb; + + { // first copy available bytes + nb = std::min(b1 - b0, size); + memcpy (dst, buffer.data() + b0, nb); + b0 += nb; + dst += nb; + size -= nb; + } + + if (size > totsz - ofs) { + size = totsz - ofs; + } + // while we would like to have more data + while (size > 0) { + assert (b0 == b1); // buffer empty on input + // try to read from main reader + b0 = 0; + b1 = (*reader)(buffer.data(), 1, std::min(bsz, size)); + + if (b1 == 0) { + // no more bytes available + break; + } + ofs += b1; + + // copy remaining bytes + size_t nb2 = std::min(b1, size); + memcpy (dst, buffer.data(), nb2); + b0 = nb2; + nb += nb2; + dst += nb2; + size -= nb2; + } + return nb / unitsize; +} + + +BufferedIOWriter::BufferedIOWriter(IOWriter *writer, size_t bsz): + writer(writer), bsz(bsz), b0(0), buffer(bsz) +{ +} + +size_t BufferedIOWriter::operator()(const void *ptr, size_t unitsize, size_t nitems) +{ + size_t size = unitsize * nitems; + if (size == 0) return 0; + const char * src = (const char*)ptr; + size_t nb; + + { // copy as many bytes as possible to buffer + nb = std::min(bsz - b0, size); + memcpy (buffer.data() + b0, src, nb); + b0 += nb; + src += nb; + size -= nb; + } + while (size > 0) { + assert(b0 == bsz); + // now we need to flush to add more bytes + size_t ofs = 0; + do { + assert (ofs < 10000000); + size_t written = (*writer)(buffer.data() + ofs, 1, bsz - ofs); + FAISS_THROW_IF_NOT(written > 0); + ofs += written; + } while(ofs != bsz); + + // copy src to buffer + size_t nb1 = std::min(bsz, size); + memcpy (buffer.data(), src, nb1); + b0 = nb1; + nb += nb1; + src += nb1; + size -= nb1; + } + + return nb / unitsize; +} + +BufferedIOWriter::~BufferedIOWriter() +{ + size_t ofs = 0; + while(ofs != b0) { + printf("Destructor write %ld \n", b0 - ofs); + size_t written = (*writer)(buffer.data() + ofs, 1, b0 - ofs); + FAISS_THROW_IF_NOT(written > 0); + ofs += written; + } + +} + + + + + +uint32_t fourcc (const char sx[4]) { + assert(4 == strlen(sx)); + const unsigned char *x = (unsigned char*)sx; + return x[0] | x[1] << 8 | x[2] << 16 | x[3] << 24; +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/io.h b/core/src/index/thirdparty/faiss/impl/io.h new file mode 100644 index 0000000000..a3a565af26 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/io.h @@ -0,0 +1,136 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/*********************************************************** + * Abstract I/O objects + * + * I/O is always sequential, seek does not need to be supported + * (indexes could be read or written to a pipe). + ***********************************************************/ + +#pragma once + +#include +#include +#include + +#include + +namespace faiss { + + +struct IOReader { + // name that can be used in error messages + std::string name; + + // fread + virtual size_t operator()( + void *ptr, size_t size, size_t nitems) = 0; + + // return a file number that can be memory-mapped + virtual int fileno (); + + virtual ~IOReader() {} +}; + +struct IOWriter { + // name that can be used in error messages + std::string name; + + // fwrite + virtual size_t operator()( + const void *ptr, size_t size, size_t nitems) = 0; + + // return a file number that can be memory-mapped + virtual int fileno (); + + virtual ~IOWriter() {} +}; + + +struct VectorIOReader:IOReader { + std::vector data; + size_t rp = 0; + size_t operator()(void *ptr, size_t size, size_t nitems) override; +}; + +struct VectorIOWriter:IOWriter { + std::vector data; + size_t operator()(const void *ptr, size_t size, size_t nitems) override; +}; + +struct FileIOReader: IOReader { + FILE *f = nullptr; + bool need_close = false; + + FileIOReader(FILE *rf); + + FileIOReader(const char * fname); + + ~FileIOReader() override; + + size_t operator()(void *ptr, size_t size, size_t nitems) override; + + int fileno() override; +}; + +struct FileIOWriter: IOWriter { + FILE *f = nullptr; + bool need_close = false; + + FileIOWriter(FILE *wf); + + FileIOWriter(const char * fname); + + ~FileIOWriter() override; + + size_t operator()(const void *ptr, size_t size, size_t nitems) override; + + int fileno() override; +}; + +/******************************************************* + * Buffered reader + writer + *******************************************************/ + + + +/** wraps an ioreader to make buffered reads to avoid too small reads */ +struct BufferedIOReader: IOReader { + + IOReader *reader; + size_t bsz, totsz, ofs; + size_t b0, b1; ///< range of available bytes in the buffer + std::vector buffer; + + BufferedIOReader(IOReader *reader, size_t bsz, + size_t totsz=(size_t)(-1)); + + size_t operator()(void *ptr, size_t size, size_t nitems) override; +}; + +struct BufferedIOWriter: IOWriter { + + IOWriter *writer; + size_t bsz, ofs; + size_t b0; ///< amount of data in buffer + std::vector buffer; + + BufferedIOWriter(IOWriter *writer, size_t bsz); + + size_t operator()(const void *ptr, size_t size, size_t nitems) override; + + // flushes + ~BufferedIOWriter(); +}; + +/// cast a 4-character string to a uint32_t that can be written and read easily +uint32_t fourcc (const char sx[4]); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/lattice_Zn.cpp b/core/src/index/thirdparty/faiss/impl/lattice_Zn.cpp new file mode 100644 index 0000000000..3e8458aa94 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/lattice_Zn.cpp @@ -0,0 +1,713 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace faiss { + +/******************************************** + * small utility functions + ********************************************/ + +namespace { + +inline float sqr(float x) { + return x * x; +} + + +typedef std::vector point_list_t; + +struct Comb { + std::vector tab; // Pascal's triangle + int nmax; + + explicit Comb(int nmax): nmax(nmax) { + tab.resize(nmax * nmax, 0); + tab[0] = 1; + for(int i = 1; i < nmax; i++) { + tab[i * nmax] = 1; + for(int j = 1; j <= i; j++) { + tab[i * nmax + j] = + tab[(i - 1) * nmax + j] + + tab[(i - 1) * nmax + (j - 1)]; + } + + } + } + + uint64_t operator()(int n, int p) const { + assert (n < nmax && p < nmax); + if (p > n) return 0; + return tab[n * nmax + p]; + } +}; + +Comb comb(100); + + + +// compute combinations of n integer values <= v that sum up to total (squared) +point_list_t sum_of_sq (float total, int v, int n, float add = 0) { + if (total < 0) { + return point_list_t(); + } else if (n == 1) { + while (sqr(v + add) > total) v--; + if (sqr(v + add) == total) { + return point_list_t(1, v + add); + } else { + return point_list_t(); + } + } else { + point_list_t res; + while (v >= 0) { + point_list_t sub_points = + sum_of_sq (total - sqr(v + add), v, n - 1, add); + for (size_t i = 0; i < sub_points.size(); i += n - 1) { + res.push_back (v + add); + for (int j = 0; j < n - 1; j++) { + res.push_back(sub_points[i + j]); + } + } + v--; + } + return res; + } +} + +int decode_comb_1 (uint64_t *n, int k1, int r) { + while (comb(r, k1) > *n) { + r--; + } + *n -= comb(r, k1); + return r; +} + +// optimized version for < 64 bits +long repeats_encode_64 ( + const std::vector & repeats, + int dim, const float *c) +{ + uint64_t coded = 0; + int nfree = dim; + uint64_t code = 0, shift = 1; + for (auto r = repeats.begin(); r != repeats.end(); ++r) { + int rank = 0, occ = 0; + uint64_t code_comb = 0; + uint64_t tosee = ~coded; + for(;;) { + // directly jump to next available slot. + int i = __builtin_ctzl(tosee); + tosee &= ~(1UL << i) ; + if (c[i] == r->val) { + code_comb += comb(rank, occ + 1); + occ++; + coded |= 1UL << i; + if (occ == r->n) break; + } + rank++; + } + uint64_t max_comb = comb(nfree, r->n); + code += shift * code_comb; + shift *= max_comb; + nfree -= r->n; + } + return code; +} + + +void repeats_decode_64( + const std::vector & repeats, + int dim, uint64_t code, float *c) +{ + uint64_t decoded = 0; + int nfree = dim; + for (auto r = repeats.begin(); r != repeats.end(); ++r) { + uint64_t max_comb = comb(nfree, r->n); + uint64_t code_comb = code % max_comb; + code /= max_comb; + + int occ = 0; + int rank = nfree; + int next_rank = decode_comb_1 (&code_comb, r->n, rank); + uint64_t tosee = ((1UL << dim) - 1) ^ decoded; + for(;;) { + int i = 63 - __builtin_clzl(tosee); + tosee &= ~(1UL << i); + rank--; + if (rank == next_rank) { + decoded |= 1UL << i; + c[i] = r->val; + occ++; + if (occ == r->n) break; + next_rank = decode_comb_1 ( + &code_comb, r->n - occ, next_rank); + } + } + nfree -= r->n; + } + +} + + + +} // anonymous namespace + +Repeats::Repeats (int dim, const float *c): dim(dim) +{ + for(int i = 0; i < dim; i++) { + int j = 0; + for(;;) { + if (j == repeats.size()) { + repeats.push_back(Repeat{c[i], 1}); + break; + } + if (repeats[j].val == c[i]) { + repeats[j].n++; + break; + } + j++; + } + } +} + + +long Repeats::count () const +{ + long accu = 1; + int remain = dim; + for (int i = 0; i < repeats.size(); i++) { + accu *= comb(remain, repeats[i].n); + remain -= repeats[i].n; + } + return accu; +} + + + +// version with a bool vector that works for > 64 dim +long Repeats::encode(const float *c) const +{ + if (dim < 64) { + return repeats_encode_64 (repeats, dim, c); + } + std::vector coded(dim, false); + int nfree = dim; + uint64_t code = 0, shift = 1; + for (auto r = repeats.begin(); r != repeats.end(); ++r) { + int rank = 0, occ = 0; + uint64_t code_comb = 0; + for (int i = 0; i < dim; i++) { + if (!coded[i]) { + if (c[i] == r->val) { + code_comb += comb(rank, occ + 1); + occ++; + coded[i] = true; + if (occ == r->n) break; + } + rank++; + } + } + uint64_t max_comb = comb(nfree, r->n); + code += shift * code_comb; + shift *= max_comb; + nfree -= r->n; + } + return code; +} + + + +void Repeats::decode(uint64_t code, float *c) const +{ + if (dim < 64) { + repeats_decode_64 (repeats, dim, code, c); + return; + } + + std::vector decoded(dim, false); + int nfree = dim; + for (auto r = repeats.begin(); r != repeats.end(); ++r) { + uint64_t max_comb = comb(nfree, r->n); + uint64_t code_comb = code % max_comb; + code /= max_comb; + + int occ = 0; + int rank = nfree; + int next_rank = decode_comb_1 (&code_comb, r->n, rank); + for (int i = dim - 1; i >= 0; i--) { + if (!decoded[i]) { + rank--; + if (rank == next_rank) { + decoded[i] = true; + c[i] = r->val; + occ++; + if (occ == r->n) break; + next_rank = decode_comb_1 ( + &code_comb, r->n - occ, next_rank); + } + } + } + nfree -= r->n; + } + +} + + + +/******************************************** + * EnumeratedVectors functions + ********************************************/ + + +void EnumeratedVectors::encode_multi(size_t n, const float *c, + uint64_t * codes) const +{ +#pragma omp parallel if (n > 1000) + { +#pragma omp for + for(int i = 0; i < n; i++) { + codes[i] = encode(c + i * dim); + } + } +} + + +void EnumeratedVectors::decode_multi(size_t n, const uint64_t * codes, + float *c) const +{ +#pragma omp parallel if (n > 1000) + { +#pragma omp for + for(int i = 0; i < n; i++) { + decode(codes[i], c + i * dim); + } + } +} + +void EnumeratedVectors::find_nn ( + size_t nc, const uint64_t * codes, + size_t nq, const float *xq, + long *labels, float *distances) +{ + for (long i = 0; i < nq; i++) { + distances[i] = -1e20; + labels[i] = -1; + } + + float c[dim]; + for(long i = 0; i < nc; i++) { + uint64_t code = codes[nc]; + decode(code, c); + for (long j = 0; j < nq; j++) { + const float *x = xq + j * dim; + float dis = fvec_inner_product(x, c, dim); + if (dis > distances[j]) { + distances[j] = dis; + labels[j] = i; + } + } + } + +} + + +/********************************************************** + * ZnSphereSearch + **********************************************************/ + + +ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) { + voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim); + natom = voc.size() / dim; +} + +float ZnSphereSearch::search(const float *x, float *c) const { + float tmp[dimS * 2]; + int tmp_int[dimS]; + return search(x, c, tmp, tmp_int); +} + +float ZnSphereSearch::search(const float *x, float *c, + float *tmp, // size 2 *dim + int *tmp_int, // size dim + int *ibest_out + ) const { + int dim = dimS; + assert (natom > 0); + int *o = tmp_int; + float *xabs = tmp; + float *xperm = tmp + dim; + + // argsort + for (int i = 0; i < dim; i++) { + o[i] = i; + xabs[i] = fabsf(x[i]); + } + std::sort(o, o + dim, [xabs](int a, int b) { + return xabs[a] > xabs[b]; + }); + for (int i = 0; i < dim; i++) { + xperm[i] = xabs[o[i]]; + } + // find best + int ibest = -1; + float dpbest = -100; + for (int i = 0; i < natom; i++) { + float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim); + if (dp > dpbest) { + dpbest = dp; + ibest = i; + } + } + // revert sort + const float *cin = voc.data() + ibest * dim; + for (int i = 0; i < dim; i++) { + c[o[i]] = copysignf (cin[i], x[o[i]]); + } + if (ibest_out) { + *ibest_out = ibest; + } + return dpbest; +} + +void ZnSphereSearch::search_multi(int n, const float *x, + float *c_out, + float *dp_out) { +#pragma omp parallel if (n > 1000) + { +#pragma omp for + for(int i = 0; i < n; i++) { + dp_out[i] = search(x + i * dimS, c_out + i * dimS); + } + } +} + + +/********************************************************** + * ZnSphereCodec + **********************************************************/ + +ZnSphereCodec::ZnSphereCodec(int dim, int r2): + ZnSphereSearch(dim, r2), + EnumeratedVectors(dim) +{ + nv = 0; + for (int i = 0; i < natom; i++) { + Repeats repeats(dim, &voc[i * dim]); + CodeSegment cs(repeats); + cs.c0 = nv; + Repeat &br = repeats.repeats.back(); + cs.signbits = br.val == 0 ? dim - br.n : dim; + code_segments.push_back(cs); + nv += repeats.count() << cs.signbits; + } + + uint64_t nvx = nv; + code_size = 0; + while (nvx > 0) { + nvx >>= 8; + code_size++; + } +} + +uint64_t ZnSphereCodec::search_and_encode(const float *x) const { + float tmp[dim * 2]; + int tmp_int[dim]; + int ano; // atom number + float c[dim]; + search(x, c, tmp, tmp_int, &ano); + uint64_t signs = 0; + float cabs[dim]; + int nnz = 0; + for (int i = 0; i < dim; i++) { + cabs[i] = fabs(c[i]); + if (c[i] != 0) { + if (c[i] < 0) { + signs |= 1UL << nnz; + } + nnz ++; + } + } + const CodeSegment &cs = code_segments[ano]; + assert(nnz == cs.signbits); + uint64_t code = cs.c0 + signs; + code += cs.encode(cabs) << cs.signbits; + return code; +} + +uint64_t ZnSphereCodec::encode(const float *x) const +{ + return search_and_encode(x); +} + + +void ZnSphereCodec::decode(uint64_t code, float *c) const { + int i0 = 0, i1 = natom; + while (i0 + 1 < i1) { + int imed = (i0 + i1) / 2; + if (code_segments[imed].c0 <= code) i0 = imed; + else i1 = imed; + } + const CodeSegment &cs = code_segments[i0]; + code -= cs.c0; + uint64_t signs = code; + code >>= cs.signbits; + cs.decode(code, c); + + int nnz = 0; + for (int i = 0; i < dim; i++) { + if (c[i] != 0) { + if (signs & (1UL << nnz)) { + c[i] = -c[i]; + } + nnz ++; + } + } +} + + +/************************************************************** + * ZnSphereCodecRec + **************************************************************/ + +uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const +{ + return all_nv[ld * (r2 + 1) + r2a]; +} + + +uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const +{ + return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a]; +} + +void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) +{ + all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum; +} + + +ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2): + EnumeratedVectors(dim), r2(r2) +{ + log2_dim = 0; + while (dim > (1 << log2_dim)) { + log2_dim++; + } + assert(dim == (1 << log2_dim) || + !"dimension must be a power of 2"); + + all_nv.resize((log2_dim + 1) * (r2 + 1)); + all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1)); + + for (int r2a = 0; r2a <= r2; r2a++) { + int r = int(sqrt(r2a)); + if (r * r == r2a) { + all_nv[r2a] = r == 0 ? 1 : 2; + } else { + all_nv[r2a] = 0; + } + } + + for (int ld = 1; ld <= log2_dim; ld++) { + + for (int r2sub = 0; r2sub <= r2; r2sub++) { + uint64_t nv = 0; + for (int r2a = 0; r2a <= r2sub; r2a++) { + int r2b = r2sub - r2a; + set_nv_cum(ld, r2sub, r2a, nv); + nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b); + } + all_nv[ld * (r2 + 1) + r2sub] = nv; + } + } + nv = get_nv(log2_dim, r2); + + uint64_t nvx = nv; + code_size = 0; + while (nvx > 0) { + nvx >>= 8; + code_size++; + } + + int cache_level = std::min(3, log2_dim - 1); + decode_cache_ld = 0; + assert(cache_level <= log2_dim); + decode_cache.resize((r2 + 1)); + + for (int r2sub = 0; r2sub <= r2; r2sub++) { + int ld = cache_level; + uint64_t nvi = get_nv(ld, r2sub); + std::vector &cache = decode_cache[r2sub]; + int dimsub = (1 << cache_level); + cache.resize (nvi * dimsub); + float c[dim]; + uint64_t code0 = get_nv_cum(cache_level + 1, r2, + r2 - r2sub); + for (int i = 0; i < nvi; i++) { + decode(i + code0, c); + memcpy(&cache[i * dimsub], c + dim - dimsub, + dimsub * sizeof(*c)); + } + } + decode_cache_ld = cache_level; +} + +uint64_t ZnSphereCodecRec::encode(const float *c) const +{ + return encode_centroid(c); +} + + + +uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const +{ + uint64_t codes[dim]; + int norm2s[dim]; + for(int i = 0; i < dim; i++) { + if (c[i] == 0) { + codes[i] = 0; + norm2s[i] = 0; + } else { + int r2i = int(c[i] * c[i]); + norm2s[i] = r2i; + codes[i] = c[i] >= 0 ? 0 : 1; + } + } + int dim2 = dim / 2; + for(int ld = 1; ld <= log2_dim; ld++) { + for (int i = 0; i < dim2; i++) { + int r2a = norm2s[2 * i]; + int r2b = norm2s[2 * i + 1]; + + uint64_t code_a = codes[2 * i]; + uint64_t code_b = codes[2 * i + 1]; + + codes[i] = + get_nv_cum(ld, r2a + r2b, r2a) + + code_a * get_nv(ld - 1, r2b) + + code_b; + norm2s[i] = r2a + r2b; + } + dim2 /= 2; + } + return codes[0]; +} + + + +void ZnSphereCodecRec::decode(uint64_t code, float *c) const +{ + uint64_t codes[dim]; + int norm2s[dim]; + codes[0] = code; + norm2s[0] = r2; + + int dim2 = 1; + for(int ld = log2_dim; ld > decode_cache_ld; ld--) { + for (int i = dim2 - 1; i >= 0; i--) { + int r2sub = norm2s[i]; + int i0 = 0, i1 = r2sub + 1; + uint64_t codei = codes[i]; + const uint64_t *cum = + &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)]; + while (i1 > i0 + 1) { + int imed = (i0 + i1) / 2; + if (cum[imed] <= codei) + i0 = imed; + else + i1 = imed; + } + int r2a = i0, r2b = r2sub - i0; + codei -= cum[r2a]; + norm2s[2 * i] = r2a; + norm2s[2 * i + 1] = r2b; + + uint64_t code_a = codei / get_nv(ld - 1, r2b); + uint64_t code_b = codei % get_nv(ld - 1, r2b); + + codes[2 * i] = code_a; + codes[2 * i + 1] = code_b; + + } + dim2 *= 2; + } + + if (decode_cache_ld == 0) { + for(int i = 0; i < dim; i++) { + if (norm2s[i] == 0) { + c[i] = 0; + } else { + float r = sqrt(norm2s[i]); + assert(r * r == norm2s[i]); + c[i] = codes[i] == 0 ? r : -r; + } + } + } else { + int subdim = 1 << decode_cache_ld; + assert ((dim2 * subdim) == dim); + + for(int i = 0; i < dim2; i++) { + + const std::vector & cache = + decode_cache[norm2s[i]]; + assert(codes[i] < cache.size()); + memcpy(c + i * subdim, + &cache[codes[i] * subdim], + sizeof(*c)* subdim); + } + } +} + +// if not use_rec, instanciate an arbitrary harmless znc_rec +ZnSphereCodecAlt::ZnSphereCodecAlt (int dim, int r2): + ZnSphereCodec (dim, r2), + use_rec ((dim & (dim - 1)) == 0), + znc_rec (use_rec ? dim : 8, + use_rec ? r2 : 14) +{} + +uint64_t ZnSphereCodecAlt::encode(const float *x) const +{ + if (!use_rec) { + // it's ok if the vector is not normalized + return ZnSphereCodec::encode(x); + } else { + // find nearest centroid + std::vector centroid(dim); + search (x, centroid.data()); + return znc_rec.encode(centroid.data()); + } +} + +void ZnSphereCodecAlt::decode(uint64_t code, float *c) const +{ + if (!use_rec) { + ZnSphereCodec::decode (code, c); + } else { + znc_rec.decode (code, c); + } +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/impl/lattice_Zn.h b/core/src/index/thirdparty/faiss/impl/lattice_Zn.h new file mode 100644 index 0000000000..f346d1e4c5 --- /dev/null +++ b/core/src/index/thirdparty/faiss/impl/lattice_Zn.h @@ -0,0 +1,199 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- +#ifndef FAISS_LATTICE_ZN_H +#define FAISS_LATTICE_ZN_H + +#include +#include +#include + +namespace faiss { + +/** returns the nearest vertex in the sphere to a query. Returns only + * the coordinates, not an id. + * + * Algorithm: all points are derived from a one atom vector up to a + * permutation and sign changes. The search function finds the most + * appropriate atom and transformation. + */ +struct ZnSphereSearch { + int dimS, r2; + int natom; + + /// size dim * ntatom + std::vector voc; + + ZnSphereSearch(int dim, int r2); + + /// find nearest centroid. x does not need to be normalized + float search(const float *x, float *c) const; + + /// full call. Requires externally-allocated temp space + float search(const float *x, float *c, + float *tmp, // size 2 *dim + int *tmp_int, // size dim + int *ibest_out = nullptr + ) const; + + // multi-threaded + void search_multi(int n, const float *x, + float *c_out, + float *dp_out); + +}; + + +/*************************************************************************** + * Support ids as well. + * + * Limitations: ids are limited to 64 bit + ***************************************************************************/ + +struct EnumeratedVectors { + /// size of the collection + uint64_t nv; + int dim; + + explicit EnumeratedVectors(int dim): nv(0), dim(dim) {} + + /// encode a vector from a collection + virtual uint64_t encode(const float *x) const = 0; + + /// decode it + virtual void decode(uint64_t code, float *c) const = 0; + + // call encode on nc vectors + void encode_multi (size_t nc, const float *c, + uint64_t * codes) const; + + // call decode on nc codes + void decode_multi (size_t nc, const uint64_t * codes, + float *c) const; + + // find the nearest neighbor of each xq + // (decodes and computes distances) + void find_nn (size_t n, const uint64_t * codes, + size_t nq, const float *xq, + long *idx, float *dis); + + virtual ~EnumeratedVectors() {} + +}; + +struct Repeat { + float val; + int n; +}; + +/** Repeats: used to encode a vector that has n occurrences of + * val. Encodes the signs and permutation of the vector. Useful for + * atoms. + */ +struct Repeats { + int dim; + std::vector repeats; + + // initialize from a template of the atom. + Repeats(int dim = 0, const float *c = nullptr); + + // count number of possible codes for this atom + long count() const; + + long encode(const float *c) const; + + void decode(uint64_t code, float *c) const; +}; + + +/** codec that can return ids for the encoded vectors + * + * uses the ZnSphereSearch to encode the vector by encoding the + * permutation and signs. Depends on ZnSphereSearch because it uses + * the atom numbers */ +struct ZnSphereCodec: ZnSphereSearch, EnumeratedVectors { + + struct CodeSegment:Repeats { + explicit CodeSegment(const Repeats & r): Repeats(r) {} + uint64_t c0; // first code assigned to segment + int signbits; + }; + + std::vector code_segments; + uint64_t nv; + size_t code_size; + + ZnSphereCodec(int dim, int r2); + + uint64_t search_and_encode(const float *x) const; + + void decode(uint64_t code, float *c) const override; + + /// takes vectors that do not need to be centroids + uint64_t encode(const float *x) const override; + +}; + +/** recursive sphere codec + * + * Uses a recursive decomposition on the dimensions to encode + * centroids found by the ZnSphereSearch. The codes are *not* + * compatible with the ones of ZnSpehreCodec + */ +struct ZnSphereCodecRec: EnumeratedVectors { + + int r2; + + int log2_dim; + int code_size; + + ZnSphereCodecRec(int dim, int r2); + + uint64_t encode_centroid(const float *c) const; + + void decode(uint64_t code, float *c) const override; + + /// vectors need to be centroids (does not work on arbitrary + /// vectors) + uint64_t encode(const float *x) const override; + + std::vector all_nv; + std::vector all_nv_cum; + + int decode_cache_ld; + std::vector > decode_cache; + + // nb of vectors in the sphere in dim 2^ld with r2 radius + uint64_t get_nv(int ld, int r2a) const; + + // cumulative version + uint64_t get_nv_cum(int ld, int r2t, int r2a) const; + void set_nv_cum(int ld, int r2t, int r2a, uint64_t v); + +}; + + +/** Codec that uses the recursive codec if dim is a power of 2 and + * the regular one otherwise */ +struct ZnSphereCodecAlt: ZnSphereCodec { + bool use_rec; + ZnSphereCodecRec znc_rec; + + ZnSphereCodecAlt (int dim, int r2); + + uint64_t encode(const float *x) const override; + + void decode(uint64_t code, float *c) const override; + +}; + + +}; + + +#endif diff --git a/core/src/index/thirdparty/faiss/index_factory.cpp b/core/src/index/thirdparty/faiss/index_factory.cpp new file mode 100644 index 0000000000..456b8e5356 --- /dev/null +++ b/core/src/index/thirdparty/faiss/index_factory.cpp @@ -0,0 +1,425 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * implementation of Hyper-parameter auto-tuning + */ + +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace faiss { + + +/*************************************************************** + * index_factory + ***************************************************************/ + +namespace { + +struct VTChain { + std::vector chain; + ~VTChain () { + for (int i = 0; i < chain.size(); i++) { + delete chain[i]; + } + } +}; + + +/// what kind of training does this coarse quantizer require? +char get_trains_alone(const Index *coarse_quantizer) { + return + dynamic_cast(coarse_quantizer) ? 1 : + dynamic_cast(coarse_quantizer) ? 2 : + 0; +} + + +} + +Index *index_factory (int d, const char *description_in, MetricType metric) +{ + FAISS_THROW_IF_NOT(metric == METRIC_L2 || + metric == METRIC_INNER_PRODUCT); + VTChain vts; + Index *coarse_quantizer = nullptr; + Index *index = nullptr; + bool add_idmap = false; + bool make_IndexRefineFlat = false; + + ScopeDeleter1 del_coarse_quantizer, del_index; + + char description[strlen(description_in) + 1]; + char *ptr; + memcpy (description, description_in, strlen(description_in) + 1); + + int64_t ncentroids = -1; + bool use_2layer = false; + + for (char *tok = strtok_r (description, " ,", &ptr); + tok; + tok = strtok_r (nullptr, " ,", &ptr)) { + int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2; + std::string stok(tok); + nbit = 8; + + // to avoid mem leaks with exceptions: + // do all tests before any instanciation + + VectorTransform *vt_1 = nullptr; + Index *coarse_quantizer_1 = nullptr; + Index *index_1 = nullptr; + + // VectorTransforms + if (sscanf (tok, "PCA%d", &d_out) == 1) { + vt_1 = new PCAMatrix (d, d_out); + d = d_out; + } else if (sscanf (tok, "PCAR%d", &d_out) == 1) { + vt_1 = new PCAMatrix (d, d_out, 0, true); + d = d_out; + } else if (sscanf (tok, "RR%d", &d_out) == 1) { + vt_1 = new RandomRotationMatrix (d, d_out); + d = d_out; + } else if (sscanf (tok, "PCAW%d", &d_out) == 1) { + vt_1 = new PCAMatrix (d, d_out, -0.5, false); + d = d_out; + } else if (sscanf (tok, "PCAWR%d", &d_out) == 1) { + vt_1 = new PCAMatrix (d, d_out, -0.5, true); + d = d_out; + } else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) { + vt_1 = new OPQMatrix (d, opq_M, d_out); + d = d_out; + } else if (sscanf (tok, "OPQ%d", &opq_M) == 1) { + vt_1 = new OPQMatrix (d, opq_M); + } else if (sscanf (tok, "ITQ%d", &d_out) == 1) { + vt_1 = new ITQTransform (d, d_out, true); + d = d_out; + } else if (stok == "ITQ") { + vt_1 = new ITQTransform (d, d, false); + } else if (sscanf (tok, "Pad%d", &d_out) == 1) { + if (d_out > d) { + vt_1 = new RemapDimensionsTransform (d, d_out, false); + d = d_out; + } + } else if (stok == "L2norm") { + vt_1 = new NormalizationTransform (d, 2.0); + + // coarse quantizers + } else if (!coarse_quantizer && + sscanf (tok, "IVF%ld_HNSW%d", &ncentroids, &M) == 2) { + FAISS_THROW_IF_NOT (metric == METRIC_L2); + coarse_quantizer_1 = new IndexHNSWFlat (d, M); + + } else if (!coarse_quantizer && + sscanf (tok, "IVF%ld", &ncentroids) == 1) { + if (metric == METRIC_L2) { + coarse_quantizer_1 = new IndexFlatL2 (d); + } else { + coarse_quantizer_1 = new IndexFlatIP (d); + } + } else if (!coarse_quantizer && sscanf (tok, "IMI2x%d", &nbit) == 1) { + FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2, + "MultiIndex not implemented for inner prod search"); + coarse_quantizer_1 = new MultiIndexQuantizer (d, 2, nbit); + ncentroids = 1 << (2 * nbit); + + } else if (!coarse_quantizer && + sscanf (tok, "Residual%dx%d", &M, &nbit) == 2) { + FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2, + "MultiIndex not implemented for inner prod search"); + coarse_quantizer_1 = new MultiIndexQuantizer (d, M, nbit); + ncentroids = int64_t(1) << (M * nbit); + use_2layer = true; + + } else if (!coarse_quantizer && + sscanf (tok, "Residual%ld", &ncentroids) == 1) { + coarse_quantizer_1 = new IndexFlatL2 (d); + use_2layer = true; + + } else if (stok == "IDMap") { + add_idmap = true; + + // IVFs + } else if (!index && (stok == "Flat" || stok == "FlatDedup")) { + if (coarse_quantizer) { + // if there was an IVF in front, then it is an IVFFlat + IndexIVF *index_ivf = stok == "Flat" ? + new IndexIVFFlat ( + coarse_quantizer, d, ncentroids, metric) : + new IndexIVFFlatDedup ( + coarse_quantizer, d, ncentroids, metric); + index_ivf->quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT; + del_coarse_quantizer.release (); + index_ivf->own_fields = true; + index_1 = index_ivf; + } else { + FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup", + "dedup supported only for IVFFlat"); + index_1 = new IndexFlat (d, metric); + } + } else if (!index && (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" || + stok == "SQfp16")) { + QuantizerType qt = + stok == "SQ8" ? QuantizerType::QT_8bit : + stok == "SQ6" ? QuantizerType::QT_6bit : + stok == "SQ4" ? QuantizerType::QT_4bit : + stok == "SQfp16" ? QuantizerType::QT_fp16 : + QuantizerType::QT_4bit; + if (coarse_quantizer) { + FAISS_THROW_IF_NOT (!use_2layer); + IndexIVFScalarQuantizer *index_ivf = + new IndexIVFScalarQuantizer ( + coarse_quantizer, d, ncentroids, qt, metric); + index_ivf->quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + del_coarse_quantizer.release (); + index_ivf->own_fields = true; + index_1 = index_ivf; + } else { + index_1 = new IndexScalarQuantizer (d, qt, metric); + } + } else if (!index && (stok == "SQ8Hybrid" || stok == "SQ4Hybrid" || stok == "SQ6Hybrid" || + stok == "SQfp16Hybrid")) { + QuantizerType qt = + stok == "SQ8Hybrid" ? QuantizerType::QT_8bit : + stok == "SQ6Hybrid" ? QuantizerType::QT_6bit : + stok == "SQ4Hybrid" ? QuantizerType::QT_4bit : + stok == "SQfp16Hybrid" ? QuantizerType::QT_fp16 : + QuantizerType::QT_4bit; + FAISS_THROW_IF_NOT_MSG(coarse_quantizer, + "SQ Hybrid only with an IVF"); + FAISS_THROW_IF_NOT (!use_2layer); + IndexIVFSQHybrid *index_ivf = + new IndexIVFSQHybrid ( + coarse_quantizer, d, ncentroids, qt, metric); + index_ivf->quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + del_coarse_quantizer.release (); + index_ivf->own_fields = true; + index_1 = index_ivf; + } else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) { + FAISS_THROW_IF_NOT_MSG(coarse_quantizer, + "PQ with + works only with an IVF"); + FAISS_THROW_IF_NOT_MSG(metric == METRIC_L2, + "IVFPQR not implemented for inner product search"); + IndexIVFPQR *index_ivf = new IndexIVFPQR ( + coarse_quantizer, d, ncentroids, M, 8, M2, 8); + index_ivf->quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + del_coarse_quantizer.release (); + index_ivf->own_fields = true; + index_1 = index_ivf; + } else if (!index && (sscanf (tok, "PQ%dx%d", &M, &nbit) == 2 || + sscanf (tok, "PQ%d", &M) == 1 || + sscanf (tok, "PQ%dnp", &M) == 1)) { + bool do_polysemous_training = stok.find("np") == std::string::npos; + if (coarse_quantizer) { + if (!use_2layer) { + IndexIVFPQ *index_ivf = new IndexIVFPQ ( + coarse_quantizer, d, ncentroids, M, nbit); + index_ivf->quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + index_ivf->metric_type = metric; + index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT; + del_coarse_quantizer.release (); + index_ivf->own_fields = true; + index_ivf->do_polysemous_training = do_polysemous_training; + index_1 = index_ivf; + } else { + Index2Layer *index_2l = new Index2Layer + (coarse_quantizer, ncentroids, M, nbit); + index_2l->q1.quantizer_trains_alone = + get_trains_alone (coarse_quantizer); + index_2l->q1.own_fields = true; + index_1 = index_2l; + } + } else { + IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric); + index_pq->do_polysemous_training = do_polysemous_training; + index_1 = index_pq; + } + } else if (!index && + sscanf (tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) { + Index * quant = new IndexFlatL2 (d); + IndexHNSW2Level * hidx2l = new IndexHNSW2Level (quant, ncent, pq_m, M); + Index2Layer * idx2l = dynamic_cast(hidx2l->storage); + idx2l->q1.own_fields = true; + index_1 = hidx2l; + } else if (!index && + sscanf (tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) { + Index * quant = new MultiIndexQuantizer (d, 2, nbit); + IndexHNSW2Level * hidx2l = + new IndexHNSW2Level (quant, 1 << (2 * nbit), pq_m, M); + Index2Layer * idx2l = dynamic_cast(hidx2l->storage); + idx2l->q1.own_fields = true; + idx2l->q1.quantizer_trains_alone = 1; + index_1 = hidx2l; + } else if (!index && + sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) { + index_1 = new IndexHNSWPQ (d, pq_m, M); + } else if (!index && + sscanf (tok, "HNSW%d", &M) == 1) { + index_1 = new IndexHNSWFlat (d, M); + } else if (!index && + sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 && + pq_m == 8) { + index_1 = new IndexHNSWSQ (d, QuantizerType::QT_8bit, M); + } else if (!index && (stok == "LSH" || stok == "LSHr" || + stok == "LSHrt" || stok == "LSHt")) { + bool rotate_data = strstr(tok, "r") != nullptr; + bool train_thresholds = strstr(tok, "t") != nullptr; + index_1 = new IndexLSH (d, d, rotate_data, train_thresholds); + } else if (!index && + sscanf (tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) { + FAISS_THROW_IF_NOT(!coarse_quantizer); + index_1 = new IndexLattice(d, M, nbit, r2); + } else if (stok == "RFlat") { + make_IndexRefineFlat = true; + } else { + FAISS_THROW_FMT( "could not parse token \"%s\" in %s\n", + tok, description_in); + } + + if (index_1 && add_idmap) { + IndexIDMap *idmap = new IndexIDMap(index_1); + del_index.set (idmap); + idmap->own_fields = true; + index_1 = idmap; + add_idmap = false; + } + + if (vt_1) { + vts.chain.push_back (vt_1); + } + + if (coarse_quantizer_1) { + coarse_quantizer = coarse_quantizer_1; + del_coarse_quantizer.set (coarse_quantizer); + } + + if (index_1) { + index = index_1; + del_index.set (index); + } + } + + FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index", + description_in); + + // nothing can go wrong now + del_index.release (); + del_coarse_quantizer.release (); + + if (add_idmap) { + fprintf(stderr, "index_factory: WARNING: " + "IDMap option not used\n"); + } + + if (vts.chain.size() > 0) { + IndexPreTransform *index_pt = new IndexPreTransform (index); + index_pt->own_fields = true; + // add from back + while (vts.chain.size() > 0) { + index_pt->prepend_transform (vts.chain.back ()); + vts.chain.pop_back (); + } + index = index_pt; + } + + if (make_IndexRefineFlat) { + IndexRefineFlat *index_rf = new IndexRefineFlat (index); + index_rf->own_fields = true; + index = index_rf; + } + + return index; +} + +IndexBinary *index_binary_factory(int d, const char *description, MetricType metric = METRIC_L2) +{ + IndexBinary *index = nullptr; + + int ncentroids = -1; + int M; + + ScopeDeleter1 del_index; + if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) { + IndexBinaryIVF *index_ivf = new IndexBinaryIVF( + new IndexBinaryHNSW(d, M), d, ncentroids + ); + index_ivf->own_fields = true; + index = index_ivf; + + } else if (sscanf(description, "BIVF%d", &ncentroids) == 1) { + IndexBinaryIVF *index_ivf = new IndexBinaryIVF( + new IndexBinaryFlat(d), d, ncentroids + ); + index_ivf->own_fields = true; + index = index_ivf; + + } else if (sscanf(description, "BHNSW%d", &M) == 1) { + IndexBinaryHNSW *index_hnsw = new IndexBinaryHNSW(d, M); + index = index_hnsw; + + } else if (std::string(description) == "BFlat") { + IndexBinary* index_x = new IndexBinaryFlat(d, metric); + + { + IndexBinaryIDMap *idmap = new IndexBinaryIDMap(index_x); + del_index.set (idmap); + idmap->own_fields = true; + index_x = idmap; + } + + if (index_x) { + index = index_x; + del_index.set(index); + } + + } else { + FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index", + description); + } + + del_index.release(); + + return index; +} + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/index_factory.h b/core/src/index/thirdparty/faiss/index_factory.h new file mode 100644 index 0000000000..ce62734298 --- /dev/null +++ b/core/src/index/thirdparty/faiss/index_factory.h @@ -0,0 +1,24 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include + +namespace faiss { + +/** Build and index with the sequence of processing steps described in + * the string. */ +Index *index_factory (int d, const char *description, + MetricType metric = METRIC_L2); + +IndexBinary *index_binary_factory (int d, const char *description, MetricType metric); + +} diff --git a/core/src/index/thirdparty/faiss/index_io.h b/core/src/index/thirdparty/faiss/index_io.h new file mode 100644 index 0000000000..ac686da71c --- /dev/null +++ b/core/src/index/thirdparty/faiss/index_io.h @@ -0,0 +1,86 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +// I/O code for indexes + +#ifndef FAISS_INDEX_IO_H +#define FAISS_INDEX_IO_H + + +#include + +/** I/O functions can read/write to a filename, a file handle or to an + * object that abstracts the medium. + * + * The read functions return objects that should be deallocated with + * delete. All references within these objectes are owned by the + * object. + */ + +namespace faiss { + +struct Index; +struct IndexBinary; +struct VectorTransform; +struct ProductQuantizer; +struct IOReader; +struct IOWriter; +struct InvertedLists; + +void write_index (const Index *idx, const char *fname); +void write_index (const Index *idx, FILE *f); +void write_index (const Index *idx, IOWriter *writer); + +void write_index_nm (const Index *idx, const char *fname); +void write_index_nm (const Index *idx, FILE *f); +void write_index_nm (const Index *idx, IOWriter *writer); + +void write_index_binary (const IndexBinary *idx, const char *fname); +void write_index_binary (const IndexBinary *idx, FILE *f); +void write_index_binary (const IndexBinary *idx, IOWriter *writer); + +// The read_index flags are implemented only for a subset of index types. +const int IO_FLAG_MMAP = 1; // try to memmap if possible +const int IO_FLAG_READ_ONLY = 2; +// strip directory component from ondisk filename, and assume it's in +// the same directory as the index file +const int IO_FLAG_ONDISK_SAME_DIR = 4; + +Index *read_index (const char *fname, int io_flags = 0); +Index *read_index (FILE * f, int io_flags = 0); +Index *read_index (IOReader *reader, int io_flags = 0); + +Index *read_index_nm (const char *fname, int io_flags = 0); +Index *read_index_nm (FILE * f, int io_flags = 0); +Index *read_index_nm (IOReader *reader, int io_flags = 0); + +IndexBinary *read_index_binary (const char *fname, int io_flags = 0); +IndexBinary *read_index_binary (FILE * f, int io_flags = 0); +IndexBinary *read_index_binary (IOReader *reader, int io_flags = 0); + +void write_VectorTransform (const VectorTransform *vt, const char *fname); +VectorTransform *read_VectorTransform (const char *fname); + +ProductQuantizer * read_ProductQuantizer (const char*fname); +ProductQuantizer * read_ProductQuantizer (IOReader *reader); + +void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname); +void write_ProductQuantizer (const ProductQuantizer*pq, IOWriter *f); + +void write_InvertedLists (const InvertedLists *ils, IOWriter *f); +InvertedLists *read_InvertedLists (IOReader *reader, int io_flags = 0); + +void write_InvertedLists_nm (const InvertedLists *ils, IOWriter *f); +InvertedLists *read_InvertedLists_nm (IOReader *reader, int io_flags = 0); + + +} // namespace faiss + + +#endif diff --git a/core/src/index/thirdparty/faiss/makefile.inc.in b/core/src/index/thirdparty/faiss/makefile.inc.in new file mode 100644 index 0000000000..244f94a17c --- /dev/null +++ b/core/src/index/thirdparty/faiss/makefile.inc.in @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +CXX = @CXX@ +CXXCPP = @CXXCPP@ +CPPFLAGS = -DFINTEGER=int @CPPFLAGS@ @OPENMP_CXXFLAGS@ @NVCC_CPPFLAGS@ +CXXFLAGS = -fPIC @ARCH_CXXFLAGS@ -Wno-sign-compare @CXXFLAGS@ +CPUFLAGS = @ARCH_CPUFLAGS@ +LDFLAGS = @OPENMP_LDFLAGS@ @LDFLAGS@ @NVCC_LDFLAGS@ +LIBS = @BLAS_LIBS@ @LAPACK_LIBS@ @LIBS@ @NVCC_LIBS@ +PYTHONCFLAGS = @PYTHON_CFLAGS@ -I@NUMPY_INCLUDE@ +SWIGFLAGS = -DSWIGWORDSIZE64 + +NVCC = @NVCC@ +CUDA_ROOT = @CUDA_PREFIX@ +CUDA_ARCH = @CUDA_ARCH@ +NVCCFLAGS = -I $(CUDA_ROOT)/targets/x86_64-linux/include/ \ +-O3 \ +-Xcompiler -fPIC \ +-Xcudafe --diag_suppress=unrecognized_attribute \ +$(CUDA_ARCH) \ +-lineinfo \ +-ccbin $(CXX) + +OS = $(shell uname -s) + +SHAREDEXT = so +SHAREDFLAGS = -shared + +ifeq ($(OS),Darwin) + SHAREDEXT = dylib + SHAREDFLAGS = -dynamiclib -undefined dynamic_lookup + SWIGFLAGS = +endif + +MKDIR_P = @MKDIR_P@ +PYTHON = @PYTHON@ +SWIG = @SWIG@ +AR ?= ar + +prefix ?= @prefix@ +exec_prefix ?= @exec_prefix@ +libdir = @libdir@ +includedir = @includedir@ diff --git a/core/src/index/thirdparty/faiss/misc/test_blas.cpp b/core/src/index/thirdparty/faiss/misc/test_blas.cpp new file mode 100644 index 0000000000..be2536497e --- /dev/null +++ b/core/src/index/thirdparty/faiss/misc/test_blas.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#undef FINTEGER +#define FINTEGER long + + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */ + +int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda, + float *tau, float *work, FINTEGER *lwork, FINTEGER *info); + +} + +float *new_random_vec(int size) +{ + float *x = new float[size]; + for (int i = 0; i < size; i++) + x[i] = drand48(); + return x; +} + + +int main() { + + FINTEGER m = 10, n = 20, k = 30; + float *a = new_random_vec(m * k), *b = new_random_vec(n * k), *c = new float[n * m]; + float one = 1.0, zero = 0.0; + + printf("BLAS test\n"); + + sgemm_("Not transposed", "Not transposed", + &m, &n, &k, &one, a, &m, b, &k, &zero, c, &m); + + printf("errors=\n"); + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + float accu = 0; + for (int l = 0; l < k; l++) + accu += a[i + l * m] * b[l + j * k]; + printf ("%6.3f ", accu - c[i + j * m]); + } + printf("\n"); + } + + long info = 0x64bL << 32; + long mi = 0x64bL << 32 | m; + float *tau = new float[m]; + FINTEGER lwork = -1; + + float work1; + + printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n"); + sgeqrf_ (&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info); + + // sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info); + printf("info=%016lx\n", info); + + if(info >> 32 == 0x64b) { + printf("Lapack uses 32-bit integers\n"); + } else { + printf("Lapack uses 64-bit integers\n"); + } + + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/python/Makefile b/core/src/index/thirdparty/faiss/python/Makefile new file mode 100644 index 0000000000..2836568253 --- /dev/null +++ b/core/src/index/thirdparty/faiss/python/Makefile @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include ../makefile.inc + +ifneq ($(strip $(NVCC)),) + SWIGFLAGS += -DGPU_WRAPPER +endif + +all: build + +# Also silently generates swigfaiss.py. +swigfaiss.cpp: swigfaiss.swig ../libfaiss.a + $(SWIG) -python -c++ -Doverride= -I../ $(SWIGFLAGS) -o $@ $< + +swigfaiss_avx2.cpp: swigfaiss.swig ../libfaiss.a + $(SWIG) -python -c++ -Doverride= -module swigfaiss_avx2 -I../ $(SWIGFLAGS) -o $@ $< + +%.o: %.cpp + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) $(PYTHONCFLAGS) \ + -I../ -c $< -o $@ + +# Extension is .so even on OSX. +_%.so: %.o ../libfaiss.a + $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) + +build: _swigfaiss.so faiss.py + $(PYTHON) setup.py build + +install: build + $(PYTHON) setup.py install + +clean: + rm -f swigfaiss*.cpp swigfaiss*.o swigfaiss*.py _swigfaiss*.so + rm -rf build/ + +.PHONY: all build clean install diff --git a/core/src/index/thirdparty/faiss/python/faiss.py b/core/src/index/thirdparty/faiss/python/faiss.py new file mode 100644 index 0000000000..2d58b7f708 --- /dev/null +++ b/core/src/index/thirdparty/faiss/python/faiss.py @@ -0,0 +1,812 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#@nolint + +# not linting this file because it imports * form swigfaiss, which +# causes a ton of useless warnings. + +import numpy as np +import sys +import inspect +import pdb +import platform +import subprocess +import logging + + +logger = logging.getLogger(__name__) + + +def instruction_set(): + if platform.system() == "Darwin": + if subprocess.check_output(["/usr/sbin/sysctl", "hw.optional.avx2_0"])[-1] == '1': + return "AVX2" + else: + return "default" + elif platform.system() == "Linux": + import numpy.distutils.cpuinfo + if "avx2" in numpy.distutils.cpuinfo.cpu.info[0].get('flags', ""): + return "AVX2" + else: + return "default" + + +try: + instr_set = instruction_set() + if instr_set == "AVX2": + logger.info("Loading faiss with AVX2 support.") + from .swigfaiss_avx2 import * + else: + logger.info("Loading faiss.") + from .swigfaiss import * + +except ImportError: + # we import * so that the symbol X can be accessed as faiss.X + logger.info("Loading faiss.") + from .swigfaiss import * + + +__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, + FAISS_VERSION_MINOR, + FAISS_VERSION_PATCH) + +################################################################## +# The functions below add or replace some methods for classes +# this is to be able to pass in numpy arrays directly +# The C++ version of the classnames will be suffixed with _c +################################################################## + + +def replace_method(the_class, name, replacement, ignore_missing=False): + try: + orig_method = getattr(the_class, name) + except AttributeError: + if ignore_missing: + return + raise + if orig_method.__name__ == 'replacement_' + name: + # replacement was done in parent class + return + setattr(the_class, name + '_c', orig_method) + setattr(the_class, name, replacement) + + +def handle_Clustering(): + def replacement_train(self, x, index, weights=None): + n, d = x.shape + assert d == self.d + if weights is not None: + assert weights.shape == (n, ) + self.train_c(n, swig_ptr(x), index, swig_ptr(weights)) + else: + self.train_c(n, swig_ptr(x), index) + def replacement_train_encoded(self, x, codec, index, weights=None): + n, d = x.shape + assert d == codec.sa_code_size() + assert codec.d == index.d + if weights is not None: + assert weights.shape == (n, ) + self.train_encoded_c(n, swig_ptr(x), codec, index, swig_ptr(weights)) + else: + self.train_encoded_c(n, swig_ptr(x), codec, index) + replace_method(Clustering, 'train', replacement_train) + replace_method(Clustering, 'train_encoded', replacement_train_encoded) + + +handle_Clustering() + + +def handle_Quantizer(the_class): + + def replacement_train(self, x): + n, d = x.shape + assert d == self.d + self.train_c(n, swig_ptr(x)) + + def replacement_compute_codes(self, x): + n, d = x.shape + assert d == self.d + codes = np.empty((n, self.code_size), dtype='uint8') + self.compute_codes_c(swig_ptr(x), swig_ptr(codes), n) + return codes + + def replacement_decode(self, codes): + n, cs = codes.shape + assert cs == self.code_size + x = np.empty((n, self.d), dtype='float32') + self.decode_c(swig_ptr(codes), swig_ptr(x), n) + return x + + replace_method(the_class, 'train', replacement_train) + replace_method(the_class, 'compute_codes', replacement_compute_codes) + replace_method(the_class, 'decode', replacement_decode) + + +handle_Quantizer(ProductQuantizer) +handle_Quantizer(ScalarQuantizer) + + +def handle_Index(the_class): + + def replacement_add(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d == self.d + self.add_c(n, swig_ptr(x)) + + def replacement_add_with_ids(self, x, ids): + n, d = x.shape + assert d == self.d + assert ids.shape == (n, ), 'not same nb of vectors as ids' + self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) + + def replacement_assign(self, x, k): + n, d = x.shape + assert d == self.d + labels = np.empty((n, k), dtype=np.int64) + self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) + return labels + + def replacement_train(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d == self.d + self.train_c(n, swig_ptr(x)) + + def replacement_search(self, x, k): + n, d = x.shape + assert d == self.d + distances = np.empty((n, k), dtype=np.float32) + labels = np.empty((n, k), dtype=np.int64) + self.search_c(n, swig_ptr(x), + k, swig_ptr(distances), + swig_ptr(labels)) + return distances, labels + + def replacement_search_and_reconstruct(self, x, k): + n, d = x.shape + assert d == self.d + distances = np.empty((n, k), dtype=np.float32) + labels = np.empty((n, k), dtype=np.int64) + recons = np.empty((n, k, d), dtype=np.float32) + self.search_and_reconstruct_c(n, swig_ptr(x), + k, swig_ptr(distances), + swig_ptr(labels), + swig_ptr(recons)) + return distances, labels, recons + + def replacement_remove_ids(self, x): + if isinstance(x, IDSelector): + sel = x + else: + assert x.ndim == 1 + index_ivf = try_extract_index_ivf (self) + if index_ivf and index_ivf.direct_map.type == DirectMap.Hashtable: + sel = IDSelectorArray(x.size, swig_ptr(x)) + else: + sel = IDSelectorBatch(x.size, swig_ptr(x)) + return self.remove_ids_c(sel) + + def replacement_reconstruct(self, key): + x = np.empty(self.d, dtype=np.float32) + self.reconstruct_c(key, swig_ptr(x)) + return x + + def replacement_reconstruct_n(self, n0, ni): + x = np.empty((ni, self.d), dtype=np.float32) + self.reconstruct_n_c(n0, ni, swig_ptr(x)) + return x + + def replacement_update_vectors(self, keys, x): + n = keys.size + assert keys.shape == (n, ) + assert x.shape == (n, self.d) + self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x)) + + def replacement_range_search(self, x, thresh): + n, d = x.shape + assert d == self.d + res = RangeSearchResult(n) + self.range_search_c(n, swig_ptr(x), thresh, res) + # get pointers and copy them + lims = rev_swig_ptr(res.lims, n + 1).copy() + nd = int(lims[-1]) + D = rev_swig_ptr(res.distances, nd).copy() + I = rev_swig_ptr(res.labels, nd).copy() + return lims, D, I + + def replacement_sa_encode(self, x): + n, d = x.shape + assert d == self.d + codes = np.empty((n, self.sa_code_size()), dtype='uint8') + self.sa_encode_c(n, swig_ptr(x), swig_ptr(codes)) + return codes + + def replacement_sa_decode(self, codes): + n, cs = codes.shape + assert cs == self.sa_code_size() + x = np.empty((n, self.d), dtype='float32') + self.sa_decode_c(n, swig_ptr(codes), swig_ptr(x)) + return x + + replace_method(the_class, 'add', replacement_add) + replace_method(the_class, 'add_with_ids', replacement_add_with_ids) + replace_method(the_class, 'assign', replacement_assign) + replace_method(the_class, 'train', replacement_train) + replace_method(the_class, 'search', replacement_search) + replace_method(the_class, 'remove_ids', replacement_remove_ids) + replace_method(the_class, 'reconstruct', replacement_reconstruct) + replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) + replace_method(the_class, 'range_search', replacement_range_search) + replace_method(the_class, 'update_vectors', replacement_update_vectors, + ignore_missing=True) + replace_method(the_class, 'search_and_reconstruct', + replacement_search_and_reconstruct, ignore_missing=True) + replace_method(the_class, 'sa_encode', replacement_sa_encode) + replace_method(the_class, 'sa_decode', replacement_sa_decode) + +def handle_IndexBinary(the_class): + + def replacement_add(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d * 8 == self.d + self.add_c(n, swig_ptr(x)) + + def replacement_add_with_ids(self, x, ids): + n, d = x.shape + assert d * 8 == self.d + assert ids.shape == (n, ), 'not same nb of vectors as ids' + self.add_with_ids_c(n, swig_ptr(x), swig_ptr(ids)) + + def replacement_train(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d * 8 == self.d + self.train_c(n, swig_ptr(x)) + + def replacement_reconstruct(self, key): + x = np.empty(self.d // 8, dtype=np.uint8) + self.reconstruct_c(key, swig_ptr(x)) + return x + + def replacement_search(self, x, k): + n, d = x.shape + assert d * 8 == self.d + distances = np.empty((n, k), dtype=np.int32) + labels = np.empty((n, k), dtype=np.int64) + self.search_c(n, swig_ptr(x), + k, swig_ptr(distances), + swig_ptr(labels)) + return distances, labels + + def replacement_range_search(self, x, thresh): + n, d = x.shape + assert d * 8 == self.d + res = RangeSearchResult(n) + self.range_search_c(n, swig_ptr(x), thresh, res) + # get pointers and copy them + lims = rev_swig_ptr(res.lims, n + 1).copy() + nd = int(lims[-1]) + D = rev_swig_ptr(res.distances, nd).copy() + I = rev_swig_ptr(res.labels, nd).copy() + return lims, D, I + + def replacement_remove_ids(self, x): + if isinstance(x, IDSelector): + sel = x + else: + assert x.ndim == 1 + sel = IDSelectorBatch(x.size, swig_ptr(x)) + return self.remove_ids_c(sel) + + replace_method(the_class, 'add', replacement_add) + replace_method(the_class, 'add_with_ids', replacement_add_with_ids) + replace_method(the_class, 'train', replacement_train) + replace_method(the_class, 'search', replacement_search) + replace_method(the_class, 'range_search', replacement_range_search) + replace_method(the_class, 'reconstruct', replacement_reconstruct) + replace_method(the_class, 'remove_ids', replacement_remove_ids) + + +def handle_VectorTransform(the_class): + + def apply_method(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d == self.d_in + y = np.empty((n, self.d_out), dtype=np.float32) + self.apply_noalloc(n, swig_ptr(x), swig_ptr(y)) + return y + + def replacement_reverse_transform(self, x): + n, d = x.shape + assert d == self.d_out + y = np.empty((n, self.d_in), dtype=np.float32) + self.reverse_transform_c(n, swig_ptr(x), swig_ptr(y)) + return y + + def replacement_vt_train(self, x): + assert x.flags.contiguous + n, d = x.shape + assert d == self.d_in + self.train_c(n, swig_ptr(x)) + + replace_method(the_class, 'train', replacement_vt_train) + # apply is reserved in Pyton... + the_class.apply_py = apply_method + replace_method(the_class, 'reverse_transform', + replacement_reverse_transform) + + +def handle_AutoTuneCriterion(the_class): + def replacement_set_groundtruth(self, D, I): + if D: + assert I.shape == D.shape + self.nq, self.gt_nnn = I.shape + self.set_groundtruth_c( + self.gt_nnn, swig_ptr(D) if D else None, swig_ptr(I)) + + def replacement_evaluate(self, D, I): + assert I.shape == D.shape + assert I.shape == (self.nq, self.nnn) + return self.evaluate_c(swig_ptr(D), swig_ptr(I)) + + replace_method(the_class, 'set_groundtruth', replacement_set_groundtruth) + replace_method(the_class, 'evaluate', replacement_evaluate) + + +def handle_ParameterSpace(the_class): + def replacement_explore(self, index, xq, crit): + assert xq.shape == (crit.nq, index.d) + ops = OperatingPoints() + self.explore_c(index, crit.nq, swig_ptr(xq), + crit, ops) + return ops + replace_method(the_class, 'explore', replacement_explore) + + +def handle_MatrixStats(the_class): + original_init = the_class.__init__ + + def replacement_init(self, m): + assert len(m.shape) == 2 + original_init(self, m.shape[0], m.shape[1], swig_ptr(m)) + + the_class.__init__ = replacement_init + +handle_MatrixStats(MatrixStats) + + +this_module = sys.modules[__name__] + + +for symbol in dir(this_module): + obj = getattr(this_module, symbol) + # print symbol, isinstance(obj, (type, types.ClassType)) + if inspect.isclass(obj): + the_class = obj + if issubclass(the_class, Index): + handle_Index(the_class) + + if issubclass(the_class, IndexBinary): + handle_IndexBinary(the_class) + + if issubclass(the_class, VectorTransform): + handle_VectorTransform(the_class) + + if issubclass(the_class, AutoTuneCriterion): + handle_AutoTuneCriterion(the_class) + + if issubclass(the_class, ParameterSpace): + handle_ParameterSpace(the_class) + + +########################################### +# Add Python references to objects +# we do this at the Python class wrapper level. +########################################### + +def add_ref_in_constructor(the_class, parameter_no): + # adds a reference to parameter parameter_no in self + # so that that parameter does not get deallocated before self + original_init = the_class.__init__ + + def replacement_init(self, *args): + original_init(self, *args) + self.referenced_objects = [args[parameter_no]] + + def replacement_init_multiple(self, *args): + original_init(self, *args) + pset = parameter_no[len(args)] + self.referenced_objects = [args[no] for no in pset] + + if type(parameter_no) == dict: + # a list of parameters to keep, depending on the number of arguments + the_class.__init__ = replacement_init_multiple + else: + the_class.__init__ = replacement_init + +def add_ref_in_method(the_class, method_name, parameter_no): + original_method = getattr(the_class, method_name) + def replacement_method(self, *args): + ref = args[parameter_no] + if not hasattr(self, 'referenced_objects'): + self.referenced_objects = [ref] + else: + self.referenced_objects.append(ref) + return original_method(self, *args) + setattr(the_class, method_name, replacement_method) + +def add_ref_in_function(function_name, parameter_no): + # assumes the function returns an object + original_function = getattr(this_module, function_name) + def replacement_function(*args): + result = original_function(*args) + ref = args[parameter_no] + result.referenced_objects = [ref] + return result + setattr(this_module, function_name, replacement_function) + +add_ref_in_constructor(IndexIVFFlat, 0) +add_ref_in_constructor(IndexIVFFlatDedup, 0) +add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]}) +add_ref_in_method(IndexPreTransform, 'prepend_transform', 0) +add_ref_in_constructor(IndexIVFPQ, 0) +add_ref_in_constructor(IndexIVFPQR, 0) +add_ref_in_constructor(Index2Layer, 0) +add_ref_in_constructor(Level1Quantizer, 0) +add_ref_in_constructor(IndexIVFScalarQuantizer, 0) +add_ref_in_constructor(IndexIDMap, 0) +add_ref_in_constructor(IndexIDMap2, 0) +add_ref_in_constructor(IndexHNSW, 0) +add_ref_in_method(IndexShards, 'add_shard', 0) +add_ref_in_method(IndexBinaryShards, 'add_shard', 0) +add_ref_in_constructor(IndexRefineFlat, 0) +add_ref_in_constructor(IndexBinaryIVF, 0) +add_ref_in_constructor(IndexBinaryFromFloat, 0) +add_ref_in_constructor(IndexBinaryIDMap, 0) +add_ref_in_constructor(IndexBinaryIDMap2, 0) + +add_ref_in_method(IndexReplicas, 'addIndex', 0) +add_ref_in_method(IndexBinaryReplicas, 'addIndex', 0) + +add_ref_in_constructor(BufferedIOWriter, 0) +add_ref_in_constructor(BufferedIOReader, 0) + +# seems really marginal... +# remove_ref_from_method(IndexReplicas, 'removeIndex', 0) + +if hasattr(this_module, 'GpuIndexFlat'): + # handle all the GPUResources refs + add_ref_in_function('index_cpu_to_gpu', 0) + add_ref_in_constructor(GpuIndexFlat, 0) + add_ref_in_constructor(GpuIndexFlatIP, 0) + add_ref_in_constructor(GpuIndexFlatL2, 0) + add_ref_in_constructor(GpuIndexIVFFlat, 0) + add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 0) + add_ref_in_constructor(GpuIndexIVFPQ, 0) + add_ref_in_constructor(GpuIndexBinaryFlat, 0) + + + +########################################### +# GPU functions +########################################### + + +def index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None): + """ builds the C++ vectors for the GPU indices and the + resources. Handles the case where the resources are assigned to + the list of GPUs """ + if gpus is None: + gpus = range(len(resources)) + vres = GpuResourcesVector() + vdev = IntVector() + for i, res in zip(gpus, resources): + vdev.push_back(i) + vres.push_back(res) + index = index_cpu_to_gpu_multiple(vres, vdev, index, co) + index.referenced_objects = resources + return index + + +def index_cpu_to_all_gpus(index, co=None, ngpu=-1): + index_gpu = index_cpu_to_gpus_list(index, co=co, gpus=None, ngpu=ngpu) + return index_gpu + + +def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1): + """ Here we can pass list of GPU ids as a parameter or ngpu to + use first n GPU's. gpus mut be a list or None""" + if (gpus is None) and (ngpu == -1): # All blank + gpus = range(get_num_gpus()) + elif (gpus is None) and (ngpu != -1): # Get number of GPU's only + gpus = range(ngpu) + res = [StandardGpuResources() for _ in gpus] + index_gpu = index_cpu_to_gpu_multiple_py(res, index, co, gpus) + return index_gpu + + +########################################### +# numpy array / std::vector conversions +########################################### + +# mapping from vector names in swigfaiss.swig and the numpy dtype names +vector_name_map = { + 'Float': 'float32', + 'Byte': 'uint8', + 'Char': 'int8', + 'Uint64': 'uint64', + 'Long': 'int64', + 'Int': 'int32', + 'Double': 'float64' + } + +def vector_to_array(v): + """ convert a C++ vector to a numpy array """ + classname = v.__class__.__name__ + assert classname.endswith('Vector') + dtype = np.dtype(vector_name_map[classname[:-6]]) + a = np.empty(v.size(), dtype=dtype) + if v.size() > 0: + memcpy(swig_ptr(a), v.data(), a.nbytes) + return a + + +def vector_float_to_array(v): + return vector_to_array(v) + + +def copy_array_to_vector(a, v): + """ copy a numpy array to a vector """ + n, = a.shape + classname = v.__class__.__name__ + assert classname.endswith('Vector') + dtype = np.dtype(vector_name_map[classname[:-6]]) + assert dtype == a.dtype, ( + 'cannot copy a %s array to a %s (should be %s)' % ( + a.dtype, classname, dtype)) + v.resize(n) + if n > 0: + memcpy(v.data(), swig_ptr(a), a.nbytes) + + +########################################### +# Wrapper for a few functions +########################################### + +def kmin(array, k): + """return k smallest values (and their indices) of the lines of a + float32 array""" + m, n = array.shape + I = np.zeros((m, k), dtype='int64') + D = np.zeros((m, k), dtype='float32') + ha = float_maxheap_array_t() + ha.ids = swig_ptr(I) + ha.val = swig_ptr(D) + ha.nh = m + ha.k = k + ha.heapify() + ha.addn(n, swig_ptr(array)) + ha.reorder() + return D, I + + +def kmax(array, k): + """return k largest values (and their indices) of the lines of a + float32 array""" + m, n = array.shape + I = np.zeros((m, k), dtype='int64') + D = np.zeros((m, k), dtype='float32') + ha = float_minheap_array_t() + ha.ids = swig_ptr(I) + ha.val = swig_ptr(D) + ha.nh = m + ha.k = k + ha.heapify() + ha.addn(n, swig_ptr(array)) + ha.reorder() + return D, I + + +def pairwise_distances(xq, xb, mt=METRIC_L2, metric_arg=0): + """compute the whole pairwise distance matrix between two sets of + vectors""" + nq, d = xq.shape + nb, d2 = xb.shape + assert d == d2 + dis = np.empty((nq, nb), dtype='float32') + if mt == METRIC_L2: + pairwise_L2sqr( + d, nq, swig_ptr(xq), + nb, swig_ptr(xb), + swig_ptr(dis)) + else: + pairwise_extra_distances( + d, nq, swig_ptr(xq), + nb, swig_ptr(xb), + mt, metric_arg, + swig_ptr(dis)) + return dis + + + + +def rand(n, seed=12345): + res = np.empty(n, dtype='float32') + float_rand(swig_ptr(res), res.size, seed) + return res + + +def randint(n, seed=12345, vmax=None): + res = np.empty(n, dtype='int64') + if vmax is None: + int64_rand(swig_ptr(res), res.size, seed) + else: + int64_rand_max(swig_ptr(res), res.size, vmax, seed) + return res + +lrand = randint + +def randn(n, seed=12345): + res = np.empty(n, dtype='float32') + float_randn(swig_ptr(res), res.size, seed) + return res + + +def eval_intersection(I1, I2): + """ size of intersection between each line of two result tables""" + n = I1.shape[0] + assert I2.shape[0] == n + k1, k2 = I1.shape[1], I2.shape[1] + ninter = 0 + for i in range(n): + ninter += ranklist_intersection_size( + k1, swig_ptr(I1[i]), k2, swig_ptr(I2[i])) + return ninter + + +def normalize_L2(x): + fvec_renorm_L2(x.shape[1], x.shape[0], swig_ptr(x)) + +# MapLong2Long interface + +def replacement_map_add(self, keys, vals): + n, = keys.shape + assert (n,) == keys.shape + self.add_c(n, swig_ptr(keys), swig_ptr(vals)) + +def replacement_map_search_multiple(self, keys): + n, = keys.shape + vals = np.empty(n, dtype='int64') + self.search_multiple_c(n, swig_ptr(keys), swig_ptr(vals)) + return vals + +replace_method(MapLong2Long, 'add', replacement_map_add) +replace_method(MapLong2Long, 'search_multiple', replacement_map_search_multiple) + + +########################################### +# Kmeans object +########################################### + + +class Kmeans: + """shallow wrapper around the Clustering object. The important method + is train().""" + + def __init__(self, d, k, **kwargs): + """d: input dimension, k: nb of centroids. Additional + parameters are passed on the ClusteringParameters object, + including niter=25, verbose=False, spherical = False + """ + self.d = d + self.k = k + self.gpu = False + self.cp = ClusteringParameters() + for k, v in kwargs.items(): + if k == 'gpu': + self.gpu = v + else: + # if this raises an exception, it means that it is a non-existent field + getattr(self.cp, k) + setattr(self.cp, k, v) + self.centroids = None + + def train(self, x, weights=None): + n, d = x.shape + assert d == self.d + clus = Clustering(d, self.k, self.cp) + if self.cp.spherical: + self.index = IndexFlatIP(d) + else: + self.index = IndexFlatL2(d) + if self.gpu: + if self.gpu == True: + ngpu = -1 + else: + ngpu = self.gpu + self.index = index_cpu_to_all_gpus(self.index, ngpu=ngpu) + clus.train(x, self.index, weights) + centroids = vector_float_to_array(clus.centroids) + self.centroids = centroids.reshape(self.k, d) + stats = clus.iteration_stats + self.obj = np.array([ + stats.at(i).obj for i in range(stats.size()) + ]) + return self.obj[-1] if self.obj.size > 0 else 0.0 + + def assign(self, x): + assert self.centroids is not None, "should train before assigning" + self.index.reset() + self.index.add(self.centroids) + D, I = self.index.search(x, 1) + return D.ravel(), I.ravel() + +# IndexProxy was renamed to IndexReplicas, remap the old name for any old code +# people may have +IndexProxy = IndexReplicas +ConcatenatedInvertedLists = HStackInvertedLists + +########################################### +# serialization of indexes to byte arrays +########################################### + +def serialize_index(index): + """ convert an index to a numpy uint8 array """ + writer = VectorIOWriter() + write_index(index, writer) + return vector_to_array(writer.data) + +def deserialize_index(data): + reader = VectorIOReader() + copy_array_to_vector(data, reader.data) + return read_index(reader) + +def serialize_index_binary(index): + """ convert an index to a numpy uint8 array """ + writer = VectorIOWriter() + write_index_binary(index, writer) + return vector_to_array(writer.data) + +def deserialize_index_binary(data): + reader = VectorIOReader() + copy_array_to_vector(data, reader.data) + return read_index_binary(reader) + + +########################################### +# ResultHeap +########################################### + +class ResultHeap: + """Accumulate query results from a sliced dataset. The final result will + be in self.D, self.I.""" + + def __init__(self, nq, k): + " nq: number of query vectors, k: number of results per query " + self.I = np.zeros((nq, k), dtype='int64') + self.D = np.zeros((nq, k), dtype='float32') + self.nq, self.k = nq, k + heaps = float_maxheap_array_t() + heaps.k = k + heaps.nh = nq + heaps.val = swig_ptr(self.D) + heaps.ids = swig_ptr(self.I) + heaps.heapify() + self.heaps = heaps + + def add_result(self, D, I): + """D, I do not need to be in a particular order (heap or sorted)""" + assert D.shape == (self.nq, self.k) + assert I.shape == (self.nq, self.k) + self.heaps.addn_with_ids( + self.k, faiss.swig_ptr(D), + faiss.swig_ptr(I), self.k) + + def finalize(self): + self.heaps.reorder() diff --git a/core/src/index/thirdparty/faiss/python/setup.py b/core/src/index/thirdparty/faiss/python/setup.py new file mode 100644 index 0000000000..89b6d398cb --- /dev/null +++ b/core/src/index/thirdparty/faiss/python/setup.py @@ -0,0 +1,50 @@ +from __future__ import print_function +from setuptools import setup, find_packages +import os +import shutil + +here = os.path.abspath(os.path.dirname(__file__)) + +check_fpath = os.path.join("_swigfaiss.so") +if not os.path.exists(check_fpath): + print("Could not find {}".format(check_fpath)) + print("Have you run `make` and `make -C python`?") + +# make the faiss python package dir +shutil.rmtree("faiss", ignore_errors=True) +os.mkdir("faiss") +shutil.copyfile("faiss.py", "faiss/__init__.py") +shutil.copyfile("swigfaiss.py", "faiss/swigfaiss.py") +shutil.copyfile("_swigfaiss.so", "faiss/_swigfaiss.so") +try: + shutil.copyfile("swigfaiss_avx2.py", "faiss/swigfaiss_avx2.py") + shutil.copyfile("_swigfaiss_avx2.so", "faiss/_swigfaiss_avx2.so") +except: + pass + +long_description=""" +Faiss is a library for efficient similarity search and clustering of dense +vectors. It contains algorithms that search in sets of vectors of any size, + up to ones that possibly do not fit in RAM. It also contains supporting +code for evaluation and parameter tuning. Faiss is written in C++ with +complete wrappers for Python/numpy. Some of the most useful algorithms +are implemented on the GPU. It is developed by Facebook AI Research. +""" +setup( + name='faiss', + version='1.6.3', + description='A library for efficient similarity search and clustering of dense vectors', + long_description=long_description, + url='https://github.com/facebookresearch/faiss', + author='Matthijs Douze, Jeff Johnson, Herve Jegou, Lucas Hosseini', + author_email='matthijs@fb.com', + license='MIT', + keywords='search nearest neighbors', + + install_requires=['numpy'], + packages=['faiss'], + package_data={ + 'faiss': ['*.so'], + }, + +) diff --git a/core/src/index/thirdparty/faiss/python/swigfaiss.swig b/core/src/index/thirdparty/faiss/python/swigfaiss.swig new file mode 100644 index 0000000000..b0d8b8173d --- /dev/null +++ b/core/src/index/thirdparty/faiss/python/swigfaiss.swig @@ -0,0 +1,1109 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- C++ -*- + +// This file describes the C++-scripting language bridge for both Lua +// and Python It contains mainly includes and a few macros. There are +// 3 preprocessor macros of interest: +// SWIGLUA: Lua-specific code +// SWIGPYTHON: Python-specific code +// GPU_WRAPPER: also compile interfaces for GPU. + +%module swigfaiss; + +// fbode SWIG fails on warnings, so make them non fatal +#pragma SWIG nowarn=321 +#pragma SWIG nowarn=403 +#pragma SWIG nowarn=325 +#pragma SWIG nowarn=389 +#pragma SWIG nowarn=341 +#pragma SWIG nowarn=512 + +%include + +typedef uint64_t size_t; + +#define __restrict + + +/******************************************************************* + * Copied verbatim to wrapper. Contains the C++-visible includes, and + * the language includes for their respective matrix libraries. + *******************************************************************/ + +%{ + + +#include +#include + + +#ifdef SWIGLUA + +#include + +extern "C" { + +#include +#include +#undef THTensor + +} + +#endif + + +#ifdef SWIGPYTHON + +#undef popcount64 + +#define SWIG_FILE_WITH_INIT +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +#endif + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include + +#include +#include + + +%} + +/******************************************************** + * GIL manipulation and exception handling + ********************************************************/ + +#ifdef SWIGPYTHON +// %catches(faiss::FaissException); + + +// Python-specific: release GIL by default for all functions +%exception { + Py_BEGIN_ALLOW_THREADS + try { + $action + } catch(faiss::FaissException & e) { + PyEval_RestoreThread(_save); + + if (PyErr_Occurred()) { + // some previous code already set the error type. + } else { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + SWIG_fail; + } catch(std::bad_alloc & ba) { + PyEval_RestoreThread(_save); + PyErr_SetString(PyExc_MemoryError, "std::bad_alloc"); + SWIG_fail; + } + Py_END_ALLOW_THREADS +} + +#endif + +#ifdef SWIGLUA + +%exception { + try { + $action + } catch(faiss::FaissException & e) { + SWIG_Lua_pushferrstring(L, "C++ exception: %s", e.what()); \ + goto fail; + } +} + +#endif + + +/******************************************************************* + * Types of vectors we want to manipulate at the scripting language + * level. + *******************************************************************/ + +// simplified interface for vector +namespace std { + + template + class vector { + public: + vector(); + void push_back(T); + void clear(); + T * data(); + size_t size(); + T at (size_t n) const; + void resize (size_t n); + void swap (vector & other); + }; +}; + + + +%template(FloatVector) std::vector; +%template(DoubleVector) std::vector; +%template(ByteVector) std::vector; +%template(CharVector) std::vector; +// NOTE(hoss): Using unsigned long instead of uint64_t because OSX defines +// uint64_t as unsigned long long, which SWIG is not aware of. +%template(Uint64Vector) std::vector; +%template(LongVector) std::vector; +%template(LongLongVector) std::vector; +%template(IntVector) std::vector; +%template(FloatVectorVector) std::vector >; +%template(ByteVectorVector) std::vector >; +%template(LongVectorVector) std::vector >; +%template(VectorTransformVector) std::vector; +%template(OperatingPointVector) std::vector; +%template(InvertedListsPtrVector) std::vector; +%template(RepeatVector) std::vector; +%template(ClusteringIterationStatsVector) std::vector; + +#ifdef GPU_WRAPPER +%template(GpuResourcesVector) std::vector; +#endif + +%include + +// produces an error on the Mac +%ignore faiss::hamming; + +/******************************************************************* + * Parse headers + *******************************************************************/ + + +%ignore *::cmp; + +%include +%include + +int get_num_gpus(); +void gpu_profiler_start(); +void gpu_profiler_stop(); +void gpu_sync_all_devices(); + +#ifdef GPU_WRAPPER + +%{ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int get_num_gpus() +{ + return faiss::gpu::getNumDevices(); +} + +void gpu_profiler_start() +{ + return faiss::gpu::profilerStart(); +} + +void gpu_profiler_stop() +{ + return faiss::gpu::profilerStop(); +} + +void gpu_sync_all_devices() +{ + return faiss::gpu::synchronizeAllDevices(); +} + +%} + +// causes weird wrapper bug +%ignore *::getMemoryManager; +%ignore *::getMemoryManagerCurrentDevice; + +%include +%include + +#else + +%{ +int get_num_gpus() +{ + return 0; +} + +void gpu_profiler_start() +{ +} + +void gpu_profiler_stop() +{ +} + +void gpu_sync_all_devices() +{ +} +%} + + +#endif + +// order matters because includes are not recursive + +%include +%include +%include + +%include +%include +%include + +%include + +%ignore faiss::ProductQuantizer::get_centroids(size_t,size_t) const; + +%include + +%include +%include +%include +%include +%include +%include +%include +%include +%ignore InvertedListScanner; +%ignore BinaryInvertedListScanner; +%include +// NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the +// non-const one. +%warnfilter(509) extract_index_ivf; +%warnfilter(509) try_extract_index_ivf; +%include +%include +%include +%include +%include +%include +%include +%include + +%include +%include + +%ignore faiss::IndexIVFPQ::alloc_type; +%include +%include +%include + +%include +%include +%include +%include +%include +%include + + + + // %ignore faiss::IndexReplicas::at(int) const; + +%include +%template(ThreadedIndexBase) faiss::ThreadedIndex; +%template(ThreadedIndexBaseBinary) faiss::ThreadedIndex; + +%include +%template(IndexShards) faiss::IndexShardsTemplate; +%template(IndexBinaryShards) faiss::IndexShardsTemplate; + +%include +%template(IndexReplicas) faiss::IndexReplicasTemplate; +%template(IndexBinaryReplicas) faiss::IndexReplicasTemplate; + +%include +%template(IndexIDMap) faiss::IndexIDMapTemplate; +%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate; +%template(IndexIDMap2) faiss::IndexIDMap2Template; +%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template; + +#ifdef GPU_WRAPPER + +// quiet SWIG warnings +%ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF; + +%include +%include +%include +%include +%include +%include +%include +%include +%include +%include +%include + +#ifdef SWIGLUA + +/// in Lua, swigfaiss_gpu is known as swigfaiss +%luacode { +local swigfaiss = swigfaiss_gpu +} + +#endif + + +#endif + + + + +/******************************************************************* + * Lua-specific: support async execution of searches in an index + * Python equivalent is just to use Python threads. + *******************************************************************/ + + +#ifdef SWIGLUA + +%{ + + +namespace faiss { + +struct AsyncIndexSearchC { + typedef Index::idx_t idx_t; + const Index * index; + + idx_t n; + const float *x; + idx_t k; + float *distances; + idx_t *labels; + + bool is_finished; + + pthread_t thread; + + + AsyncIndexSearchC (const Index *index, + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels): + index(index), n(n), x(x), k(k), distances(distances), + labels(labels) + { + is_finished = false; + pthread_create (&thread, NULL, &AsyncIndexSearchC::callback, + this); + } + + static void *callback (void *arg) + { + AsyncIndexSearchC *aidx = (AsyncIndexSearchC *)arg; + aidx->do_search(); + return NULL; + } + + void do_search () + { + index->search (n, x, k, distances, labels); + } + void join () + { + pthread_join (thread, NULL); + } + +}; + +} + +%} + +// re-decrlare only what we need +namespace faiss { + +struct AsyncIndexSearchC { + typedef Index::idx_t idx_t; + bool is_finished; + AsyncIndexSearchC (const Index *index, + idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels); + + + void join (); +}; + +} + + +#endif + + + + +/******************************************************************* + * downcast return of some functions so that the sub-class is used + * instead of the generic upper-class. + *******************************************************************/ + +#ifdef SWIGLUA + +%define DOWNCAST(subclass) + if (dynamic_cast ($1)) { + SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## subclass, $owner); + } else +%enddef + +%define DOWNCAST2(subclass, longname) + if (dynamic_cast ($1)) { + SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## longname, $owner); + } else +%enddef + +%define DOWNCAST_GPU(subclass) + if (dynamic_cast ($1)) { + SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__gpu__ ## subclass, $owner); + } else +%enddef + +#endif + + +#ifdef SWIGPYTHON + +%define DOWNCAST(subclass) + if (dynamic_cast ($1)) { + $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner); + } else +%enddef + +%define DOWNCAST2(subclass, longname) + if (dynamic_cast ($1)) { + $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner); + } else +%enddef + +%define DOWNCAST_GPU(subclass) + if (dynamic_cast ($1)) { + $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__gpu__ ## subclass,$owner); + } else +%enddef + +#endif + +%newobject read_index; +%newobject read_index_binary; +%newobject read_VectorTransform; +%newobject read_ProductQuantizer; +%newobject clone_index; +%newobject clone_VectorTransform; + +// Subclasses should appear before their parent +%typemap(out) faiss::Index * { + DOWNCAST2 ( IndexIDMap, IndexIDMapTemplateT_faiss__Index_t ) + DOWNCAST2 ( IndexIDMap2, IndexIDMap2TemplateT_faiss__Index_t ) + DOWNCAST2 ( IndexShards, IndexShardsTemplateT_faiss__Index_t ) + DOWNCAST2 ( IndexReplicas, IndexReplicasTemplateT_faiss__Index_t ) + DOWNCAST ( IndexIVFPQR ) + DOWNCAST ( IndexIVFPQ ) + DOWNCAST ( IndexIVFSpectralHash ) + DOWNCAST ( IndexIVFScalarQuantizer ) + DOWNCAST ( IndexIVFFlatDedup ) + DOWNCAST ( IndexIVFFlat ) + DOWNCAST ( IndexIVF ) + DOWNCAST ( IndexFlat ) + DOWNCAST ( IndexPQ ) + DOWNCAST ( IndexScalarQuantizer ) + DOWNCAST ( IndexLSH ) + DOWNCAST ( IndexLattice ) + DOWNCAST ( IndexPreTransform ) + DOWNCAST ( MultiIndexQuantizer ) + DOWNCAST ( IndexHNSWFlat ) + DOWNCAST ( IndexHNSWPQ ) + DOWNCAST ( IndexHNSWSQ ) + DOWNCAST ( IndexHNSW2Level ) + DOWNCAST ( Index2Layer ) +#ifdef GPU_WRAPPER + DOWNCAST_GPU ( GpuIndexIVFPQ ) + DOWNCAST_GPU ( GpuIndexIVFFlat ) + DOWNCAST_GPU ( GpuIndexIVFScalarQuantizer ) + DOWNCAST_GPU ( GpuIndexFlat ) +#endif + // default for non-recognized classes + DOWNCAST ( Index ) + if ($1 == NULL) + { +#ifdef SWIGPYTHON + $result = SWIG_Py_Void(); +#endif + // Lua does not need a push for nil + } else { + assert(false); + } +#ifdef SWIGLUA + SWIG_arg++; +#endif +} + +%typemap(out) faiss::IndexBinary * { + DOWNCAST2 ( IndexBinaryReplicas, IndexReplicasTemplateT_faiss__IndexBinary_t ) + DOWNCAST2 ( IndexBinaryIDMap, IndexIDMapTemplateT_faiss__IndexBinary_t ) + DOWNCAST2 ( IndexBinaryIDMap2, IndexIDMap2TemplateT_faiss__IndexBinary_t ) + DOWNCAST ( IndexBinaryIVF ) + DOWNCAST ( IndexBinaryFlat ) + DOWNCAST ( IndexBinaryFromFloat ) + DOWNCAST ( IndexBinaryHNSW ) +#ifdef GPU_WRAPPER + DOWNCAST_GPU ( GpuIndexBinaryFlat ) +#endif + // default for non-recognized classes + DOWNCAST ( IndexBinary ) + if ($1 == NULL) + { +#ifdef SWIGPYTHON + $result = SWIG_Py_Void(); +#endif + // Lua does not need a push for nil + } else { + assert(false); + } +#ifdef SWIGLUA + SWIG_arg++; +#endif +} + +%typemap(out) faiss::VectorTransform * { + DOWNCAST (RemapDimensionsTransform) + DOWNCAST (OPQMatrix) + DOWNCAST (PCAMatrix) + DOWNCAST (RandomRotationMatrix) + DOWNCAST (LinearTransform) + DOWNCAST (NormalizationTransform) + DOWNCAST (CenteringTransform) + DOWNCAST (VectorTransform) + { + assert(false); + } +#ifdef SWIGLUA + SWIG_arg++; +#endif +} + +%typemap(out) faiss::InvertedLists * { + DOWNCAST (ArrayInvertedLists) + DOWNCAST (OnDiskInvertedLists) + DOWNCAST (VStackInvertedLists) + DOWNCAST (HStackInvertedLists) + DOWNCAST (MaskedInvertedLists) + DOWNCAST (InvertedLists) + { + assert(false); + } +#ifdef SWIGLUA + SWIG_arg++; +#endif +} + +// just to downcast pointers that come from elsewhere (eg. direct +// access to object fields) +%inline %{ +faiss::Index * downcast_index (faiss::Index *index) +{ + return index; +} +faiss::VectorTransform * downcast_VectorTransform (faiss::VectorTransform *vt) +{ + return vt; +} +faiss::IndexBinary * downcast_IndexBinary (faiss::IndexBinary *index) +{ + return index; +} +faiss::InvertedLists * downcast_InvertedLists (faiss::InvertedLists *il) +{ + return il; +} +%} + +%include +%include +%include + +%newobject index_factory; +%newobject index_binary_factory; + +%include +%include +%include + + +#ifdef GPU_WRAPPER + +%include + +%newobject index_gpu_to_cpu; +%newobject index_cpu_to_gpu; +%newobject index_cpu_to_gpu_multiple; + +%include + +#endif + +// Python-specific: do not release GIL any more, as functions below +// use the Python/C API +#ifdef SWIGPYTHON +%exception; +#endif + + + + + +/******************************************************************* + * Python specific: numpy array <-> C++ pointer interface + *******************************************************************/ + +#ifdef SWIGPYTHON + +%{ +PyObject *swig_ptr (PyObject *a) +{ + if(!PyArray_Check(a)) { + PyErr_SetString(PyExc_ValueError, "input not a numpy array"); + return NULL; + } + PyArrayObject *ao = (PyArrayObject *)a; + + if(!PyArray_ISCONTIGUOUS(ao)) { + PyErr_SetString(PyExc_ValueError, "array is not C-contiguous"); + return NULL; + } + void * data = PyArray_DATA(ao); + if(PyArray_TYPE(ao) == NPY_FLOAT32) { + return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0); + } + if(PyArray_TYPE(ao) == NPY_FLOAT64) { + return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0); + } + if(PyArray_TYPE(ao) == NPY_INT32) { + return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0); + } + if(PyArray_TYPE(ao) == NPY_UINT8) { + return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0); + } + if(PyArray_TYPE(ao) == NPY_INT8) { + return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0); + } + if(PyArray_TYPE(ao) == NPY_UINT64) { +#ifdef SWIGWORDSIZE64 + return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0); +#else + return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long_long, 0); +#endif + } + if(PyArray_TYPE(ao) == NPY_INT64) { +#ifdef SWIGWORDSIZE64 + return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0); +#else + return SWIG_NewPointerObj(data, SWIGTYPE_p_long_long, 0); +#endif + } + PyErr_SetString(PyExc_ValueError, "did not recognize array type"); + return NULL; +} + + +struct PythonInterruptCallback: faiss::InterruptCallback { + + bool want_interrupt () override { + int err; + { + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + err = PyErr_CheckSignals(); + PyGILState_Release(gstate); + } + return err == -1; + } + +}; + + +%} + + +%init %{ + /* needed, else crash at runtime */ + import_array(); + + faiss::InterruptCallback::instance.reset(new PythonInterruptCallback()); + +%} + +// return a pointer usable as input for functions that expect pointers +PyObject *swig_ptr (PyObject *a); + +%define REV_SWIG_PTR(ctype, numpytype) + +%{ +PyObject * rev_swig_ptr(ctype *src, npy_intp size) { + return PyArray_SimpleNewFromData(1, &size, numpytype, src); +} +%} + +PyObject * rev_swig_ptr(ctype *src, size_t size); + +%enddef + +REV_SWIG_PTR(float, NPY_FLOAT32); +REV_SWIG_PTR(int, NPY_INT32); +REV_SWIG_PTR(unsigned char, NPY_UINT8); +REV_SWIG_PTR(int64_t, NPY_INT64); +REV_SWIG_PTR(uint64_t, NPY_UINT64); + +#endif + + + +/******************************************************************* + * Lua specific: Torch tensor <-> C++ pointer interface + *******************************************************************/ + +#ifdef SWIGLUA + + +// provide a XXX_ptr function to convert Lua XXXTensor -> C++ XXX* + +%define TYPE_CONVERSION(ctype, tensortype) + +// typemap for the *_ptr_from_cdata function +%typemap(in) ctype** { + if(lua_type(L, $input) != 10) { + fprintf(stderr, "not cdata input\n"); + SWIG_fail; + } + $1 = (ctype**)lua_topointer(L, $input); +} + + +// SWIG and C declaration for the *_ptr_from_cdata function +%{ +ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs) { + return *x + ofs; +} +%} +ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs); + +// the *_ptr function +%luacode { + +function swigfaiss. ctype ## _ptr(tensor) + assert(tensor:type() == "torch." .. # tensortype, "need a " .. # tensortype) + assert(tensor:isContiguous(), "requires contiguous tensor") + return swigfaiss. ctype ## _ptr_from_cdata( + tensor:storage():data(), + tensor:storageOffset() - 1) +end + +} + +%enddef + +TYPE_CONVERSION (int, IntTensor) +TYPE_CONVERSION (float, FloatTensor) +TYPE_CONVERSION (long, LongTensor) +TYPE_CONVERSION (uint64_t, LongTensor) +TYPE_CONVERSION (uint8_t, ByteTensor) + +#endif + +/******************************************************************* + * How should the template objects apprear in the scripting language? + *******************************************************************/ + +// answer: the same as the C++ typedefs, but we still have to redefine them + +%template() faiss::CMin; +%template() faiss::CMin; +%template() faiss::CMax; +%template() faiss::CMax; + +%template(float_minheap_array_t) faiss::HeapArray >; +%template(int_minheap_array_t) faiss::HeapArray >; + +%template(float_maxheap_array_t) faiss::HeapArray >; +%template(int_maxheap_array_t) faiss::HeapArray >; + + +/******************************************************************* + * Expose a few basic functions + *******************************************************************/ + + +void omp_set_num_threads (int num_threads); +int omp_get_max_threads (); +void *memcpy(void *dest, const void *src, size_t n); + + +/******************************************************************* + * For Faiss/Pytorch interop via pointers encoded as longs + *******************************************************************/ + +%inline %{ +float * cast_integer_to_float_ptr (long x) { + return (float*)x; +} + +long * cast_integer_to_long_ptr (long x) { + return (long*)x; +} + +int * cast_integer_to_int_ptr (long x) { + return (int*)x; +} + +%} + + + +/******************************************************************* + * Range search interface + *******************************************************************/ + +%ignore faiss::BufferList::Buffer; +%ignore faiss::RangeSearchPartialResult::QueryResult; +%ignore faiss::IDSelectorBatch::set; +%ignore faiss::IDSelectorBatch::bloom; + +%ignore faiss::InterruptCallback::instance; +%ignore faiss::InterruptCallback::lock; +%include + +%{ +// may be useful for lua code launched in background from shell + +#include +void ignore_SIGTTIN() { + signal(SIGTTIN, SIG_IGN); +} +%} + +void ignore_SIGTTIN(); + + +%inline %{ + +// numpy misses a hash table implementation, hence this class. It +// represents not found values as -1 like in the Index implementation + +struct MapLong2Long { + std::unordered_map map; + + void add(size_t n, const int64_t *keys, const int64_t *vals) { + map.reserve(map.size() + n); + for (size_t i = 0; i < n; i++) { + map[keys[i]] = vals[i]; + } + } + + long search(int64_t key) { + if (map.count(key) == 0) { + return -1; + } else { + return map[key]; + } + } + + void search_multiple(size_t n, int64_t *keys, int64_t * vals) { + for (size_t i = 0; i < n; i++) { + vals[i] = search(keys[i]); + } + } +}; + +%} + +/******************************************************************* + * Support I/O to arbitrary functions + *******************************************************************/ + + +%inline %{ + +#ifdef SWIGPYTHON + + +struct PyCallbackIOWriter: faiss::IOWriter { + + PyObject * callback; + size_t bs; // maximum write size + + PyCallbackIOWriter(PyObject *callback, + size_t bs = 1024 * 1024): + callback(callback), bs(bs) { + Py_INCREF(callback); + name = "PyCallbackIOWriter"; + } + + size_t operator()(const void *ptrv, size_t size, size_t nitems) override { + size_t ws = size * nitems; + const char *ptr = (const char*)ptrv; + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + while(ws > 0) { + size_t wi = ws > bs ? bs : ws; + PyObject* bo = PyBytes_FromStringAndSize(ptr, wi); + PyObject *arglist = Py_BuildValue("(N)", bo); + if(!arglist) { + PyGILState_Release(gstate); + return 0; + } + ptr += wi; + ws -= wi; + PyObject * result = PyObject_CallObject(callback, arglist); + Py_DECREF(arglist); + if (result == NULL) { + PyGILState_Release(gstate); + return 0; + } + Py_DECREF(result); + } + PyGILState_Release(gstate); + return nitems; + } + + ~PyCallbackIOWriter() { + Py_DECREF(callback); + } + +}; + +struct PyCallbackIOReader: faiss::IOReader { + + PyObject * callback; + size_t bs; // maximum buffer size + + PyCallbackIOReader(PyObject *callback, + size_t bs = 1024 * 1024): + callback(callback), bs(bs) { + Py_INCREF(callback); + name = "PyCallbackIOReader"; + } + + size_t operator()(void *ptrv, size_t size, size_t nitems) override { + size_t rs = size * nitems; + char *ptr = (char*)ptrv; + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + while(rs > 0) { + size_t ri = rs > bs ? bs : rs; + PyObject *arglist = Py_BuildValue("(n)", ri); + PyObject * result = PyObject_CallObject(callback, arglist); + Py_DECREF(arglist); + if (result == NULL) { + PyGILState_Release(gstate); + return 0; + } + if(!PyBytes_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_RuntimeError, + "read callback did not return a bytes object"); + PyGILState_Release(gstate); + throw faiss::FaissException("reader error"); + } + size_t sz = PyBytes_Size(result); + if (sz == 0 || sz > rs) { + Py_DECREF(result); + PyErr_Format(PyExc_RuntimeError, + "read callback returned %ld bytes (asked %ld)", + sz, rs); + PyGILState_Release(gstate); + throw faiss::FaissException("reader error"); + } + memcpy(ptr, PyBytes_AsString(result), sz); + Py_DECREF(result); + ptr += sz; + rs -= sz; + } + PyGILState_Release(gstate); + return nitems; + } + + ~PyCallbackIOReader() { + Py_DECREF(callback); + } + +}; + +#endif + +%} + + + +%inline %{ + void wait() { + // in gdb, use return to get out of this function + for(int i = 0; i == 0; i += 0); + } + %} + +// End of file... diff --git a/core/src/index/thirdparty/faiss/tests/Makefile b/core/src/index/thirdparty/faiss/tests/Makefile new file mode 100644 index 0000000000..684100de70 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/Makefile @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include ../makefile.inc + +TESTS_SRC = $(wildcard *.cpp) +TESTS_OBJ = $(TESTS_SRC:.cpp=.o) + + +all: run + +run: tests + ./tests + +tests: $(TESTS_OBJ) ../libfaiss.a gtest/make/gtest_main.a + $(CXX) -o $@ $^ $(LDFLAGS) $(LIBS) + +%.o: %.cpp gtest + $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) -c -o $@ $< -Igtest/include -I.. + +gtest/make/gtest_main.a: gtest + $(MAKE) -C gtest/make CXX="$(CXX)" CXXFLAGS="$(CXXFLAGS)" gtest_main.a + +gtest: + curl -L https://github.com/google/googletest/archive/release-1.8.0.tar.gz | tar xz && \ + mv googletest-release-1.8.0/googletest gtest && \ + rm -rf googletest-release-1.8.0 + +clean: + rm -f tests + rm -f $(TESTS_OBJ) + rm -rf gtest + + +.PHONY: all clean run diff --git a/core/src/index/thirdparty/faiss/tests/common.py b/core/src/index/thirdparty/faiss/tests/common.py new file mode 100644 index 0000000000..8621dd822a --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/common.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# a few common functions for the tests + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import faiss + +# reduce number of threads to avoid excessive nb of threads in opt +# mode (recuces runtime from 100s to 4s!) +faiss.omp_set_num_threads(4) + + +def random_unitary(n, d, seed): + x = faiss.randn(n * d, seed).reshape(n, d) + faiss.normalize_L2(x) + return x + + +class Randu10k: + + def __init__(self): + self.nb = 10000 + self.nq = 1000 + self.nt = 10000 + self.d = 128 + + self.xb = random_unitary(self.nb, self.d, 1) + self.xt = random_unitary(self.nt, self.d, 2) + self.xq = random_unitary(self.nq, self.d, 3) + + dotprods = np.dot(self.xq, self.xb.T) + self.gt = dotprods.argmax(1) + self.k = 100 + + def launch(self, name, index): + if not index.is_trained: + index.train(self.xt) + index.add(self.xb) + return index.search(self.xq, self.k) + + def evalres(self, DI): + D, I = DI + e = {} + for rank in 1, 10, 100: + e[rank] = ((I[:, :rank] == self.gt.reshape(-1, 1)).sum() / + float(self.nq)) + print("1-recalls: %s" % e) + return e + + +class Randu10kUnbalanced(Randu10k): + + def __init__(self): + Randu10k.__init__(self) + + weights = 0.95 ** np.arange(self.d) + rs = np.random.RandomState(123) + weights = weights[rs.permutation(self.d)] + self.xb *= weights + self.xb /= np.linalg.norm(self.xb, axis=1)[:, np.newaxis] + self.xq *= weights + self.xq /= np.linalg.norm(self.xq, axis=1)[:, np.newaxis] + self.xt *= weights + self.xt /= np.linalg.norm(self.xt, axis=1)[:, np.newaxis] + + dotprods = np.dot(self.xq, self.xb.T) + self.gt = dotprods.argmax(1) + self.k = 100 + + +def get_dataset(d, nb, nt, nq): + rs = np.random.RandomState(123) + xb = rs.rand(nb, d).astype('float32') + xt = rs.rand(nt, d).astype('float32') + xq = rs.rand(nq, d).astype('float32') + + return (xt, xb, xq) + + +def get_dataset_2(d, nt, nb, nq): + """A dataset that is not completely random but still challenging to + index + """ + d1 = 10 # intrinsic dimension (more or less) + n = nb + nt + nq + rs = np.random.RandomState(1338) + x = rs.normal(size=(n, d1)) + x = np.dot(x, rs.rand(d1, d)) + # now we have a d1-dim ellipsoid in d-dimensional space + # higher factor (>4) -> higher frequency -> less linear + x = x * (rs.rand(d) * 4 + 0.1) + x = np.sin(x) + x = x.astype('float32') + return x[:nt], x[nt:nt + nb], x[nt + nb:] + + +def make_binary_dataset(d, nt, nb, nq): + assert d % 8 == 0 + rs = np.random.RandomState(123) + x = rs.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8') + return x[:nt], x[nt:-nq], x[-nq:] + + +def compare_binary_result_lists(D1, I1, D2, I2): + """comparing result lists is difficult because there are many + ties. Here we sort by (distance, index) pairs and ignore the largest + distance of each result. Compatible result lists should pass this.""" + assert D1.shape == I1.shape == D2.shape == I2.shape + n, k = D1.shape + ndiff = (D1 != D2).sum() + assert ndiff == 0, '%d differences in distance matrix %s' % ( + ndiff, D1.shape) + + def normalize_DI(D, I): + norm = I.max() + 1.0 + Dr = D.astype('float64') + I / norm + # ignore -1s and elements on last column + Dr[I1 == -1] = 1e20 + Dr[D == D[:, -1:]] = 1e20 + Dr.sort(axis=1) + return Dr + ndiff = (normalize_DI(D1, I1) != normalize_DI(D2, I2)).sum() + assert ndiff == 0, '%d differences in normalized D matrix' % ndiff diff --git a/core/src/index/thirdparty/faiss/tests/test_binary_factory.py b/core/src/index/thirdparty/faiss/tests/test_binary_factory.py new file mode 100644 index 0000000000..70ddbb6e99 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_binary_factory.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import unittest +import faiss + + +class TestBinaryFactory(unittest.TestCase): + + def test_factory_IVF(self): + + index = faiss.index_binary_factory(16, "BIVF10") + assert index.invlists is not None + assert index.nlist == 10 + assert index.code_size == 2 + + def test_factory_Flat(self): + + index = faiss.index_binary_factory(16, "BFlat") + assert index.code_size == 2 + + def test_factory_HNSW(self): + + index = faiss.index_binary_factory(256, "BHNSW32") + assert index.code_size == 32 + + def test_factory_IVF_HNSW(self): + + index = faiss.index_binary_factory(256, "BIVF1024_BHNSW32") + assert index.code_size == 32 + assert index.nlist == 1024 diff --git a/core/src/index/thirdparty/faiss/tests/test_binary_flat.cpp b/core/src/index/thirdparty/faiss/tests/test_binary_flat.cpp new file mode 100644 index 0000000000..eb20cee87b --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_binary_flat.cpp @@ -0,0 +1,64 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +TEST(BinaryFlat, accuracy) { + // dimension of the vectors to index + int d = 64; + + // size of the database we plan to index + size_t nb = 1000; + + // make the index object and train it + faiss::IndexBinaryFlat index(d); + + srand(35); + + std::vector database(nb * (d / 8)); + for (size_t i = 0; i < nb * (d / 8); i++) { + database[i] = rand() % 0x100; + } + + { // populating the database + index.add(nb, database.data()); + } + + size_t nq = 200; + + { // searching the database + + std::vector queries(nq * (d / 8)); + for (size_t i = 0; i < nq * (d / 8); i++) { + queries[i] = rand() % 0x100; + } + + int k = 5; + std::vector nns(k * nq); + std::vector dis(k * nq); + + index.search(nq, queries.data(), k, dis.data(), nns.data()); + + for (size_t i = 0; i < nq; ++i) { + faiss::HammingComputer8 hc(queries.data() + i * (d / 8), d / 8); + hamdis_t dist_min = hc.hamming(database.data()); + for (size_t j = 1; j < nb; ++j) { + hamdis_t dist = hc.hamming(database.data() + j * (d / 8)); + if (dist < dist_min) { + dist_min = dist; + } + } + EXPECT_EQ(dist_min, dis[k * i]); + } + } +} diff --git a/core/src/index/thirdparty/faiss/tests/test_binary_hashindex.py b/core/src/index/thirdparty/faiss/tests/test_binary_hashindex.py new file mode 100644 index 0000000000..1ee5a5f7da --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_binary_hashindex.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import unittest +import numpy as np +import faiss + +from common import make_binary_dataset + + +def bitvec_shuffle(a, order): + n, d = a.shape + db, = order.shape + b = np.empty((n, db // 8), dtype='uint8') + faiss.bitvec_shuffle( + n, d * 8, db, + faiss.swig_ptr(order), + faiss.swig_ptr(a), faiss.swig_ptr(b)) + return b + + +class TestSmallFuncs(unittest.TestCase): + + def test_shuffle(self): + d = 256 + n = 1000 + rs = np.random.RandomState(123) + o = rs.permutation(d).astype('int32') + + x = rs.randint(256, size=(n, d // 8)).astype('uint8') + + y1 = bitvec_shuffle(x, o[:128]) + y2 = bitvec_shuffle(x, o[128:]) + y = np.hstack((y1, y2)) + + oinv = np.empty(d, dtype='int32') + oinv[o] = np.arange(d) + z = bitvec_shuffle(y, oinv) + + np.testing.assert_array_equal(x, z) + + +class TestRange(unittest.TestCase): + + def test_hash(self): + d = 128 + nq = 100 + nb = 2000 + + (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) + + index_ref = faiss.IndexBinaryFlat(d) + index_ref.add(xb) + + radius = 55 + + Lref, Dref, Iref = index_ref.range_search(xq, radius) + + print("nb res: ", Lref[-1]) + + index = faiss.IndexBinaryHash(d, 10) + index.add(xb) + # index.display() + nfound = [] + ndis = [] + stats = faiss.cvar.indexBinaryHash_stats + for n_bitflips in range(index.b + 1): + index.nflip = n_bitflips + stats.reset() + Lnew, Dnew, Inew = index.range_search(xq, radius) + for i in range(nq): + ref = Iref[Lref[i]:Lref[i + 1]] + new = Inew[Lnew[i]:Lnew[i + 1]] + snew = set(new) + # no duplicates + self.assertTrue(len(new) == len(snew)) + # subset of real results + self.assertTrue(snew <= set(ref)) + nfound.append(Lnew[-1]) + ndis.append(stats.ndis) + print('nfound=', nfound) + print('ndis=', ndis) + nfound = np.array(nfound) + self.assertTrue(nfound[-1] == Lref[-1]) + self.assertTrue(np.all(nfound[1:] >= nfound[:-1])) + + def test_multihash(self): + d = 128 + nq = 100 + nb = 2000 + + (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) + + index_ref = faiss.IndexBinaryFlat(d) + index_ref.add(xb) + + radius = 55 + + Lref, Dref, Iref = index_ref.range_search(xq, radius) + + print("nb res: ", Lref[-1]) + + nfound = [] + ndis = [] + + for nh in 1, 3, 5: + index = faiss.IndexBinaryMultiHash(d, nh, 10) + index.add(xb) + # index.display() + stats = faiss.cvar.indexBinaryHash_stats + index.nflip = 2 + stats.reset() + Lnew, Dnew, Inew = index.range_search(xq, radius) + for i in range(nq): + ref = Iref[Lref[i]:Lref[i + 1]] + new = Inew[Lnew[i]:Lnew[i + 1]] + snew = set(new) + # no duplicates + self.assertTrue(len(new) == len(snew)) + # subset of real results + self.assertTrue(snew <= set(ref)) + nfound.append(Lnew[-1]) + ndis.append(stats.ndis) + print('nfound=', nfound) + print('ndis=', ndis) + nfound = np.array(nfound) + # self.assertTrue(nfound[-1] == Lref[-1]) + self.assertTrue(np.all(nfound[1:] >= nfound[:-1])) + + +class TestKnn(unittest.TestCase): + + def test_hash_and_multihash(self): + d = 128 + nq = 100 + nb = 2000 + + (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) + + index_ref = faiss.IndexBinaryFlat(d) + index_ref.add(xb) + k = 10 + Dref, Iref = index_ref.search(xq, k) + + nfound = {} + for nh in 0, 1, 3, 5: + + for nbit in 4, 7: + if nh == 0: + index = faiss.IndexBinaryHash(d, nbit) + else: + index = faiss.IndexBinaryMultiHash(d, nh, nbit) + index.add(xb) + index.nflip = 2 + Dnew, Inew = index.search(xq, k) + nf = 0 + for i in range(nq): + ref = Iref[i] + new = Inew[i] + snew = set(new) + # no duplicates + self.assertTrue(len(new) == len(snew)) + nf += len(set(ref) & snew) + print('nfound', nh, nbit, nf) + nfound[(nh, nbit)] = nf + self.assertGreater(nfound[(nh, 4)], nfound[(nh, 7)]) + + # test serialization + index2 = faiss.deserialize_index_binary( + faiss.serialize_index_binary(index)) + + D2, I2 = index2.search(xq, k) + np.testing.assert_array_equal(Inew, I2) + np.testing.assert_array_equal(Dnew, D2) + + print('nfound=', nfound) + self.assertGreater(3, abs(nfound[(0, 7)] - nfound[(1, 7)])) + self.assertGreater(nfound[(3, 7)], nfound[(1, 7)]) + self.assertGreater(nfound[(5, 7)], nfound[(3, 7)]) diff --git a/core/src/index/thirdparty/faiss/tests/test_binary_io.py b/core/src/index/thirdparty/faiss/tests/test_binary_io.py new file mode 100644 index 0000000000..4af7dab9ca --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_binary_io.py @@ -0,0 +1,217 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Binary indexes (de)serialization""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import unittest +import faiss +import os +import tempfile + +def make_binary_dataset(d, nb, nt, nq): + assert d % 8 == 0 + x = np.random.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8') + return x[:nt], x[nt:-nq], x[-nq:] + + +class TestBinaryFlat(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + + (_, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq) + + def test_flat(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryFlat(d) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + _, tmpnam = tempfile.mkstemp() + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + D2, I2 = index2.search(self.xq, 3) + + assert (I2 == I).all() + assert (D2 == D).all() + + finally: + os.remove(tmpnam) + + +class TestBinaryIVF(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 200 + nb = 1500 + nq = 500 + + (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq) + + def test_ivf_flat(self): + d = self.xq.shape[1] * 8 + + quantizer = faiss.IndexBinaryFlat(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(self.xt) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + _, tmpnam = tempfile.mkstemp() + + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + D2, I2 = index2.search(self.xq, 3) + + assert (I2 == I).all() + assert (D2 == D).all() + + finally: + os.remove(tmpnam) + + +class TestObjectOwnership(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 200 + nb = 1500 + nq = 500 + + (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq) + + def test_read_index_ownership(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryFlat(d) + index.add(self.xb) + + _, tmpnam = tempfile.mkstemp() + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + assert index2.thisown + finally: + os.remove(tmpnam) + + +class TestBinaryFromFloat(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 200 + nb = 1500 + nq = 500 + + (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq) + + def test_binary_from_float(self): + d = self.xq.shape[1] * 8 + + float_index = faiss.IndexHNSWFlat(d, 16) + index = faiss.IndexBinaryFromFloat(float_index) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + _, tmpnam = tempfile.mkstemp() + + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + D2, I2 = index2.search(self.xq, 3) + + assert (I2 == I).all() + assert (D2 == D).all() + + finally: + os.remove(tmpnam) + + +class TestBinaryHNSW(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 200 + nb = 1500 + nq = 500 + + (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq) + + def test_hnsw(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryHNSW(d) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + _, tmpnam = tempfile.mkstemp() + + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + D2, I2 = index2.search(self.xq, 3) + + assert (I2 == I).all() + assert (D2 == D).all() + + finally: + os.remove(tmpnam) + + def test_ivf_hnsw(self): + d = self.xq.shape[1] * 8 + + quantizer = faiss.IndexBinaryHNSW(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(self.xt) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + _, tmpnam = tempfile.mkstemp() + + try: + faiss.write_index_binary(index, tmpnam) + + index2 = faiss.read_index_binary(tmpnam) + + D2, I2 = index2.search(self.xq, 3) + + assert (I2 == I).all() + assert (D2 == D).all() + + finally: + os.remove(tmpnam) + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_build_blocks.py b/core/src/index/thirdparty/faiss/tests/test_build_blocks.py new file mode 100644 index 0000000000..d1ce73cd1b --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_build_blocks.py @@ -0,0 +1,600 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +import faiss +import unittest + +from common import get_dataset_2 + + + +class TestClustering(unittest.TestCase): + + def test_clustering(self): + d = 64 + n = 1000 + rs = np.random.RandomState(123) + x = rs.uniform(size=(n, d)).astype('float32') + + x *= 10 + + km = faiss.Kmeans(d, 32, niter=10) + err32 = km.train(x) + + # check that objective is decreasing + prev = 1e50 + for o in km.obj: + self.assertGreater(prev, o) + prev = o + + km = faiss.Kmeans(d, 64, niter=10) + err64 = km.train(x) + + # check that 64 centroids give a lower quantization error than 32 + self.assertGreater(err32, err64) + + km = faiss.Kmeans(d, 32, niter=10, int_centroids=True) + err_int = km.train(x) + + # check that integer centoids are not as good as float ones + self.assertGreater(err_int, err32) + self.assertTrue(np.all(km.centroids == np.floor(km.centroids))) + + + def test_nasty_clustering(self): + d = 2 + rs = np.random.RandomState(123) + x = np.zeros((100, d), dtype='float32') + for i in range(5): + x[i * 20:i * 20 + 20] = rs.uniform(size=d) + + # we have 5 distinct points but ask for 10 centroids... + km = faiss.Kmeans(d, 10, niter=10, verbose=True) + km.train(x) + + def test_redo(self): + d = 64 + n = 1000 + + rs = np.random.RandomState(123) + x = rs.uniform(size=(n, d)).astype('float32') + + # make sure that doing 10 redos yields a better objective than just 1 + + clus = faiss.Clustering(d, 20) + clus.nredo = 1 + clus.train(x, faiss.IndexFlatL2(d)) + obj1 = clus.iteration_stats.at(clus.iteration_stats.size() - 1).obj + + clus = faiss.Clustering(d, 20) + clus.nredo = 10 + clus.train(x, faiss.IndexFlatL2(d)) + obj10 = clus.iteration_stats.at(clus.iteration_stats.size() - 1).obj + + self.assertGreater(obj1, obj10) + + def test_1ptpercluster(self): + # https://github.com/facebookresearch/faiss/issues/842 + X = np.random.randint(0, 1, (5, 10)).astype('float32') + k = 5 + niter = 10 + verbose = True + kmeans = faiss.Kmeans(X.shape[1], k, niter=niter, verbose=verbose) + kmeans.train(X) + l2_distances, I = kmeans.index.search(X, 1) + + def test_weighted(self): + d = 32 + sigma = 0.1 + + # Data is naturally clustered in 10 clusters. + # 5 clusters have 100 points + # 5 clusters have 10 points + # run k-means with 5 clusters + + ccent = faiss.randn((10, d), 123) + faiss.normalize_L2(ccent) + x = [ccent[i] + sigma * faiss.randn((100, d), 1234 + i) for i in range(5)] + x += [ccent[i] + sigma * faiss.randn((10, d), 1234 + i) for i in range(5, 10)] + x = np.vstack(x) + + clus = faiss.Clustering(d, 5) + index = faiss.IndexFlatL2(d) + clus.train(x, index) + cdis1, perm1 = index.search(ccent, 1) + + # distance^2 of ground-truth centroids to clusters + cdis1_first = cdis1[:5].sum() + cdis1_last = cdis1[5:].sum() + + # now assign weight 0.1 to the 5 first clusters and weight 10 + # to the 5 last ones and re-run k-means + weights = np.ones(100 * 5 + 10 * 5, dtype='float32') + weights[:100 * 5] = 0.1 + weights[100 * 5:] = 10 + + clus = faiss.Clustering(d, 5) + index = faiss.IndexFlatL2(d) + clus.train(x, index, weights=weights) + cdis2, perm2 = index.search(ccent, 1) + + # distance^2 of ground-truth centroids to clusters + cdis2_first = cdis2[:5].sum() + cdis2_last = cdis2[5:].sum() + + print(cdis1_first, cdis1_last) + print(cdis2_first, cdis2_last) + + # with the new clustering, the last should be much (*2) closer + # to their centroids + self.assertGreater(cdis1_last, cdis1_first * 2) + self.assertGreater(cdis2_first, cdis2_last * 2) + + def test_encoded(self): + d = 32 + k = 5 + xt, xb, xq = get_dataset_2(d, 1000, 0, 0) + + # make sure that training on a compressed then decompressed + # dataset gives the same result as decompressing on-the-fly + + codec = faiss.IndexScalarQuantizer(d, faiss.ScalarQuantizer.QT_4bit) + codec.train(xt) + codes = codec.sa_encode(xt) + + xt2 = codec.sa_decode(codes) + + clus = faiss.Clustering(d, k) + # clus.verbose = True + clus.niter = 0 + index = faiss.IndexFlatL2(d) + clus.train(xt2, index) + ref_centroids = faiss.vector_to_array(clus.centroids).reshape(-1, d) + + _, ref_errs = index.search(xt2, 1) + + clus = faiss.Clustering(d, k) + # clus.verbose = True + clus.niter = 0 + clus.decode_block_size = 120 + index = faiss.IndexFlatL2(d) + clus.train_encoded(codes, codec, index) + new_centroids = faiss.vector_to_array(clus.centroids).reshape(-1, d) + + _, new_errs = index.search(xt2, 1) + + # It's the same operation, so should be bit-exact the same + self.assertTrue(np.all(ref_centroids == new_centroids)) + + +class TestPCA(unittest.TestCase): + + def test_pca(self): + d = 64 + n = 1000 + np.random.seed(123) + x = np.random.random(size=(n, d)).astype('float32') + + pca = faiss.PCAMatrix(d, 10) + pca.train(x) + y = pca.apply_py(x) + + # check that energy per component is decreasing + column_norm2 = (y**2).sum(0) + + prev = 1e50 + for o in column_norm2: + self.assertGreater(prev, o) + prev = o + + +class TestProductQuantizer(unittest.TestCase): + + def test_pq(self): + d = 64 + n = 2000 + cs = 4 + np.random.seed(123) + x = np.random.random(size=(n, d)).astype('float32') + pq = faiss.ProductQuantizer(d, cs, 8) + pq.train(x) + codes = pq.compute_codes(x) + x2 = pq.decode(codes) + diff = ((x - x2)**2).sum() + + # print("diff=", diff) + # diff= 4418.0562 + self.assertGreater(5000, diff) + + pq10 = faiss.ProductQuantizer(d, cs, 10) + assert pq10.code_size == 5 + pq10.verbose = True + pq10.cp.verbose = True + pq10.train(x) + codes = pq10.compute_codes(x) + + x10 = pq10.decode(codes) + diff10 = ((x - x10)**2).sum() + self.assertGreater(diff, diff10) + + def do_test_codec(self, nbit): + pq = faiss.ProductQuantizer(16, 2, nbit) + + # simulate training + rs = np.random.RandomState(123) + centroids = rs.rand(2, 1 << nbit, 8).astype('float32') + faiss.copy_array_to_vector(centroids.ravel(), pq.centroids) + + idx = rs.randint(1 << nbit, size=(100, 2)) + # can be encoded exactly + x = np.hstack(( + centroids[0, idx[:, 0]], + centroids[1, idx[:, 1]] + )) + + # encode / decode + codes = pq.compute_codes(x) + xr = pq.decode(codes) + assert np.all(xr == x) + + # encode w/ external index + assign_index = faiss.IndexFlatL2(8) + pq.assign_index = assign_index + codes2 = np.empty((100, pq.code_size), dtype='uint8') + pq.compute_codes_with_assign_index( + faiss.swig_ptr(x), faiss.swig_ptr(codes2), 100) + assert np.all(codes == codes2) + + def test_codec(self): + for i in range(16): + print("Testing nbits=%d" % (i + 1)) + self.do_test_codec(i + 1) + + +class TestRevSwigPtr(unittest.TestCase): + + def test_rev_swig_ptr(self): + + index = faiss.IndexFlatL2(4) + xb0 = np.vstack([ + i * 10 + np.array([1, 2, 3, 4], dtype='float32') + for i in range(5)]) + index.add(xb0) + xb = faiss.rev_swig_ptr(index.xb.data(), 4 * 5).reshape(5, 4) + self.assertEqual(np.abs(xb0 - xb).sum(), 0) + + +class TestException(unittest.TestCase): + + def test_exception(self): + + index = faiss.IndexFlatL2(10) + + a = np.zeros((5, 10), dtype='float32') + b = np.zeros(5, dtype='int64') + + # an unsupported operation for IndexFlat + self.assertRaises( + RuntimeError, + index.add_with_ids, a, b + ) + # assert 'add_with_ids not implemented' in str(e) + + def test_exception_2(self): + self.assertRaises( + RuntimeError, + faiss.index_factory, 12, 'IVF256,Flat,PQ8' + ) + # assert 'could not parse' in str(e) + + +class TestMapLong2Long(unittest.TestCase): + + def test_maplong2long(self): + keys = np.array([13, 45, 67]) + vals = np.array([3, 8, 2]) + + m = faiss.MapLong2Long() + m.add(keys, vals) + + assert np.all(m.search_multiple(keys) == vals) + + assert m.search(12343) == -1 + + +class TestOrthognalReconstruct(unittest.TestCase): + + def test_recons_orthonormal(self): + lt = faiss.LinearTransform(20, 10, True) + rs = np.random.RandomState(10) + A, _ = np.linalg.qr(rs.randn(20, 20)) + A = A[:10].astype('float32') + faiss.copy_array_to_vector(A.ravel(), lt.A) + faiss.copy_array_to_vector(rs.randn(10).astype('float32'), lt.b) + + lt.set_is_orthonormal() + lt.is_trained = True + assert lt.is_orthonormal + + x = rs.rand(30, 20).astype('float32') + xt = lt.apply_py(x) + xtt = lt.reverse_transform(xt) + xttt = lt.apply_py(xtt) + + err = ((xt - xttt)**2).sum() + + self.assertGreater(1e-5, err) + + def test_recons_orthogona_impossible(self): + lt = faiss.LinearTransform(20, 10, True) + rs = np.random.RandomState(10) + A = rs.randn(10 * 20).astype('float32') + faiss.copy_array_to_vector(A.ravel(), lt.A) + faiss.copy_array_to_vector(rs.randn(10).astype('float32'), lt.b) + lt.is_trained = True + + lt.set_is_orthonormal() + assert not lt.is_orthonormal + + x = rs.rand(30, 20).astype('float32') + xt = lt.apply_py(x) + try: + lt.reverse_transform(xt) + except Exception: + pass + else: + self.assertFalse('should do an exception') + + +class TestMAdd(unittest.TestCase): + + def test_1(self): + # try with dimensions that are multiples of 16 or not + rs = np.random.RandomState(123) + swig_ptr = faiss.swig_ptr + for dim in 16, 32, 20, 25: + for _repeat in 1, 2, 3, 4, 5: + a = rs.rand(dim).astype('float32') + b = rs.rand(dim).astype('float32') + c = np.zeros(dim, dtype='float32') + bf = rs.uniform(5.0) - 2.5 + idx = faiss.fvec_madd_and_argmin( + dim, swig_ptr(a), bf, swig_ptr(b), + swig_ptr(c)) + ref_c = a + b * bf + assert np.abs(c - ref_c).max() < 1e-5 + assert idx == ref_c.argmin() + + +class TestNyFuncs(unittest.TestCase): + + def test_l2(self): + rs = np.random.RandomState(123) + swig_ptr = faiss.swig_ptr + for d in 1, 2, 4, 8, 12, 16: + x = rs.rand(d).astype('float32') + for ny in 128, 129, 130: + print("d=%d ny=%d" % (d, ny)) + y = rs.rand(ny, d).astype('float32') + ref = ((x - y) ** 2).sum(1) + new = np.zeros(ny, dtype='float32') + faiss.fvec_L2sqr_ny(swig_ptr(new), swig_ptr(x), + swig_ptr(y), d, ny) + assert np.abs(ref - new).max() < 1e-4 + + def test_IP(self): + # this one is not optimized with SIMD but just in case + rs = np.random.RandomState(123) + swig_ptr = faiss.swig_ptr + for d in 1, 2, 4, 8, 12, 16: + x = rs.rand(d).astype('float32') + for ny in 128, 129, 130: + print("d=%d ny=%d" % (d, ny)) + y = rs.rand(ny, d).astype('float32') + ref = (x * y).sum(1) + new = np.zeros(ny, dtype='float32') + faiss.fvec_inner_products_ny( + swig_ptr(new), swig_ptr(x), swig_ptr(y), d, ny) + assert np.abs(ref - new).max() < 1e-4 + + +class TestMatrixStats(unittest.TestCase): + + def test_0s(self): + rs = np.random.RandomState(123) + m = rs.rand(40, 20).astype('float32') + m[5:10] = 0 + comments = faiss.MatrixStats(m).comments + print(comments) + assert 'has 5 copies' in comments + assert '5 null vectors' in comments + + def test_copies(self): + rs = np.random.RandomState(123) + m = rs.rand(40, 20).astype('float32') + m[::2] = m[1::2] + comments = faiss.MatrixStats(m).comments + print(comments) + assert '20 vectors are distinct' in comments + + def test_dead_dims(self): + rs = np.random.RandomState(123) + m = rs.rand(40, 20).astype('float32') + m[:, 5:10] = 0 + comments = faiss.MatrixStats(m).comments + print(comments) + assert '5 dimensions are constant' in comments + + def test_rogue_means(self): + rs = np.random.RandomState(123) + m = rs.rand(40, 20).astype('float32') + m[:, 5:10] += 12345 + comments = faiss.MatrixStats(m).comments + print(comments) + assert '5 dimensions are too large wrt. their variance' in comments + + def test_normalized(self): + rs = np.random.RandomState(123) + m = rs.rand(40, 20).astype('float32') + faiss.normalize_L2(m) + comments = faiss.MatrixStats(m).comments + print(comments) + assert 'vectors are normalized' in comments + + +class TestScalarQuantizer(unittest.TestCase): + + def test_8bit_equiv(self): + rs = np.random.RandomState(123) + for _it in range(20): + for d in 13, 16, 24: + x = np.floor(rs.rand(5, d) * 256).astype('float32') + x[0] = 0 + x[1] = 255 + + # make sure to test extreme cases + x[2, 0] = 0 + x[3, 0] = 255 + x[2, 1] = 255 + x[3, 1] = 0 + + ref_index = faiss.IndexScalarQuantizer( + d, faiss.ScalarQuantizer.QT_8bit) + ref_index.train(x[:2]) + ref_index.add(x[2:3]) + + index = faiss.IndexScalarQuantizer( + d, faiss.ScalarQuantizer.QT_8bit_direct) + assert index.is_trained + index.add(x[2:3]) + + assert np.all( + faiss.vector_to_array(ref_index.codes) == + faiss.vector_to_array(index.codes)) + + # Note that distances are not the same because ref_index + # reconstructs x as x + 0.5 + D, I = index.search(x[3:], 1) + + # assert D[0, 0] == Dref[0, 0] + print(D[0, 0], ((x[3] - x[2]) ** 2).sum()) + assert D[0, 0] == ((x[3] - x[2]) ** 2).sum() + + def test_6bit_equiv(self): + rs = np.random.RandomState(123) + for d in 3, 6, 8, 16, 36: + trainset = np.zeros((2, d), dtype='float32') + trainset[0, :] = 0 + trainset[0, :] = 63 + + index = faiss.IndexScalarQuantizer( + d, faiss.ScalarQuantizer.QT_6bit) + index.train(trainset) + + print('cs=', index.code_size) + + x = rs.randint(64, size=(100, d)).astype('float32') + + # verify encoder / decoder + index.add(x) + x2 = index.reconstruct_n(0, x.shape[0]) + assert np.all(x == x2 - 0.5) + + # verify AVX decoder (used only for search) + y = 63 * rs.rand(20, d).astype('float32') + + D, I = index.search(y, 10) + for i in range(20): + for j in range(10): + dis = ((y[i] - x2[I[i, j]]) ** 2).sum() + print(dis, D[i, j]) + assert abs(D[i, j] - dis) / dis < 1e-5 + +class TestRandom(unittest.TestCase): + + def test_rand(self): + x = faiss.rand(2000) + assert np.all(x >= 0) and np.all(x < 1) + h, _ = np.histogram(x, np.arange(0, 1, 0.1)) + assert h.min() > 160 and h.max() < 240 + + def test_randint(self): + x = faiss.randint(20000, vmax=100) + assert np.all(x >= 0) and np.all(x < 100) + c = np.bincount(x, minlength=100) + print(c) + assert c.max() - c.min() < 50 * 2 + + +class TestPairwiseDis(unittest.TestCase): + + def test_L2(self): + swig_ptr = faiss.swig_ptr + x = faiss.rand((100, 10), seed=1) + y = faiss.rand((200, 10), seed=2) + ix = faiss.randint(50, vmax=100) + iy = faiss.randint(50, vmax=200) + dis = np.empty(50, dtype='float32') + faiss.pairwise_indexed_L2sqr( + 10, 50, + swig_ptr(x), swig_ptr(ix), + swig_ptr(y), swig_ptr(iy), + swig_ptr(dis)) + + for i in range(50): + assert np.allclose( + dis[i], ((x[ix[i]] - y[iy[i]]) ** 2).sum()) + + def test_IP(self): + swig_ptr = faiss.swig_ptr + x = faiss.rand((100, 10), seed=1) + y = faiss.rand((200, 10), seed=2) + ix = faiss.randint(50, vmax=100) + iy = faiss.randint(50, vmax=200) + dis = np.empty(50, dtype='float32') + faiss.pairwise_indexed_inner_product( + 10, 50, + swig_ptr(x), swig_ptr(ix), + swig_ptr(y), swig_ptr(iy), + swig_ptr(dis)) + + for i in range(50): + assert np.allclose( + dis[i], np.dot(x[ix[i]], y[iy[i]])) + + +class TestSWIGWrap(unittest.TestCase): + """ various regressions with the SWIG wrapper """ + + def test_size_t_ptr(self): + # issue 1064 + index = faiss.IndexHNSWFlat(10, 32) + + hnsw = index.hnsw + index.add(np.random.rand(100, 10).astype('float32')) + be = np.empty(2, 'uint64') + hnsw.neighbor_range(23, 0, faiss.swig_ptr(be), faiss.swig_ptr(be[1:])) + + def test_id_map_at(self): + # issue 1020 + n_features = 100 + feature_dims = 10 + + features = np.random.random((n_features, feature_dims)).astype(np.float32) + idx = np.arange(n_features).astype(np.int64) + + index = faiss.IndexFlatL2(feature_dims) + index = faiss.IndexIDMap2(index) + index.add_with_ids(features, idx) + + [index.id_map.at(int(i)) for i in range(index.ntotal)] + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_dealloc_invlists.cpp b/core/src/index/thirdparty/faiss/tests/test_dealloc_invlists.cpp new file mode 100644 index 0000000000..d77cd242ac --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_dealloc_invlists.cpp @@ -0,0 +1,183 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace faiss; + +namespace { + +typedef Index::idx_t idx_t; + + +// dimension of the vectors to index +int d = 32; + +// nb of training vectors +size_t nt = 5000; + +// size of the database points per window step +size_t nb = 1000; + +// nb of queries +size_t nq = 200; + + +std::vector make_data(size_t n) +{ + std::vector database (n * d); + for (size_t i = 0; i < n * d; i++) { + database[i] = drand48(); + } + return database; +} + +std::unique_ptr make_trained_index(const char *index_type) +{ + auto index = std::unique_ptr(index_factory(d, index_type)); + auto xt = make_data(nt * d); + index->train(nt, xt.data()); + ParameterSpace().set_index_parameter (index.get(), "nprobe", 4); + return index; +} + +std::vector search_index(Index *index, const float *xq) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + index->search (nq, xq, k, D.data(), I.data()); + return I; +} + + + + + +/************************************************************* + * Test functions for a given index type + *************************************************************/ + +struct EncapsulateInvertedLists: InvertedLists { + + const InvertedLists *il; + + EncapsulateInvertedLists(const InvertedLists *il): + InvertedLists(il->nlist, il->code_size), + il(il) + {} + + static void * memdup (const void *m, size_t size) { + if (size == 0) return nullptr; + return memcpy (malloc(size), m, size); + } + + size_t list_size(size_t list_no) const override { + return il->list_size (list_no); + } + + const uint8_t * get_codes (size_t list_no) const override { + return (uint8_t*)memdup (il->get_codes(list_no), + list_size(list_no) * code_size); + } + + const idx_t * get_ids (size_t list_no) const override { + return (idx_t*)memdup (il->get_ids(list_no), + list_size(list_no) * sizeof(idx_t)); + } + + void release_codes (size_t, const uint8_t *codes) const override { + free ((void*)codes); + } + + void release_ids (size_t, const idx_t *ids) const override { + free ((void*)ids); + } + + const uint8_t * get_single_code (size_t list_no, size_t offset) + const override { + return (uint8_t*)memdup (il->get_single_code(list_no, offset), + code_size); + } + + size_t add_entries(size_t, size_t, const idx_t*, const uint8_t*) override { + assert(!"not implemented"); + return 0; + } + + void update_entries(size_t, size_t, size_t, const idx_t*, const uint8_t*) + override { + assert(!"not implemented"); + } + + void resize(size_t, size_t) override { + assert(!"not implemented"); + } + + ~EncapsulateInvertedLists() override {} +}; + + + +int test_dealloc_invlists (const char *index_key) { + + std::unique_ptr index = make_trained_index(index_key); + IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get()); + + auto xb = make_data (nb * d); + index->add(nb, xb.data()); + + auto xq = make_data (nq * d); + + auto ref_res = search_index (index.get(), xq.data()); + + EncapsulateInvertedLists eil(index_ivf->invlists); + + index_ivf->own_invlists = false; + index_ivf->replace_invlists (&eil, false); + + // TEST: this could crash or leak mem + auto new_res = search_index (index.get(), xq.data()); + + // delete explicitly + delete eil.il; + + // just to make sure + EXPECT_EQ (ref_res, new_res); + return 0; +} + +} // anonymous namespace + + + +/************************************************************* + * Test entry points + *************************************************************/ + +TEST(TestIvlistDealloc, IVFFlat) { + test_dealloc_invlists ("IVF32,Flat"); +} + +TEST(TestIvlistDealloc, IVFSQ) { + test_dealloc_invlists ("IVF32,SQ8"); +} + +TEST(TestIvlistDealloc, IVFPQ) { + test_dealloc_invlists ("IVF32,PQ4np"); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_extra_distances.py b/core/src/index/thirdparty/faiss/tests/test_extra_distances.py new file mode 100644 index 0000000000..3977075879 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_extra_distances.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 +# noqa E741 + +import numpy as np + +import faiss +import unittest + +from common import get_dataset_2 + +import scipy.spatial.distance + + +class TestExtraDistances(unittest.TestCase): + """ check wrt. the scipy implementation """ + + def make_example(self): + rs = np.random.RandomState(123) + x = rs.rand(5, 32).astype('float32') + y = rs.rand(3, 32).astype('float32') + return x, y + + def run_simple_dis_test(self, ref_func, metric_type): + xq, yb = self.make_example() + ref_dis = np.array([ + [ref_func(x, y) for y in yb] + for x in xq + ]) + new_dis = faiss.pairwise_distances(xq, yb, metric_type) + self.assertTrue(np.allclose(ref_dis, new_dis)) + + def test_L1(self): + self.run_simple_dis_test(scipy.spatial.distance.cityblock, + faiss.METRIC_L1) + + def test_Linf(self): + self.run_simple_dis_test(scipy.spatial.distance.chebyshev, + faiss.METRIC_Linf) + + def test_L2(self): + xq, yb = self.make_example() + ref_dis = np.array([ + [scipy.spatial.distance.sqeuclidean(x, y) for y in yb] + for x in xq + ]) + new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_L2) + self.assertTrue(np.allclose(ref_dis, new_dis)) + + ref_dis = np.array([ + [scipy.spatial.distance.euclidean(x, y) for y in yb] + for x in xq + ]) + new_dis = np.sqrt(new_dis) # post processing + self.assertTrue(np.allclose(ref_dis, new_dis)) + + def test_Lp(self): + p = 1.5 + xq, yb = self.make_example() + ref_dis = np.array([ + [scipy.spatial.distance.minkowski(x, y, p) for y in yb] + for x in xq + ]) + new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_Lp, p) + new_dis = new_dis ** (1 / p) # post processing + self.assertTrue(np.allclose(ref_dis, new_dis)) + + def test_canberra(self): + self.run_simple_dis_test(scipy.spatial.distance.canberra, + faiss.METRIC_Canberra) + + def test_braycurtis(self): + self.run_simple_dis_test(scipy.spatial.distance.braycurtis, + faiss.METRIC_BrayCurtis) + + def xx_test_jensenshannon(self): + # this distance does not seem to be implemented in scipy + # vectors should probably be L1 normalized + self.run_simple_dis_test(scipy.spatial.distance.jensenshannon, + faiss.METRIC_JensenShannon) + + +class TestKNN(unittest.TestCase): + """ test that the knn search gives the same as distance matrix + argmin """ + + def do_test_knn(self, mt): + d = 10 + nb = 100 + nq = 50 + nt = 0 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + index = faiss.IndexFlat(d, mt) + index.add(xb) + + D, I = index.search(xq, 10) + + dis = faiss.pairwise_distances(xq, xb, mt) + o = dis.argsort(axis=1) + assert np.all(I == o[:, :10]) + + for q in range(nq): + assert np.all(D[q] == dis[q, I[q]]) + + index2 = faiss.deserialize_index(faiss.serialize_index(index)) + + D2, I2 = index2.search(xq, 10) + + self.assertTrue(np.all(I == I2)) + + def test_L1(self): + self.do_test_knn(faiss.METRIC_L1) + + def test_Linf(self): + self.do_test_knn(faiss.METRIC_Linf) + + +class TestHNSW(unittest.TestCase): + """ since it has a distance computer, HNSW should work """ + + def test_hnsw(self): + + d = 10 + nb = 1000 + nq = 100 + nt = 0 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + mt = faiss.METRIC_L1 + + index = faiss.IndexHNSW(faiss.IndexFlat(d, mt)) + index.add(xb) + + D, I = index.search(xq, 10) + + dis = faiss.pairwise_distances(xq, xb, mt) + + for q in range(nq): + assert np.all(D[q] == dis[q, I[q]]) diff --git a/core/src/index/thirdparty/faiss/tests/test_factory.py b/core/src/index/thirdparty/faiss/tests/test_factory.py new file mode 100644 index 0000000000..e08b0ca850 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_factory.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import numpy as np +import unittest +import faiss + + +class TestFactory(unittest.TestCase): + + def test_factory_1(self): + + index = faiss.index_factory(12, "IVF10,PQ4") + assert index.do_polysemous_training + + index = faiss.index_factory(12, "IVF10,PQ4np") + assert not index.do_polysemous_training + + index = faiss.index_factory(12, "PQ4") + assert index.do_polysemous_training + + index = faiss.index_factory(12, "PQ4np") + assert not index.do_polysemous_training + + try: + index = faiss.index_factory(10, "PQ4") + except RuntimeError: + pass + else: + assert False, "should do a runtime error" + + def test_factory_2(self): + + index = faiss.index_factory(12, "SQ8") + assert index.code_size == 12 + + def test_factory_3(self): + + index = faiss.index_factory(12, "IVF10,PQ4") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 3) + assert index.nprobe == 3 + + index = faiss.index_factory(12, "PCAR8,IVF10,PQ4") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 3) + assert faiss.downcast_index(index.index).nprobe == 3 + + def test_factory_4(self): + index = faiss.index_factory(12, "IVF10,FlatDedup") + assert index.instances is not None + + +class TestCloneSize(unittest.TestCase): + + def test_clone_size(self): + index = faiss.index_factory(20, 'PCA10,Flat') + xb = faiss.rand((100, 20)) + index.train(xb) + index.add(xb) + index2 = faiss.clone_index(index) + assert index2.ntotal == 100 diff --git a/core/src/index/thirdparty/faiss/tests/test_index.py b/core/src/index/thirdparty/faiss/tests/test_index.py new file mode 100644 index 0000000000..c41f7f8c0b --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_index.py @@ -0,0 +1,639 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""this is a basic test script for simple indices work""" +from __future__ import absolute_import, division, print_function +# no unicode_literals because it messes up in py2 + +import numpy as np +import unittest +import faiss +import tempfile +import os +import re +import warnings + +from common import get_dataset, get_dataset_2 + +class TestModuleInterface(unittest.TestCase): + + def test_version_attribute(self): + assert hasattr(faiss, '__version__') + assert re.match('^\\d+\\.\\d+\\.\\d+$', faiss.__version__) + + +class EvalIVFPQAccuracy(unittest.TestCase): + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + d = xt.shape[1] + + gt_index = faiss.IndexFlatL2(d) + gt_index.add(xb) + D, gt_nns = gt_index.search(xq, 1) + + coarse_quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.train(xt) + index.add(xb) + index.nprobe = 4 + D, nns = index.search(xq, 10) + n_ok = (nns == gt_nns).sum() + nq = xq.shape[0] + + self.assertGreater(n_ok, nq * 0.66) + + # check that and Index2Layer gives the same reconstruction + # this is a bit fragile: it assumes 2 runs of training give + # the exact same result. + index2 = faiss.Index2Layer(coarse_quantizer, 32, 8) + if True: + index2.train(xt) + else: + index2.pq = index.pq + index2.is_trained = True + index2.add(xb) + ref_recons = index.reconstruct_n(0, nb) + new_recons = index2.reconstruct_n(0, nb) + self.assertTrue(np.all(ref_recons == new_recons)) + + + def test_IMI(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + d = xt.shape[1] + + gt_index = faiss.IndexFlatL2(d) + gt_index.add(xb) + D, gt_nns = gt_index.search(xq, 1) + + nbits = 5 + coarse_quantizer = faiss.MultiIndexQuantizer(d, 2, nbits) + index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8) + index.quantizer_trains_alone = 1 + index.train(xt) + index.add(xb) + index.nprobe = 100 + D, nns = index.search(xq, 10) + n_ok = (nns == gt_nns).sum() + + # Should return 166 on mac, and 170 on linux. + self.assertGreater(n_ok, 165) + + ############# replace with explicit assignment indexes + nbits = 5 + pq = coarse_quantizer.pq + centroids = faiss.vector_to_array(pq.centroids) + centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub) + ai0 = faiss.IndexFlatL2(pq.dsub) + ai0.add(centroids[0]) + ai1 = faiss.IndexFlatL2(pq.dsub) + ai1.add(centroids[1]) + + coarse_quantizer_2 = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1) + coarse_quantizer_2.pq = pq + coarse_quantizer_2.is_trained = True + + index.quantizer = coarse_quantizer_2 + + index.reset() + index.add(xb) + + D, nns = index.search(xq, 10) + n_ok = (nns == gt_nns).sum() + + # should return the same result + self.assertGreater(n_ok, 165) + + + def test_IMI_2(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + d = xt.shape[1] + + gt_index = faiss.IndexFlatL2(d) + gt_index.add(xb) + D, gt_nns = gt_index.search(xq, 1) + + ############# redo including training + nbits = 5 + ai0 = faiss.IndexFlatL2(int(d / 2)) + ai1 = faiss.IndexFlatL2(int(d / 2)) + + coarse_quantizer = faiss.MultiIndexQuantizer2(d, nbits, ai0, ai1) + index = faiss.IndexIVFPQ(coarse_quantizer, d, (1 << nbits) ** 2, 8, 8) + index.quantizer_trains_alone = 1 + index.train(xt) + index.add(xb) + index.nprobe = 100 + D, nns = index.search(xq, 10) + n_ok = (nns == gt_nns).sum() + + # should return the same result + self.assertGreater(n_ok, 165) + + + + + +class TestMultiIndexQuantizer(unittest.TestCase): + + def test_search_k1(self): + + # verify codepath for k = 1 and k > 1 + + d = 64 + nb = 0 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + miq = faiss.MultiIndexQuantizer(d, 2, 6) + + miq.train(xt) + + D1, I1 = miq.search(xq, 1) + + D5, I5 = miq.search(xq, 5) + + self.assertEqual(np.abs(I1[:, :1] - I5[:, :1]).max(), 0) + self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0) + + +class TestScalarQuantizer(unittest.TestCase): + + def test_4variants_ivf(self): + d = 32 + nt = 2500 + nq = 400 + nb = 5000 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + + # common quantizer + quantizer = faiss.IndexFlatL2(d) + + ncent = 64 + + index_gt = faiss.IndexFlatL2(d) + index_gt.add(xb) + D, I_ref = index_gt.search(xq, 10) + + nok = {} + + index = faiss.IndexIVFFlat(quantizer, d, ncent, + faiss.METRIC_L2) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + nok['flat'] = (I[:, 0] == I_ref[:, 0]).sum() + + for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): + qtype = getattr(faiss.ScalarQuantizer, qname) + index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent, + qtype, faiss.METRIC_L2) + + index.nprobe = 4 + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + + nok[qname] = (I[:, 0] == I_ref[:, 0]).sum() + print(nok, nq) + + self.assertGreaterEqual(nok['flat'], nq * 0.6) + # The tests below are a bit fragile, it happens that the + # ordering between uniform and non-uniform are reverted, + # probably because the dataset is small, which introduces + # jitter + self.assertGreaterEqual(nok['flat'], nok['QT_8bit']) + self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit']) + self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) + self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) + self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) + + def test_4variants(self): + d = 32 + nt = 2500 + nq = 400 + nb = 5000 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index_gt = faiss.IndexFlatL2(d) + index_gt.add(xb) + D_ref, I_ref = index_gt.search(xq, 10) + + nok = {} + + for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform QT_fp16".split(): + qtype = getattr(faiss.ScalarQuantizer, qname) + index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2) + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + nok[qname] = (I[:, 0] == I_ref[:, 0]).sum() + + print(nok, nq) + + self.assertGreaterEqual(nok['QT_8bit'], nq * 0.9) + self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit']) + self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) + self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) + self.assertGreaterEqual(nok['QT_fp16'], nok['QT_8bit']) + + +class TestRangeSearch(unittest.TestCase): + + def test_range_search(self): + d = 4 + nt = 100 + nq = 10 + nb = 50 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + + Dref, Iref = index.search(xq, 5) + + thresh = 0.1 # *squared* distance + lims, D, I = index.range_search(xq, thresh) + + for i in range(nq): + Iline = I[lims[i]:lims[i + 1]] + Dline = D[lims[i]:lims[i + 1]] + for j, dis in zip(Iref[i], Dref[i]): + if dis < thresh: + li, = np.where(Iline == j) + self.assertTrue(li.size == 1) + idx = li[0] + self.assertGreaterEqual(1e-4, abs(Dline[idx] - dis)) + + +class TestSearchAndReconstruct(unittest.TestCase): + + def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None): + n, d = xb.shape + assert xq.shape[1] == d + assert index.d == d + + D_ref, I_ref = index.search(xq, k) + R_ref = index.reconstruct_n(0, n) + D, I, R = index.search_and_reconstruct(xq, k) + + self.assertTrue((D == D_ref).all()) + self.assertTrue((I == I_ref).all()) + self.assertEqual(R.shape[:2], I.shape) + self.assertEqual(R.shape[2], d) + + # (n, k, ..) -> (n * k, ..) + I_flat = I.reshape(-1) + R_flat = R.reshape(-1, d) + # Filter out -1s when not enough results + R_flat = R_flat[I_flat >= 0] + I_flat = I_flat[I_flat >= 0] + + recons_ref_err = np.mean(np.linalg.norm(R_flat - R_ref[I_flat])) + self.assertLessEqual(recons_ref_err, 1e-6) + + def norm1(x): + return np.sqrt((x ** 2).sum(axis=1)) + + recons_err = np.mean(norm1(R_flat - xb[I_flat])) + + print('Reconstruction error = %.3f' % recons_err) + if eps is not None: + self.assertLessEqual(recons_err, eps) + + return D, I, R + + def test_IndexFlat(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=0.0) + + def test_IndexIVFFlat(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 32, faiss.METRIC_L2) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=0.0) + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=1.0) + + def test_MultiIndex(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.index_factory(d, "IMI2x5,PQ8np") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=1.0) + + def test_IndexTransform(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.index_factory(d, "L2norm,PCA8,IVF32,PQ8np") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq) + + +class TestHNSW(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + + (_, self.xb, self.xq) = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexFlatL2(d) + index.add(self.xb) + Dref, Iref = index.search(self.xq, 1) + self.Iref = Iref + + def test_hnsw(self): + d = self.xq.shape[1] + + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def test_hnsw_unbounded_queue(self): + d = self.xq.shape[1] + + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + index.search_bounded_queue = False + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 460) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def io_and_retest(self, index, Dhnsw, Ihnsw): + _, tmpfile = tempfile.mkstemp() + try: + faiss.write_index(index, tmpfile) + index2 = faiss.read_index(tmpfile) + finally: + if os.path.exists(tmpfile): + os.unlink(tmpfile) + + Dhnsw2, Ihnsw2 = index2.search(self.xq, 1) + + self.assertTrue(np.all(Dhnsw2 == Dhnsw)) + self.assertTrue(np.all(Ihnsw2 == Ihnsw)) + + # also test clone + index3 = faiss.clone_index(index) + Dhnsw3, Ihnsw3 = index3.search(self.xq, 1) + + self.assertTrue(np.all(Dhnsw3 == Dhnsw)) + self.assertTrue(np.all(Ihnsw3 == Ihnsw)) + + + def test_hnsw_2level(self): + d = self.xq.shape[1] + + quant = faiss.IndexFlatL2(d) + + index = faiss.IndexHNSW2Level(quant, 256, 8, 8) + index.train(self.xb) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + self.assertGreaterEqual((self.Iref == Ihnsw).sum(), 310) + + self.io_and_retest(index, Dhnsw, Ihnsw) + + def test_add_0_vecs(self): + index = faiss.IndexHNSWFlat(10, 16) + zero_vecs = np.zeros((0, 10), dtype='float32') + # infinite loop + index.add(zero_vecs) + + def test_hnsw_IP(self): + d = self.xq.shape[1] + + index_IP = faiss.IndexFlatIP(d) + index_IP.add(self.xb) + Dref, Iref = index_IP.search(self.xq, 1) + + index = faiss.IndexHNSWFlat(d, 16, faiss.METRIC_INNER_PRODUCT) + index.add(self.xb) + Dhnsw, Ihnsw = index.search(self.xq, 1) + + print('nb equal: ', (Iref == Ihnsw).sum()) + + self.assertGreaterEqual((Iref == Ihnsw).sum(), 480) + + mask = Iref[:, 0] == Ihnsw[:, 0] + assert np.allclose(Dref[mask, 0], Dhnsw[mask, 0]) + + + + +class TestDistancesPositive(unittest.TestCase): + + def test_l2_pos(self): + """ + roundoff errors occur only with the L2 decomposition used + with BLAS, ie. in IndexFlatL2 and with + n > distance_compute_blas_threshold = 20 + """ + + d = 128 + n = 100 + + rs = np.random.RandomState(1234) + x = rs.rand(n, d).astype('float32') + + index = faiss.IndexFlatL2(d) + index.add(x) + + D, I = index.search(x, 10) + + assert np.all(D >= 0) + + +class TestReconsException(unittest.TestCase): + + def test_recons_exception(self): + + d = 64 # dimension + nb = 1000 + rs = np.random.RandomState(1234) + xb = rs.rand(nb, d).astype('float32') + nlist = 10 + quantizer = faiss.IndexFlatL2(d) # the other index + index = faiss.IndexIVFFlat(quantizer, d, nlist) + index.train(xb) + index.add(xb) + index.make_direct_map() + + index.reconstruct(9) + + self.assertRaises( + RuntimeError, + index.reconstruct, 100001 + ) + + def test_reconstuct_after_add(self): + index = faiss.index_factory(10, 'IVF5,SQfp16') + index.train(faiss.randn((100, 10), 123)) + index.add(faiss.randn((100, 10), 345)) + index.make_direct_map() + index.add(faiss.randn((100, 10), 678)) + + # should not raise an exception + index.reconstruct(5) + print(index.ntotal) + index.reconstruct(150) + + +class TestReconsHash(unittest.TestCase): + + def do_test(self, index_key): + d = 32 + index = faiss.index_factory(d, index_key) + index.train(faiss.randn((100, d), 123)) + + # reference reconstruction + index.add(faiss.randn((100, d), 345)) + index.add(faiss.randn((100, d), 678)) + ref_recons = index.reconstruct_n(0, 200) + + # with lookup + index.reset() + rs = np.random.RandomState(123) + ids = rs.choice(10000, size=200, replace=False) + index.add_with_ids(faiss.randn((100, d), 345), ids[:100]) + index.set_direct_map_type(faiss.DirectMap.Hashtable) + index.add_with_ids(faiss.randn((100, d), 678), ids[100:]) + + # compare + for i in range(0, 200, 13): + recons = index.reconstruct(int(ids[i])) + self.assertTrue(np.all(recons == ref_recons[i])) + + # test I/O + buf = faiss.serialize_index(index) + index2 = faiss.deserialize_index(buf) + + # compare + for i in range(0, 200, 13): + recons = index2.reconstruct(int(ids[i])) + self.assertTrue(np.all(recons == ref_recons[i])) + + # remove + toremove = np.ascontiguousarray(ids[0:200:3]) + + sel = faiss.IDSelectorArray(50, faiss.swig_ptr(toremove[:50])) + + # test both ways of removing elements + nremove = index2.remove_ids(sel) + nremove += index2.remove_ids(toremove[50:]) + + self.assertEqual(nremove, len(toremove)) + + for i in range(0, 200, 13): + if i % 3 == 0: + self.assertRaises( + RuntimeError, + index2.reconstruct, int(ids[i]) + ) + else: + recons = index2.reconstruct(int(ids[i])) + self.assertTrue(np.all(recons == ref_recons[i])) + + # index error should raise + self.assertRaises( + RuntimeError, + index.reconstruct, 20000 + ) + + def test_IVFFlat(self): + self.do_test("IVF5,Flat") + + def test_IVFSQ(self): + self.do_test("IVF5,SQfp16") + + def test_IVFPQ(self): + self.do_test("IVF5,PQ4x4np") + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_index_accuracy.py b/core/src/index/thirdparty/faiss/tests/test_index_accuracy.py new file mode 100644 index 0000000000..d97362f843 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_index_accuracy.py @@ -0,0 +1,673 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals +# noqa E741 +# translation of test_knn.lua + +import numpy as np +import unittest +import faiss + +from common import Randu10k, get_dataset_2, Randu10kUnbalanced + +ev = Randu10k() + +d = ev.d + +# Parameters inverted indexes +ncentroids = int(4 * np.sqrt(ev.nb)) +kprobe = int(np.sqrt(ncentroids)) + +# Parameters for LSH +nbits = d + +# Parameters for indexes involving PQ +M = int(d / 8) # for PQ: #subquantizers +nbits_per_index = 8 # for PQ + + +class IndexAccuracy(unittest.TestCase): + + def test_IndexFlatIP(self): + q = faiss.IndexFlatIP(d) # Ask inner product + res = ev.launch('FLAT / IP', q) + e = ev.evalres(res) + assert e[1] == 1.0 + + def test_IndexFlatL2(self): + q = faiss.IndexFlatL2(d) + res = ev.launch('FLAT / L2', q) + e = ev.evalres(res) + assert e[1] == 1.0 + + def test_ivf_kmeans(self): + ivfk = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, ncentroids) + ivfk.nprobe = kprobe + res = ev.launch('IndexIVFFlat', ivfk) + e = ev.evalres(res) + # should give 0.260 0.260 0.260 + assert e[1] > 0.2 + + # test parallel mode + Dref, Iref = ivfk.search(ev.xq, 100) + ivfk.parallel_mode = 1 + Dnew, Inew = ivfk.search(ev.xq, 100) + print((Iref != Inew).sum(), Iref.size) + assert (Iref != Inew).sum() < Iref.size / 5000.0 + assert np.all(Dref == Dnew) + + def test_indexLSH(self): + q = faiss.IndexLSH(d, nbits) + res = ev.launch('FLAT / LSH Cosine', q) + e = ev.evalres(res) + # should give 0.070 0.250 0.580 + assert e[10] > 0.2 + + def test_IndexLSH_32_48(self): + # CHECK: the difference between 32 and 48 does not make much sense + for nbits2 in 32, 48: + q = faiss.IndexLSH(d, nbits2) + res = ev.launch('LSH half size', q) + e = ev.evalres(res) + # should give 0.003 0.019 0.108 + assert e[10] > 0.018 + + def test_IndexPQ(self): + q = faiss.IndexPQ(d, M, nbits_per_index) + res = ev.launch('FLAT / PQ L2', q) + e = ev.evalres(res) + # should give 0.070 0.230 0.260 + assert e[10] > 0.2 + + # Approximate search module: PQ with inner product distance + def test_IndexPQ_ip(self): + q = faiss.IndexPQ(d, M, nbits_per_index, faiss.METRIC_INNER_PRODUCT) + res = ev.launch('FLAT / PQ IP', q) + e = ev.evalres(res) + # should give 0.070 0.230 0.260 + #(same result as regular PQ on normalized distances) + assert e[10] > 0.2 + + def test_IndexIVFPQ(self): + ivfpq = faiss.IndexIVFPQ(faiss.IndexFlatL2(d), d, ncentroids, M, 8) + ivfpq.nprobe = kprobe + res = ev.launch('IVF PQ', ivfpq) + e = ev.evalres(res) + # should give 0.070 0.230 0.260 + assert e[10] > 0.2 + + # TODO: translate evaluation of nested + + # Approximate search: PQ with full vector refinement + def test_IndexPQ_refined(self): + q = faiss.IndexPQ(d, M, nbits_per_index) + res = ev.launch('PQ non-refined', q) + e = ev.evalres(res) + q.reset() + + rq = faiss.IndexRefineFlat(q) + res = ev.launch('PQ refined', rq) + e2 = ev.evalres(res) + assert e2[10] >= e[10] + rq.k_factor = 4 + + res = ev.launch('PQ refined*4', rq) + e3 = ev.evalres(res) + assert e3[10] >= e2[10] + + def test_polysemous(self): + index = faiss.IndexPQ(d, M, nbits_per_index) + index.do_polysemous_training = True + # reduce nb iterations to speed up training for the test + index.polysemous_training.n_iter = 50000 + index.polysemous_training.n_redo = 1 + res = ev.launch('normal PQ', index) + e_baseline = ev.evalres(res) + index.search_type = faiss.IndexPQ.ST_polysemous + + index.polysemous_ht = int(M / 16. * 58) + + stats = faiss.cvar.indexPQ_stats + stats.reset() + + res = ev.launch('Polysemous ht=%d' % index.polysemous_ht, + index) + e_polysemous = ev.evalres(res) + print(e_baseline, e_polysemous, index.polysemous_ht) + print(stats.n_hamming_pass, stats.ncode) + # The randu dataset is difficult, so we are not too picky on + # the results. Here we assert that we have < 10 % loss when + # computing full PQ on fewer than 20% of the data. + assert stats.n_hamming_pass < stats.ncode / 5 + # Test disabled because difference is 0.17 on aarch64 + # TODO check why??? + # assert e_polysemous[10] > e_baseline[10] - 0.1 + + def test_ScalarQuantizer(self): + quantizer = faiss.IndexFlatL2(d) + ivfpq = faiss.IndexIVFScalarQuantizer( + quantizer, d, ncentroids, + faiss.ScalarQuantizer.QT_8bit) + ivfpq.nprobe = kprobe + res = ev.launch('IVF SQ', ivfpq) + e = ev.evalres(res) + # should give 0.234 0.236 0.236 + assert e[10] > 0.235 + + + +class TestSQFlavors(unittest.TestCase): + """ tests IP in addition to L2, non multiple of 8 dimensions + """ + + def add2columns(self, x): + return np.hstack(( + x, np.zeros((x.shape[0], 2), dtype='float32') + )) + + def subtest_add2col(self, xb, xq, index, qname): + """Test with 2 additional dimensions to take also the non-SIMD + codepath. We don't retrain anything but add 2 dims to the + queries, the centroids and the trained ScalarQuantizer. + """ + nb, d = xb.shape + + d2 = d + 2 + xb2 = self.add2columns(xb) + xq2 = self.add2columns(xq) + + nlist = index.nlist + quantizer = faiss.downcast_index(index.quantizer) + quantizer2 = faiss.IndexFlat(d2, index.metric_type) + centroids = faiss.vector_to_array(quantizer.xb).reshape(nlist, d) + centroids2 = self.add2columns(centroids) + quantizer2.add(centroids2) + index2 = faiss.IndexIVFScalarQuantizer( + quantizer2, d2, index.nlist, index.sq.qtype, + index.metric_type) + index2.nprobe = 4 + if qname in ('8bit', '4bit'): + trained = faiss.vector_to_array(index.sq.trained).reshape(2, -1) + nt = trained.shape[1] + # 2 lines: vmins and vdiffs + new_nt = int(nt * d2 / d) + trained2 = np.hstack(( + trained, + np.zeros((2, new_nt - nt), dtype='float32') + )) + trained2[1, nt:] = 1.0 # set vdiff to 1 to avoid div by 0 + faiss.copy_array_to_vector(trained2.ravel(), index2.sq.trained) + else: + index2.sq.trained = index.sq.trained + + index2.is_trained = True + index2.add(xb2) + return index2.search(xq2, 10) + + + # run on Sept 18, 2018 with nprobe=4 + 4 bit bugfix + ref_results = { + (0, '8bit'): 984, + (0, '4bit'): 978, + (0, '8bit_uniform'): 985, + (0, '4bit_uniform'): 979, + (0, 'fp16'): 985, + (1, '8bit'): 979, + (1, '4bit'): 973, + (1, '8bit_uniform'): 979, + (1, '4bit_uniform'): 972, + (1, 'fp16'): 979, + # added 2019-06-26 + (0, '6bit'): 985, + (1, '6bit'): 987, + } + + def subtest(self, mt): + d = 32 + xt, xb, xq = get_dataset_2(d, 2000, 1000, 200) + nlist = 64 + + gt_index = faiss.IndexFlat(d, mt) + gt_index.add(xb) + gt_D, gt_I = gt_index.search(xq, 10) + quantizer = faiss.IndexFlat(d, mt) + for qname in '8bit 4bit 8bit_uniform 4bit_uniform fp16 6bit'.split(): + qtype = getattr(faiss.ScalarQuantizer, 'QT_' + qname) + index = faiss.IndexIVFScalarQuantizer( + quantizer, d, nlist, qtype, mt) + index.train(xt) + index.add(xb) + index.nprobe = 4 # hopefully more robust than 1 + D, I = index.search(xq, 10) + ninter = faiss.eval_intersection(I, gt_I) + print('(%d, %s): %d, ' % (mt, repr(qname), ninter)) + assert abs(ninter - self.ref_results[(mt, qname)]) <= 10 + + if qname == '6bit': + # the test below fails triggers ASAN. TODO check what's wrong + continue + + D2, I2 = self.subtest_add2col(xb, xq, index, qname) + assert np.all(I2 == I) + + # also test range search + + if mt == faiss.METRIC_INNER_PRODUCT: + radius = float(D[:, -1].max()) + else: + radius = float(D[:, -1].min()) + print('radius', radius) + + lims, D3, I3 = index.range_search(xq, radius) + ntot = ndiff = 0 + for i in range(len(xq)): + l0, l1 = lims[i], lims[i + 1] + Inew = set(I3[l0:l1]) + if mt == faiss.METRIC_INNER_PRODUCT: + mask = D2[i] > radius + else: + mask = D2[i] < radius + Iref = set(I2[i, mask]) + ndiff += len(Inew ^ Iref) + ntot += len(Iref) + print('ndiff %d / %d' % (ndiff, ntot)) + assert ndiff < ntot * 0.01 + + for pm in 1, 2: + print('parallel_mode=%d' % pm) + index.parallel_mode = pm + lims4, D4, I4 = index.range_search(xq, radius) + print('sizes', lims4[1:] - lims4[:-1]) + for qno in range(len(lims) - 1): + Iref = I3[lims[qno]: lims[qno+1]] + Inew = I4[lims4[qno]: lims4[qno+1]] + assert set(Iref) == set(Inew), "q %d ref %s new %s" % ( + qno, Iref, Inew) + + def test_SQ_IP(self): + self.subtest(faiss.METRIC_INNER_PRODUCT) + + def test_SQ_L2(self): + self.subtest(faiss.METRIC_L2) + + +class TestSQByte(unittest.TestCase): + + def subtest_8bit_direct(self, metric_type, d): + xt, xb, xq = get_dataset_2(d, 500, 1000, 30) + + # rescale everything to get integer + tmin, tmax = xt.min(), xt.max() + + def rescale(x): + x = np.floor((x - tmin) * 256 / (tmax - tmin)) + x[x < 0] = 0 + x[x > 255] = 255 + return x + + xt = rescale(xt) + xb = rescale(xb) + xq = rescale(xq) + + gt_index = faiss.IndexFlat(d, metric_type) + gt_index.add(xb) + Dref, Iref = gt_index.search(xq, 10) + + index = faiss.IndexScalarQuantizer( + d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type) + index.add(xb) + D, I = index.search(xq, 10) + + assert np.all(I == Iref) + assert np.all(D == Dref) + + # same, with IVF + + nlist = 64 + quantizer = faiss.IndexFlat(d, metric_type) + + gt_index = faiss.IndexIVFFlat(quantizer, d, nlist, metric_type) + gt_index.nprobe = 4 + gt_index.train(xt) + gt_index.add(xb) + Dref, Iref = gt_index.search(xq, 10) + + index = faiss.IndexIVFScalarQuantizer( + quantizer, d, nlist, + faiss.ScalarQuantizer.QT_8bit_direct, metric_type) + index.nprobe = 4 + index.by_residual = False + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + + assert np.all(I == Iref) + assert np.all(D == Dref) + + def test_8bit_direct(self): + for d in 13, 16, 24: + for metric_type in faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT: + self.subtest_8bit_direct(metric_type, d) + + + +class TestPQFlavors(unittest.TestCase): + + # run on Dec 14, 2018 + ref_results = { + (1, True): 800, + (1, True, 20): 794, + (1, False): 769, + (0, True): 831, + (0, True, 20): 828, + (0, False): 829, + } + + def test_IVFPQ_IP(self): + self.subtest(faiss.METRIC_INNER_PRODUCT) + + def test_IVFPQ_L2(self): + self.subtest(faiss.METRIC_L2) + + def subtest(self, mt): + d = 32 + xt, xb, xq = get_dataset_2(d, 2000, 1000, 200) + nlist = 64 + + gt_index = faiss.IndexFlat(d, mt) + gt_index.add(xb) + gt_D, gt_I = gt_index.search(xq, 10) + quantizer = faiss.IndexFlat(d, mt) + for by_residual in True, False: + + index = faiss.IndexIVFPQ( + quantizer, d, nlist, 4, 8) + index.metric_type = mt + index.by_residual = by_residual + if by_residual: + # perform cheap polysemous training + index.do_polysemous_training = True + pt = faiss.PolysemousTraining() + pt.n_iter = 50000 + pt.n_redo = 1 + index.polysemous_training = pt + + index.train(xt) + index.add(xb) + index.nprobe = 4 + D, I = index.search(xq, 10) + + ninter = faiss.eval_intersection(I, gt_I) + print('(%d, %s): %d, ' % (mt, by_residual, ninter)) + + assert abs(ninter - self.ref_results[mt, by_residual]) <= 3 + + index.use_precomputed_table = 0 + D2, I2 = index.search(xq, 10) + assert np.all(I == I2) + + if by_residual: + + index.use_precomputed_table = 1 + index.polysemous_ht = 20 + D, I = index.search(xq, 10) + ninter = faiss.eval_intersection(I, gt_I) + print('(%d, %s, %d): %d, ' % ( + mt, by_residual, index.polysemous_ht, ninter)) + + # polysemous behaves bizarrely on ARM + assert (ninter >= self.ref_results[ + mt, by_residual, index.polysemous_ht] - 4) + + # also test range search + + if mt == faiss.METRIC_INNER_PRODUCT: + radius = float(D[:, -1].max()) + else: + radius = float(D[:, -1].min()) + print('radius', radius) + + lims, D3, I3 = index.range_search(xq, radius) + ntot = ndiff = 0 + for i in range(len(xq)): + l0, l1 = lims[i], lims[i + 1] + Inew = set(I3[l0:l1]) + if mt == faiss.METRIC_INNER_PRODUCT: + mask = D2[i] > radius + else: + mask = D2[i] < radius + Iref = set(I2[i, mask]) + ndiff += len(Inew ^ Iref) + ntot += len(Iref) + print('ndiff %d / %d' % (ndiff, ntot)) + assert ndiff < ntot * 0.02 + + def test_IVFPQ_non8bit(self): + d = 16 + xt, xb, xq = get_dataset_2(d, 10000, 2000, 200) + nlist = 64 + + gt_index = faiss.IndexFlat(d) + gt_index.add(xb) + gt_D, gt_I = gt_index.search(xq, 10) + + quantizer = faiss.IndexFlat(d) + ninter = {} + for v in '2x8', '8x2': + if v == '8x2': + index = faiss.IndexIVFPQ( + quantizer, d, nlist, 2, 8) + else: + index = faiss.IndexIVFPQ( + quantizer, d, nlist, 8, 2) + index.train(xt) + index.add(xb) + index.npobe = 16 + + D, I = index.search(xq, 10) + ninter[v] = faiss.eval_intersection(I, gt_I) + print('ninter=', ninter) + # this should be the case but we don't observe + # that... Probavly too few test points + # assert ninter['2x8'] > ninter['8x2'] + # ref numbers on 2019-11-02 + assert abs(ninter['2x8'] - 458) < 4 + assert abs(ninter['8x2'] - 465) < 4 + + +class TestFlat1D(unittest.TestCase): + + def test_flat_1d(self): + rs = np.random.RandomState(123545) + k = 10 + xb = rs.uniform(size=(100, 1)).astype('float32') + # make sure to test below and above + xq = rs.uniform(size=(1000, 1)).astype('float32') * 1.1 - 0.05 + + ref = faiss.IndexFlatL2(1) + ref.add(xb) + ref_D, ref_I = ref.search(xq, k) + + new = faiss.IndexFlat1D() + new.add(xb) + + new_D, new_I = new.search(xq, 10) + + ndiff = (np.abs(ref_I - new_I) != 0).sum() + + assert(ndiff < 100) + new_D = new_D ** 2 + max_diff_D = np.abs(ref_D - new_D).max() + assert(max_diff_D < 1e-5) + + +class OPQRelativeAccuracy(unittest.TestCase): + # translated from test_opq.lua + + def test_OPQ(self): + + M = 4 + + ev = Randu10kUnbalanced() + d = ev.d + index = faiss.IndexPQ(d, M, 8) + + res = ev.launch('PQ', index) + e_pq = ev.evalres(res) + + index_pq = faiss.IndexPQ(d, M, 8) + opq_matrix = faiss.OPQMatrix(d, M) + # opq_matrix.verbose = true + opq_matrix.niter = 10 + opq_matrix.niter_pq = 4 + index = faiss.IndexPreTransform(opq_matrix, index_pq) + + res = ev.launch('OPQ', index) + e_opq = ev.evalres(res) + + print('e_pq=%s' % e_pq) + print('e_opq=%s' % e_opq) + + # verify that OPQ better than PQ + for r in 1, 10, 100: + assert(e_opq[r] > e_pq[r]) + + def test_OIVFPQ(self): + # Parameters inverted indexes + ncentroids = 50 + M = 4 + + ev = Randu10kUnbalanced() + d = ev.d + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(quantizer, d, ncentroids, M, 8) + index.nprobe = 5 + + res = ev.launch('IVFPQ', index) + e_ivfpq = ev.evalres(res) + + quantizer = faiss.IndexFlatL2(d) + index_ivfpq = faiss.IndexIVFPQ(quantizer, d, ncentroids, M, 8) + index_ivfpq.nprobe = 5 + opq_matrix = faiss.OPQMatrix(d, M) + opq_matrix.niter = 10 + index = faiss.IndexPreTransform(opq_matrix, index_ivfpq) + + res = ev.launch('O+IVFPQ', index) + e_oivfpq = ev.evalres(res) + + # verify same on OIVFPQ + for r in 1, 10, 100: + print(e_oivfpq[r], e_ivfpq[r]) + assert(e_oivfpq[r] >= e_ivfpq[r]) + + +class TestRoundoff(unittest.TestCase): + + def test_roundoff(self): + # params that force use of BLAS implementation + nb = 100 + nq = 25 + d = 4 + xb = np.zeros((nb, d), dtype='float32') + + xb[:, 0] = np.arange(nb) + 12345 + xq = xb[:nq] + 0.3 + + index = faiss.IndexFlat(d) + index.add(xb) + + D, I = index.search(xq, 1) + + # this does not work + assert not np.all(I.ravel() == np.arange(nq)) + + index = faiss.IndexPreTransform( + faiss.CenteringTransform(d), + faiss.IndexFlat(d)) + + index.train(xb) + index.add(xb) + + D, I = index.search(xq, 1) + + # this works + assert np.all(I.ravel() == np.arange(nq)) + + +class TestSpectralHash(unittest.TestCase): + + # run on 2019-04-02 + ref_results = { + (32, 'global', 10): 505, + (32, 'centroid', 10): 524, + (32, 'centroid_half', 10): 21, + (32, 'median', 10): 510, + (32, 'global', 1): 8, + (32, 'centroid', 1): 20, + (32, 'centroid_half', 1): 26, + (32, 'median', 1): 14, + (64, 'global', 10): 768, + (64, 'centroid', 10): 767, + (64, 'centroid_half', 10): 21, + (64, 'median', 10): 765, + (64, 'global', 1): 28, + (64, 'centroid', 1): 21, + (64, 'centroid_half', 1): 20, + (64, 'median', 1): 29, + (128, 'global', 10): 968, + (128, 'centroid', 10): 945, + (128, 'centroid_half', 10): 21, + (128, 'median', 10): 958, + (128, 'global', 1): 271, + (128, 'centroid', 1): 279, + (128, 'centroid_half', 1): 171, + (128, 'median', 1): 253, + } + + def test_sh(self): + d = 32 + xt, xb, xq = get_dataset_2(d, 2000, 1000, 200) + nlist, nprobe = 1, 1 + + gt_index = faiss.IndexFlatL2(d) + gt_index.add(xb) + gt_D, gt_I = gt_index.search(xq, 10) + + for nbit in 32, 64, 128: + quantizer = faiss.IndexFlatL2(d) + + index_lsh = faiss.IndexLSH(d, nbit, True) + index_lsh.add(xb) + D, I = index_lsh.search(xq, 10) + ninter = faiss.eval_intersection(I, gt_I) + + print('LSH baseline: %d' % ninter) + + for period in 10.0, 1.0: + + for tt in 'global centroid centroid_half median'.split(): + index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, + nbit, period) + index.nprobe = nprobe + index.threshold_type = getattr( + faiss.IndexIVFSpectralHash, + 'Thresh_' + tt + ) + + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + + ninter = faiss.eval_intersection(I, gt_I) + key = (nbit, tt, period) + + print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter)) + assert abs(ninter - self.ref_results[key]) <= 4 + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_index_binary.py b/core/src/index/thirdparty/faiss/tests/test_index_binary.py new file mode 100644 index 0000000000..c61e2fa5df --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_index_binary.py @@ -0,0 +1,375 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""this is a basic test script for simple indices work""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import unittest +import faiss + +from common import compare_binary_result_lists, make_binary_dataset + + + +def binary_to_float(x): + n, d = x.shape + x8 = x.reshape(n * d, -1) + c8 = 2 * ((x8 >> np.arange(8)) & 1).astype('int8') - 1 + return c8.astype('float32').reshape(n, d * 8) + + +def binary_dis(x, y): + return sum(faiss.popcount64(int(xi ^ yi)) for xi, yi in zip(x, y)) + + +class TestBinaryPQ(unittest.TestCase): + """ Use a PQ that mimicks a binary encoder """ + + def test_encode_to_binary(self): + d = 256 + nt = 256 + nb = 1500 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nt, nb, nq) + pq = faiss.ProductQuantizer(d, int(d / 8), 8) + + centroids = binary_to_float( + np.tile(np.arange(256), int(d / 8)).astype('uint8').reshape(-1, 1)) + + faiss.copy_array_to_vector(centroids.ravel(), pq.centroids) + pq.is_trained = True + + codes = pq.compute_codes(binary_to_float(xb)) + + assert np.all(codes == xb) + + indexpq = faiss.IndexPQ(d, int(d / 8), 8) + indexpq.pq = pq + indexpq.is_trained = True + + indexpq.add(binary_to_float(xb)) + D, I = indexpq.search(binary_to_float(xq), 3) + + for i in range(nq): + for j, dj in zip(I[i], D[i]): + ref_dis = binary_dis(xq[i], xb[j]) + assert 4 * ref_dis == dj + + nlist = 32 + quantizer = faiss.IndexFlatL2(d) + # pretext class for training + iflat = faiss.IndexIVFFlat(quantizer, d, nlist) + iflat.train(binary_to_float(xt)) + + indexivfpq = faiss.IndexIVFPQ(quantizer, d, nlist, int(d / 8), 8) + + indexivfpq.pq = pq + indexivfpq.is_trained = True + indexivfpq.by_residual = False + + indexivfpq.add(binary_to_float(xb)) + indexivfpq.nprobe = 4 + + D, I = indexivfpq.search(binary_to_float(xq), 3) + + for i in range(nq): + for j, dj in zip(I[i], D[i]): + ref_dis = binary_dis(xq[i], xb[j]) + assert 4 * ref_dis == dj + + +class TestBinaryFlat(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + + (_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq) + + def test_flat(self): + d = self.xq.shape[1] * 8 + nq = self.xq.shape[0] + + index = faiss.IndexBinaryFlat(d) + index.add(self.xb) + D, I = index.search(self.xq, 3) + + for i in range(nq): + for j, dj in zip(I[i], D[i]): + ref_dis = binary_dis(self.xq[i], self.xb[j]) + assert dj == ref_dis + + # test reconstruction + assert np.all(index.reconstruct(12) == self.xb[12]) + + def test_empty_flat(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryFlat(d) + + for use_heap in [True, False]: + index.use_heap = use_heap + Dflat, Iflat = index.search(self.xq, 10) + + assert(np.all(Iflat == -1)) + assert(np.all(Dflat == 2147483647)) # NOTE(hoss): int32_t max + + def test_range_search(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryFlat(d) + index.add(self.xb) + D, I = index.search(self.xq, 10) + thresh = int(np.median(D[:, -1])) + + lims, D2, I2 = index.range_search(self.xq, thresh) + nt1 = nt2 = 0 + for i in range(len(self.xq)): + range_res = I2[lims[i]:lims[i + 1]] + if thresh > D[i, -1]: + self.assertTrue(set(I[i]) <= set(range_res)) + nt1 += 1 + elif thresh < D[i, -1]: + self.assertTrue(set(range_res) <= set(I[i])) + nt2 += 1 + # in case of equality we have a problem with ties + print('nb tests', nt1, nt2) + # nb tests is actually low... + self.assertTrue(nt1 > 19 and nt2 > 19) + + +class TestBinaryIVF(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 200 + nb = 1500 + nq = 500 + + (self.xt, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq) + index = faiss.IndexBinaryFlat(d) + index.add(self.xb) + Dref, Iref = index.search(self.xq, 10) + self.Dref = Dref + + def test_ivf_flat_exhaustive(self): + d = self.xq.shape[1] * 8 + + quantizer = faiss.IndexBinaryFlat(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 8 + index.train(self.xt) + index.add(self.xb) + Divfflat, _ = index.search(self.xq, 10) + + np.testing.assert_array_equal(self.Dref, Divfflat) + + def test_ivf_flat2(self): + d = self.xq.shape[1] * 8 + + quantizer = faiss.IndexBinaryFlat(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(self.xt) + index.add(self.xb) + Divfflat, _ = index.search(self.xq, 10) + + self.assertEqual((self.Dref == Divfflat).sum(), 4122) + + def test_ivf_range(self): + d = self.xq.shape[1] * 8 + + quantizer = faiss.IndexBinaryFlat(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(self.xt) + index.add(self.xb) + D, I = index.search(self.xq, 10) + + radius = int(np.median(D[:, -1]) + 1) + Lr, Dr, Ir = index.range_search(self.xq, radius) + + for i in range(len(self.xq)): + res = Ir[Lr[i]:Lr[i + 1]] + if D[i, -1] < radius: + self.assertTrue(set(I[i]) <= set(res)) + else: + subset = I[i, D[i, :] < radius] + self.assertTrue(set(subset) == set(res)) + + + def test_ivf_flat_empty(self): + d = self.xq.shape[1] * 8 + + index = faiss.IndexBinaryIVF(faiss.IndexBinaryFlat(d), d, 8) + index.train(self.xt) + + for use_heap in [True, False]: + index.use_heap = use_heap + Divfflat, Iivfflat = index.search(self.xq, 10) + + assert(np.all(Iivfflat == -1)) + assert(np.all(Divfflat == 2147483647)) # NOTE(hoss): int32_t max + + def test_ivf_reconstruction(self): + d = self.xq.shape[1] * 8 + quantizer = faiss.IndexBinaryFlat(d) + index = faiss.IndexBinaryIVF(quantizer, d, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(self.xt) + + index.add(self.xb) + index.set_direct_map_type(faiss.DirectMap.Array) + + for i in range(0, len(self.xb), 13): + np.testing.assert_array_equal( + index.reconstruct(i), + self.xb[i] + ) + + # try w/ hashtable + index = faiss.IndexBinaryIVF(quantizer, d, 8) + rs = np.random.RandomState(123) + ids = rs.choice(10000, size=len(self.xb), replace=False) + index.add_with_ids(self.xb, ids) + index.set_direct_map_type(faiss.DirectMap.Hashtable) + + for i in range(0, len(self.xb), 13): + np.testing.assert_array_equal( + index.reconstruct(int(ids[i])), + self.xb[i] + ) + + +class TestHNSW(unittest.TestCase): + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + d = 32 + nt = 0 + nb = 1500 + nq = 500 + + (_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq) + + def test_hnsw_exact_distances(self): + d = self.xq.shape[1] * 8 + nq = self.xq.shape[0] + + index = faiss.IndexBinaryHNSW(d, 16) + index.add(self.xb) + Dists, Ids = index.search(self.xq, 3) + + for i in range(nq): + for j, dj in zip(Ids[i], Dists[i]): + ref_dis = binary_dis(self.xq[i], self.xb[j]) + self.assertEqual(dj, ref_dis) + + def test_hnsw(self): + d = self.xq.shape[1] * 8 + + # NOTE(hoss): Ensure the HNSW construction is deterministic. + nthreads = faiss.omp_get_max_threads() + faiss.omp_set_num_threads(1) + + index_hnsw_float = faiss.IndexHNSWFlat(d, 16) + index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float) + + index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16) + + index_hnsw_ref.add(self.xb) + index_hnsw_bin.add(self.xb) + + faiss.omp_set_num_threads(nthreads) + + Dref, Iref = index_hnsw_ref.search(self.xq, 3) + Dbin, Ibin = index_hnsw_bin.search(self.xq, 3) + + self.assertTrue((Dref == Dbin).all()) + + + +class TestReplicasAndShards(unittest.TestCase): + + def test_replicas(self): + d = 32 + nq = 100 + nb = 200 + + (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) + + index_ref = faiss.IndexBinaryFlat(d) + index_ref.add(xb) + + Dref, Iref = index_ref.search(xq, 10) + + nrep = 5 + index = faiss.IndexBinaryReplicas() + for _i in range(nrep): + sub_idx = faiss.IndexBinaryFlat(d) + sub_idx.add(xb) + index.addIndex(sub_idx) + + D, I = index.search(xq, 10) + + self.assertTrue((Dref == D).all()) + self.assertTrue((Iref == I).all()) + + index2 = faiss.IndexBinaryReplicas() + for _i in range(nrep): + sub_idx = faiss.IndexBinaryFlat(d) + index2.addIndex(sub_idx) + + index2.add(xb) + D2, I2 = index2.search(xq, 10) + + self.assertTrue((Dref == D2).all()) + self.assertTrue((Iref == I2).all()) + + def test_shards(self): + d = 32 + nq = 100 + nb = 200 + + (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) + + index_ref = faiss.IndexBinaryFlat(d) + index_ref.add(xb) + + Dref, Iref = index_ref.search(xq, 10) + + nrep = 5 + index = faiss.IndexBinaryShards(d) + for i in range(nrep): + sub_idx = faiss.IndexBinaryFlat(d) + sub_idx.add(xb[i * nb // nrep : (i + 1) * nb // nrep]) + index.add_shard(sub_idx) + + D, I = index.search(xq, 10) + + compare_binary_result_lists(Dref, Iref, D, I) + + index2 = faiss.IndexBinaryShards(d) + for _i in range(nrep): + sub_idx = faiss.IndexBinaryFlat(d) + index2.add_shard(sub_idx) + + index2.add(xb) + D2, I2 = index2.search(xq, 10) + + compare_binary_result_lists(Dref, Iref, D2, I2) + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_index_binary_from_float.py b/core/src/index/thirdparty/faiss/tests/test_index_binary_from_float.py new file mode 100644 index 0000000000..73d6c726d4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_index_binary_from_float.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import numpy as np +import unittest +import faiss + + +def make_binary_dataset(d, nb, nt, nq): + assert d % 8 == 0 + rs = np.random.RandomState(123) + x = rs.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8') + return x[:nt], x[nt:-nq], x[-nq:] + + +def binary_to_float(x): + n, d = x.shape + x8 = x.reshape(n * d, -1) + c8 = 2 * ((x8 >> np.arange(8)) & 1).astype('int8') - 1 + return c8.astype('float32').reshape(n, d * 8) + + +class TestIndexBinaryFromFloat(unittest.TestCase): + """Use a binary index backed by a float index""" + + def test_index_from_float(self): + d = 256 + nt = 0 + nb = 1500 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq) + + index_ref = faiss.IndexFlatL2(d) + index_ref.add(binary_to_float(xb)) + + index = faiss.IndexFlatL2(d) + index_bin = faiss.IndexBinaryFromFloat(index) + index_bin.add(xb) + + D_ref, I_ref = index_ref.search(binary_to_float(xq), 10) + D, I = index_bin.search(xq, 10) + + np.testing.assert_allclose((D_ref / 4.0).astype('int32'), D) + + def test_wrapped_quantizer(self): + d = 256 + nt = 150 + nb = 1500 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq) + + nlist = 16 + quantizer_ref = faiss.IndexBinaryFlat(d) + index_ref = faiss.IndexBinaryIVF(quantizer_ref, d, nlist) + index_ref.train(xt) + + index_ref.add(xb) + + unwrapped_quantizer = faiss.IndexFlatL2(d) + quantizer = faiss.IndexBinaryFromFloat(unwrapped_quantizer) + index = faiss.IndexBinaryIVF(quantizer, d, nlist) + + index.train(xt) + + index.add(xb) + + D_ref, I_ref = index_ref.search(xq, 10) + D, I = index.search(xq, 10) + + np.testing.assert_array_equal(D_ref, D) + + def test_wrapped_quantizer_IMI(self): + d = 256 + nt = 3500 + nb = 10000 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq) + + index_ref = faiss.IndexBinaryFlat(d) + + index_ref.add(xb) + + nlist_exp = 6 + nlist = 2 ** (2 * nlist_exp) + float_quantizer = faiss.MultiIndexQuantizer(d, 2, nlist_exp) + wrapped_quantizer = faiss.IndexBinaryFromFloat(float_quantizer) + wrapped_quantizer.train(xt) + + assert nlist == float_quantizer.ntotal + + index = faiss.IndexBinaryIVF(wrapped_quantizer, d, + float_quantizer.ntotal) + index.nprobe = 2048 + assert index.is_trained + + index.add(xb) + + D_ref, I_ref = index_ref.search(xq, 10) + D, I = index.search(xq, 10) + + recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \ + / float(D_ref.shape[0]) + + assert recall > 0.82, "recall = %g" % recall + + def test_wrapped_quantizer_HNSW(self): + faiss.omp_set_num_threads(1) + + def bin2float(v): + def byte2float(byte): + return np.array([-1.0 + 2.0 * (byte & (1 << b) != 0) + for b in range(0, 8)]) + + return np.hstack([byte2float(byte) for byte in v]).astype('float32') + + def floatvec2nparray(v): + return np.array([np.float32(v.at(i)) for i in range(0, v.size())]) \ + .reshape(-1, d) + + d = 256 + nt = 12800 + nb = 10000 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq) + + index_ref = faiss.IndexBinaryFlat(d) + + index_ref.add(xb) + + nlist = 256 + clus = faiss.Clustering(d, nlist) + clus_index = faiss.IndexFlatL2(d) + + xt_f = np.array([bin2float(v) for v in xt]) + clus.train(xt_f, clus_index) + + centroids = floatvec2nparray(clus.centroids) + hnsw_quantizer = faiss.IndexHNSWFlat(d, 32) + hnsw_quantizer.add(centroids) + hnsw_quantizer.is_trained = True + wrapped_quantizer = faiss.IndexBinaryFromFloat(hnsw_quantizer) + + assert nlist == hnsw_quantizer.ntotal + assert nlist == wrapped_quantizer.ntotal + assert wrapped_quantizer.is_trained + + index = faiss.IndexBinaryIVF(wrapped_quantizer, d, + hnsw_quantizer.ntotal) + index.nprobe = 128 + + assert index.is_trained + + index.add(xb) + + D_ref, I_ref = index_ref.search(xq, 10) + D, I = index.search(xq, 10) + + recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \ + / float(D_ref.shape[0]) + + assert recall > 0.77, "recall = %g" % recall + + +class TestOverrideKmeansQuantizer(unittest.TestCase): + + def test_override(self): + d = 256 + nt = 3500 + nb = 10000 + nq = 500 + (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq) + + def train_and_get_centroids(override_kmeans_index): + index = faiss.index_binary_factory(d, "BIVF10") + index.verbose = True + + if override_kmeans_index is not None: + index.clustering_index = override_kmeans_index + + index.train(xt) + + centroids = faiss.downcast_IndexBinary(index.quantizer).xb + return faiss.vector_to_array(centroids).reshape(-1, d // 8) + + centroids_ref = train_and_get_centroids(None) + + # should do the exact same thing + centroids_new = train_and_get_centroids(faiss.IndexFlatL2(d)) + + assert np.all(centroids_ref == centroids_new) + + # will do less accurate assignment... Sanity check that the + # index is indeed used by kmeans + centroids_new = train_and_get_centroids(faiss.IndexLSH(d, 16)) + + assert not np.all(centroids_ref == centroids_new) diff --git a/core/src/index/thirdparty/faiss/tests/test_index_composite.py b/core/src/index/thirdparty/faiss/tests/test_index_composite.py new file mode 100644 index 0000000000..55230f9d9b --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_index_composite.py @@ -0,0 +1,571 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" more elaborate that test_index.py """ +from __future__ import absolute_import, division, print_function + +import numpy as np +import unittest +import faiss +import os +import shutil +import tempfile + +from common import get_dataset_2 + +class TestRemove(unittest.TestCase): + + def do_merge_then_remove(self, ondisk): + d = 10 + nb = 1000 + nq = 200 + nt = 200 + + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + quantizer = faiss.IndexFlatL2(d) + + index1 = faiss.IndexIVFFlat(quantizer, d, 20) + index1.train(xt) + + filename = None + if ondisk: + filename = tempfile.mkstemp()[1] + invlists = faiss.OnDiskInvertedLists( + index1.nlist, index1.code_size, + filename) + index1.replace_invlists(invlists) + + index1.add(xb[:int(nb / 2)]) + + index2 = faiss.IndexIVFFlat(quantizer, d, 20) + assert index2.is_trained + index2.add(xb[int(nb / 2):]) + + Dref, Iref = index1.search(xq, 10) + index1.merge_from(index2, int(nb / 2)) + + assert index1.ntotal == nb + + index1.remove_ids(faiss.IDSelectorRange(int(nb / 2), nb)) + + assert index1.ntotal == int(nb / 2) + Dnew, Inew = index1.search(xq, 10) + + assert np.all(Dnew == Dref) + assert np.all(Inew == Iref) + + if filename is not None: + os.unlink(filename) + + def test_remove_regular(self): + self.do_merge_then_remove(False) + + def test_remove_ondisk(self): + self.do_merge_then_remove(True) + + def test_remove(self): + # only tests the python interface + + index = faiss.IndexFlat(5) + xb = np.zeros((10, 5), dtype='float32') + xb[:, 0] = np.arange(10) + 1000 + index.add(xb) + index.remove_ids(np.arange(5) * 2) + xb2 = faiss.vector_float_to_array(index.xb).reshape(5, 5) + assert np.all(xb2[:, 0] == xb[np.arange(5) * 2 + 1, 0]) + + def test_remove_id_map(self): + sub_index = faiss.IndexFlat(5) + xb = np.zeros((10, 5), dtype='float32') + xb[:, 0] = np.arange(10) + 1000 + index = faiss.IndexIDMap2(sub_index) + index.add_with_ids(xb, np.arange(10) + 100) + assert index.reconstruct(104)[0] == 1004 + index.remove_ids(np.array([103])) + assert index.reconstruct(104)[0] == 1004 + try: + index.reconstruct(103) + except RuntimeError: + pass + else: + assert False, 'should have raised an exception' + + def test_remove_id_map_2(self): + # from https://github.com/facebookresearch/faiss/issues/255 + rs = np.random.RandomState(1234) + X = rs.randn(10, 10).astype(np.float32) + idx = np.array([0, 10, 20, 30, 40, 5, 15, 25, 35, 45], np.int64) + remove_set = np.array([10, 30], dtype=np.int64) + index = faiss.index_factory(10, 'IDMap,Flat') + index.add_with_ids(X[:5, :], idx[:5]) + index.remove_ids(remove_set) + index.add_with_ids(X[5:, :], idx[5:]) + + print (index.search(X, 1)) + + for i in range(10): + _, searchres = index.search(X[i:i + 1, :], 1) + if idx[i] in remove_set: + assert searchres[0] != idx[i] + else: + assert searchres[0] == idx[i] + + def test_remove_id_map_binary(self): + sub_index = faiss.IndexBinaryFlat(40) + xb = np.zeros((10, 5), dtype='uint8') + xb[:, 0] = np.arange(10) + 100 + index = faiss.IndexBinaryIDMap2(sub_index) + index.add_with_ids(xb, np.arange(10) + 1000) + assert index.reconstruct(1004)[0] == 104 + index.remove_ids(np.array([1003])) + assert index.reconstruct(1004)[0] == 104 + try: + index.reconstruct(1003) + except RuntimeError: + pass + else: + assert False, 'should have raised an exception' + + # while we are there, let's test I/O as well... + _, tmpnam = tempfile.mkstemp() + try: + faiss.write_index_binary(index, tmpnam) + index = faiss.read_index_binary(tmpnam) + finally: + os.remove(tmpnam) + + assert index.reconstruct(1004)[0] == 104 + try: + index.reconstruct(1003) + except RuntimeError: + pass + else: + assert False, 'should have raised an exception' + + + +class TestRangeSearch(unittest.TestCase): + + def test_range_search_id_map(self): + sub_index = faiss.IndexFlat(5, 1) # L2 search instead of inner product + xb = np.zeros((10, 5), dtype='float32') + xb[:, 0] = np.arange(10) + 1000 + index = faiss.IndexIDMap2(sub_index) + index.add_with_ids(xb, np.arange(10) + 100) + dist = float(np.linalg.norm(xb[3] - xb[0])) * 0.99 + res_subindex = sub_index.range_search(xb[[0], :], dist) + res_index = index.range_search(xb[[0], :], dist) + assert len(res_subindex[2]) == 2 + np.testing.assert_array_equal(res_subindex[2] + 100, res_index[2]) + + +class TestUpdate(unittest.TestCase): + + def test_update(self): + d = 64 + nb = 1000 + nt = 1500 + nq = 100 + np.random.seed(123) + xb = np.random.random(size=(nb, d)).astype('float32') + xt = np.random.random(size=(nt, d)).astype('float32') + xq = np.random.random(size=(nq, d)).astype('float32') + + index = faiss.index_factory(d, "IVF64,Flat") + index.train(xt) + index.add(xb) + index.nprobe = 32 + D, I = index.search(xq, 5) + + index.make_direct_map() + recons_before = np.vstack([index.reconstruct(i) for i in range(nb)]) + + # revert order of the 200 first vectors + nu = 200 + index.update_vectors(np.arange(nu), xb[nu - 1::-1].copy()) + + recons_after = np.vstack([index.reconstruct(i) for i in range(nb)]) + + # make sure reconstructions remain the same + diff_recons = recons_before[:nu] - recons_after[nu - 1::-1] + assert np.abs(diff_recons).max() == 0 + + D2, I2 = index.search(xq, 5) + + assert np.all(D == D2) + + gt_map = np.arange(nb) + gt_map[:nu] = np.arange(nu, 0, -1) - 1 + eqs = I.ravel() == gt_map[I2.ravel()] + + assert np.all(eqs) + + +class TestPCAWhite(unittest.TestCase): + + def test_white(self): + + # generate data + d = 4 + nt = 1000 + nb = 200 + nq = 200 + + # normal distribition + x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d) + + index = faiss.index_factory(d, 'Flat') + + xt = x[:nt] + xb = x[nt:-nq] + xq = x[-nq:] + + # NN search on normal distribution + index.add(xb) + Do, Io = index.search(xq, 5) + + # make distribution very skewed + x *= [10, 4, 1, 0.5] + rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d)) + x = np.dot(x, rr).astype('float32') + + xt = x[:nt] + xb = x[nt:-nq] + xq = x[-nq:] + + # L2 search on skewed distribution + index = faiss.index_factory(d, 'Flat') + + index.add(xb) + Dl2, Il2 = index.search(xq, 5) + + # whiten + L2 search on L2 distribution + index = faiss.index_factory(d, 'PCAW%d,Flat' % d) + + index.train(xt) + index.add(xb) + Dw, Iw = index.search(xq, 5) + + # make sure correlation of whitened results with original + # results is much better than simple L2 distances + # should be 961 vs. 264 + assert (faiss.eval_intersection(Io, Iw) > + 2 * faiss.eval_intersection(Io, Il2)) + + +class TestTransformChain(unittest.TestCase): + + def test_chain(self): + + # generate data + d = 4 + nt = 1000 + nb = 200 + nq = 200 + + # normal distribition + x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d) + + # make distribution very skewed + x *= [10, 4, 1, 0.5] + rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d)) + x = np.dot(x, rr).astype('float32') + + xt = x[:nt] + xb = x[nt:-nq] + xq = x[-nq:] + + index = faiss.index_factory(d, "L2norm,PCA2,L2norm,Flat") + + assert index.chain.size() == 3 + l2_1 = faiss.downcast_VectorTransform(index.chain.at(0)) + assert l2_1.norm == 2 + pca = faiss.downcast_VectorTransform(index.chain.at(1)) + assert not pca.is_trained + index.train(xt) + assert pca.is_trained + + index.add(xb) + D, I = index.search(xq, 5) + + # do the computation manually and check if we get the same result + def manual_trans(x): + x = x.copy() + faiss.normalize_L2(x) + x = pca.apply_py(x) + faiss.normalize_L2(x) + return x + + index2 = faiss.IndexFlatL2(2) + index2.add(manual_trans(xb)) + D2, I2 = index2.search(manual_trans(xq), 5) + + assert np.all(I == I2) + +class TestRareIO(unittest.TestCase): + + def compare_results(self, index1, index2, xq): + + Dref, Iref = index1.search(xq, 5) + Dnew, Inew = index2.search(xq, 5) + + assert np.all(Dref == Dnew) + assert np.all(Iref == Inew) + + def do_mmappedIO(self, sparse, in_pretransform=False): + d = 10 + nb = 1000 + nq = 200 + nt = 200 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + quantizer = faiss.IndexFlatL2(d) + index1 = faiss.IndexIVFFlat(quantizer, d, 20) + if sparse: + # makes the inverted lists sparse because all elements get + # assigned to the same invlist + xt += (np.ones(10) * 1000).astype('float32') + + if in_pretransform: + # make sure it still works when wrapped in an IndexPreTransform + index1 = faiss.IndexPreTransform(index1) + + index1.train(xt) + index1.add(xb) + + _, fname = tempfile.mkstemp() + try: + + faiss.write_index(index1, fname) + + index2 = faiss.read_index(fname) + self.compare_results(index1, index2, xq) + + index3 = faiss.read_index(fname, faiss.IO_FLAG_MMAP) + self.compare_results(index1, index3, xq) + finally: + if os.path.exists(fname): + os.unlink(fname) + + def test_mmappedIO_sparse(self): + self.do_mmappedIO(True) + + def test_mmappedIO_full(self): + self.do_mmappedIO(False) + + def test_mmappedIO_pretrans(self): + self.do_mmappedIO(False, True) + + +class TestIVFFlatDedup(unittest.TestCase): + + def normalize_res(self, D, I): + dmax = D[-1] + res = [(d, i) for d, i in zip(D, I) if d < dmax] + res.sort() + return res + + def test_dedup(self): + d = 10 + nb = 1000 + nq = 200 + nt = 500 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + # introduce duplicates + xb[500:900:2] = xb[501:901:2] + xb[901::4] = xb[900::4] + xb[902::4] = xb[900::4] + xb[903::4] = xb[900::4] + + # also in the train set + xt[201::2] = xt[200::2] + + quantizer = faiss.IndexFlatL2(d) + index_new = faiss.IndexIVFFlatDedup(quantizer, d, 20) + + index_new.verbose = True + # should display + # IndexIVFFlatDedup::train: train on 350 points after dedup (was 500 points) + index_new.train(xt) + + index_ref = faiss.IndexIVFFlat(quantizer, d, 20) + assert index_ref.is_trained + + index_ref.nprobe = 5 + index_ref.add(xb) + index_new.nprobe = 5 + index_new.add(xb) + + Dref, Iref = index_ref.search(xq, 20) + Dnew, Inew = index_new.search(xq, 20) + + for i in range(nq): + ref = self.normalize_res(Dref[i], Iref[i]) + new = self.normalize_res(Dnew[i], Inew[i]) + assert ref == new + + # test I/O + _, tmpfile = tempfile.mkstemp() + try: + faiss.write_index(index_new, tmpfile) + index_st = faiss.read_index(tmpfile) + finally: + if os.path.exists(tmpfile): + os.unlink(tmpfile) + Dst, Ist = index_st.search(xq, 20) + + for i in range(nq): + new = self.normalize_res(Dnew[i], Inew[i]) + st = self.normalize_res(Dst[i], Ist[i]) + assert st == new + + # test remove + toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950))) + index_ref.remove_ids(toremove) + index_new.remove_ids(toremove) + + Dref, Iref = index_ref.search(xq, 20) + Dnew, Inew = index_new.search(xq, 20) + + for i in range(nq): + ref = self.normalize_res(Dref[i], Iref[i]) + new = self.normalize_res(Dnew[i], Inew[i]) + assert ref == new + + +class TestSerialize(unittest.TestCase): + + def test_serialize_to_vector(self): + d = 10 + nb = 1000 + nq = 200 + nt = 500 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + + Dref, Iref = index.search(xq, 5) + + writer = faiss.VectorIOWriter() + faiss.write_index(index, writer) + + ar_data = faiss.vector_to_array(writer.data) + + # direct transfer of vector + reader = faiss.VectorIOReader() + reader.data.swap(writer.data) + + index2 = faiss.read_index(reader) + + Dnew, Inew = index2.search(xq, 5) + assert np.all(Dnew == Dref) and np.all(Inew == Iref) + + # from intermediate numpy array + reader = faiss.VectorIOReader() + faiss.copy_array_to_vector(ar_data, reader.data) + + index3 = faiss.read_index(reader) + + Dnew, Inew = index3.search(xq, 5) + assert np.all(Dnew == Dref) and np.all(Inew == Iref) + + +class TestRenameOndisk(unittest.TestCase): + + def test_rename(self): + d = 10 + nb = 500 + nq = 100 + nt = 100 + + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + quantizer = faiss.IndexFlatL2(d) + + index1 = faiss.IndexIVFFlat(quantizer, d, 20) + index1.train(xt) + + dirname = tempfile.mkdtemp() + + try: + + # make an index with ondisk invlists + invlists = faiss.OnDiskInvertedLists( + index1.nlist, index1.code_size, + dirname + '/aa.ondisk') + index1.replace_invlists(invlists) + index1.add(xb) + D1, I1 = index1.search(xq, 10) + faiss.write_index(index1, dirname + '/aa.ivf') + + # move the index elsewhere + os.mkdir(dirname + '/1') + for fname in 'aa.ondisk', 'aa.ivf': + os.rename(dirname + '/' + fname, + dirname + '/1/' + fname) + + # try to read it: fails! + try: + index2 = faiss.read_index(dirname + '/1/aa.ivf') + except RuntimeError: + pass # normal + else: + assert False + + # read it with magic flag + index2 = faiss.read_index(dirname + '/1/aa.ivf', + faiss.IO_FLAG_ONDISK_SAME_DIR) + D2, I2 = index2.search(xq, 10) + assert np.all(I1 == I2) + + finally: + shutil.rmtree(dirname) + + +class TestInvlistMeta(unittest.TestCase): + + def test_slice_vstack(self): + d = 10 + nb = 1000 + nq = 100 + nt = 200 + + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 30) + + index.train(xt) + index.add(xb) + Dref, Iref = index.search(xq, 10) + + # faiss.wait() + + il0 = index.invlists + ils = [] + ilv = faiss.InvertedListsPtrVector() + for sl in 0, 1, 2: + il = faiss.SliceInvertedLists(il0, sl * 10, sl * 10 + 10) + ils.append(il) + ilv.push_back(il) + + il2 = faiss.VStackInvertedLists(ilv.size(), ilv.data()) + + index2 = faiss.IndexIVFFlat(quantizer, d, 30) + index2.replace_invlists(il2) + index2.ntotal = index.ntotal + + D, I = index2.search(xq, 10) + assert np.all(D == Dref) + assert np.all(I == Iref) + + + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_io.py b/core/src/index/thirdparty/faiss/tests/test_io.py new file mode 100644 index 0000000000..7e3d6edf59 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_io.py @@ -0,0 +1,220 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import numpy as np +import unittest +import faiss +import tempfile +import os +import io +import sys +import warnings +from multiprocessing.dummy import Pool as ThreadPool + +from common import get_dataset, get_dataset_2 + + +class TestIOVariants(unittest.TestCase): + + def test_io_error(self): + d, n = 32, 1000 + x = np.random.uniform(size=(n, d)).astype('float32') + index = faiss.IndexFlatL2(d) + index.add(x) + _, fname = tempfile.mkstemp() + try: + faiss.write_index(index, fname) + + # should be fine + faiss.read_index(fname) + + # now damage file + data = open(fname, 'rb').read() + data = data[:int(len(data) / 2)] + open(fname, 'wb').write(data) + + # should make a nice readable exception that mentions the filename + try: + faiss.read_index(fname) + except RuntimeError as e: + if fname not in str(e): + raise + else: + raise + + finally: + if os.path.exists(fname): + os.unlink(fname) + + +class TestCallbacks(unittest.TestCase): + + def do_write_callback(self, bsz): + d, n = 32, 1000 + x = np.random.uniform(size=(n, d)).astype('float32') + index = faiss.IndexFlatL2(d) + index.add(x) + + f = io.BytesIO() + # test with small block size + writer = faiss.PyCallbackIOWriter(f.write, 1234) + + if bsz > 0: + writer = faiss.BufferedIOWriter(writer, bsz) + + faiss.write_index(index, writer) + del writer # make sure all writes committed + + if sys.version_info[0] < 3: + buf = f.getvalue() + else: + buf = f.getbuffer() + + index2 = faiss.deserialize_index(np.frombuffer(buf, dtype='uint8')) + + self.assertEqual(index.d, index2.d) + self.assertTrue(np.all( + faiss.vector_to_array(index.xb) == faiss.vector_to_array(index2.xb) + )) + + # This is not a callable function: shoudl raise an exception + writer = faiss.PyCallbackIOWriter("blabla") + self.assertRaises( + Exception, + faiss.write_index, index, writer + ) + + def test_buf_read(self): + x = np.random.uniform(size=20) + + _, fname = tempfile.mkstemp() + try: + x.tofile(fname) + + f = open(fname, 'rb') + reader = faiss.PyCallbackIOReader(f.read, 1234) + + bsz = 123 + reader = faiss.BufferedIOReader(reader, bsz) + + y = np.zeros_like(x) + print('nbytes=', y.nbytes) + reader(faiss.swig_ptr(y), y.nbytes, 1) + + np.testing.assert_array_equal(x, y) + finally: + if os.path.exists(fname): + os.unlink(fname) + + def do_read_callback(self, bsz): + d, n = 32, 1000 + x = np.random.uniform(size=(n, d)).astype('float32') + index = faiss.IndexFlatL2(d) + index.add(x) + + _, fname = tempfile.mkstemp() + try: + faiss.write_index(index, fname) + + f = open(fname, 'rb') + + reader = faiss.PyCallbackIOReader(f.read, 1234) + + if bsz > 0: + reader = faiss.BufferedIOReader(reader, bsz) + + index2 = faiss.read_index(reader) + + self.assertEqual(index.d, index2.d) + np.testing.assert_array_equal( + faiss.vector_to_array(index.xb), + faiss.vector_to_array(index2.xb) + ) + + # This is not a callable function: should raise an exception + reader = faiss.PyCallbackIOReader("blabla") + self.assertRaises( + Exception, + faiss.read_index, reader + ) + finally: + if os.path.exists(fname): + os.unlink(fname) + + def test_write_callback(self): + self.do_write_callback(0) + + def test_write_buffer(self): + self.do_write_callback(123) + self.do_write_callback(2345) + + def test_read_callback(self): + self.do_read_callback(0) + + def test_read_callback_buffered(self): + self.do_read_callback(123) + self.do_read_callback(12345) + + def test_read_buffer(self): + d, n = 32, 1000 + x = np.random.uniform(size=(n, d)).astype('float32') + index = faiss.IndexFlatL2(d) + index.add(x) + + _, fname = tempfile.mkstemp() + try: + faiss.write_index(index, fname) + + reader = faiss.BufferedIOReader( + faiss.FileIOReader(fname), 1234) + + index2 = faiss.read_index(reader) + + self.assertEqual(index.d, index2.d) + np.testing.assert_array_equal( + faiss.vector_to_array(index.xb), + faiss.vector_to_array(index2.xb) + ) + + finally: + if os.path.exists(fname): + os.unlink(fname) + + + def test_transfer_pipe(self): + """ transfer an index through a Unix pipe """ + + d, n = 32, 1000 + x = np.random.uniform(size=(n, d)).astype('float32') + index = faiss.IndexFlatL2(d) + index.add(x) + Dref, Iref = index.search(x, 10) + + rf, wf = os.pipe() + + # start thread that will decompress the index + + def index_from_pipe(): + reader = faiss.PyCallbackIOReader(lambda size: os.read(rf, size)) + return faiss.read_index(reader) + + fut = ThreadPool(1).apply_async(index_from_pipe, ()) + + # write to pipe + writer = faiss.PyCallbackIOWriter(lambda b: os.write(wf, b)) + faiss.write_index(index, writer) + + index2 = fut.get() + + # closing is not really useful but it does not hurt + os.close(wf) + os.close(rf) + + Dnew, Inew = index2.search(x, 10) + + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_equal(Dref, Dnew) diff --git a/core/src/index/thirdparty/faiss/tests/test_ivflib.py b/core/src/index/thirdparty/faiss/tests/test_ivflib.py new file mode 100644 index 0000000000..0166013c08 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_ivflib.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import unittest +import faiss +import numpy as np + +class TestIVFlib(unittest.TestCase): + + def test_methods_exported(self): + methods = ['check_compatible_for_merge', 'extract_index_ivf', + 'merge_into', 'search_centroid', + 'search_and_return_centroids', 'get_invlist_range', + 'set_invlist_range', 'search_with_parameters'] + + for method in methods: + assert callable(getattr(faiss, method, None)) + + +def search_single_scan(index, xq, k, bs=128): + """performs a search so that the inverted lists are accessed + sequentially by blocks of size bs""" + + # handle pretransform + if isinstance(index, faiss.IndexPreTransform): + xq = index.apply_py(xq) + index = faiss.downcast_index(index.index) + + # coarse assignment + coarse_dis, assign = index.quantizer.search(xq, index.nprobe) + nlist = index.nlist + assign_buckets = assign // bs + nq = len(xq) + + rh = faiss.ResultHeap(nq, k) + index.parallel_mode |= index.PARALLEL_MODE_NO_HEAP_INIT + + for l0 in range(0, nlist, bs): + bucket_no = l0 // bs + skip_rows, skip_cols = np.where(assign_buckets != bucket_no) + sub_assign = assign.copy() + sub_assign[skip_rows, skip_cols] = -1 + + index.search_preassigned( + nq, faiss.swig_ptr(xq), k, + faiss.swig_ptr(sub_assign), faiss.swig_ptr(coarse_dis), + faiss.swig_ptr(rh.D), faiss.swig_ptr(rh.I), + False, None + ) + + rh.finalize() + + return rh.D, rh.I + + +class TestSequentialScan(unittest.TestCase): + + def test_sequential_scan(self): + d = 20 + index = faiss.index_factory(d, 'IVF100,SQ8') + + rs = np.random.RandomState(123) + xt = rs.rand(5000, d).astype('float32') + xb = rs.rand(10000, d).astype('float32') + index.train(xt) + index.add(xb) + k = 15 + xq = rs.rand(200, d).astype('float32') + + ref_D, ref_I = index.search(xq, k) + D, I = search_single_scan(index, xq, k, bs=10) + + assert np.all(D == ref_D) + assert np.all(I == ref_I) diff --git a/core/src/index/thirdparty/faiss/tests/test_ivfpq_codec.cpp b/core/src/index/thirdparty/faiss/tests/test_ivfpq_codec.cpp new file mode 100644 index 0000000000..8d18ac0ad9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_ivfpq_codec.cpp @@ -0,0 +1,67 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include +#include +#include + + +namespace { + +// dimension of the vectors to index +int d = 64; + +// size of the database we plan to index +size_t nb = 8000; + + +double eval_codec_error (long ncentroids, long m, const std::vector &v) +{ + faiss::IndexFlatL2 coarse_quantizer (d); + faiss::IndexIVFPQ index (&coarse_quantizer, d, + ncentroids, m, 8); + index.pq.cp.niter = 10; // speed up train + index.train (nb, v.data()); + + // encode and decode to compute reconstruction error + + std::vector keys (nb); + std::vector codes (nb * m); + index.encode_multiple (nb, keys.data(), v.data(), codes.data(), true); + + std::vector v2 (nb * d); + index.decode_multiple (nb, keys.data(), codes.data(), v2.data()); + + return faiss::fvec_L2sqr (v.data(), v2.data(), nb * d); +} + +} // namespace + + +TEST(IVFPQ, codec) { + + std::vector database (nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + + double err0 = eval_codec_error(16, 8, database); + + // should be more accurate as there are more coarse centroids + double err1 = eval_codec_error(128, 8, database); + EXPECT_GT(err0, err1); + + // should be more accurate as there are more PQ codes + double err2 = eval_codec_error(16, 16, database); + EXPECT_GT(err0, err2); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_ivfpq_indexing.cpp b/core/src/index/thirdparty/faiss/tests/test_ivfpq_indexing.cpp new file mode 100644 index 0000000000..9f4bbcd2ca --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_ivfpq_indexing.cpp @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +#include + +#include +#include +#include + +TEST(IVFPQ, accuracy) { + + // dimension of the vectors to index + int d = 64; + + // size of the database we plan to index + size_t nb = 1000; + + // make a set of nt training vectors in the unit cube + // (could be the database) + size_t nt = 1500; + + // make the index object and train it + faiss::IndexFlatL2 coarse_quantizer (d); + + // a reasonable number of cetroids to index nb vectors + int ncentroids = 25; + + faiss::IndexIVFPQ index (&coarse_quantizer, d, + ncentroids, 16, 8); + + // index that gives the ground-truth + faiss::IndexFlatL2 index_gt (d); + + srand48 (35); + + { // training + + std::vector trainvecs (nt * d); + for (size_t i = 0; i < nt * d; i++) { + trainvecs[i] = drand48(); + } + index.verbose = true; + index.train (nt, trainvecs.data()); + } + + { // populating the database + + std::vector database (nb * d); + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + + index.add (nb, database.data()); + index_gt.add (nb, database.data()); + } + + int nq = 200; + int n_ok; + + { // searching the database + + std::vector queries (nq * d); + for (size_t i = 0; i < nq * d; i++) { + queries[i] = drand48(); + } + + std::vector gt_nns (nq); + std::vector gt_dis (nq); + + index_gt.search (nq, queries.data(), 1, + gt_dis.data(), gt_nns.data()); + + index.nprobe = 5; + int k = 5; + std::vector nns (k * nq); + std::vector dis (k * nq); + + index.search (nq, queries.data(), k, dis.data(), nns.data()); + + n_ok = 0; + for (int q = 0; q < nq; q++) { + + for (int i = 0; i < k; i++) + if (nns[q * k + i] == gt_nns[q]) + n_ok++; + } + EXPECT_GT(n_ok, nq * 0.4); + } + +} diff --git a/core/src/index/thirdparty/faiss/tests/test_lowlevel_ivf.cpp b/core/src/index/thirdparty/faiss/tests/test_lowlevel_ivf.cpp new file mode 100644 index 0000000000..7baf801b7b --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_lowlevel_ivf.cpp @@ -0,0 +1,566 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace faiss; + +namespace { + +typedef Index::idx_t idx_t; + + +// dimension of the vectors to index +int d = 32; + +// nb of training vectors +size_t nt = 5000; + +// size of the database points per window step +size_t nb = 1000; + +// nb of queries +size_t nq = 200; + +int k = 10; + + +std::vector make_data(size_t n) +{ + std::vector database (n * d); + for (size_t i = 0; i < n * d; i++) { + database[i] = drand48(); + } + return database; +} + +std::unique_ptr make_trained_index(const char *index_type, + MetricType metric_type) +{ + auto index = std::unique_ptr(index_factory( + d, index_type, metric_type)); + auto xt = make_data(nt); + index->train(nt, xt.data()); + ParameterSpace().set_index_parameter (index.get(), "nprobe", 4); + return index; +} + +std::vector search_index(Index *index, const float *xq) { + std::vector I(k * nq); + std::vector D(k * nq); + index->search (nq, xq, k, D.data(), I.data()); + return I; +} + + + + +/************************************************************* + * Test functions for a given index type + *************************************************************/ + + + +void test_lowlevel_access (const char *index_key, MetricType metric) { + std::unique_ptr index = make_trained_index(index_key, metric); + + auto xb = make_data (nb); + index->add(nb, xb.data()); + + /** handle the case if we have a preprocessor */ + + const IndexPreTransform *index_pt = + dynamic_cast (index.get()); + + int dt = index->d; + const float * xbt = xb.data(); + std::unique_ptr del_xbt; + + if (index_pt) { + dt = index_pt->index->d; + xbt = index_pt->apply_chain (nb, xb.data()); + if (xbt != xb.data()) { + del_xbt.reset((float*)xbt); + } + } + + IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get()); + + /** Test independent encoding + * + * Makes it possible to do additions on a custom inverted list + * implementation. From a set of vectors, computes the inverted + * list ids + the codes corresponding to each vector. + */ + + std::vector list_nos (nb); + std::vector codes (index_ivf->code_size * nb); + index_ivf->quantizer->assign(nb, xbt, list_nos.data()); + index_ivf->encode_vectors (nb, xbt, list_nos.data(), codes.data()); + + // compare with normal IVF addition + + const InvertedLists *il = index_ivf->invlists; + + for (int list_no = 0; list_no < index_ivf->nlist; list_no++) { + InvertedLists::ScopedCodes ivf_codes (il, list_no); + InvertedLists::ScopedIds ivf_ids (il, list_no); + size_t list_size = il->list_size (list_no); + for (int i = 0; i < list_size; i++) { + const uint8_t *ref_code = ivf_codes.get() + i * il->code_size; + const uint8_t *new_code = + codes.data() + ivf_ids[i] * il->code_size; + EXPECT_EQ (memcmp(ref_code, new_code, il->code_size), 0); + } + } + + /** Test independent search + * + * Manually scans through inverted lists, computing distances and + * ordering results organized in a heap. + */ + + // sample some example queries and get reference search results. + auto xq = make_data (nq); + auto ref_I = search_index (index.get(), xq.data()); + + // handle preprocessing + const float * xqt = xq.data(); + std::unique_ptr del_xqt; + + if (index_pt) { + xqt = index_pt->apply_chain (nq, xq.data()); + if (xqt != xq.data()) { + del_xqt.reset((float*)xqt); + } + } + + // quantize the queries to get the inverted list ids to visit. + int nprobe = index_ivf->nprobe; + + std::vector q_lists (nq * nprobe); + std::vector q_dis (nq * nprobe); + + index_ivf->quantizer->search (nq, xqt, nprobe, + q_dis.data(), q_lists.data()); + + // object that does the scanning and distance computations. + std::unique_ptr scanner ( + index_ivf->get_InvertedListScanner()); + + for (int i = 0; i < nq; i++) { + std::vector I (k, -1); + float default_dis = metric == METRIC_L2 ? HUGE_VAL : -HUGE_VAL; + std::vector D (k, default_dis); + + scanner->set_query (xqt + i * dt); + + for (int j = 0; j < nprobe; j++) { + int list_no = q_lists[i * nprobe + j]; + if (list_no < 0) continue; + scanner->set_list (list_no, q_dis[i * nprobe + j]); + + // here we get the inverted lists from the InvertedLists + // object but they could come from anywhere + + scanner->scan_codes ( + il->list_size (list_no), + InvertedLists::ScopedCodes(il, list_no).get(), + InvertedLists::ScopedIds(il, list_no).get(), + D.data(), I.data(), k); + + if (j == 0) { + // all results so far come from list_no, so let's check if + // the distance function works + for (int jj = 0; jj < k; jj++) { + int vno = I[jj]; + if (vno < 0) break; // heap is not full yet + + // we have the codes from the addition test + float computed_D = scanner->distance_to_code ( + codes.data() + vno * il->code_size); + + EXPECT_EQ (computed_D, D[jj]); + } + } + } + + // re-order heap + if (metric == METRIC_L2) { + maxheap_reorder (k, D.data(), I.data()); + } else { + minheap_reorder (k, D.data(), I.data()); + } + + // check that we have the same results as the reference search + for (int j = 0; j < k; j++) { + EXPECT_EQ (I[j], ref_I[i * k + j]); + } + } + + +} + +} // anonymous namespace + + + +/************************************************************* + * Test entry points + *************************************************************/ + +TEST(TestLowLevelIVF, IVFFlatL2) { + test_lowlevel_access ("IVF32,Flat", METRIC_L2); +} + +TEST(TestLowLevelIVF, PCAIVFFlatL2) { + test_lowlevel_access ("PCAR16,IVF32,Flat", METRIC_L2); +} + +TEST(TestLowLevelIVF, IVFFlatIP) { + test_lowlevel_access ("IVF32,Flat", METRIC_INNER_PRODUCT); +} + +TEST(TestLowLevelIVF, IVFSQL2) { + test_lowlevel_access ("IVF32,SQ8", METRIC_L2); +} + +TEST(TestLowLevelIVF, IVFSQIP) { + test_lowlevel_access ("IVF32,SQ8", METRIC_INNER_PRODUCT); +} + + +TEST(TestLowLevelIVF, IVFPQL2) { + test_lowlevel_access ("IVF32,PQ4np", METRIC_L2); +} + +TEST(TestLowLevelIVF, IVFPQIP) { + test_lowlevel_access ("IVF32,PQ4np", METRIC_INNER_PRODUCT); +} + + +/************************************************************* + * Same for binary (a bit simpler) + *************************************************************/ + +namespace { + +int nbit = 256; + +// here d is used the number of ints -> d=32 means 128 bits + +std::vector make_data_binary(size_t n) +{ + + std::vector database (n * nbit / 8); + for (size_t i = 0; i < n * d; i++) { + database[i] = lrand48(); + } + return database; +} + +std::unique_ptr make_trained_index_binary(const char *index_type) +{ + auto index = std::unique_ptr(index_binary_factory( + nbit, index_type)); + auto xt = make_data_binary (nt); + index->train(nt, xt.data()); + return index; +} + + +void test_lowlevel_access_binary (const char *index_key) { + std::unique_ptr index = + make_trained_index_binary (index_key); + + IndexBinaryIVF * index_ivf = dynamic_cast + (index.get()); + assert (index_ivf); + + index_ivf->nprobe = 4; + + auto xb = make_data_binary (nb); + index->add(nb, xb.data()); + + std::vector list_nos (nb); + index_ivf->quantizer->assign(nb, xb.data(), list_nos.data()); + + /* For binary there is no test for encoding because binary vectors + * are copied verbatim to the inverted lists */ + + const InvertedLists *il = index_ivf->invlists; + + /** Test independent search + * + * Manually scans through inverted lists, computing distances and + * ordering results organized in a heap. + */ + + // sample some example queries and get reference search results. + auto xq = make_data_binary (nq); + + std::vector I_ref(k * nq); + std::vector D_ref(k * nq); + index->search (nq, xq.data(), k, D_ref.data(), I_ref.data()); + + // quantize the queries to get the inverted list ids to visit. + int nprobe = index_ivf->nprobe; + + std::vector q_lists (nq * nprobe); + std::vector q_dis (nq * nprobe); + + // quantize queries + index_ivf->quantizer->search (nq, xq.data(), nprobe, + q_dis.data(), q_lists.data()); + + // object that does the scanning and distance computations. + std::unique_ptr scanner ( + index_ivf->get_InvertedListScanner()); + + for (int i = 0; i < nq; i++) { + std::vector I (k, -1); + uint32_t default_dis = 1 << 30; + std::vector D (k, default_dis); + + scanner->set_query (xq.data() + i * index_ivf->code_size); + + for (int j = 0; j < nprobe; j++) { + int list_no = q_lists[i * nprobe + j]; + if (list_no < 0) continue; + scanner->set_list (list_no, q_dis[i * nprobe + j]); + + // here we get the inverted lists from the InvertedLists + // object but they could come from anywhere + + scanner->scan_codes ( + il->list_size (list_no), + InvertedLists::ScopedCodes(il, list_no).get(), + InvertedLists::ScopedIds(il, list_no).get(), + D.data(), I.data(), k); + + if (j == 0) { + // all results so far come from list_no, so let's check if + // the distance function works + for (int jj = 0; jj < k; jj++) { + int vno = I[jj]; + if (vno < 0) break; // heap is not full yet + + // we have the codes from the addition test + float computed_D = scanner->distance_to_code ( + xb.data() + vno * il->code_size); + + EXPECT_EQ (computed_D, D[jj]); + } + } + } + + printf("new before reroder: ["); + for (int j = 0; j < k; j++) + printf("%ld,%d ", I[j], D[j]); + printf("]\n"); + + // re-order heap + heap_reorder > (k, D.data(), I.data()); + + printf("ref: ["); + for (int j = 0; j < k; j++) + printf("%ld,%d ", I_ref[j], D_ref[j]); + printf("]\nnew: ["); + for (int j = 0; j < k; j++) + printf("%ld,%d ", I[j], D[j]); + printf("]\n"); + + // check that we have the same results as the reference search + for (int j = 0; j < k; j++) { + // here the order is not guaranteed to be the same + // so we scan through ref results + // EXPECT_EQ (I[j], I_ref[i * k + j]); + EXPECT_LE (D[j], D_ref[i * k + k - 1]); + if (D[j] < D_ref[i * k + k - 1]) { + int j2 = 0; + while (j2 < k) { + if (I[j] == I_ref[i * k + j2]) break; + j2++; + } + EXPECT_LT(j2, k); // it was found + if (j2 < k) { + EXPECT_EQ(D[j], D_ref[i * k + j2]); + } + } + + } + + } + + +} + +} // anonymous namespace + + +TEST(TestLowLevelIVF, IVFBinary) { + test_lowlevel_access_binary ("BIVF32"); +} + + +namespace { + +void test_threaded_search (const char *index_key, MetricType metric) { + std::unique_ptr index = make_trained_index(index_key, metric); + + auto xb = make_data (nb); + index->add(nb, xb.data()); + + /** handle the case if we have a preprocessor */ + + const IndexPreTransform *index_pt = + dynamic_cast (index.get()); + + int dt = index->d; + const float * xbt = xb.data(); + std::unique_ptr del_xbt; + + if (index_pt) { + dt = index_pt->index->d; + xbt = index_pt->apply_chain (nb, xb.data()); + if (xbt != xb.data()) { + del_xbt.reset((float*)xbt); + } + } + + IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get()); + + /** Test independent search + * + * Manually scans through inverted lists, computing distances and + * ordering results organized in a heap. + */ + + // sample some example queries and get reference search results. + auto xq = make_data (nq); + auto ref_I = search_index (index.get(), xq.data()); + + // handle preprocessing + const float * xqt = xq.data(); + std::unique_ptr del_xqt; + + if (index_pt) { + xqt = index_pt->apply_chain (nq, xq.data()); + if (xqt != xq.data()) { + del_xqt.reset((float*)xqt); + } + } + + // quantize the queries to get the inverted list ids to visit. + int nprobe = index_ivf->nprobe; + + std::vector q_lists (nq * nprobe); + std::vector q_dis (nq * nprobe); + + index_ivf->quantizer->search (nq, xqt, nprobe, + q_dis.data(), q_lists.data()); + + // now run search in this many threads + int nproc = 3; + + + for (int i = 0; i < nq; i++) { + + // one result table per thread + std::vector I (k * nproc, -1); + float default_dis = metric == METRIC_L2 ? HUGE_VAL : -HUGE_VAL; + std::vector D (k * nproc, default_dis); + + auto search_function = [index_ivf, &I, &D, dt, i, nproc, + xqt, nprobe, &q_dis, &q_lists] + (int rank) { + const InvertedLists *il = index_ivf->invlists; + + // object that does the scanning and distance computations. + std::unique_ptr scanner ( + index_ivf->get_InvertedListScanner()); + + idx_t *local_I = I.data() + rank * k; + float *local_D = D.data() + rank * k; + + scanner->set_query (xqt + i * dt); + + for (int j = rank; j < nprobe; j += nproc) { + int list_no = q_lists[i * nprobe + j]; + if (list_no < 0) continue; + scanner->set_list (list_no, q_dis[i * nprobe + j]); + + scanner->scan_codes ( + il->list_size (list_no), + InvertedLists::ScopedCodes(il, list_no).get(), + InvertedLists::ScopedIds(il, list_no).get(), + local_D, local_I, k); + } + }; + + // start the threads. Threads are numbered rank=0..nproc-1 (a la MPI) + // thread rank takes care of inverted lists + // rank, rank+nproc, rank+2*nproc,... + std::vector threads; + for (int rank = 0; rank < nproc; rank++) { + threads.emplace_back(search_function, rank); + } + + // join threads, merge heaps + for (int rank = 0; rank < nproc; rank++) { + threads[rank].join(); + if (rank == 0) continue; // nothing to merge + // merge into first result + if (metric == METRIC_L2) { + maxheap_addn (k, D.data(), I.data(), + D.data() + rank * k, + I.data() + rank * k, k); + } else { + minheap_addn (k, D.data(), I.data(), + D.data() + rank * k, + I.data() + rank * k, k); + } + } + + // re-order heap + if (metric == METRIC_L2) { + maxheap_reorder (k, D.data(), I.data()); + } else { + minheap_reorder (k, D.data(), I.data()); + } + + // check that we have the same results as the reference search + for (int j = 0; j < k; j++) { + EXPECT_EQ (I[j], ref_I[i * k + j]); + } + } + + +} + +} // anonymous namepace + + +TEST(TestLowLevelIVF, ThreadedSearch) { + test_threaded_search ("IVF32,Flat", METRIC_L2); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_merge.cpp b/core/src/index/thirdparty/faiss/tests/test_merge.cpp new file mode 100644 index 0000000000..47af106149 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_merge.cpp @@ -0,0 +1,257 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace { + + +struct Tempfilename { + + static pthread_mutex_t mutex; + + std::string filename; + + Tempfilename (const char *prefix = nullptr) { + pthread_mutex_lock (&mutex); + char *cfname = tempnam (nullptr, prefix); + filename = cfname; + free(cfname); + pthread_mutex_unlock (&mutex); + } + + ~Tempfilename () { + if (access (filename.c_str(), F_OK)) { + unlink (filename.c_str()); + } + } + + const char *c_str() { + return filename.c_str(); + } + +}; + +pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER; + + +typedef faiss::Index::idx_t idx_t; + +// parameters to use for the test +int d = 64; +size_t nb = 1000; +size_t nq = 100; +int nindex = 4; +int k = 10; +int nlist = 40; + +struct CommonData { + + std::vector database; + std::vector queries; + std::vector ids; + faiss::IndexFlatL2 quantizer; + + CommonData(): database (nb * d), queries (nq * d), ids(nb), quantizer (d) { + + for (size_t i = 0; i < nb * d; i++) { + database[i] = drand48(); + } + for (size_t i = 0; i < nq * d; i++) { + queries[i] = drand48(); + } + for (int i = 0; i < nb; i++) { + ids[i] = 123 + 456 * i; + } + { // just to train the quantizer + faiss::IndexIVFFlat iflat (&quantizer, d, nlist); + iflat.train(nb, database.data()); + } + } +}; + +CommonData cd; + +/// perform a search on shards, then merge and search again and +/// compare results. +int compare_merged (faiss::IndexShards *index_shards, bool shift_ids, + bool standard_merge = true) +{ + + std::vector refI(k * nq); + std::vector refD(k * nq); + + index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data()); + Tempfilename filename; + + std::vector newI(k * nq); + std::vector newD(k * nq); + + if (standard_merge) { + + for (int i = 1; i < nindex; i++) { + faiss::ivflib::merge_into( + index_shards->at(0), index_shards->at(i), + shift_ids); + } + + index_shards->sync_with_shard_indexes(); + } else { + std::vector lists; + faiss::IndexIVF *index0 = nullptr; + size_t ntotal = 0; + for (int i = 0; i < nindex; i++) { + auto index_ivf = dynamic_cast(index_shards->at(i)); + assert (index_ivf); + if (i == 0) { + index0 = index_ivf; + } + lists.push_back (index_ivf->invlists); + ntotal += index_ivf->ntotal; + } + + auto il = new faiss::OnDiskInvertedLists( + index0->nlist, index0->code_size, + filename.c_str()); + + il->merge_from(lists.data(), lists.size()); + + index0->replace_invlists(il, true); + index0->ntotal = ntotal; + } + // search only on first index + index_shards->at(0)->search(nq, cd.queries.data(), + k, newD.data(), newI.data()); + + size_t ndiff = 0; + for (size_t i = 0; i < k * nq; i++) { + if (refI[i] != newI[i]) { + ndiff ++; + } + } + return ndiff; +} + +} // namespace + + +// test on IVFFlat with implicit numbering +TEST(MERGE, merge_flat_no_ids) { + faiss::IndexShards index_shards(d); + index_shards.own_fields = true; + for (int i = 0; i < nindex; i++) { + index_shards.add_shard ( + new faiss::IndexIVFFlat (&cd.quantizer, d, nlist)); + } + EXPECT_TRUE(index_shards.is_trained); + index_shards.add(nb, cd.database.data()); + size_t prev_ntotal = index_shards.ntotal; + int ndiff = compare_merged(&index_shards, true); + EXPECT_EQ (prev_ntotal, index_shards.ntotal); + EXPECT_EQ(0, ndiff); +} + + +// test on IVFFlat, explicit ids +TEST(MERGE, merge_flat) { + faiss::IndexShards index_shards(d, false, false); + index_shards.own_fields = true; + + for (int i = 0; i < nindex; i++) { + index_shards.add_shard ( + new faiss::IndexIVFFlat (&cd.quantizer, d, nlist)); + } + + EXPECT_TRUE(index_shards.is_trained); + index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data()); + int ndiff = compare_merged(&index_shards, false); + EXPECT_GE(0, ndiff); +} + +// test on IVFFlat and a VectorTransform +TEST(MERGE, merge_flat_vt) { + faiss::IndexShards index_shards(d, false, false); + index_shards.own_fields = true; + + // here we have to retrain because of the vectorTransform + faiss::RandomRotationMatrix rot(d, d); + rot.init(1234); + faiss::IndexFlatL2 quantizer (d); + + { // just to train the quantizer + faiss::IndexIVFFlat iflat (&quantizer, d, nlist); + faiss::IndexPreTransform ipt (&rot, &iflat); + ipt.train(nb, cd.database.data()); + } + + for (int i = 0; i < nindex; i++) { + faiss::IndexPreTransform * ipt = new faiss::IndexPreTransform ( + new faiss::RandomRotationMatrix (rot), + new faiss::IndexIVFFlat (&quantizer, d, nlist) + ); + ipt->own_fields = true; + index_shards.add_shard (ipt); + } + EXPECT_TRUE(index_shards.is_trained); + index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data()); + size_t prev_ntotal = index_shards.ntotal; + int ndiff = compare_merged(&index_shards, false); + EXPECT_EQ (prev_ntotal, index_shards.ntotal); + EXPECT_GE(0, ndiff); +} + + +// put the merged invfile on disk +TEST(MERGE, merge_flat_ondisk) { + faiss::IndexShards index_shards(d, false, false); + index_shards.own_fields = true; + Tempfilename filename; + + for (int i = 0; i < nindex; i++) { + auto ivf = new faiss::IndexIVFFlat (&cd.quantizer, d, nlist); + if (i == 0) { + auto il = new faiss::OnDiskInvertedLists ( + ivf->nlist, ivf->code_size, + filename.c_str()); + ivf->replace_invlists(il, true); + } + index_shards.add_shard (ivf); + } + + EXPECT_TRUE(index_shards.is_trained); + index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data()); + int ndiff = compare_merged(&index_shards, false); + + EXPECT_EQ(ndiff, 0); +} + +// now use ondisk specific merge +TEST(MERGE, merge_flat_ondisk_2) { + faiss::IndexShards index_shards(d, false, false); + index_shards.own_fields = true; + + for (int i = 0; i < nindex; i++) { + index_shards.add_shard ( + new faiss::IndexIVFFlat (&cd.quantizer, d, nlist)); + } + EXPECT_TRUE(index_shards.is_trained); + index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data()); + int ndiff = compare_merged(&index_shards, false, false); + EXPECT_GE(0, ndiff); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_meta_index.py b/core/src/index/thirdparty/faiss/tests/test_meta_index.py new file mode 100644 index 0000000000..137efc2aeb --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_meta_index.py @@ -0,0 +1,264 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +# translation of test_meta_index.lua + +import numpy as np +import faiss +import unittest + +from common import Randu10k + +ru = Randu10k() + +xb = ru.xb +xt = ru.xt +xq = ru.xq +nb, d = xb.shape +nq, d = xq.shape + + +class IDRemap(unittest.TestCase): + + def test_id_remap_idmap(self): + # reference: index without remapping + + index = faiss.IndexPQ(d, 8, 8) + k = 10 + index.train(xt) + index.add(xb) + _Dref, Iref = index.search(xq, k) + + # try a remapping + ids = np.arange(nb)[::-1].copy() + + sub_index = faiss.IndexPQ(d, 8, 8) + index2 = faiss.IndexIDMap(sub_index) + + index2.train(xt) + index2.add_with_ids(xb, ids) + + _D, I = index2.search(xq, k) + + assert np.all(I == nb - 1 - Iref) + + def test_id_remap_ivf(self): + # coarse quantizer in common + coarse_quantizer = faiss.IndexFlatIP(d) + ncentroids = 25 + + # reference: index without remapping + + index = faiss.IndexIVFPQ(coarse_quantizer, d, + ncentroids, 8, 8) + index.nprobe = 5 + k = 10 + index.train(xt) + index.add(xb) + _Dref, Iref = index.search(xq, k) + + # try a remapping + ids = np.arange(nb)[::-1].copy() + + index2 = faiss.IndexIVFPQ(coarse_quantizer, d, + ncentroids, 8, 8) + index2.nprobe = 5 + + index2.train(xt) + index2.add_with_ids(xb, ids) + + _D, I = index2.search(xq, k) + assert np.all(I == nb - 1 - Iref) + + +class Shards(unittest.TestCase): + + def test_shards(self): + k = 32 + ref_index = faiss.IndexFlatL2(d) + + print('ref search') + ref_index.add(xb) + _Dref, Iref = ref_index.search(xq, k) + print(Iref[:5, :6]) + + shard_index = faiss.IndexShards(d) + shard_index_2 = faiss.IndexShards(d, True, False) + + ni = 3 + for i in range(ni): + i0 = int(i * nb / ni) + i1 = int((i + 1) * nb / ni) + index = faiss.IndexFlatL2(d) + index.add(xb[i0:i1]) + shard_index.add_shard(index) + + index_2 = faiss.IndexFlatL2(d) + irm = faiss.IndexIDMap(index_2) + shard_index_2.add_shard(irm) + + # test parallel add + shard_index_2.verbose = True + shard_index_2.add(xb) + + for test_no in range(3): + with_threads = test_no == 1 + + print('shard search test_no = %d' % test_no) + if with_threads: + remember_nt = faiss.omp_get_max_threads() + faiss.omp_set_num_threads(1) + shard_index.threaded = True + else: + shard_index.threaded = False + + if test_no != 2: + _D, I = shard_index.search(xq, k) + else: + _D, I = shard_index_2.search(xq, k) + + print(I[:5, :6]) + + if with_threads: + faiss.omp_set_num_threads(remember_nt) + + ndiff = (I != Iref).sum() + + print('%d / %d differences' % (ndiff, nq * k)) + assert(ndiff < nq * k / 1000.) + + +class Merge(unittest.TestCase): + + def make_index_for_merge(self, quant, index_type, master_index): + ncent = 40 + if index_type == 1: + index = faiss.IndexIVFFlat(quant, d, ncent, faiss.METRIC_L2) + if master_index: + index.is_trained = True + elif index_type == 2: + index = faiss.IndexIVFPQ(quant, d, ncent, 4, 8) + if master_index: + index.pq = master_index.pq + index.is_trained = True + elif index_type == 3: + index = faiss.IndexIVFPQR(quant, d, ncent, 4, 8, 8, 8) + if master_index: + index.pq = master_index.pq + index.refine_pq = master_index.refine_pq + index.is_trained = True + elif index_type == 4: + # quant used as the actual index + index = faiss.IndexIDMap(quant) + return index + + def do_test_merge(self, index_type): + k = 16 + quant = faiss.IndexFlatL2(d) + ref_index = self.make_index_for_merge(quant, index_type, False) + + # trains the quantizer + ref_index.train(xt) + + print('ref search') + ref_index.add(xb) + _Dref, Iref = ref_index.search(xq, k) + print(Iref[:5, :6]) + + indexes = [] + ni = 3 + for i in range(ni): + i0 = int(i * nb / ni) + i1 = int((i + 1) * nb / ni) + index = self.make_index_for_merge(quant, index_type, ref_index) + index.is_trained = True + index.add(xb[i0:i1]) + indexes.append(index) + + index = indexes[0] + + for i in range(1, ni): + print('merge ntotal=%d other.ntotal=%d ' % ( + index.ntotal, indexes[i].ntotal)) + index.merge_from(indexes[i], index.ntotal) + + _D, I = index.search(xq, k) + print(I[:5, :6]) + + ndiff = (I != Iref).sum() + print('%d / %d differences' % (ndiff, nq * k)) + assert(ndiff < nq * k / 1000.) + + def test_merge(self): + self.do_test_merge(1) + self.do_test_merge(2) + self.do_test_merge(3) + + def do_test_remove(self, index_type): + k = 16 + quant = faiss.IndexFlatL2(d) + index = self.make_index_for_merge(quant, index_type, None) + + # trains the quantizer + index.train(xt) + + if index_type < 4: + index.add(xb) + else: + gen = np.random.RandomState(1234) + id_list = gen.permutation(nb * 7)[:nb] + index.add_with_ids(xb, id_list) + + + print('ref search ntotal=%d' % index.ntotal) + Dref, Iref = index.search(xq, k) + + toremove = np.zeros(nq * k, dtype=int) + nr = 0 + for i in range(nq): + for j in range(k): + # remove all even results (it's ok if there are duplicates + # in the list of ids) + if Iref[i, j] % 2 == 0: + nr = nr + 1 + toremove[nr] = Iref[i, j] + + print('nr=', nr) + + idsel = faiss.IDSelectorBatch( + nr, faiss.swig_ptr(toremove)) + + for i in range(nr): + assert(idsel.is_member(int(toremove[i]))) + + nremoved = index.remove_ids(idsel) + + print('nremoved=%d ntotal=%d' % (nremoved, index.ntotal)) + + D, I = index.search(xq, k) + + # make sure results are in the same order with even ones removed + for i in range(nq): + j2 = 0 + for j in range(k): + if Iref[i, j] % 2 != 0: + assert I[i, j2] == Iref[i, j] + assert abs(D[i, j2] - Dref[i, j]) < 1e-5 + j2 += 1 + + def test_remove(self): + self.do_test_remove(1) + self.do_test_remove(2) + self.do_test_remove(4) + + + + + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_omp_threads.cpp b/core/src/index/thirdparty/faiss/tests/test_omp_threads.cpp new file mode 100644 index 0000000000..216a89dde1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_omp_threads.cpp @@ -0,0 +1,14 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +TEST(Threading, openmp) { + EXPECT_TRUE(faiss::check_openmp()); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_omp_threads_py.py b/core/src/index/thirdparty/faiss/tests/test_omp_threads_py.py new file mode 100644 index 0000000000..c96494dc1f --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_omp_threads_py.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import faiss +import unittest + + +class TestOpenMP(unittest.TestCase): + + def test_openmp(self): + assert faiss.check_openmp() diff --git a/core/src/index/thirdparty/faiss/tests/test_ondisk_ivf.cpp b/core/src/index/thirdparty/faiss/tests/test_ondisk_ivf.cpp new file mode 100644 index 0000000000..c7f717fafe --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_ondisk_ivf.cpp @@ -0,0 +1,220 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + + +namespace { + +struct Tempfilename { + + static pthread_mutex_t mutex; + + std::string filename; + + Tempfilename (const char *prefix = nullptr) { + pthread_mutex_lock (&mutex); + char *cfname = tempnam (nullptr, prefix); + filename = cfname; + free(cfname); + pthread_mutex_unlock (&mutex); + } + + ~Tempfilename () { + if (access (filename.c_str(), F_OK)) { + unlink (filename.c_str()); + } + } + + const char *c_str() { + return filename.c_str(); + } + +}; + +pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER; + +} // namespace + + +TEST(ONDISK, make_invlists) { + int nlist = 100; + int code_size = 32; + int nadd = 1000000; + std::unordered_map listnos; + + Tempfilename filename; + + faiss::OnDiskInvertedLists ivf ( + nlist, code_size, + filename.c_str()); + + { + std::vector code(32); + for (int i = 0; i < nadd; i++) { + double d = drand48(); + int list_no = int(nlist * d * d); // skewed distribution + int * ar = (int*)code.data(); + ar[0] = i; + ar[1] = list_no; + ivf.add_entry (list_no, i, code.data()); + listnos[i] = list_no; + } + } + + int ntot = 0; + for (int i = 0; i < nlist; i++) { + int size = ivf.list_size(i); + const faiss::Index::idx_t *ids = ivf.get_ids (i); + const uint8_t *codes = ivf.get_codes (i); + for (int j = 0; j < size; j++) { + faiss::Index::idx_t id = ids[j]; + const int * ar = (const int*)&codes[code_size * j]; + EXPECT_EQ (ar[0], id); + EXPECT_EQ (ar[1], i); + EXPECT_EQ (listnos[id], i); + ntot ++; + } + } + EXPECT_EQ (ntot, nadd); +}; + + +TEST(ONDISK, test_add) { + int d = 8; + int nlist = 30, nq = 200, nb = 1500, k = 10; + faiss::IndexFlatL2 quantizer(d); + { + std::vector x(d * nlist); + faiss::float_rand(x.data(), d * nlist, 12345); + quantizer.add(nlist, x.data()); + } + std::vector xb(d * nb); + faiss::float_rand(xb.data(), d * nb, 23456); + + faiss::IndexIVFFlat index(&quantizer, d, nlist); + index.add(nb, xb.data()); + + std::vector xq(d * nb); + faiss::float_rand(xq.data(), d * nq, 34567); + + std::vector ref_D (nq * k); + std::vector ref_I (nq * k); + + index.search (nq, xq.data(), k, + ref_D.data(), ref_I.data()); + + Tempfilename filename, filename2; + + // test add + search + { + faiss::IndexIVFFlat index2(&quantizer, d, nlist); + + faiss::OnDiskInvertedLists ivf ( + index.nlist, index.code_size, + filename.c_str()); + + index2.replace_invlists(&ivf); + + index2.add(nb, xb.data()); + + std::vector new_D (nq * k); + std::vector new_I (nq * k); + + index2.search (nq, xq.data(), k, + new_D.data(), new_I.data()); + + EXPECT_EQ (ref_D, new_D); + EXPECT_EQ (ref_I, new_I); + + write_index(&index2, filename2.c_str()); + + } + + // test io + { + faiss::Index *index3 = faiss::read_index(filename2.c_str()); + + std::vector new_D (nq * k); + std::vector new_I (nq * k); + + index3->search (nq, xq.data(), k, + new_D.data(), new_I.data()); + + EXPECT_EQ (ref_D, new_D); + EXPECT_EQ (ref_I, new_I); + + delete index3; + } + +}; + + + +// WARN this thest will run multithreaded only in opt mode +TEST(ONDISK, make_invlists_threaded) { + int nlist = 100; + int code_size = 32; + int nadd = 1000000; + + Tempfilename filename; + + faiss::OnDiskInvertedLists ivf ( + nlist, code_size, + filename.c_str()); + + std::vector list_nos (nadd); + + for (int i = 0; i < nadd; i++) { + double d = drand48(); + list_nos[i] = int(nlist * d * d); // skewed distribution + } + +#pragma omp parallel + { + std::vector code(32); +#pragma omp for + for (int i = 0; i < nadd; i++) { + int list_no = list_nos[i]; + int * ar = (int*)code.data(); + ar[0] = i; + ar[1] = list_no; + ivf.add_entry (list_no, i, code.data()); + } + } + + int ntot = 0; + for (int i = 0; i < nlist; i++) { + int size = ivf.list_size(i); + const faiss::Index::idx_t *ids = ivf.get_ids (i); + const uint8_t *codes = ivf.get_codes (i); + for (int j = 0; j < size; j++) { + faiss::Index::idx_t id = ids[j]; + const int * ar = (const int*)&codes[code_size * j]; + EXPECT_EQ (ar[0], id); + EXPECT_EQ (ar[1], i); + EXPECT_EQ (list_nos[id], i); + ntot ++; + } + } + EXPECT_EQ (ntot, nadd); + +}; diff --git a/core/src/index/thirdparty/faiss/tests/test_oom_exception.py b/core/src/index/thirdparty/faiss/tests/test_oom_exception.py new file mode 100644 index 0000000000..72dfdc7e47 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_oom_exception.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +import sys +import faiss +import unittest +import resource + +class TestOOMException(unittest.TestCase): + + def test_outrageous_alloc(self): + # Disable test on OSX. + if sys.platform == "darwin": + return + + # https://github.com/facebookresearch/faiss/issues/758 + soft_as, hard_as = resource.getrlimit(resource.RLIMIT_AS) + # make sure that allocing more than 10G will fail + resource.setrlimit(resource.RLIMIT_AS, (10 * 1024 * 1024, hard_as)) + try: + x = faiss.IntVector() + try: + x.resize(10**11) # 400 G of RAM + except MemoryError: + pass # good, that's what we expect + else: + assert False, "should raise exception" + finally: + resource.setrlimit(resource.RLIMIT_AS, (soft_as, hard_as)) + + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_pairs_decoding.cpp b/core/src/index/thirdparty/faiss/tests/test_pairs_decoding.cpp new file mode 100644 index 0000000000..7857d0fb50 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_pairs_decoding.cpp @@ -0,0 +1,189 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include + + +namespace { + +typedef faiss::Index::idx_t idx_t; + +/************************************************************* + * Test utils + *************************************************************/ + + +// dimension of the vectors to index +int d = 64; + +// size of the database we plan to index +size_t nb = 8000; + +// nb of queries +size_t nq = 200; + +std::vector make_data(size_t n) +{ + std::vector database (n * d); + for (size_t i = 0; i < n * d; i++) { + database[i] = drand48(); + } + return database; +} + +std::unique_ptr make_index(const char *index_type, + const std::vector & x) { + + auto index = std::unique_ptr ( + faiss::index_factory(d, index_type)); + index->train(nb, x.data()); + index->add(nb, x.data()); + return index; +} + +/************************************************************* + * Test functions for a given index type + *************************************************************/ + +bool test_search_centroid(const char *index_key) { + std::vector xb = make_data(nb); // database vectors + auto index = make_index(index_key, xb); + + /* First test: find the centroids associated to the database + vectors and make sure that each vector does indeed appear in + the inverted list corresponding to its centroid */ + + std::vector centroid_ids (nb); + faiss::ivflib::search_centroid( + index.get(), xb.data(), nb, centroid_ids.data()); + + const faiss::IndexIVF * ivf = faiss::ivflib::extract_index_ivf + (index.get()); + + for(int i = 0; i < nb; i++) { + bool found = false; + int list_no = centroid_ids[i]; + int list_size = ivf->invlists->list_size (list_no); + auto * list = ivf->invlists->get_ids (list_no); + + for(int j = 0; j < list_size; j++) { + if (list[j] == i) { + found = true; + break; + } + } + if(!found) return false; + } + return true; +} + +int test_search_and_return_centroids(const char *index_key) { + std::vector xb = make_data(nb); // database vectors + auto index = make_index(index_key, xb); + + std::vector centroid_ids (nb); + faiss::ivflib::search_centroid(index.get(), xb.data(), + nb, centroid_ids.data()); + + faiss::IndexIVF * ivf = + faiss::ivflib::extract_index_ivf (index.get()); + ivf->nprobe = 4; + + std::vector xq = make_data(nq); // database vectors + + int k = 5; + + // compute a reference search result + + std::vector refI (nq * k); + std::vector refD (nq * k); + index->search (nq, xq.data(), k, refD.data(), refI.data()); + + // compute search result + + std::vector newI (nq * k); + std::vector newD (nq * k); + + std::vector query_centroid_ids (nq); + std::vector result_centroid_ids (nq * k); + + faiss::ivflib::search_and_return_centroids(index.get(), + nq, xq.data(), k, + newD.data(), newI.data(), + query_centroid_ids.data(), + result_centroid_ids.data()); + + // first verify that we have the same result as the standard search + + if (newI != refI) { + return 1; + } + + // then check if the result ids are indeed in the inverted list + // they are supposed to be in + + for(int i = 0; i < nq * k; i++) { + int list_no = result_centroid_ids[i]; + int result_no = newI[i]; + + if (result_no < 0) continue; + + bool found = false; + + int list_size = ivf->invlists->list_size (list_no); + auto * list = ivf->invlists->get_ids (list_no); + + for(int j = 0; j < list_size; j++) { + if (list[j] == result_no) { + found = true; + break; + } + } + if(!found) return 2; + } + return 0; +} + +} // namespace + + +/************************************************************* + * Test entry points + *************************************************************/ + +TEST(test_search_centroid, IVFFlat) { + bool ok = test_search_centroid("IVF32,Flat"); + EXPECT_TRUE(ok); +} + +TEST(test_search_centroid, PCAIVFFlat) { + bool ok = test_search_centroid("PCA16,IVF32,Flat"); + EXPECT_TRUE(ok); +} + +TEST(test_search_and_return_centroids, IVFFlat) { + int err = test_search_and_return_centroids("IVF32,Flat"); + EXPECT_NE(err, 1); + EXPECT_NE(err, 2); +} + +TEST(test_search_and_return_centroids, PCAIVFFlat) { + int err = test_search_and_return_centroids("PCA16,IVF32,Flat"); + EXPECT_NE(err, 1); + EXPECT_NE(err, 2); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_params_override.cpp b/core/src/index/thirdparty/faiss/tests/test_params_override.cpp new file mode 100644 index 0000000000..d6df2a4efe --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_params_override.cpp @@ -0,0 +1,231 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace faiss; + +namespace { + +typedef Index::idx_t idx_t; + + +// dimension of the vectors to index +int d = 32; + +// size of the database we plan to index +size_t nb = 1000; + +// nb of queries +size_t nq = 200; + + + +std::vector make_data(size_t n) +{ + std::vector database (n * d); + for (size_t i = 0; i < n * d; i++) { + database[i] = drand48(); + } + return database; +} + +std::unique_ptr make_index(const char *index_type, + MetricType metric, + const std::vector & x) +{ + std::unique_ptr index(index_factory(d, index_type, metric)); + index->train(nb, x.data()); + index->add(nb, x.data()); + return index; +} + +std::vector search_index(Index *index, const float *xq) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + index->search (nq, xq, k, D.data(), I.data()); + return I; +} + +std::vector search_index_with_params( + Index *index, const float *xq, IVFSearchParameters *params) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + ivflib::search_with_parameters (index, nq, xq, k, + D.data(), I.data(), params); + return I; +} + + + + +/************************************************************* + * Test functions for a given index type + *************************************************************/ + +int test_params_override (const char *index_key, MetricType metric) { + std::vector xb = make_data(nb); // database vectors + auto index = make_index(index_key, metric, xb); + //index->train(nb, xb.data()); + // index->add(nb, xb.data()); + std::vector xq = make_data(nq); + ParameterSpace ps; + ps.set_index_parameter(index.get(), "nprobe", 2); + auto res2ref = search_index(index.get(), xq.data()); + ps.set_index_parameter(index.get(), "nprobe", 9); + auto res9ref = search_index(index.get(), xq.data()); + ps.set_index_parameter(index.get(), "nprobe", 1); + + IVFSearchParameters params; + params.max_codes = 0; + params.nprobe = 2; + auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms); + params.nprobe = 9; + auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms); + + if (res2ref != res2new) + return 2; + + if (res9ref != res9new) + return 9; + + return 0; +} + + +} // namespace + + +/************************************************************* + * Test entry points + *************************************************************/ + +TEST(TPO, IVFFlat) { + int err1 = test_params_override ("IVF32,Flat", METRIC_L2); + EXPECT_EQ(err1, 0); + int err2 = test_params_override ("IVF32,Flat", METRIC_INNER_PRODUCT); + EXPECT_EQ(err2, 0); +} + +TEST(TPO, IVFPQ) { + int err1 = test_params_override ("IVF32,PQ8np", METRIC_L2); + EXPECT_EQ(err1, 0); + int err2 = test_params_override ("IVF32,PQ8np", METRIC_INNER_PRODUCT); + EXPECT_EQ(err2, 0); +} + +TEST(TPO, IVFSQ) { + int err1 = test_params_override ("IVF32,SQ8", METRIC_L2); + EXPECT_EQ(err1, 0); + int err2 = test_params_override ("IVF32,SQ8", METRIC_INNER_PRODUCT); + EXPECT_EQ(err2, 0); +} + +TEST(TPO, IVFFlatPP) { + int err1 = test_params_override ("PCA16,IVF32,SQ8", METRIC_L2); + EXPECT_EQ(err1, 0); + int err2 = test_params_override ("PCA16,IVF32,SQ8", METRIC_INNER_PRODUCT); + EXPECT_EQ(err2, 0); +} + + + +/************************************************************* + * Same for binary indexes + *************************************************************/ + + +std::vector make_data_binary(size_t n) { + std::vector database (n * d / 8); + for (size_t i = 0; i < n * d / 8; i++) { + database[i] = lrand48(); + } + return database; +} + +std::unique_ptr make_index(const char *index_type, + const std::vector & x) +{ + + auto index = std::unique_ptr + (dynamic_cast(index_binary_factory (d, index_type))); + index->train(nb, x.data()); + index->add(nb, x.data()); + return index; +} + +std::vector search_index(IndexBinaryIVF *index, const uint8_t *xq) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + index->search (nq, xq, k, D.data(), I.data()); + return I; +} + +std::vector search_index_with_params( + IndexBinaryIVF *index, const uint8_t *xq, IVFSearchParameters *params) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + + std::vector Iq(params->nprobe * nq); + std::vector Dq(params->nprobe * nq); + + index->quantizer->search(nq, xq, params->nprobe, + Dq.data(), Iq.data()); + index->search_preassigned(nq, xq, k, Iq.data(), Dq.data(), + D.data(), I.data(), + false, params); + return I; +} + +int test_params_override_binary (const char *index_key) { + std::vector xb = make_data_binary(nb); // database vectors + auto index = make_index (index_key, xb); + index->train(nb, xb.data()); + index->add(nb, xb.data()); + std::vector xq = make_data_binary(nq); + index->nprobe = 2; + auto res2ref = search_index(index.get(), xq.data()); + index->nprobe = 9; + auto res9ref = search_index(index.get(), xq.data()); + index->nprobe = 1; + + IVFSearchParameters params; + params.max_codes = 0; + params.nprobe = 2; + auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms); + params.nprobe = 9; + auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms); + + if (res2ref != res2new) + return 2; + + if (res9ref != res9new) + return 9; + + return 0; +} + +TEST(TPOB, IVF) { + int err1 = test_params_override_binary ("BIVF32"); + EXPECT_EQ(err1, 0); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_pq_encoding.cpp b/core/src/index/thirdparty/faiss/tests/test_pq_encoding.cpp new file mode 100644 index 0000000000..214e925d15 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_pq_encoding.cpp @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +#include + +#include + + +namespace { + +const std::vector random_vector(size_t s) { + std::vector v(s, 0); + for (size_t i = 0; i < s; ++i) { + v[i] = rand(); + } + + return v; +} + +} // namespace + + +TEST(PQEncoderGeneric, encode) { + const int nsubcodes = 97; + const int minbits = 1; + const int maxbits = 24; + const std::vector values = random_vector(nsubcodes); + + for(int nbits = minbits; nbits <= maxbits; ++nbits) { + std::cerr << "nbits = " << nbits << std::endl; + + const uint64_t mask = (1ull << nbits) - 1; + std::unique_ptr codes( + new uint8_t[(nsubcodes * maxbits + 7) / 8] + ); + + // NOTE(hoss): Necessary scope to ensure trailing bits are flushed to mem. + { + faiss::PQEncoderGeneric encoder(codes.get(), nbits); + for (const auto& v : values) { + encoder.encode(v & mask); + } + } + + faiss::PQDecoderGeneric decoder(codes.get(), nbits); + for (int i = 0; i < nsubcodes; ++i) { + uint64_t v = decoder.decode(); + EXPECT_EQ(values[i] & mask, v); + } + } +} + + +TEST(PQEncoder8, encode) { + const int nsubcodes = 100; + const std::vector values = random_vector(nsubcodes); + const uint64_t mask = 0xFF; + std::unique_ptr codes(new uint8_t[nsubcodes]); + + faiss::PQEncoder8 encoder(codes.get(), 8); + for (const auto& v : values) { + encoder.encode(v & mask); + } + + faiss::PQDecoder8 decoder(codes.get(), 8); + for (int i = 0; i < nsubcodes; ++i) { + uint64_t v = decoder.decode(); + EXPECT_EQ(values[i] & mask, v); + } +} + + +TEST(PQEncoder16, encode) { + const int nsubcodes = 100; + const std::vector values = random_vector(nsubcodes); + const uint64_t mask = 0xFFFF; + std::unique_ptr codes(new uint8_t[2 * nsubcodes]); + + faiss::PQEncoder16 encoder(codes.get(), 16); + for (const auto& v : values) { + encoder.encode(v & mask); + } + + faiss::PQDecoder16 decoder(codes.get(), 16); + for (int i = 0; i < nsubcodes; ++i) { + uint64_t v = decoder.decode(); + EXPECT_EQ(values[i] & mask, v); + } +} diff --git a/core/src/index/thirdparty/faiss/tests/test_referenced_objects.py b/core/src/index/thirdparty/faiss/tests/test_referenced_objects.py new file mode 100644 index 0000000000..35bf0f8eaa --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_referenced_objects.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""make sure that the referenced objects are kept""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import unittest +import faiss +import sys +import gc + +d = 10 +xt = np.random.rand(100, d).astype('float32') +xb = np.random.rand(20, d).astype('float32') + + +class TestReferenced(unittest.TestCase): + + def test_IndexIVF(self): + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 10) + index.train(xt) + index.add(xb) + del quantizer + gc.collect() + index.add(xb) + + def test_count_refs(self): + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 10) + refc1 = sys.getrefcount(quantizer) + del index + gc.collect() + refc2 = sys.getrefcount(quantizer) + assert refc2 == refc1 - 1 + + def test_IndexIVF_2(self): + index = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, 10) + index.train(xt) + index.add(xb) + + def test_IndexPreTransform(self): + ltrans = faiss.NormalizationTransform(d) + sub_index = faiss.IndexFlatL2(d) + index = faiss.IndexPreTransform(ltrans, sub_index) + index.add(xb) + del ltrans + gc.collect() + index.add(xb) + del sub_index + gc.collect() + index.add(xb) + + def test_IndexPreTransform_2(self): + sub_index = faiss.IndexFlatL2(d) + index = faiss.IndexPreTransform(sub_index) + ltrans = faiss.NormalizationTransform(d) + index.prepend_transform(ltrans) + index.add(xb) + del ltrans + gc.collect() + index.add(xb) + del sub_index + gc.collect() + index.add(xb) + + def test_IDMap(self): + sub_index = faiss.IndexFlatL2(d) + index = faiss.IndexIDMap(sub_index) + index.add_with_ids(xb, np.arange(len(xb))) + del sub_index + gc.collect() + index.add_with_ids(xb, np.arange(len(xb))) + + def test_shards(self): + index = faiss.IndexShards(d) + for _i in range(3): + sub_index = faiss.IndexFlatL2(d) + sub_index.add(xb) + index.add_shard(sub_index) + gc.collect() + index.search(xb, 10) + + +dbin = 32 +xtbin = np.random.randint(256, size=(100, int(dbin / 8))).astype('uint8') +xbbin = np.random.randint(256, size=(20, int(dbin / 8))).astype('uint8') + + +class TestReferencedBinary(unittest.TestCase): + + def test_binary_ivf(self): + index = faiss.IndexBinaryIVF(faiss.IndexBinaryFlat(dbin), dbin, 10) + gc.collect() + index.train(xtbin) + + def test_wrap(self): + index = faiss.IndexBinaryFromFloat(faiss.IndexFlatL2(dbin)) + gc.collect() + index.add(xbbin) + +if __name__ == '__main__': + unittest.main() diff --git a/core/src/index/thirdparty/faiss/tests/test_sliding_ivf.cpp b/core/src/index/thirdparty/faiss/tests/test_sliding_ivf.cpp new file mode 100644 index 0000000000..90ab516c83 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_sliding_ivf.cpp @@ -0,0 +1,240 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace faiss; + +typedef Index::idx_t idx_t; + + +// dimension of the vectors to index +int d = 32; + +// nb of training vectors +size_t nt = 5000; + +// size of the database points per window step +size_t nb = 1000; + +// nb of queries +size_t nq = 200; + + +int total_size = 40; +int window_size = 10; + + + + + +std::vector make_data(size_t n) +{ + std::vector database (n * d); + for (size_t i = 0; i < n * d; i++) { + database[i] = drand48(); + } + return database; +} + +std::unique_ptr make_trained_index(const char *index_type) +{ + auto index = std::unique_ptr(index_factory(d, index_type)); + auto xt = make_data(nt * d); + index->train(nt, xt.data()); + ParameterSpace().set_index_parameter (index.get(), "nprobe", 4); + return index; +} + +std::vector search_index(Index *index, const float *xq) { + int k = 10; + std::vector I(k * nq); + std::vector D(k * nq); + index->search (nq, xq, k, D.data(), I.data()); + return I; +} + + + + + +/************************************************************* + * Test functions for a given index type + *************************************************************/ + + +// make a few slices of indexes that can be merged +void make_index_slices (const Index* trained_index, + std::vector > & sub_indexes) { + + for (int i = 0; i < total_size; i++) { + sub_indexes.emplace_back (clone_index (trained_index)); + + printf ("preparing sub-index # %d\n", i); + + Index * index = sub_indexes.back().get(); + + auto xb = make_data(nb * d); + std::vector ids (nb); + for (int j = 0; j < nb; j++) { + ids[j] = lrand48(); + } + index->add_with_ids (nb, xb.data(), ids.data()); + } + +} + +// build merged index explicitly at sliding window position i +Index *make_merged_index( + const Index* trained_index, + const std::vector > & sub_indexes, + int i) { + + Index * merged_index = clone_index (trained_index); + for (int j = i - window_size + 1; j <= i; j++) { + if (j < 0 || j >= total_size) continue; + std::unique_ptr sub_index ( + clone_index (sub_indexes[j].get())); + IndexIVF *ivf0 = ivflib::extract_index_ivf (merged_index); + IndexIVF *ivf1 = ivflib::extract_index_ivf (sub_index.get()); + ivf0->merge_from (*ivf1, 0); + merged_index->ntotal = ivf0->ntotal; + } + return merged_index; +} + +int test_sliding_window (const char *index_key) { + + std::unique_ptr trained_index = make_trained_index(index_key); + + // make the index slices + std::vector > sub_indexes; + + make_index_slices (trained_index.get(), sub_indexes); + + // now slide over the windows + std::unique_ptr index (clone_index (trained_index.get())); + ivflib::SlidingIndexWindow window (index.get()); + + auto xq = make_data (nq * d); + + for (int i = 0; i < total_size + window_size; i++) { + + printf ("doing step %d / %d\n", i, total_size + window_size); + + // update the index + window.step (i < total_size ? sub_indexes[i].get() : nullptr, + i >= window_size); + printf (" current n_slice = %d\n", window.n_slice); + + auto new_res = search_index (index.get(), xq.data()); + + std::unique_ptr merged_index ( + make_merged_index (trained_index.get(), sub_indexes, i)); + + auto ref_res = search_index (merged_index.get(), xq.data ()); + + EXPECT_EQ (ref_res.size(), new_res.size()); + + EXPECT_EQ (ref_res, new_res); + } + return 0; +} + + +int test_sliding_invlists (const char *index_key) { + + std::unique_ptr trained_index = make_trained_index(index_key); + + // make the index slices + std::vector > sub_indexes; + + make_index_slices (trained_index.get(), sub_indexes); + + // now slide over the windows + std::unique_ptr index (clone_index (trained_index.get())); + IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get()); + + auto xq = make_data (nq * d); + + for (int i = 0; i < total_size + window_size; i++) { + + printf ("doing step %d / %d\n", i, total_size + window_size); + + // update the index + std::vector ils; + for (int j = i - window_size + 1; j <= i; j++) { + if (j < 0 || j >= total_size) continue; + ils.push_back (ivflib::extract_index_ivf ( + sub_indexes[j].get())->invlists); + } + if (ils.size() == 0) continue; + + ConcatenatedInvertedLists *ci = + new ConcatenatedInvertedLists (ils.size(), ils.data()); + + // will be deleted by the index + index_ivf->replace_invlists (ci, true); + + printf (" nb invlists = %ld\n", ils.size()); + + auto new_res = search_index (index.get(), xq.data()); + + std::unique_ptr merged_index ( + make_merged_index (trained_index.get(), sub_indexes, i)); + + auto ref_res = search_index (merged_index.get(), xq.data ()); + + EXPECT_EQ (ref_res.size(), new_res.size()); + + size_t ndiff = 0; + for (size_t j = 0; j < ref_res.size(); j++) { + if (ref_res[j] != new_res[j]) + ndiff++; + } + printf(" nb differences: %ld / %ld\n", + ndiff, ref_res.size()); + EXPECT_EQ (ref_res, new_res); + } + return 0; +} + + + + + +/************************************************************* + * Test entry points + *************************************************************/ + +TEST(SlidingWindow, IVFFlat) { + test_sliding_window ("IVF32,Flat"); +} + +TEST(SlidingWindow, PCAIVFFlat) { + test_sliding_window ("PCA24,IVF32,Flat"); +} + +TEST(SlidingInvlists, IVFFlat) { + test_sliding_invlists ("IVF32,Flat"); +} + +TEST(SlidingInvlists, PCAIVFFlat) { + test_sliding_invlists ("PCA24,IVF32,Flat"); +} diff --git a/core/src/index/thirdparty/faiss/tests/test_standalone_codec.py b/core/src/index/thirdparty/faiss/tests/test_standalone_codec.py new file mode 100644 index 0000000000..95dc58c998 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_standalone_codec.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +#! /usr/bin/env python2 + +""" test byte codecs """ + +from __future__ import print_function +import numpy as np +import unittest +import faiss +import tempfile +import os + +from common import get_dataset_2 + + +class TestEncodeDecode(unittest.TestCase): + + def do_encode_twice(self, factory_key): + d = 96 + nb = 1000 + nq = 0 + nt = 2000 + + xt, x, _ = get_dataset_2(d, nt, nb, nq) + + assert x.size > 0 + + codec = faiss.index_factory(d, factory_key) + + codec.train(xt) + + codes = codec.sa_encode(x) + x2 = codec.sa_decode(codes) + + codes2 = codec.sa_encode(x2) + + if 'IVF' not in factory_key: + self.assertTrue(np.all(codes == codes2)) + else: + # some rows are not reconstructed exactly because they + # flip into another quantization cell + nrowdiff = (codes != codes2).any(axis=1).sum() + self.assertTrue(nrowdiff < 10) + + x3 = codec.sa_decode(codes2) + if 'IVF' not in factory_key: + self.assertTrue(np.allclose(x2, x3)) + else: + diffs = np.abs(x2 - x3).sum(axis=1) + avg = np.abs(x2).sum(axis=1).mean() + diffs.sort() + assert diffs[-10] < avg * 1e-5 + + def test_SQ8(self): + self.do_encode_twice('SQ8') + + def test_IVFSQ8(self): + self.do_encode_twice('IVF256,SQ8') + + def test_PCAIVFSQ8(self): + self.do_encode_twice('PCAR32,IVF256,SQ8') + + def test_PQ6x8(self): + self.do_encode_twice('PQ6np') + + def test_PQ6x6(self): + self.do_encode_twice('PQ6x6np') + + def test_IVFPQ6x8np(self): + self.do_encode_twice('IVF512,PQ6np') + + def test_LSH(self): + self.do_encode_twice('LSHrt') + + +class TestIndexEquiv(unittest.TestCase): + + def do_test(self, key1, key2): + d = 96 + nb = 1000 + nq = 0 + nt = 2000 + + xt, x, _ = get_dataset_2(d, nt, nb, nq) + + codec_ref = faiss.index_factory(d, key1) + codec_ref.train(xt) + + code_ref = codec_ref.sa_encode(x) + x_recons_ref = codec_ref.sa_decode(code_ref) + + codec_new = faiss.index_factory(d, key2) + codec_new.pq = codec_ref.pq + + # replace quantizer, avoiding mem leak + oldq = codec_new.q1.quantizer + oldq.this.own() + codec_new.q1.own_fields = False + codec_new.q1.quantizer = codec_ref.quantizer + codec_new.is_trained = True + + code_new = codec_new.sa_encode(x) + x_recons_new = codec_new.sa_decode(code_new) + + self.assertTrue(np.all(code_new == code_ref)) + self.assertTrue(np.all(x_recons_new == x_recons_ref)) + + codec_new_2 = faiss.deserialize_index( + faiss.serialize_index(codec_new)) + + code_new = codec_new_2.sa_encode(x) + x_recons_new = codec_new_2.sa_decode(code_new) + + self.assertTrue(np.all(code_new == code_ref)) + self.assertTrue(np.all(x_recons_new == x_recons_ref)) + + def test_IVFPQ(self): + self.do_test("IVF512,PQ6np", "Residual512,PQ6") + + def test_IMI(self): + self.do_test("IMI2x5,PQ6np", "Residual2x5,PQ6") + + +class TestAccuracy(unittest.TestCase): + """ comparative accuracy of a few types of indexes """ + + def compare_accuracy(self, lowac, highac, max_errs=(1e10, 1e10)): + d = 96 + nb = 1000 + nq = 0 + nt = 2000 + + xt, x, _ = get_dataset_2(d, nt, nb, nq) + + errs = [] + + for factory_string in lowac, highac: + + codec = faiss.index_factory(d, factory_string) + print('sa codec: code size %d' % codec.sa_code_size()) + codec.train(xt) + + codes = codec.sa_encode(x) + x2 = codec.sa_decode(codes) + + err = ((x - x2) ** 2).sum() + errs.append(err) + + print(errs) + self.assertGreater(errs[0], errs[1]) + + self.assertGreater(max_errs[0], errs[0]) + self.assertGreater(max_errs[1], errs[1]) + + # just a small IndexLattice I/O test + if 'Lattice' in highac: + codec2 = faiss.deserialize_index( + faiss.serialize_index(codec)) + codes = codec.sa_encode(x) + x3 = codec.sa_decode(codes) + self.assertTrue(np.all(x2 == x3)) + + def test_SQ(self): + self.compare_accuracy('SQ4', 'SQ8') + + def test_SQ2(self): + self.compare_accuracy('SQ6', 'SQ8') + + def test_SQ3(self): + self.compare_accuracy('SQ8', 'SQfp16') + + def test_PQ(self): + self.compare_accuracy('PQ6x8np', 'PQ8x8np') + + def test_PQ2(self): + self.compare_accuracy('PQ8x6np', 'PQ8x8np') + + def test_IVFvsPQ(self): + self.compare_accuracy('PQ8np', 'IVF256,PQ8np') + + def test_Lattice(self): + # measured low/high: 20946.244, 5277.483 + self.compare_accuracy('ZnLattice3x10_4', + 'ZnLattice3x20_4', + (22000, 5400)) + + def test_Lattice2(self): + # here the difference is actually tiny + # measured errs: [16403.072, 15967.735] + self.compare_accuracy('ZnLattice3x12_1', + 'ZnLattice3x12_7', + (18000, 16000)) + + +swig_ptr = faiss.swig_ptr + + +class LatticeTest(unittest.TestCase): + """ Low-level lattice tests """ + + def test_repeats(self): + rs = np.random.RandomState(123) + dim = 32 + for i in range(1000): + vec = np.floor((rs.rand(dim) ** 7) * 3).astype('float32') + vecs = vec.copy() + vecs.sort() + repeats = faiss.Repeats(dim, swig_ptr(vecs)) + rr = [repeats.repeats.at(i) for i in range(repeats.repeats.size())] + # print([(r.val, r.n) for r in rr]) + code = repeats.encode(swig_ptr(vec)) + #print(vec, code) + vec2 = np.zeros(dim, dtype='float32') + repeats.decode(code, swig_ptr(vec2)) + # print(vec2) + assert np.all(vec == vec2) + + def test_ZnSphereCodec_encode_centroid(self): + dim = 8 + r2 = 5 + ref_codec = faiss.ZnSphereCodec(dim, r2) + codec = faiss.ZnSphereCodecRec(dim, r2) + # print(ref_codec.nv, codec.nv) + assert ref_codec.nv == codec.nv + s = set() + for i in range(ref_codec.nv): + c = np.zeros(dim, dtype='float32') + ref_codec.decode(i, swig_ptr(c)) + code = codec.encode_centroid(swig_ptr(c)) + assert 0 <= code < codec.nv + s.add(code) + assert len(s) == codec.nv + + def test_ZnSphereCodecRec(self): + dim = 16 + r2 = 6 + codec = faiss.ZnSphereCodecRec(dim, r2) + # print("nv=", codec.nv) + for i in range(codec.nv): + c = np.zeros(dim, dtype='float32') + codec.decode(i, swig_ptr(c)) + code = codec.encode_centroid(swig_ptr(c)) + assert code == i + + def run_ZnSphereCodecAlt(self, dim, r2): + # dim = 32 + # r2 = 14 + codec = faiss.ZnSphereCodecAlt(dim, r2) + rs = np.random.RandomState(123) + n = 100 + codes = rs.randint(codec.nv, size=n).astype('uint64') + x = np.empty((n, dim), dtype='float32') + codec.decode_multi(n, swig_ptr(codes), swig_ptr(x)) + codes2 = np.empty(n, dtype='uint64') + codec.encode_multi(n, swig_ptr(x), swig_ptr(codes2)) + + assert np.all(codes == codes2) + + def test_ZnSphereCodecAlt32(self): + self.run_ZnSphereCodecAlt(32, 14) + + def test_ZnSphereCodecAlt24(self): + self.run_ZnSphereCodecAlt(24, 14) + + +class TestBitstring(unittest.TestCase): + """ Low-level bit string tests """ + + def test_rw(self): + rs = np.random.RandomState(1234) + nbyte = 1000 + sz = 0 + + bs = np.ones(nbyte, dtype='uint8') + bw = faiss.BitstringWriter(swig_ptr(bs), nbyte) + + if False: + ctrl = [(7, 0x35), (13, 0x1d74)] + for nbit, x in ctrl: + bw.write(x, nbit) + else: + ctrl = [] + while True: + nbit = int(1 + 62 * rs.rand() ** 4) + if sz + nbit > nbyte * 8: + break + x = rs.randint(1 << nbit) + bw.write(x, nbit) + ctrl.append((nbit, x)) + sz += nbit + + bignum = 0 + sz = 0 + for nbit, x in ctrl: + bignum |= x << sz + sz += nbit + + for i in range(nbyte): + self.assertTrue(((bignum >> (i * 8)) & 255) == bs[i]) + + for i in range(nbyte): + print(bin(bs[i] + 256)[3:], end=' ') + print() + + br = faiss.BitstringReader(swig_ptr(bs), nbyte) + + for nbit, xref in ctrl: + xnew = br.read(nbit) + print('nbit %d xref %x xnew %x' % (nbit, xref, xnew)) + self.assertTrue(xnew == xref) diff --git a/core/src/index/thirdparty/faiss/tests/test_threaded_index.cpp b/core/src/index/thirdparty/faiss/tests/test_threaded_index.cpp new file mode 100644 index 0000000000..7cad760c09 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_threaded_index.cpp @@ -0,0 +1,253 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +struct TestException : public std::exception { }; + +struct MockIndex : public faiss::Index { + explicit MockIndex(idx_t d) : + faiss::Index(d) { + resetMock(); + } + + void resetMock() { + flag = false; + nCalled = 0; + xCalled = nullptr; + kCalled = 0; + distancesCalled = nullptr; + labelsCalled = nullptr; + } + + void add(idx_t n, const float* x) override { + nCalled = n; + xCalled = x; + } + + void search(idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels) const override { + nCalled = n; + xCalled = x; + kCalled = k; + distancesCalled = distances; + labelsCalled = labels; + } + + void reset() override { } + + bool flag; + + mutable idx_t nCalled; + mutable const float* xCalled; + mutable idx_t kCalled; + mutable float* distancesCalled; + mutable idx_t* labelsCalled; +}; + +template +struct MockThreadedIndex : public faiss::ThreadedIndex { + using idx_t = faiss::Index::idx_t; + + explicit MockThreadedIndex(bool threaded) + : faiss::ThreadedIndex(threaded) { + } + + void add(idx_t, const float*) override { } + void search(idx_t, const float*, idx_t, float*, idx_t*) const override {} + void reset() override {} +}; + +} + +TEST(ThreadedIndex, SingleException) { + std::vector> idxs; + + for (int i = 0; i < 3; ++i) { + idxs.emplace_back(new MockIndex(1)); + } + + auto fn = + [](int i, MockIndex* index) { + if (i == 1) { + throw TestException(); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(i * 250)); + + index->flag = true; + } + }; + + // Try with threading and without + for (bool threaded : {true, false}) { + // clear flags + for (auto& idx : idxs) { + idx->resetMock(); + } + + MockThreadedIndex ti(threaded); + for (auto& idx : idxs) { + ti.addIndex(idx.get()); + } + + // The second index should throw + EXPECT_THROW(ti.runOnIndex(fn), TestException); + + // Index 0 and 2 should have processed + EXPECT_TRUE(idxs[0]->flag); + EXPECT_TRUE(idxs[2]->flag); + } +} + +TEST(ThreadedIndex, MultipleException) { + std::vector> idxs; + + for (int i = 0; i < 3; ++i) { + idxs.emplace_back(new MockIndex(1)); + } + + auto fn = + [](int i, MockIndex* index) { + if (i < 2) { + throw TestException(); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(i * 250)); + + index->flag = true; + } + }; + + // Try with threading and without + for (bool threaded : {true, false}) { + // clear flags + for (auto& idx : idxs) { + idx->resetMock(); + } + + MockThreadedIndex ti(threaded); + for (auto& idx : idxs) { + ti.addIndex(idx.get()); + } + + // Multiple indices threw an exception that was aggregated into a + // FaissException + EXPECT_THROW(ti.runOnIndex(fn), faiss::FaissException); + + // Index 2 should have processed + EXPECT_TRUE(idxs[2]->flag); + } +} + +TEST(ThreadedIndex, TestReplica) { + int numReplicas = 5; + int n = 10 * numReplicas; + int d = 3; + int k = 6; + + // Try with threading and without + for (bool threaded : {true, false}) { + std::vector> idxs; + faiss::IndexReplicas replica(d); + + for (int i = 0; i < numReplicas; ++i) { + idxs.emplace_back(new MockIndex(d)); + replica.addIndex(idxs.back().get()); + } + + std::vector x(n * d); + std::vector distances(n * k); + std::vector labels(n * k); + + replica.add(n, x.data()); + + for (int i = 0; i < idxs.size(); ++i) { + EXPECT_EQ(idxs[i]->nCalled, n); + EXPECT_EQ(idxs[i]->xCalled, x.data()); + } + + for (auto& idx : idxs) { + idx->resetMock(); + } + + replica.search(n, x.data(), k, distances.data(), labels.data()); + + for (int i = 0; i < idxs.size(); ++i) { + auto perReplica = n / idxs.size(); + + EXPECT_EQ(idxs[i]->nCalled, perReplica); + EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perReplica * d); + EXPECT_EQ(idxs[i]->kCalled, k); + EXPECT_EQ(idxs[i]->distancesCalled, + distances.data() + (i * perReplica) * k); + EXPECT_EQ(idxs[i]->labelsCalled, + labels.data() + (i * perReplica) * k); + } + } +} + +TEST(ThreadedIndex, TestShards) { + int numShards = 7; + int d = 3; + int n = 10 * numShards; + int k = 6; + + // Try with threading and without + for (bool threaded : {true, false}) { + std::vector> idxs; + faiss::IndexShards shards(d, threaded); + + for (int i = 0; i < numShards; ++i) { + idxs.emplace_back(new MockIndex(d)); + shards.addIndex(idxs.back().get()); + } + + std::vector x(n * d); + std::vector distances(n * k); + std::vector labels(n * k); + + shards.add(n, x.data()); + + for (int i = 0; i < idxs.size(); ++i) { + auto perShard = n / idxs.size(); + + EXPECT_EQ(idxs[i]->nCalled, perShard); + EXPECT_EQ(idxs[i]->xCalled, x.data() + i * perShard * d); + } + + for (auto& idx : idxs) { + idx->resetMock(); + } + + shards.search(n, x.data(), k, distances.data(), labels.data()); + + for (int i = 0; i < idxs.size(); ++i) { + auto perShard = n / idxs.size(); + + EXPECT_EQ(idxs[i]->nCalled, n); + EXPECT_EQ(idxs[i]->xCalled, x.data()); + EXPECT_EQ(idxs[i]->kCalled, k); + // There is a temporary buffer used for shards + EXPECT_EQ(idxs[i]->distancesCalled, + idxs[0]->distancesCalled + i * k * n); + EXPECT_EQ(idxs[i]->labelsCalled, + idxs[0]->labelsCalled + i * k * n); + } + } +} diff --git a/core/src/index/thirdparty/faiss/tests/test_transfer_invlists.cpp b/core/src/index/thirdparty/faiss/tests/test_transfer_invlists.cpp new file mode 100644 index 0000000000..8766d88e6f --- /dev/null +++ b/core/src/index/thirdparty/faiss/tests/test_transfer_invlists.cpp @@ -0,0 +1,159 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace { + +// parameters to use for the test +int d = 64; +size_t nb = 1000; +size_t nq = 100; +size_t nt = 500; +int k = 10; +int nlist = 40; + +using namespace faiss; + +typedef faiss::Index::idx_t idx_t; + + +std::vector get_data (size_t nb, int seed) { + std::vector x (nb * d); + float_randn (x.data(), nb * d, seed); + return x; +} + + +void test_index_type(const char *factory_string) { + + // transfer inverted lists in nslice slices + int nslice = 3; + + /**************************************************************** + * trained reference index + ****************************************************************/ + + std::unique_ptr trained (index_factory (d, factory_string)); + + { + auto xt = get_data (nt, 123); + trained->train (nt, xt.data()); + } + + // sample nq query vectors to check if results are the same + auto xq = get_data (nq, 818); + + + /**************************************************************** + * source index + ***************************************************************/ + std::unique_ptr src_index (clone_index (trained.get())); + + { // add some data to source index + auto xb = get_data (nb, 245); + src_index->add (nb, xb.data()); + } + + ParameterSpace().set_index_parameter (src_index.get(), "nprobe", 4); + + // remember reference search result on source index + std::vector Iref (nq * k); + std::vector Dref (nq * k); + src_index->search (nq, xq.data(), k, Dref.data(), Iref.data()); + + /**************************************************************** + * destination index -- should be replaced by source index + ***************************************************************/ + + std::unique_ptr dst_index (clone_index (trained.get())); + + { // initial state: filled in with some garbage + int nb2 = nb + 10; + auto xb = get_data (nb2, 366); + dst_index->add (nb2, xb.data()); + } + + std::vector Inew (nq * k); + std::vector Dnew (nq * k); + + ParameterSpace().set_index_parameter (dst_index.get(), "nprobe", 4); + + // transfer from source to destination in nslice slices + for (int sl = 0; sl < nslice; sl++) { + + // so far, the indexes are different + dst_index->search (nq, xq.data(), k, Dnew.data(), Inew.data()); + EXPECT_TRUE (Iref != Inew); + EXPECT_TRUE (Dref != Dnew); + + // range of inverted list indices to transfer + long i0 = sl * nlist / nslice; + long i1 = (sl + 1) * nlist / nslice; + + std::vector data_to_transfer; + { + std::unique_ptr il + (ivflib::get_invlist_range (src_index.get(), i0, i1)); + // serialize inverted lists + VectorIOWriter wr; + write_InvertedLists (il.get(), &wr); + data_to_transfer.swap (wr.data); + } + + // transfer data here from source machine to dest machine + + { + VectorIOReader reader; + reader.data.swap (data_to_transfer); + + // deserialize inverted lists + std::unique_ptr il + (dynamic_cast + (read_InvertedLists (&reader))); + + // swap inverted lists. Block searches here! + { + ivflib::set_invlist_range (dst_index.get(), i0, i1, il.get()); + } + } + + } + EXPECT_EQ (dst_index->ntotal, src_index->ntotal); + + // now, the indexes are the same + dst_index->search (nq, xq.data(), k, Dnew.data(), Inew.data()); + EXPECT_TRUE (Iref == Inew); + EXPECT_TRUE (Dref == Dnew); + +} + +} // namespace + + +TEST(TRANS, IVFFlat) { + test_index_type ("IVF40,Flat"); +} + +TEST(TRANS, IVFFlatPreproc) { + test_index_type ("PCAR32,IVF40,Flat"); +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/1-Flat.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/1-Flat.cpp new file mode 100644 index 0000000000..f8632bb6c8 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/1-Flat.cpp @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + + float *xb = new float[d * nb]; + float *xq = new float[d * nq]; + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) + xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + } + + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) + xq[d * i + j] = drand48(); + xq[d * i] += i / 1000.; + } + + faiss::IndexFlatL2 index(d); // call constructor + printf("is_trained = %s\n", index.is_trained ? "true" : "false"); + index.add(nb, xb); // add vectors to the index + printf("ntotal = %ld\n", index.ntotal); + + int k = 4; + + { // sanity check: search 5 first vectors of xb + long *I = new long[k * 5]; + float *D = new float[k * 5]; + + index.search(5, xb, k, D, I); + + // print results + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("D=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%7g ", D[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index.search(nq, xq, k, D, I); + + // print results + printf("I (5 first results)=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (5 last results)=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + + + delete [] xb; + delete [] xq; + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/2-IVFFlat.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/2-IVFFlat.cpp new file mode 100644 index 0000000000..ce13f1d1ae --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/2-IVFFlat.cpp @@ -0,0 +1,81 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include + + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + + float *xb = new float[d * nb]; + float *xq = new float[d * nq]; + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) + xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + } + + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) + xq[d * i + j] = drand48(); + xq[d * i] += i / 1000.; + } + + + int nlist = 100; + int k = 4; + + faiss::IndexFlatL2 quantizer(d); // the other index + faiss::IndexIVFFlat index(&quantizer, d, nlist, faiss::METRIC_L2); + // here we specify METRIC_L2, by default it performs inner-product search + assert(!index.is_trained); + index.train(nb, xb); + assert(index.is_trained); + index.add(nb, xb); + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index.search(nq, xq, k, D, I); + + printf("I=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + index.nprobe = 10; + index.search(nq, xq, k, D, I); + + printf("I=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + + + delete [] xb; + delete [] xq; + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/3-IVFPQ.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/3-IVFPQ.cpp new file mode 100644 index 0000000000..85a2de0578 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/3-IVFPQ.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include "../../utils/ConcurrentBitset.h" + + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10;//10000; // nb of queries + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + + float *xb = new float[d * nb]; + float *xq = new float[d * nq]; + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) + xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + } + + srand((unsigned)time(NULL)); + printf("delete ids: \n"); + for(int i = 0; i < nq; i++) { + auto tmp = rand()%nb; + bitset->set(tmp); + printf("%d ", tmp); + for(int j = 0; j < d; j++) + xq[d * i + j] = xb[d * tmp + j]; +// xq[d * i] += i / 1000.; + } + printf("\n"); + + + int nlist = 100; + int k = 4; + int m = 8; // bytes per vector + faiss::IndexFlatL2 quantizer(d); // the other index + faiss::IndexIVFPQ index(&quantizer, d, nlist, m, 8); + // here we specify METRIC_L2, by default it performs inner-product search + index.train(nb, xb); + index.add(nb, xb); + + printf("------------sanity check----------------\n"); + { // sanity check + long *I = new long[k * 5]; + float *D = new float[k * 5]; + + index.search(5, xb, k, D, I); + + printf("I=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("D=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%7g ", D[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + printf("---------------search xq-------------\n"); + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index.nprobe = 10; + index.search(nq, xq, k, D, I); + + printf("I=\n"); + for(int i = 0; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + printf("----------------search xq with delete------------\n"); + { // search xq with delete + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index.nprobe = 10; + index.search(nq, xq, k, D, I, bitset); + + printf("I=\n"); + for(int i = 0; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + + + delete [] xb; + delete [] xq; + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/4-GPU.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/4-GPU.cpp new file mode 100644 index 0000000000..49c5c8a06e --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/4-GPU.cpp @@ -0,0 +1,119 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include + + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + + float *xb = new float[d * nb]; + float *xq = new float[d * nq]; + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) + xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + } + + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) + xq[d * i + j] = drand48(); + xq[d * i] += i / 1000.; + } + + faiss::gpu::StandardGpuResources res; + + // Using a flat index + + faiss::gpu::GpuIndexFlatL2 index_flat(&res, d); + + printf("is_trained = %s\n", index_flat.is_trained ? "true" : "false"); + index_flat.add(nb, xb); // add vectors to the index + printf("ntotal = %ld\n", index_flat.ntotal); + + int k = 4; + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index_flat.search(nq, xq, k, D, I); + + // print results + printf("I (5 first results)=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (5 last results)=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + // Using an IVF index + + int nlist = 100; + faiss::gpu::GpuIndexIVFFlat index_ivf(&res, d, nlist, faiss::METRIC_L2); + // here we specify METRIC_L2, by default it performs inner-product search + + assert(!index_ivf.is_trained); + index_ivf.train(nb, xb); + assert(index_ivf.is_trained); + index_ivf.add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", index_ivf.is_trained ? "true" : "false"); + printf("ntotal = %ld\n", index_ivf.ntotal); + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + index_ivf.search(nq, xq, k, D, I); + + // print results + printf("I (5 first results)=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (5 last results)=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + + delete [] xb; + delete [] xq; + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/5-GPU.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/5-GPU.cpp new file mode 100644 index 0000000000..212fb53f1c --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/5-GPU.cpp @@ -0,0 +1,234 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include "faiss/IndexIVF.h" +#include "faiss/IndexFlat.h" +#include "faiss/index_io.h" +#include "faiss/gpu/GpuIndexFlat.h" +#include "faiss/gpu/StandardGpuResources.h" +#include "faiss/gpu/GpuAutoTune.h" +#include "faiss/gpu/GpuCloner.h" +#include "faiss/gpu/GpuClonerOptions.h" +#include "faiss/gpu/GpuIndexIVF.h" + +#include "faiss/impl/FaissAssert.h" +#include "faiss/impl/AuxIndexStructures.h" + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" + +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexScalarQuantizer.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/utils/distances.h" +#include "faiss/index_factory.h" + +using namespace faiss; + +#define PRINT_RESULT 0 + +void print_result(const char* unit, long number, long k, long nq, long *I) { + printf("%s: I (2 first results)=\n", unit); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("%s: I (2 last results)=\n", unit); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + + +int main() { + const char* filename = "index500k.index"; + +#if PRINT_RESULT + int number = 8; +#endif + + int d = 512; // dimension + int nq = 10; // nb of queries + int nprobe = 1; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + faiss::distance_compute_blas_threshold = 800; + + faiss::gpu::StandardGpuResources res; + + int k = 8; + std::shared_ptr gpu_index_ivf_ptr; + + const char* index_description = "IVF16384,SQ8"; +// const char* index_description = "IVF3276,SQ8"; + + faiss::Index *cpu_index = nullptr; + faiss::IndexIVF* cpu_ivf_index = nullptr; + if((access(filename,F_OK))==-1) { + // create database + long nb = 500000; // database size +// printf("-----------------------\n"); + long size = d * nb; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < nb; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_L2); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + } else { + cpu_index = faiss::read_index(filename); + } + + cpu_ivf_index = dynamic_cast(cpu_index); + if(cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + auto init_gpu =[&](int device_id, faiss::gpu::GpuClonerOptions* option) { + option->allInGpu = true; + faiss::Index* tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, option); + delete tmp_index; + }; + + auto gpu_executor = [&](int device_id, faiss::gpu::GpuClonerOptions* option) { + auto tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, option); + delete tmp_index; + double t0 = getmillisecs (); + { + // cpu to gpu + option->allInGpu = true; + + tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, option); + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + } + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + if(option->allInGpu) { + faiss::gpu::GpuIndexIVF* gpu_index_ivf = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf->setNumProbes(nprobe); + for(long i = 0; i < 1; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + } else { + faiss::IndexIVFScalarQuantizer* index_ivf = + dynamic_cast(gpu_index_ivf_ptr.get()); + index_ivf->nprobe = nprobe; + for(long i = 0; i < 1; ++ i) { + double t2 = getmillisecs(); + index_ivf->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("- GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); + + }; + printf("----------------------------------\n"); + auto cpu_executor = [&]() { // search xq + printf("CPU: \n"); + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("CPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + }; + + for(long i = 0; i < 1; ++ i) { + cpu_executor(); + } + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + +// init_gpu(0, &option0); +// init_gpu(1, &option1); + +// double tx = getmillisecs(); + std::thread t1(gpu_executor, 0, &option0); + std::thread t2(gpu_executor, 1, &option1); + t1.join(); + t2.join(); +// double ty = getmillisecs(); +// printf("Total GPU execution time: %0.2f\n", ty - tx); + + delete [] xq; + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/5-Multiple-GPUs.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/5-Multiple-GPUs.cpp new file mode 100644 index 0000000000..3152b731a1 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/5-Multiple-GPUs.cpp @@ -0,0 +1,100 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + + +int main() { + int d = 64; // dimension + int nb = 100000; // database size + int nq = 10000; // nb of queries + + float *xb = new float[d * nb]; + float *xq = new float[d * nq]; + + for(int i = 0; i < nb; i++) { + for(int j = 0; j < d; j++) + xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + } + + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) + xq[d * i + j] = drand48(); + xq[d * i] += i / 1000.; + } + + int ngpus = faiss::gpu::getNumDevices(); + + printf("Number of GPUs: %d\n", ngpus); + + std::vector res; + std::vector devs; + for(int i = 0; i < ngpus; i++) { + res.push_back(new faiss::gpu::StandardGpuResources); + devs.push_back(i); + } + + faiss::IndexFlatL2 cpu_index(d); + + faiss::Index *gpu_index = + faiss::gpu::index_cpu_to_gpu_multiple( + res, + devs, + &cpu_index + ); + + printf("is_trained = %s\n", gpu_index->is_trained ? "true" : "false"); + gpu_index->add(nb, xb); // add vectors to the index + printf("ntotal = %ld\n", gpu_index->ntotal); + + int k = 4; + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + gpu_index->search(nq, xq, k, D, I); + + // print results + printf("I (5 first results)=\n"); + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (5 last results)=\n"); + for(int i = nq - 5; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + delete [] I; + delete [] D; + } + + delete gpu_index; + + for(int i = 0; i < ngpus; i++) { + delete res[i]; + } + + delete [] xb; + delete [] xq; + + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/6-GPU.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/6-GPU.cpp new file mode 100644 index 0000000000..f992884cba --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/6-GPU.cpp @@ -0,0 +1,255 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include "faiss/IndexIVF.h" +#include "faiss/IndexFlat.h" +#include "faiss/index_io.h" +#include "faiss/gpu/GpuIndexFlat.h" +#include "faiss/gpu/StandardGpuResources.h" +#include "faiss/gpu/GpuAutoTune.h" +#include "faiss/gpu/GpuCloner.h" +#include "faiss/gpu/GpuClonerOptions.h" +#include "faiss/gpu/GpuIndexIVF.h" +#include "faiss/gpu/GpuIndexIVFSQHybrid.h" + +#include "faiss/impl/FaissAssert.h" +#include "faiss/impl/AuxIndexStructures.h" + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" + +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexSQHybrid.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/utils/distances.h" +#include "faiss/index_factory.h" + +using namespace faiss; + +#define PRINT_RESULT 0 + +void print_result(const char* unit, long number, long k, long nq, long *I) { + printf("%s: I (2 first results)=\n", unit); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("%s: I (2 last results)=\n", unit); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + + +int main() { + const char* filename = "index500k-h.index"; + +#if PRINT_RESULT + int number = 8; +#endif + + int d = 512; // dimension + int nq = 10; // nb of queries + int nprobe = 1; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + faiss::distance_compute_blas_threshold = 800; + + faiss::gpu::StandardGpuResources res; + + int k = 8; + std::shared_ptr gpu_index_ivf_ptr; + + const char* index_description = "IVF16384,SQ8Hybrid"; +// const char* index_description = "IVF3276,SQ8"; + + faiss::Index *cpu_index = nullptr; + faiss::IndexIVF* cpu_ivf_index = nullptr; + if((access(filename,F_OK))==-1) { + // create database + long nb = 500000; // database size +// printf("-----------------------\n"); + long size = d * nb; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < nb; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_L2); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + } else { + cpu_index = faiss::read_index(filename); + } + + cpu_ivf_index = dynamic_cast(cpu_index); + if(cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + auto gpu_executor = [&](int device_id, faiss::gpu::GpuClonerOptions* option, faiss::IndexComposition* index_composition) { + auto tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + delete tmp_index; + double t0 = getmillisecs (); + { + // cpu to gpu + tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + } + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid* gpu_index_ivf_hybrid = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for(long i = 0; i < 1; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); + + }; + printf("----------------------------------\n"); + auto cpu_executor = [&](faiss::IndexComposition* index_composition) { // search xq + printf("CPU: \n"); + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + + faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if(is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition->quantizer; + } + + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("CPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + }; + + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + + faiss::IndexComposition index_composition0; + index_composition0.index = cpu_index; + index_composition0.quantizer = nullptr; + index_composition0.mode = 0; // only quantizer + + // Copy quantizer to GPU 0 + auto index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + + faiss::IndexComposition index_composition1; + index_composition1.index = cpu_index; + index_composition1.quantizer = nullptr; + index_composition1.mode = 0; // only quantizer + + // Copy quantizer to GPU 1 + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + + std::thread t_cpu1(cpu_executor, &index_composition0); + t_cpu1.join(); + std::thread t_cpu2(cpu_executor, &index_composition1); + t_cpu2.join(); + + index_composition0.mode = 2; // only data + index_composition1.mode = 2; // only data + + index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + +// double tx = getmillisecs(); + std::thread t1(gpu_executor, 0, &option0, &index_composition0); + std::thread t2(gpu_executor, 1, &option1, &index_composition1); + t1.join(); + t2.join(); + +// std::thread t3(gpu_executor, 0, &option0, &index_composition0); +// std::thread t4(gpu_executor, 1, &option1, &index_composition1); +// t3.join(); +// t4.join(); +// double ty = getmillisecs(); +// printf("Total GPU execution time: %0.2f\n", ty - tx); + cpu_executor(&index_composition0); + cpu_executor(&index_composition1); + + delete [] xq; + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/6-RUN.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/6-RUN.cpp new file mode 100644 index 0000000000..2c09fef266 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/6-RUN.cpp @@ -0,0 +1,247 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include "faiss/IndexIVF.h" +#include "faiss/IndexFlat.h" +#include "faiss/index_io.h" +#include "faiss/gpu/GpuIndexFlat.h" +#include "faiss/gpu/StandardGpuResources.h" +#include "faiss/gpu/GpuAutoTune.h" +#include "faiss/gpu/GpuCloner.h" +#include "faiss/gpu/GpuClonerOptions.h" +#include "faiss/gpu/GpuIndexIVF.h" +#include "faiss/gpu/GpuIndexIVFSQHybrid.h" + +#include "faiss/impl/FaissAssert.h" +#include "faiss/impl/AuxIndexStructures.h" + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" + +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexSQHybrid.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/utils/distances.h" +#include "faiss/index_factory.h" + +using namespace faiss; + +#define PRINT_RESULT 0 +std::shared_ptr gpu_index_ivf_ptr; +const int d = 512; // dimension +const int nq = 1000; // nb of queries +const int nprobe = 1; +int k = 8; + +void +print_result(const char* unit, long number, long k, long nq, long* I) { + printf("%s: I (2 first results)=\n", unit); + for (int i = 0; i < number; i++) { + for (int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("%s: I (2 last results)=\n", unit); + for (int i = nq - number; i < nq; i++) { + for (int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + +void +cpu_executor(faiss::Index* cpu_index, float*& xq) { // search xq + printf("CPU: \n"); + long* I = new long[k * nq]; + float* D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("CPU", number, k, nq, I); +#endif + delete[] I; + delete[] D; +}; + +void +hybrid_executor(faiss::Index* cpu_index, + faiss::IndexComposition* index_composition, + float*& xq) { // search xq + printf("HYBRID: \n"); + long* I = new long[k * nq]; + float* D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + + faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if (is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition->quantizer; + } + + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("HYBRID execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("HYBRID", number, k, nq, I); +#endif + delete[] I; + delete[] D; +}; + +void +gpu_executor(faiss::gpu::StandardGpuResources& res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + float*& xq) { + auto tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + delete tmp_index; + double t0 = getmillisecs(); + { + // cpu to gpu + tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + } + double t1 = getmillisecs(); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + { + long* I = new long[k * nq]; + float* D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid + * gpu_index_ivf_hybrid = dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for (long i = 0; i < 1; ++i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete[] I; + delete[] D; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); + +}; + +int +main() { + const char* filename = "index500k-h.index"; + faiss::gpu::StandardGpuResources res; + +#if PRINT_RESULT + int number = 8; +#endif + + float* xq = new float[d * nq]; + for (int i = 0; i < nq; i++) { + for (int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + faiss::distance_compute_blas_threshold = 800; + + faiss::Index* cpu_index = nullptr; + faiss::IndexIVF* cpu_ivf_index = nullptr; + if ((access(filename, F_OK)) == -1) { + printf("index file not found."); + exit(-1); + } else { + cpu_index = faiss::read_index(filename); + } + + cpu_ivf_index = dynamic_cast(cpu_index); + if (cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + printf("============================\n"); + cpu_executor(cpu_index, xq); + cpu_executor(cpu_index, xq); + printf("============================\n"); + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + + faiss::IndexComposition index_composition0; + index_composition0.index = cpu_index; + index_composition0.quantizer = nullptr; + index_composition0.mode = 0; // only quantizer + + // Copy quantizer to GPU 0 + auto index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + + faiss::IndexComposition index_composition1; + index_composition1.index = cpu_index; + index_composition1.quantizer = nullptr; + index_composition1.mode = 0; // only quantizer + + // Copy quantizer to GPU 1 + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + + hybrid_executor(cpu_index, &index_composition0, xq); + hybrid_executor(cpu_index, &index_composition1, xq); + + printf("============================\n"); + + index_composition0.mode = 2; // only data + index_composition1.mode = 2; // only data + + index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + + gpu_executor(res, 0, &option0, &index_composition0, xq); + gpu_executor(res, 1, &option1, &index_composition1, xq); + + printf("============================\n"); + + hybrid_executor(cpu_index, &index_composition0, xq); + hybrid_executor(cpu_index, &index_composition1, xq); + + delete[] xq; + gpu_index_ivf_ptr = nullptr; + return 0; +} diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/7-GPU.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/7-GPU.cpp new file mode 100644 index 0000000000..4ab91f27db --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/7-GPU.cpp @@ -0,0 +1,347 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include "faiss/IndexIVF.h" +#include "faiss/IndexFlat.h" +#include "faiss/index_io.h" +#include "faiss/gpu/GpuIndexFlat.h" +#include "faiss/gpu/StandardGpuResources.h" +#include "faiss/gpu/GpuAutoTune.h" +#include "faiss/gpu/GpuClonerOptions.h" +#include "faiss/gpu/GpuCloner.h" +#include "faiss/gpu/GpuIndexIVF.h" +#include "faiss/gpu/GpuIndexIVFSQHybrid.h" + + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" + +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexSQHybrid.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/utils/distances.h" +#include "faiss/clone_index.h" +#include "faiss/index_factory.h" + +using namespace faiss; + +#define PRINT_RESULT 0 + +void print_result(const char* unit, long number, long k, long nq, long *I) { + printf("%s: I (2 first results)=\n", unit); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("%s: I (2 last results)=\n", unit); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + +void +GpuLoad(faiss::gpu::StandardGpuResources* res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + std::shared_ptr& gpu_index_ivf_ptr + ) { + + double t0 = getmillisecs (); + + auto tmp_index = faiss::gpu::index_cpu_to_gpu(res, device_id, index_composition, option); + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); +} + +void +GpuExecutor( + std::shared_ptr& gpu_index_ivf_ptr, + faiss::gpu::StandardGpuResources& res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq) { + double t0 = getmillisecs (); + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid* gpu_index_ivf_hybrid = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for(long i = 0; i < 4; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + gpu_index_ivf_ptr = nullptr; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); +} + + +void +GpuExecutor( + faiss::gpu::StandardGpuResources& res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq) { + auto tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + delete tmp_index; + double t0 = getmillisecs (); + // cpu to gpu + tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + auto gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid* gpu_index_ivf_hybrid = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for(long i = 0; i < 4; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + gpu_index_ivf_ptr = nullptr; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); +} + +void +CpuExecutor( + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq, + faiss::Index *cpu_index) { + printf("CPU: \n"); + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + + faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if(is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition->quantizer; + } + + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("CPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; +} + +int main() { + const char* filename = "index500k-h.index"; + +#if PRINT_RESULT + int number = 8; +#endif + + int d = 512; // dimension + int nq = 1000; // nb of queries + int nprobe = 8; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + faiss::distance_compute_blas_threshold = 800; + + faiss::gpu::StandardGpuResources res; + + int k = 1000; + std::shared_ptr gpu_index_ivf_ptr; + + const char* index_description = "IVF16384,SQ8Hybrid"; +// const char* index_description = "IVF3276,SQ8"; + + faiss::Index *cpu_index = nullptr; + faiss::IndexIVF* cpu_ivf_index = nullptr; + if((access(filename,F_OK))==-1) { + // create database + long nb = 500000; // database size +// printf("-----------------------\n"); + long size = d * nb; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < nb; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_L2); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + } else { + cpu_index = faiss::read_index(filename); + } + + cpu_ivf_index = dynamic_cast(cpu_index); + if(cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + + option0.allInGpu = true; + option1.allInGpu = true; + + faiss::IndexComposition index_composition0; + index_composition0.index = cpu_index; + index_composition0.quantizer = nullptr; + index_composition0.mode = 1; // only quantizer + + // Copy quantizer to GPU 0 + auto index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + + faiss::IndexComposition index_composition1; + index_composition1.index = cpu_index; + index_composition1.quantizer = nullptr; + index_composition1.mode = 1; // only quantizer + + // Copy quantizer to GPU 1 + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + +// std::thread t_cpu1(cpu_executor, &index_composition0); +// t_cpu1.join(); +// std::thread t_cpu2(cpu_executor, &index_composition1); +// t_cpu2.join(); + +// index_composition0.mode = 2; // only data +// index_composition1.mode = 2; // only data +// +// index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); +// delete index1; +// index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); +// delete index1; + +// double tx = getmillisecs(); +// std::thread t1(gpu_executor, 0, &option0, &index_composition0); +// std::thread t2(gpu_executor, 1, &option1, &index_composition1); +// t1.join(); +// t2.join(); +// for(long i = 0; i < 10; ++ i) { +// std::shared_ptr gpu_index_ptr00; +// std::shared_ptr gpu_index_ptr01; +// +// std::thread t00(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr00)); +//// std::thread t2(GpuLoad, &res, 1, &option1, &index_composition1, std::ref(gpu_index_ptr1)); +// std::thread t01(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr01)); +// +// t00.join(); +// +// GpuExecutor(gpu_index_ptr00, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); +// +// t01.join(); +//// t2.join(); +// GpuExecutor(gpu_index_ptr01, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); +//// GpuExecutor(gpu_index_ptr1, res, 1, &option1, &index_composition1, nq, nprobe, k, xq); +// } + +// std::thread t3(gpu_executor, 0, &option0, &index_composition0); +// std::thread t4(gpu_executor, 1, &option1, &index_composition1); +// t3.join(); +// t4.join(); +// double ty = getmillisecs(); +// printf("Total GPU execution time: %0.2f\n", ty - tx); + + CpuExecutor(&index_composition0, nq, nprobe, k, xq, cpu_index); + CpuExecutor(&index_composition1, nq, nprobe, k, xq, cpu_index); + + ///// + delete [] xq; + return 0; +} + diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/8-GPU.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/8-GPU.cpp new file mode 100644 index 0000000000..11f49a09cc --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/8-GPU.cpp @@ -0,0 +1,479 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include "faiss/IndexIVF.h" +#include "faiss/IndexFlat.h" +#include "faiss/index_io.h" +#include "faiss/gpu/GpuIndexFlat.h" +#include "faiss/gpu/StandardGpuResources.h" +#include "faiss/gpu/GpuAutoTune.h" +#include "faiss/gpu/GpuClonerOptions.h" +#include "faiss/gpu/GpuCloner.h" +#include "faiss/gpu/GpuIndexIVF.h" +#include "faiss/gpu/GpuIndexIVFSQHybrid.h" + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" + +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexSQHybrid.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/utils/distances.h" +#include "faiss/clone_index.h" +#include "faiss/index_factory.h" + +using namespace faiss; + +#define PRINT_RESULT 0 + +void print_result(const char* unit, long number, long k, long nq, long *I) { + printf("%s: I (2 first results)=\n", unit); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("%s: I (2 last results)=\n", unit); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + +void +GpuLoad(faiss::gpu::StandardGpuResources* res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + std::shared_ptr& gpu_index_ivf_ptr + ) { + + double t0 = getmillisecs (); + + auto tmp_index = faiss::gpu::index_cpu_to_gpu(res, device_id, index_composition, option); + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); +} + +void +GpuExecutor( + std::shared_ptr& gpu_index_ivf_ptr, + faiss::gpu::StandardGpuResources& res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq) { + double t0 = getmillisecs (); + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid* gpu_index_ivf_hybrid = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for(long i = 0; i < 4; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + gpu_index_ivf_ptr = nullptr; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); +} + + +void +GpuExecutor( + faiss::gpu::StandardGpuResources& res, + int device_id, + faiss::gpu::GpuClonerOptions* option, + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq) { + auto tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + delete tmp_index; + double t0 = getmillisecs (); + // cpu to gpu + tmp_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, index_composition, option); + auto gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + { + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + faiss::gpu::GpuIndexIVFSQHybrid* gpu_index_ivf_hybrid = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf_hybrid->setNumProbes(nprobe); + for(long i = 0; i < 4; ++ i) { + double t2 = getmillisecs(); + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("* GPU: %d, execution time: %0.2f\n", device_id, t3 - t2); + } + + // print results +#if PRINT_RESULT + print_result("GPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; + gpu_index_ivf_ptr = nullptr; + } + double t4 = getmillisecs(); + + printf("GPU:%d total time: %0.2f\n", device_id, t4 - t0); +} + +void +CpuExecutor( + faiss::IndexComposition* index_composition, + int nq, + int nprobe, + int k, + float* xq, + faiss::Index *cpu_index) { + printf("CPU: \n"); + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + + faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if(is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition->quantizer; + } + + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if PRINT_RESULT + print_result("CPU", number, k, nq, I); +#endif + delete [] I; + delete [] D; +} + +void create_index(const char* filename, const char* index_description, long db_size, long d) { + faiss::gpu::StandardGpuResources res; + if((access(filename,F_OK))==-1) { + // create database + long size = d * db_size; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < db_size; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_INNER_PRODUCT); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + std::shared_ptr gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(db_size, xb); + assert(device_index->is_trained); + device_index->add(db_size, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + faiss::Index *cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + } +} + +void execute_index(const char* filename, int d, int nq, int nprobe, int k, float* xq) { + faiss::gpu::StandardGpuResources res; + faiss::Index* cpu_index = faiss::read_index(filename); + faiss::IndexIVF* cpu_ivf_index = dynamic_cast(cpu_index); + + if(cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + + option0.allInGpu = true; + option1.allInGpu = true; + + faiss::IndexComposition index_composition0; + index_composition0.index = cpu_index; + index_composition0.quantizer = nullptr; + index_composition0.mode = 1; // only quantizer + + // Copy quantizer to GPU 0 + auto index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + + faiss::IndexComposition index_composition1; + index_composition1.index = cpu_index; + index_composition1.quantizer = nullptr; + index_composition1.mode = 1; // only quantizer + + // Copy quantizer to GPU 1 + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + + // std::thread t_cpu1(cpu_executor, &index_composition0); + // t_cpu1.join(); + // std::thread t_cpu2(cpu_executor, &index_composition1); + // t_cpu2.join(); + + index_composition0.mode = 2; // only data + index_composition1.mode = 2; // only data + + // index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + // delete index1; + // index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + // delete index1; + + // double tx = getmillisecs(); + // std::thread t1(gpu_executor, 0, &option0, &index_composition0); + // std::thread t2(gpu_executor, 1, &option1, &index_composition1); + // t1.join(); + // t2.join(); + for(long i = 0; i < 1; ++ i) { + std::shared_ptr gpu_index_ptr00; + std::shared_ptr gpu_index_ptr01; + + std::thread t00(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr00)); + // std::thread t2(GpuLoad, &res, 1, &option1, &index_composition1, std::ref(gpu_index_ptr1)); + std::thread t01(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr01)); + + t00.join(); + + GpuExecutor(gpu_index_ptr00, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); + + t01.join(); + // t2.join(); + GpuExecutor(gpu_index_ptr01, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); + // GpuExecutor(gpu_index_ptr1, res, 1, &option1, &index_composition1, nq, nprobe, k, xq); + } + + delete index_composition0.quantizer; + delete index_composition1.quantizer; + delete cpu_index; +} + +int main() { + const char* filename = "index500k-h.index"; + int d = 512; // dimension + int nq = 1000; // nb of queries + int nprobe = 16; + int k = 1000; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + + long db_size = 500000; + const char* index_description = "IVF16384,SQ8Hybrid"; + create_index(filename, index_description, db_size, d); + for(long i = 0; i < 1000; ++ i) { + execute_index(filename, d, nq, nprobe, k, xq); + } + delete[] xq; + xq = nullptr; + return 0; +} + +/* +int main() { + const char* filename = "index500k-h.index"; + +#if PRINT_RESULT + int number = 8; +#endif + + int d = 512; // dimension + int nq = 1000; // nb of queries + int nprobe = 16; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); + } + } + faiss::distance_compute_blas_threshold = 800; + + faiss::gpu::StandardGpuResources res; + + int k = 1000; + std::shared_ptr gpu_index_ivf_ptr; + + const char* index_description = "IVF16384,SQ8Hybrid"; +// const char* index_description = "IVF3276,SQ8"; + + faiss::Index *cpu_index = nullptr; + faiss::IndexIVF* cpu_ivf_index = nullptr; + if((access(filename,F_OK))==-1) { + // create database + long nb = 500000; // database size +// printf("-----------------------\n"); + long size = d * nb; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < nb; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_INNER_PRODUCT); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + } else { + cpu_index = faiss::read_index(filename); + } + + cpu_ivf_index = dynamic_cast(cpu_index); + if(cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + faiss::gpu::GpuClonerOptions option0; + faiss::gpu::GpuClonerOptions option1; + + option0.allInGpu = true; + option1.allInGpu = true; + + faiss::IndexComposition index_composition0; + index_composition0.index = cpu_index; + index_composition0.quantizer = nullptr; + index_composition0.mode = 1; // only quantizer + + // Copy quantizer to GPU 0 + auto index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + + faiss::IndexComposition index_composition1; + index_composition1.index = cpu_index; + index_composition1.quantizer = nullptr; + index_composition1.mode = 1; // only quantizer + + // Copy quantizer to GPU 1 + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + +// std::thread t_cpu1(cpu_executor, &index_composition0); +// t_cpu1.join(); +// std::thread t_cpu2(cpu_executor, &index_composition1); +// t_cpu2.join(); + + index_composition0.mode = 2; // only data + index_composition1.mode = 2; // only data + + index1 = faiss::gpu::index_cpu_to_gpu(&res, 0, &index_composition0, &option0); + delete index1; + index1 = faiss::gpu::index_cpu_to_gpu(&res, 1, &index_composition1, &option1); + delete index1; + +// double tx = getmillisecs(); +// std::thread t1(gpu_executor, 0, &option0, &index_composition0); +// std::thread t2(gpu_executor, 1, &option1, &index_composition1); +// t1.join(); +// t2.join(); + for(long i = 0; i < 10; ++ i) { + std::shared_ptr gpu_index_ptr00; + std::shared_ptr gpu_index_ptr01; + + std::thread t00(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr00)); +// std::thread t2(GpuLoad, &res, 1, &option1, &index_composition1, std::ref(gpu_index_ptr1)); + std::thread t01(GpuLoad, &res, 0, &option0, &index_composition0, std::ref(gpu_index_ptr01)); + + t00.join(); + + GpuExecutor(gpu_index_ptr00, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); + + t01.join(); +// t2.join(); + GpuExecutor(gpu_index_ptr01, res, 0, &option0, &index_composition0, nq, nprobe, k, xq); +// GpuExecutor(gpu_index_ptr1, res, 1, &option1, &index_composition1, nq, nprobe, k, xq); + } + +// std::thread t3(gpu_executor, 0, &option0, &index_composition0); +// std::thread t4(gpu_executor, 1, &option1, &index_composition1); +// t3.join(); +// t4.join(); +// double ty = getmillisecs(); +// printf("Total GPU execution time: %0.2f\n", ty - tx); +// CpuExecutor(&index_composition0, nq, nprobe, k, xq, cpu_index); +// CpuExecutor(&index_composition1, nq, nprobe, k, xq, cpu_index); + + ///// + delete [] xq; + return 0; +} +*/ diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/9-BinaryFlat.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/9-BinaryFlat.cpp new file mode 100644 index 0000000000..547cc6d88d --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/9-BinaryFlat.cpp @@ -0,0 +1,115 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +// #define TEST_HAMMING + +long int getTime(timeval end, timeval start) { + return 1000*(end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec)/1000; +} + +int main() { + // freopen("0.txt", "w", stdout); + + size_t d = 128; // dimension + size_t nb = 40000000; // database size + size_t nq = 10; // nb of queries + + uint8_t *xb = new uint8_t[d * nb / sizeof(uint8_t)]; + uint8_t *xq = new uint8_t[d * nq / sizeof(uint8_t)]; + + // skip 0 + lrand48(); + + size_t size_to_long = d * nb / sizeof(int32_t); + for(size_t i = 0; i < size_to_long; i++) { + ((int32_t*)xb)[i] = lrand48(); + } + + size_to_long = d * nq / sizeof(long int); + for(size_t i = 0; i < size_to_long; i++) { + ((int32_t*)xq)[i] = lrand48(); + } +#ifdef TEST_HAMMING + printf("test haming\n"); + faiss::IndexBinaryFlat index(d, faiss::MetricType::METRIC_Hamming); +#else + faiss::IndexBinaryFlat index(d, faiss::MetricType::METRIC_Jaccard); +#endif + index.add(nb, xb); + printf("ntotal = %ld d = %d\n", index.ntotal, index.d); + + int k = 10; + +#if 0 + { // sanity check: search 5 first vectors of xb + int64_t *I = new int64_t[k * 5]; + int32_t *D = new int32_t[k * 5]; + float *d_float = reinterpret_cast(D); + + index.search(5, xb, k, D, I); + + // print results + for(int i = 0; i < 5; i++) { + for(int j = 0; j < k; j++) +#ifdef TEST_HAMMING + printf("%8ld %d\n", I[i * k + j], D[i * k + j]); +#else + printf("%8ld %.08f\n", I[i * k + j], d_float[i * k + j]); +#endif + printf("\n"); + } + + delete [] I; + delete [] D; + } +#endif + + { // search xq + int64_t *I = new int64_t[k * nq]; + int32_t *D = new int32_t[k * nq]; + float *d_float = reinterpret_cast(D); + + for (int loop = 1; loop <= nq; loop ++) { + timeval t0; + gettimeofday(&t0, 0); + + index.search(loop, xq, k, D, I); + + timeval t1; + gettimeofday(&t1, 0); + printf("search nq %d time %ldms\n", loop, getTime(t1,t0)); +#if 0 + for (int i = 0; i < loop; i++) { + for(int j = 0; j < k; j++) +#ifdef TEST_HAMMING + printf("%8ld %d\n", I[i * k + j], D[i * k + j]); +#else + printf("%8ld %.08f\n", I[j + i * k], d_float[j + i * k]); +#endif + printf("\n"); + } +#endif + } + + delete [] I; + delete [] D; + } + + delete [] xb; + delete [] xq; + + return 0; +} + + diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/Makefile b/core/src/index/thirdparty/faiss/tutorial/cpp/Makefile new file mode 100644 index 0000000000..472975f1d9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/Makefile @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +-include ../../makefile.inc + +CPU_TARGETS = 1-Flat 2-IVFFlat 3-IVFPQ 9-BinaryFlat +GPU_TARGETS = 4-GPU 5-Multiple-GPUs + +default: cpu + +all: cpu gpu + +cpu: $(CPU_TARGETS) + +gpu: $(GPU_TARGETS) + +%: %.cpp ../../libfaiss.a + $(CXX) $(CXXFLAGS) $(CPPFLAGS) -o $@ $^ $(LDFLAGS) -I../../ $(LIBS) + +clean: + rm -f $(CPU_TARGETS) $(GPU_TARGETS) + +.PHONY: all cpu default gpu clean diff --git a/core/src/index/thirdparty/faiss/tutorial/cpp/faiss_test.cpp b/core/src/index/thirdparty/faiss/tutorial/cpp/faiss_test.cpp new file mode 100644 index 0000000000..6377f133c2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/cpp/faiss_test.cpp @@ -0,0 +1,378 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + + +#include "faiss/FaissAssert.h" +#include "faiss/AuxIndexStructures.h" + +#include "faiss/IndexFlat.h" +#include "faiss/VectorTransform.h" +#include "faiss/IndexLSH.h" +#include "faiss/IndexPQ.h" +#include "faiss/IndexIVF.h" +#include "faiss/IndexIVFPQ.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexIVFSpectralHash.h" +#include "faiss/MetaIndexes.h" +#include "faiss/IndexScalarQuantizer.h" +#include "faiss/IndexHNSW.h" +#include "faiss/OnDiskInvertedLists.h" +#include "faiss/IndexBinaryFlat.h" +#include "faiss/IndexBinaryFromFloat.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/gpu/GpuIndexIVFSQ.h" +#include "faiss/utils.h" + + +using namespace faiss; + +void +generate_file(const char *filename, + long nb, + long dimension, + std::string index_desc, + faiss::gpu::StandardGpuResources &res) { + long size = dimension * nb; + float *xb = new float[size]; + printf("size: %lf(GB)\n", (size * sizeof(float)) / (3 * 1024.0 * 1024 * 1024)); + for (long i = 0; i < nb; i++) { + for (long j = 0; j < dimension; j++) { + float rand = drand48(); + xb[dimension * i + j] = rand; + } + } + + faiss::Index *ori_index = faiss::index_factory(dimension, index_desc.c_str(), faiss::METRIC_L2); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); + + faiss::Index *cpu_index = faiss::gpu::index_gpu_to_cpu((device_index)); + faiss::write_index(cpu_index, filename); + printf("index: %s is stored successfully.\n", filename); + delete[] xb; + + return; +} + +faiss::Index * +get_index(const char *filename) { + return faiss::read_index(filename); +} + +void +execute_on_gpu(faiss::Index *index, float *xq, long nq, long k, long nprobe, + faiss::gpu::StandardGpuResources &res, long* I, float* D) { + + double t0 = getmillisecs(); + + faiss::gpu::CpuToGpuClonerOptions option; + option.readonly = true; + faiss::Index *tmp_index = faiss::gpu::cpu_to_gpu(&res, 0, index, &option); + std::shared_ptr gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + double t1 = getmillisecs(); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + + + double t2 = getmillisecs(); + faiss::gpu::GpuIndexIVF *gpu_index_ivf = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf->setNumProbes(nprobe); + + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("GPU execution time: %0.2f\n", t3 - t2); +} + +void execute_on_cpu(faiss::Index *index, float* xq, long nq, long k, long nprobe, long* I, float* D) { + faiss::IndexIVF* ivf_index = + dynamic_cast(index); + ivf_index->nprobe = nprobe; + index->search(nq, xq, k, D, I); +} + +float *construct_queries(long nq, long dimension) { + float *xq = new float[dimension * nq]; + for (int i = 0; i < nq; i++) { + for (int j = 0; j < dimension; j++) { + xq[dimension * i + j] = drand48(); + } + } + return xq; +} + +void print_result(long number, long nq, long k, long *I, float *D) { + printf("I (%ld first results)=\n", number); + for (int i = 0; i < number; i++) { + for (int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (%ld last results)=\n", number); + for (int i = nq - number; i < nq; i++) { + for (int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +} + +void faiss_setting() { + faiss::distance_compute_blas_threshold = 800; +} + +int main() { + const char *filename = "index5.index"; + +#if 0 + long dimension = 512; + long nb = 6000000; + long nq = 1000; + long topk = 16; + long print_number = 8; + long nprobe = 32; + + std::string index_desc = "IVF16384,SQ8"; + faiss::gpu::StandardGpuResources res; + if ((access(filename, F_OK)) == -1) { + printf("file doesn't exist, create one\n"); + generate_file(filename, nb, dimension, index_desc, res); + } + + // Construct queries + float *xq = construct_queries(nq, dimension); + + // Read index + faiss::Index *index = get_index(filename); + + // Execute on GPU + long *I = new long[topk * nq]; + float *D = new float[topk * nq]; + execute_on_gpu(index, xq, nq, topk, nprobe, res, I, D); + + // Print results + print_result(print_number, nq, topk, I, D); + delete[] I; I = nullptr; + delete[] D; D = nullptr; + + // Execute on CPU + I = new long[topk * nq]; + D = new float[topk * nq]; + execute_on_cpu(index, xq, nq, topk, nprobe, I, D); + + // Print results + print_result(print_number, nq, topk, I, D); + delete[] I; + delete[] D; + + return 0; +#else + int number = 8; + int d = 512; // dimension + int nq = 1000; // nb of queries + int nprobe = 16; + float *xq = new float[d * nq]; + for(int i = 0; i < nq; i++) { + for(int j = 0; j < d; j++) { + xq[d * i + j] = drand48(); +// printf("%lf ", xq[d * i + j]); + } +// xq[d * i] += i / 1000.; +// printf("\n"); + } + faiss::distance_compute_blas_threshold = 800; + + faiss::gpu::StandardGpuResources res; + + int k = 16; + std::shared_ptr gpu_index_ivf_ptr; + + const char* index_description = "IVF16384,SQ8"; + // const char* index_description = "IVF3276,Flat"; +// Index *index_factory (int d, const char *description, +// MetricType metric = METRIC_L2); + + faiss::Index *cpu_index = nullptr; + if((access(filename,F_OK))==-1) { + long nb = 6000000; + long dimension = d; + printf("file doesn't exist, create one\n"); + generate_file(filename, nb, dimension, index_description, res); + /* + // create database + // database size +// printf("-----------------------\n"); + long size = d * nb; + float *xb = new float[size]; + memset(xb, 0, size * sizeof(float)); + printf("size: %ld\n", (size * sizeof(float)) ); + for(long i = 0; i < nb; i++) { + for(long j = 0; j < d; j++) { + float rand = drand48(); + xb[d * i + j] = rand; +// printf("%lf ", xb[d * i + j]); + } +// xb[d * i] += i / 1000.; +// printf("\n"); + } + + // Using an IVF index + // here we specify METRIC_L2, by default it performs inner-product search + + faiss::Index *ori_index = faiss::index_factory(d, index_description, faiss::METRIC_L2); + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, 0, ori_index); + + gpu_index_ivf_ptr = std::shared_ptr(device_index); + + assert(!device_index->is_trained); + device_index->train(nb, xb); + assert(device_index->is_trained); + device_index->add(nb, xb); // add vectors to the index + + printf("is_trained = %s\n", device_index->is_trained ? "true" : "false"); + printf("ntotal = %ld\n", device_index->ntotal); + + cpu_index = faiss::gpu::index_gpu_to_cpu ((device_index)); + faiss::write_index(cpu_index, filename); + printf("index.index is stored successfully.\n"); + delete [] xb; + */ + } else { + cpu_index = get_index(filename); + } + + { + // cpu to gpu + double t0 = getmillisecs (); + faiss::gpu::CpuToGpuClonerOptions option; + option.readonly = true; + faiss::Index* tmp_index = faiss::gpu::cpu_to_gpu(&res, 0, cpu_index, &option); + + gpu_index_ivf_ptr = std::shared_ptr(tmp_index); + + // Gpu index dump + + auto gpu_index_ivf_sq_ptr = dynamic_cast(tmp_index); +// gpu_index_ivf_sq_ptr->dump(); + double t1 = getmillisecs (); + printf("CPU to GPU loading time: %0.2f\n", t1 - t0); + // // Cpu index dump + // auto cpu_index_ivf_sq_ptr = dynamic_cast(cpu_index); + // cpu_index_ivf_sq_ptr->dump(); + } + + + { // search xq + long *I = new long[k * nq]; + float *D = new float[k * nq]; + double t2 = getmillisecs(); + faiss::gpu::GpuIndexIVF* gpu_index_ivf = + dynamic_cast(gpu_index_ivf_ptr.get()); + gpu_index_ivf->setNumProbes(nprobe); + + gpu_index_ivf_ptr->search(nq, xq, k, D, I); + double t3 = getmillisecs(); + printf("GPU execution time: %0.2f\n", t3 - t2); + + // print results + printf("GPU: \n"); +#if 0 + printf("GPU: I (2 first results)=\n"); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("GPU: %5ld(%f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + + printf("GPU: I (2 last results)=\n"); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("GPU: %5ld(%f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } +#else + printf("I (2 first results)=\n"); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (2 last results)=\n"); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +#endif + delete [] I; + delete [] D; + } + printf("----------------------------------\n"); + { // search xq + printf("CPU: \n"); + long *I = new long[k * nq]; + float *D = new float[k * nq]; + + double t4 = getmillisecs(); + faiss::IndexIVF* ivf_index = + dynamic_cast(cpu_index); + ivf_index->nprobe = nprobe; + cpu_index->search(nq, xq, k, D, I); + double t5 = getmillisecs(); + printf("CPU execution time: %0.2f\n", t5 - t4); +#if 0 + // print results + printf("CPU: I (2 first results)=\n"); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("CPU: %5ld(%f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } + + printf("CPU: I (2 last results)=\n"); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("CPU: %5ld(%f) ", I[i * k + j], D[i * k + j]); + printf("\n"); + } +#else + // print results + printf("I (2 first results)=\n"); + for(int i = 0; i < number; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("I (2 last results)=\n"); + for(int i = nq - number; i < nq; i++) { + for(int j = 0; j < k; j++) + printf("%5ld ", I[i * k + j]); + printf("\n"); + } +#endif + delete [] I; + delete [] D; + } + + + delete [] xq; + return 0; +#endif +} \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/tutorial/python/1-Flat.py b/core/src/index/thirdparty/faiss/tutorial/python/1-Flat.py new file mode 100644 index 0000000000..584c7bc703 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/python/1-Flat.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +d = 64 # dimension +nb = 100000 # database size +nq = 10000 # nb of queries +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xb[:, 0] += np.arange(nb) / 1000. +xq = np.random.random((nq, d)).astype('float32') +xq[:, 0] += np.arange(nq) / 1000. + +import faiss # make faiss available +index = faiss.IndexFlatL2(d) # build the index +print(index.is_trained) +index.add(xb) # add vectors to the index +print(index.ntotal) + +k = 4 # we want to see 4 nearest neighbors +D, I = index.search(xb[:5], k) # sanity check +print(I) +print(D) +D, I = index.search(xq, k) # actual search +print(I[:5]) # neighbors of the 5 first queries +print(I[-5:]) # neighbors of the 5 last queries diff --git a/core/src/index/thirdparty/faiss/tutorial/python/2-IVFFlat.py b/core/src/index/thirdparty/faiss/tutorial/python/2-IVFFlat.py new file mode 100644 index 0000000000..a4ac0c4d1f --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/python/2-IVFFlat.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +d = 64 # dimension +nb = 100000 # database size +nq = 10000 # nb of queries +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xb[:, 0] += np.arange(nb) / 1000. +xq = np.random.random((nq, d)).astype('float32') +xq[:, 0] += np.arange(nq) / 1000. + +import faiss + +nlist = 100 +k = 4 +quantizer = faiss.IndexFlatL2(d) # the other index +index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2) +# here we specify METRIC_L2, by default it performs inner-product search + +assert not index.is_trained +index.train(xb) +assert index.is_trained + +index.add(xb) # add may be a bit slower as well +D, I = index.search(xq, k) # actual search +print(I[-5:]) # neighbors of the 5 last queries +index.nprobe = 10 # default nprobe is 1, try a few more +D, I = index.search(xq, k) +print(I[-5:]) # neighbors of the 5 last queries diff --git a/core/src/index/thirdparty/faiss/tutorial/python/3-IVFPQ.py b/core/src/index/thirdparty/faiss/tutorial/python/3-IVFPQ.py new file mode 100644 index 0000000000..e502239ca4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/python/3-IVFPQ.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +d = 64 # dimension +nb = 100000 # database size +nq = 10000 # nb of queries +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xb[:, 0] += np.arange(nb) / 1000. +xq = np.random.random((nq, d)).astype('float32') +xq[:, 0] += np.arange(nq) / 1000. + +import faiss + +nlist = 100 +m = 8 +k = 4 +quantizer = faiss.IndexFlatL2(d) # this remains the same +index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8) + # 8 specifies that each sub-vector is encoded as 8 bits +index.train(xb) +index.add(xb) +D, I = index.search(xb[:5], k) # sanity check +print(I) +print(D) +index.nprobe = 10 # make comparable with experiment above +D, I = index.search(xq, k) # search +print(I[-5:]) diff --git a/core/src/index/thirdparty/faiss/tutorial/python/4-GPU.py b/core/src/index/thirdparty/faiss/tutorial/python/4-GPU.py new file mode 100644 index 0000000000..6f5e37e535 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/python/4-GPU.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +d = 64 # dimension +nb = 100000 # database size +nq = 10000 # nb of queries +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xb[:, 0] += np.arange(nb) / 1000. +xq = np.random.random((nq, d)).astype('float32') +xq[:, 0] += np.arange(nq) / 1000. + +import faiss # make faiss available + +res = faiss.StandardGpuResources() # use a single GPU + +## Using a flat index + +index_flat = faiss.IndexFlatL2(d) # build a flat (CPU) index + +# make it a flat GPU index +gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat) + +gpu_index_flat.add(xb) # add vectors to the index +print(gpu_index_flat.ntotal) + +k = 4 # we want to see 4 nearest neighbors +D, I = gpu_index_flat.search(xq, k) # actual search +print(I[:5]) # neighbors of the 5 first queries +print(I[-5:]) # neighbors of the 5 last queries + + +## Using an IVF index + +nlist = 100 +quantizer = faiss.IndexFlatL2(d) # the other index +index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2) +# here we specify METRIC_L2, by default it performs inner-product search + +# make it an IVF GPU index +gpu_index_ivf = faiss.index_cpu_to_gpu(res, 0, index_ivf) + +assert not gpu_index_ivf.is_trained +gpu_index_ivf.train(xb) # add vectors to the index +assert gpu_index_ivf.is_trained + +gpu_index_ivf.add(xb) # add vectors to the index +print(gpu_index_ivf.ntotal) + +k = 4 # we want to see 4 nearest neighbors +D, I = gpu_index_ivf.search(xq, k) # actual search +print(I[:5]) # neighbors of the 5 first queries +print(I[-5:]) # neighbors of the 5 last queries diff --git a/core/src/index/thirdparty/faiss/tutorial/python/5-Multiple-GPUs.py b/core/src/index/thirdparty/faiss/tutorial/python/5-Multiple-GPUs.py new file mode 100644 index 0000000000..c458587ce9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/tutorial/python/5-Multiple-GPUs.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +d = 64 # dimension +nb = 100000 # database size +nq = 10000 # nb of queries +np.random.seed(1234) # make reproducible +xb = np.random.random((nb, d)).astype('float32') +xb[:, 0] += np.arange(nb) / 1000. +xq = np.random.random((nq, d)).astype('float32') +xq[:, 0] += np.arange(nq) / 1000. + +import faiss # make faiss available + +ngpus = faiss.get_num_gpus() + +print("number of GPUs:", ngpus) + +cpu_index = faiss.IndexFlatL2(d) + +gpu_index = faiss.index_cpu_to_all_gpus( # build the index + cpu_index +) + +gpu_index.add(xb) # add vectors to the index +print(gpu_index.ntotal) + +k = 4 # we want to see 4 nearest neighbors +D, I = gpu_index.search(xq, k) # actual search +print(I[:5]) # neighbors of the 5 first queries +print(I[-5:]) # neighbors of the 5 last queries diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp new file mode 100644 index 0000000000..d6ebaa44f0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -0,0 +1,337 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace faiss { + +static const size_t size_1M = 1 * 1024 * 1024; +static const size_t batch_size = 65536; + +template +static +void binary_distence_knn_hc( + int bytes_per_code, + float_maxheap_array_t * ha, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = ha->k; + + if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) { + int thread_max_num = omp_get_max_threads(); + // init heap + size_t thread_heap_size = ha->nh * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + for (int i = 0; i < all_heap_size; i++) { + value[i] = 1.0 / 0.0; + labels[i] = -1; + } + + T *hc = new T[ha->nh]; + for (size_t i = 0; i < ha->nh; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } + +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); + + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < ha->nh; i++) { + tadis_t dis = hc[i].compute (bs2_); + + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; + if (dis < val_[0]) { + faiss::maxheap_swap_top (k, val_, ids_, dis, j); + } + } + } + } + + for (size_t t = 1; t < thread_max_num; t++) { + // merge heap + for (size_t i = 0; i < ha->nh; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + faiss::maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + // copy result + memcpy(ha->val, value, thread_heap_size * sizeof(float)); + memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t)); + + delete[] hc; + delete[] value; + delete[] labels; + + } else { + if (init_heap) ha->heapify (); + + const size_t block_size = batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); +#pragma omp parallel for + for (size_t i = 0; i < ha->nh; i++) { + T hc (bs1 + i * bytes_per_code, bytes_per_code); + + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + tadis_t dis; + tadis_t * __restrict bh_val_ = ha->val + i * k; + int64_t * __restrict bh_ids_ = ha->ids + i * k; + size_t j; + for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { + if(!bitset || !bitset->test(j)){ + dis = hc.compute (bs2_); + if (dis < bh_val_[0]) { + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); + } + } + } + + } + } + } + + if (order) ha->reorder (); +} + +void binary_distence_knn_hc ( + MetricType metric_type, + float_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int order, + ConcurrentBitsetPtr bitset) +{ + switch (metric_type) { + case METRIC_Jaccard: + case METRIC_Tanimoto: + switch (ncodes) { +#define binary_distence_knn_hc_jaccard(ncodes) \ + case ncodes: \ + binary_distence_knn_hc \ + (ncodes, ha, a, b, nb, order, true, bitset); \ + break; + binary_distence_knn_hc_jaccard(8); + binary_distence_knn_hc_jaccard(16); + binary_distence_knn_hc_jaccard(32); + binary_distence_knn_hc_jaccard(64); + binary_distence_knn_hc_jaccard(128); + binary_distence_knn_hc_jaccard(256); + binary_distence_knn_hc_jaccard(512); +#undef binary_distence_knn_hc_jaccard + default: + binary_distence_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); + break; + } + break; + + default: + break; + } +} + +template +static +void binary_distence_knn_mc( + int bytes_per_code, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + size_t k, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset) +{ + if ((bytes_per_code + sizeof(size_t) + k * sizeof(int64_t)) * n1 < size_1M) { + int thread_max_num = omp_get_max_threads(); + + size_t group_num = n1 * thread_max_num; + size_t *match_num = new size_t[group_num]; + int64_t *match_data = new int64_t[group_num * k]; + for (size_t i = 0; i < group_num; i++) { + match_num[i] = 0; + } + + T *hc = new T[n1]; + for (size_t i = 0; i < n1; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } + +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); + + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < n1; i++) { + if (hc[i].compute(bs2_)) { + size_t match_index = thread_no * n1 + i; + size_t &index = match_num[match_index]; + if (index < k) { + match_data[match_index * k + index] = j; + index++; + } + } + } + } + } + for (size_t i = 0; i < n1; i++) { + size_t n_i = 0; + float *distances_i = distances + i * k; + int64_t *labels_i = labels + i * k; + + for (size_t t = 0; t < thread_max_num && n_i < k; t++) { + size_t match_index = t * n1 + i; + size_t copy_num = std::min(k - n_i, match_num[match_index]); + memcpy(labels_i + n_i, match_data + match_index * k, copy_num * sizeof(int64_t)); + memset(distances_i + n_i, 0, copy_num * sizeof(float)); + n_i += copy_num; + } + for (; n_i < k; n_i++) { + distances_i[n_i] = 1.0 / 0.0; + labels_i[n_i] = -1; + } + } + + delete[] hc; + delete[] match_num; + delete[] match_data; + + } else { + size_t *num = new size_t[n1]; + for (size_t i = 0; i < n1; i++) { + num[i] = 0; + } + + const size_t block_size = batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); +#pragma omp parallel for + for (size_t i = 0; i < n1; i++) { + size_t num_i = num[i]; + if (num_i == k) continue; + float * dis = distances + i * k; + int64_t * lab = labels + i * k; + + T hc (bs1 + i * bytes_per_code, bytes_per_code); + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) { + if(!bitset || !bitset->test(j)){ + if (hc.compute (bs2_)) { + dis[num_i] = 0; + lab[num_i] = j; + if (++num_i == k) break; + } + } + } + num[i] = num_i; + } + } + + for (size_t i = 0; i < n1; i++) { + float * dis = distances + i * k; + int64_t * lab = labels + i * k; + for (size_t num_i = num[i]; num_i < k; num_i++) { + dis[num_i] = 1.0 / 0.0; + lab[num_i] = -1; + } + } + + delete[] num; + } +} + +void binary_distence_knn_mc ( + MetricType metric_type, + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset) { + + switch (metric_type) { + case METRIC_Substructure: + switch (ncodes) { +#define binary_distence_knn_mc_Substructure(ncodes) \ + case ncodes: \ + binary_distence_knn_mc \ + (ncodes, a, b, na, nb, k, distances, labels, bitset); \ + break; + binary_distence_knn_mc_Substructure(8); + binary_distence_knn_mc_Substructure(16); + binary_distence_knn_mc_Substructure(32); + binary_distence_knn_mc_Substructure(64); + binary_distence_knn_mc_Substructure(128); + binary_distence_knn_mc_Substructure(256); + binary_distence_knn_mc_Substructure(512); +#undef binary_distence_knn_mc_Substructure + default: + binary_distence_knn_mc + (ncodes, a, b, na, nb, k, distances, labels, bitset); + break; + } + break; + + case METRIC_Superstructure: + switch (ncodes) { +#define binary_distence_knn_mc_Superstructure(ncodes) \ + case ncodes: \ + binary_distence_knn_mc \ + (ncodes, a, b, na, nb, k, distances, labels, bitset); \ + break; + binary_distence_knn_mc_Superstructure(8); + binary_distence_knn_mc_Superstructure(16); + binary_distence_knn_mc_Superstructure(32); + binary_distence_knn_mc_Superstructure(64); + binary_distence_knn_mc_Superstructure(128); + binary_distence_knn_mc_Superstructure(256); + binary_distence_knn_mc_Superstructure(512); +#undef binary_distence_knn_mc_Superstructure + default: + binary_distence_knn_mc + (ncodes, a, b, na, nb, k, distances, labels, bitset); + break; + } + break; + + default: + break; + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.h b/core/src/index/thirdparty/faiss/utils/BinaryDistance.h new file mode 100644 index 0000000000..fccdfd3674 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.h @@ -0,0 +1,62 @@ +#ifndef FAISS_BINARY_DISTANCE_H +#define FAISS_BINARY_DISTANCE_H + +#include "faiss/Index.h" + +#include + +#include + +#include + +/* The binary distance type */ +typedef float tadis_t; + +namespace faiss { + +/** Return the k smallest distances for a set of binary query vectors, + * using a max heap. + * @param a queries, size ha->nh * ncodes + * @param b database, size nb * ncodes + * @param nb number of database vectors + * @param ncodes size of the binary codes (bytes) + * @param ordered if != 0: order the results by decreasing distance + * (may be bottleneck for k/n > 0.01) */ + void binary_distence_knn_hc ( + MetricType metric_type, + float_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int ordered, + ConcurrentBitsetPtr bitset = nullptr); + + /** Return the k matched distances for a set of binary query vectors, + * using a max heap. + * @param a queries, size ha->nh * ncodes + * @param b database, size nb * ncodes + * @param na number of queries vectors + * @param nb number of database vectors + * @param k number of the matched vectors to return + * @param ncodes size of the binary codes (bytes) + */ + void binary_distence_knn_mc ( + MetricType metric_type, + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + float *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset); + +} // namespace faiss + +#include +#include +#include + +#endif // FAISS_BINARY_DISTANCE_H diff --git a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp new file mode 100644 index 0000000000..2bdd404ee9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "ConcurrentBitset.h" + +namespace faiss { + +ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : capacity_(capacity), bitset_(((capacity + 8 - 1) >> 3)) { + if (init_value) { + memset(mutable_data(), init_value, (capacity + 8 - 1) >> 3); + } +} + +std::vector>& +ConcurrentBitset::bitset() { + return bitset_; +} + +ConcurrentBitset& +ConcurrentBitset::operator&=(ConcurrentBitset& bitset) { + // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { + // bitset_[i].fetch_and(bitset.bitset()[i].load()); + // } + + auto u8_1 = const_cast(data()); + auto u8_2 = const_cast(bitset.data()); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] &= u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] &= u8_2[i]; + } + + return *this; +} + +std::shared_ptr +ConcurrentBitset::operator&(const std::shared_ptr& bitset) { + auto result_bitset = std::make_shared(bitset->capacity()); + + auto result_8 = const_cast(result_bitset->data()); + auto result_64 = reinterpret_cast(result_8); + + auto u8_1 = const_cast(data()); + auto u8_2 = const_cast(bitset->data()); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + result_64[i] = u64_1[i] & u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + result_8 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + result_8[i] = u8_1[i] & u8_2[i]; + } + + + return result_bitset; +} + +ConcurrentBitset& +ConcurrentBitset::operator|=(ConcurrentBitset& bitset) { + // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { + // bitset_[i].fetch_or(bitset.bitset()[i].load()); + // } + + auto u8_1 = const_cast(data()); + auto u8_2 = const_cast(bitset.data()); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] |= u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] |= u8_2[i]; + } + + return *this; +} + +std::shared_ptr +ConcurrentBitset::operator|(const std::shared_ptr& bitset) { + auto result_bitset = std::make_shared(bitset->capacity()); + + auto result_8 = const_cast(result_bitset->data()); + auto result_64 = reinterpret_cast(result_8); + + auto u8_1 = const_cast(data()); + auto u8_2 = const_cast(bitset->data()); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + result_64[i] = u64_1[i] | u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + result_8 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + result_8[i] = u8_1[i] | u8_2[i]; + } + + return result_bitset; +} + +ConcurrentBitset& +ConcurrentBitset::operator^=(ConcurrentBitset& bitset) { + // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { + // bitset_[i].fetch_xor(bitset.bitset()[i].load()); + // } + + auto u8_1 = const_cast(data()); + auto u8_2 = const_cast(bitset.data()); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] &= u64_2[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + u8_2 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] ^= u8_2[i]; + } + + return *this; +} + +bool +ConcurrentBitset::test(id_type_t id) { + return bitset_[id >> 3].load() & (0x1 << (id & 0x7)); +} + +void +ConcurrentBitset::set(id_type_t id) { + bitset_[id >> 3].fetch_or(0x1 << (id & 0x7)); +} + +void +ConcurrentBitset::clear(id_type_t id) { + bitset_[id >> 3].fetch_and(~(0x1 << (id & 0x7))); +} + +size_t +ConcurrentBitset::capacity() { + return capacity_; +} + +size_t +ConcurrentBitset::u8size() { + return ((capacity_ + 8 - 1) >> 3); +} + +const uint8_t* +ConcurrentBitset::data() { + return reinterpret_cast(bitset_.data()); +} + +uint8_t* +ConcurrentBitset::mutable_data() { + return reinterpret_cast(bitset_.data()); +} +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h new file mode 100644 index 0000000000..5959aa34cf --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace faiss { + +class ConcurrentBitset { + public: + using id_type_t = int64_t; + + explicit ConcurrentBitset(id_type_t size, uint8_t init_value = 0); + + // ConcurrentBitset(const ConcurrentBitset&) = delete; + // ConcurrentBitset& + // operator=(const ConcurrentBitset&) = delete; + + std::vector>& + bitset(); + + ConcurrentBitset& + operator&=(ConcurrentBitset& bitset); + + std::shared_ptr + operator&(const std::shared_ptr& bitset); + + ConcurrentBitset& + operator|=(ConcurrentBitset& bitset); + + std::shared_ptr + operator|(const std::shared_ptr& bitset); + + ConcurrentBitset& + operator^=(ConcurrentBitset& bitset); + + bool + test(id_type_t id); + + void + set(id_type_t id); + + void + clear(id_type_t id); + + size_t + capacity(); + + const uint8_t* + data(); + + uint8_t* + mutable_data(); + + size_t + u8size(); + + private: + size_t capacity_; + std::vector> bitset_; +}; + +using ConcurrentBitsetPtr = std::shared_ptr; + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/Heap.cpp b/core/src/index/thirdparty/faiss/utils/Heap.cpp new file mode 100644 index 0000000000..0b7cfab547 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/Heap.cpp @@ -0,0 +1,120 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* Function for soft heap */ + +#include + + +namespace faiss { + + +template +void HeapArray::heapify () +{ +#pragma omp parallel for + for (size_t j = 0; j < nh; j++) + heap_heapify (k, val + j * k, ids + j * k); +} + +template +void HeapArray::reorder () +{ +#pragma omp parallel for + for (size_t j = 0; j < nh; j++) + heap_reorder (k, val + j * k, ids + j * k); +} + +template +void HeapArray::addn (size_t nj, const T *vin, TI j0, + size_t i0, int64_t ni) +{ + if (ni == -1) ni = nh; + assert (i0 >= 0 && i0 + ni <= nh); +#pragma omp parallel for + for (size_t i = i0; i < i0 + ni; i++) { + T * __restrict simi = get_val(i); + TI * __restrict idxi = get_ids (i); + const T *ip_line = vin + (i - i0) * nj; + + for (size_t j = 0; j < nj; j++) { + T ip = ip_line [j]; + if (C::cmp(simi[0], ip)) { + heap_swap_top (k, simi, idxi, ip, j + j0); + } + } + } +} + +template +void HeapArray::addn_with_ids ( + size_t nj, const T *vin, const TI *id_in, + int64_t id_stride, size_t i0, int64_t ni) +{ + if (id_in == nullptr) { + addn (nj, vin, 0, i0, ni); + return; + } + if (ni == -1) ni = nh; + assert (i0 >= 0 && i0 + ni <= nh); +#pragma omp parallel for + for (size_t i = i0; i < i0 + ni; i++) { + T * __restrict simi = get_val(i); + TI * __restrict idxi = get_ids (i); + const T *ip_line = vin + (i - i0) * nj; + const TI *id_line = id_in + (i - i0) * id_stride; + + for (size_t j = 0; j < nj; j++) { + T ip = ip_line [j]; + if (C::cmp(simi[0], ip)) { + heap_swap_top (k, simi, idxi, ip, id_line [j]); + } + } + } +} + +template +void HeapArray::per_line_extrema ( + T * out_val, + TI * out_ids) const +{ +#pragma omp parallel for + for (size_t j = 0; j < nh; j++) { + int64_t imin = -1; + typename C::T xval = C::Crev::neutral (); + const typename C::T * x_ = val + j * k; + for (size_t i = 0; i < k; i++) + if (C::cmp (x_[i], xval)) { + xval = x_[i]; + imin = i; + } + if (out_val) + out_val[j] = xval; + + if (out_ids) { + if (ids && imin != -1) + out_ids[j] = ids [j * k + imin]; + else + out_ids[j] = imin; + } + } +} + + + + +// explicit instanciations + +template struct HeapArray >; +template struct HeapArray >; +template struct HeapArray >; +template struct HeapArray >; + + +} // END namespace fasis diff --git a/core/src/index/thirdparty/faiss/utils/Heap.h b/core/src/index/thirdparty/faiss/utils/Heap.h new file mode 100644 index 0000000000..9962cbc112 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/Heap.h @@ -0,0 +1,543 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * C++ support for heaps. The set of functions is tailored for + * efficient similarity search. + * + * There is no specific object for a heap, and the functions that + * operate on a signle heap are inlined, because heaps are often + * small. More complex functions are implemented in Heaps.cpp + * + */ + + +#ifndef FAISS_Heap_h +#define FAISS_Heap_h + +#include +#include +#include + +#include +#include +#include + +#include + + +namespace faiss { + +/******************************************************************* + * C object: uniform handling of min and max heap + *******************************************************************/ + +/** The C object gives the type T of the values in the heap, the type + * of the keys, TI and the comparison that is done: > for the minheap + * and < for the maxheap. The neutral value will always be dropped in + * favor of any other value in the heap. + */ + +template +struct CMax; + +// traits of minheaps = heaps where the minimum value is stored on top +// useful to find the *max* values of an array +template +struct CMin { + typedef T_ T; + typedef TI_ TI; + typedef CMax Crev; + inline static bool cmp (T a, T b) { + return a < b; + } + // value that will be popped first -> must be smaller than all others + // for int types this is not strictly the smallest val (-max - 1) + inline static T neutral () { + return -std::numeric_limits::max(); + } +}; + + +template +struct CMax { + typedef T_ T; + typedef TI_ TI; + typedef CMin Crev; + inline static bool cmp (T a, T b) { + return a > b; + } + inline static T neutral () { + return std::numeric_limits::max(); + } +}; + + +/******************************************************************* + * Basic heap ops: push and pop + *******************************************************************/ + +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template inline +void heap_swap_top (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + typename C::T val, typename C::TI ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { + if (C::cmp(val, bh_val[i1])) + break; + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + i = i1; + } + else { + if (C::cmp(val, bh_val[i2])) + break; + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + i = i2; + } + } + bh_val[i] = val; + bh_ids[i] = ids; +} + + +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template inline +void heap_pop (size_t k, typename C::T * bh_val, typename C::TI * bh_ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + typename C::T val = bh_val[k]; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { + if (C::cmp(val, bh_val[i1])) + break; + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + i = i1; + } + else { + if (C::cmp(val, bh_val[i2])) + break; + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + i = i2; + } + } + bh_val[i] = bh_val[k]; + bh_ids[i] = bh_ids[k]; +} + + + +/** Pushes the element (val, ids) into the heap bh_val[0..k-2] and + * bh_ids[0..k-2]. on output the element at k-1 is defined. + */ +template inline +void heap_push (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + typename C::T val, typename C::TI ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = k, i_father; + while (i > 1) { + i_father = i >> 1; + if (!C::cmp (val, bh_val[i_father])) /* the heap structure is ok */ + break; + bh_val[i] = bh_val[i_father]; + bh_ids[i] = bh_ids[i_father]; + i = i_father; + } + bh_val[i] = val; + bh_ids[i] = ids; +} + + + +/* Partial instanciation for heaps with TI = int64_t */ + +template inline +void minheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_swap_top > (k, bh_val, bh_ids, val, ids); +} + + +template inline +void minheap_pop (size_t k, T * bh_val, int64_t * bh_ids) +{ + heap_pop > (k, bh_val, bh_ids); +} + + +template inline +void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_push > (k, bh_val, bh_ids, val, ids); +} + + +template inline +void maxheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_swap_top > (k, bh_val, bh_ids, val, ids); +} + + +template inline +void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids) +{ + heap_pop > (k, bh_val, bh_ids); +} + + +template inline +void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_push > (k, bh_val, bh_ids, val, ids); +} + + + +/******************************************************************* + * Heap initialization + *******************************************************************/ + +/* Initialization phase for the heap (with unconditionnal pushes). + * Store k0 elements in a heap containing up to k values. Note that + * (bh_val, bh_ids) can be the same as (x, ids) */ +template inline +void heap_heapify ( + size_t k, + typename C::T * bh_val, + typename C::TI * bh_ids, + const typename C::T * x = nullptr, + const typename C::TI * ids = nullptr, + size_t k0 = 0) +{ + if (k0 > 0) assert (x); + + if (ids) { + for (size_t i = 0; i < k0; i++) + heap_push (i+1, bh_val, bh_ids, x[i], ids[i]); + } else { + for (size_t i = 0; i < k0; i++) + heap_push (i+1, bh_val, bh_ids, x[i], i); + } + + for (size_t i = k0; i < k; i++) { + bh_val[i] = C::neutral(); + bh_ids[i] = -1; + } + +} + +template inline +void minheap_heapify ( + size_t k, T * bh_val, + int64_t * bh_ids, + const T * x = nullptr, + const int64_t * ids = nullptr, + size_t k0 = 0) +{ + heap_heapify< CMin > (k, bh_val, bh_ids, x, ids, k0); +} + + +template inline +void maxheap_heapify ( + size_t k, + T * bh_val, + int64_t * bh_ids, + const T * x = nullptr, + const int64_t * ids = nullptr, + size_t k0 = 0) +{ + heap_heapify< CMax > (k, bh_val, bh_ids, x, ids, k0); +} + + + +/******************************************************************* + * Add n elements to the heap + *******************************************************************/ + + +/* Add some elements to the heap */ +template inline +void heap_addn (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + const typename C::T * x, + const typename C::TI * ids, + size_t n) +{ + size_t i; + if (ids) + for (i = 0; i < n; i++) { + if (C::cmp (bh_val[0], x[i])) { + heap_swap_top (k, bh_val, bh_ids, x[i], ids[i]); + } + } + else + for (i = 0; i < n; i++) { + if (C::cmp (bh_val[0], x[i])) { + heap_swap_top (k, bh_val, bh_ids, x[i], i); + } + } +} + + +/* Partial instanciation for heaps with TI = int64_t */ + +template inline +void minheap_addn (size_t k, T * bh_val, int64_t * bh_ids, + const T * x, const int64_t * ids, size_t n) +{ + heap_addn > (k, bh_val, bh_ids, x, ids, n); +} + +template inline +void maxheap_addn (size_t k, T * bh_val, int64_t * bh_ids, + const T * x, const int64_t * ids, size_t n) +{ + heap_addn > (k, bh_val, bh_ids, x, ids, n); +} + + + + + + +/******************************************************************* + * Heap finalization (reorder elements) + *******************************************************************/ + + +/* This function maps a binary heap into an sorted structure. + It returns the number */ +template inline +size_t heap_reorder (size_t k, typename C::T * bh_val, typename C::TI * bh_ids) +{ + size_t i, ii; + + for (i = 0, ii = 0; i < k; i++) { + /* top element should be put at the end of the list */ + typename C::T val = bh_val[0]; + typename C::TI id = bh_ids[0]; + + /* boundary case: we will over-ride this value if not a true element */ + heap_pop (k-i, bh_val, bh_ids); + bh_val[k-ii-1] = val; + bh_ids[k-ii-1] = id; + if (id != -1) ii++; + } + /* Count the number of elements which are effectively returned */ + size_t nel = ii; + + memmove (bh_val, bh_val+k-ii, ii * sizeof(*bh_val)); + memmove (bh_ids, bh_ids+k-ii, ii * sizeof(*bh_ids)); + + for (; ii < k; ii++) { + bh_val[ii] = C::neutral(); + bh_ids[ii] = -1; + } + return nel; +} + +template inline +size_t minheap_reorder (size_t k, T * bh_val, int64_t * bh_ids) +{ + return heap_reorder< CMin > (k, bh_val, bh_ids); +} + +template inline +size_t maxheap_reorder (size_t k, T * bh_val, int64_t * bh_ids) +{ + return heap_reorder< CMax > (k, bh_val, bh_ids); +} + + + + + +/******************************************************************* + * Operations on heap arrays + *******************************************************************/ + +/** a template structure for a set of [min|max]-heaps it is tailored + * so that the actual data of the heaps can just live in compact + * arrays. + */ +template +struct HeapArray { + typedef typename C::TI TI; + typedef typename C::T T; + + size_t nh; ///< number of heaps + size_t k; ///< allocated size per heap + TI * ids; ///< identifiers (size nh * k) + T * val; ///< values (distances or similarities), size nh * k + + /// Return the list of values for a heap + T * get_val (size_t key) { return val + key * k; } + + /// Correspponding identifiers + TI * get_ids (size_t key) { return ids + key * k; } + + /// prepare all the heaps before adding + void heapify (); + + /** add nj elements to heaps i0:i0+ni, with sequential ids + * + * @param nj nb of elements to add to each heap + * @param vin elements to add, size ni * nj + * @param j0 add this to the ids that are added + * @param i0 first heap to update + * @param ni nb of elements to update (-1 = use nh) + */ + void addn (size_t nj, const T *vin, TI j0 = 0, + size_t i0 = 0, int64_t ni = -1); + + /** same as addn + * + * @param id_in ids of the elements to add, size ni * nj + * @param id_stride stride for id_in + */ + void addn_with_ids ( + size_t nj, const T *vin, const TI *id_in = nullptr, + int64_t id_stride = 0, size_t i0 = 0, int64_t ni = -1); + + /// reorder all the heaps + void reorder (); + + /** this is not really a heap function. It just finds the per-line + * extrema of each line of array D + * @param vals_out extreme value of each line (size nh, or NULL) + * @param idx_out index of extreme value (size nh or NULL) + */ + void per_line_extrema (T *vals_out, TI *idx_out) const; + +}; + + +/* Define useful heaps */ +typedef HeapArray > float_minheap_array_t; +typedef HeapArray > int_minheap_array_t; + +typedef HeapArray > float_maxheap_array_t; +typedef HeapArray > int_maxheap_array_t; + +// The heap templates are instanciated explicitly in Heap.cpp + + + + + + + + + + + + + + + + + + + +/********************************************************************* + * Indirect heaps: instead of having + * + * node i = (bh_ids[i], bh_val[i]), + * + * in indirect heaps, + * + * node i = (bh_ids[i], bh_val[bh_ids[i]]), + * + *********************************************************************/ + + +template +inline +void indirect_heap_pop ( + size_t k, + const typename C::T * bh_val, + typename C::TI * bh_ids) +{ + bh_ids--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh_val[bh_ids[k]]; + size_t i = 1; + while (1) { + size_t i1 = i << 1; + size_t i2 = i1 + 1; + if (i1 > k) + break; + typename C::TI id1 = bh_ids[i1], id2 = bh_ids[i2]; + if (i2 == k + 1 || C::cmp(bh_val[id1], bh_val[id2])) { + if (C::cmp(val, bh_val[id1])) + break; + bh_ids[i] = id1; + i = i1; + } else { + if (C::cmp(val, bh_val[id2])) + break; + bh_ids[i] = id2; + i = i2; + } + } + bh_ids[i] = bh_ids[k]; +} + + + +template +inline +void indirect_heap_push (size_t k, + const typename C::T * bh_val, typename C::TI * bh_ids, + typename C::TI id) +{ + bh_ids--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh_val[id]; + size_t i = k; + while (i > 1) { + size_t i_father = i >> 1; + if (!C::cmp (val, bh_val[bh_ids[i_father]])) + break; + bh_ids[i] = bh_ids[i_father]; + i = i_father; + } + bh_ids[i] = id; +} + + +} // namespace faiss + +#endif /* FAISS_Heap_h */ diff --git a/core/src/index/thirdparty/faiss/utils/WorkerThread.cpp b/core/src/index/thirdparty/faiss/utils/WorkerThread.cpp new file mode 100644 index 0000000000..83b5c97e47 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/WorkerThread.cpp @@ -0,0 +1,126 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include +#include + +namespace faiss { + +namespace { + +// Captures any exceptions thrown by the lambda and returns them via the promise +void runCallback(std::function& fn, + std::promise& promise) { + try { + fn(); + promise.set_value(true); + } catch (...) { + promise.set_exception(std::current_exception()); + } +} + +} // namespace + +WorkerThread::WorkerThread() : + wantStop_(false) { + startThread(); + + // Make sure that the thread has started before continuing + add([](){}).get(); +} + +WorkerThread::~WorkerThread() { + stop(); + waitForThreadExit(); +} + +void +WorkerThread::startThread() { + thread_ = std::thread([this](){ threadMain(); }); +} + +void +WorkerThread::stop() { + std::lock_guard guard(mutex_); + + wantStop_ = true; + monitor_.notify_one(); +} + +std::future +WorkerThread::add(std::function f) { + std::lock_guard guard(mutex_); + + if (wantStop_) { + // The timer thread has been stopped, or we want to stop; we can't + // schedule anything else + std::promise p; + auto fut = p.get_future(); + + // did not execute + p.set_value(false); + return fut; + } + + auto pr = std::promise(); + auto fut = pr.get_future(); + + queue_.emplace_back(std::make_pair(std::move(f), std::move(pr))); + + // Wake up our thread + monitor_.notify_one(); + return fut; +} + +void +WorkerThread::threadMain() { + threadLoop(); + + // Call all pending tasks + FAISS_ASSERT(wantStop_); + + // flush all pending operations + for (auto& f : queue_) { + runCallback(f.first, f.second); + } +} + +void +WorkerThread::threadLoop() { + while (true) { + std::pair, std::promise> data; + + { + std::unique_lock lock(mutex_); + + while (!wantStop_ && queue_.empty()) { + monitor_.wait(lock); + } + + if (wantStop_) { + return; + } + + data = std::move(queue_.front()); + queue_.pop_front(); + } + + runCallback(data.first, data.second); + } +} + +void +WorkerThread::waitForThreadExit() { + try { + thread_.join(); + } catch (...) { + } +} + +} // namespace diff --git a/core/src/index/thirdparty/faiss/utils/WorkerThread.h b/core/src/index/thirdparty/faiss/utils/WorkerThread.h new file mode 100644 index 0000000000..7ab21e9f90 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/WorkerThread.h @@ -0,0 +1,61 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#pragma once + +#include +#include +#include +#include + +namespace faiss { + +class WorkerThread { + public: + WorkerThread(); + + /// Stops and waits for the worker thread to exit, flushing all + /// pending lambdas + ~WorkerThread(); + + /// Request that the worker thread stop itself + void stop(); + + /// Blocking waits in the current thread for the worker thread to + /// stop + void waitForThreadExit(); + + /// Adds a lambda to run on the worker thread; returns a future that + /// can be used to block on its completion. + /// Future status is `true` if the lambda was run in the worker + /// thread; `false` if it was not run, because the worker thread is + /// exiting or has exited. + std::future add(std::function f); + + private: + void startThread(); + void threadMain(); + void threadLoop(); + + /// Thread that all queued lambdas are run on + std::thread thread_; + + /// Mutex for the queue and exit status + std::mutex mutex_; + + /// Monitor for the exit status and the queue + std::condition_variable monitor_; + + /// Whether or not we want the thread to exit + bool wantStop_; + + /// Queue of pending lambdas to call + std::deque, std::promise>> queue_; +}; + +} // namespace diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp new file mode 100644 index 0000000000..e97e873614 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -0,0 +1,1073 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + +#ifndef FINTEGER +#define FINTEGER long +#endif + + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */ + +int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda, + float *tau, float *work, FINTEGER *lwork, FINTEGER *info); + +int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha, + const float *a, FINTEGER *lda, const float *x, FINTEGER *incx, + float *beta, float *y, FINTEGER *incy); + +} + + +namespace faiss { + + +/*************************************************************************** + * Matrix/vector ops + ***************************************************************************/ + + + +/* Compute the inner product between a vector x and + a set of ny vectors y. + These functions are not intended to replace BLAS matrix-matrix, as they + would be significantly less efficient in this case. */ +void fvec_inner_products_ny (float * ip, + const float * x, + const float * y, + size_t d, size_t ny) +{ + // Not sure which one is fastest +#if 0 + { + FINTEGER di = d; + FINTEGER nyi = ny; + float one = 1.0, zero = 0.0; + FINTEGER onei = 1; + sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); + } +#endif + for (size_t i = 0; i < ny; i++) { + ip[i] = fvec_inner_product (x, y, d); + y += d; + } +} + + + + + +/* Compute the L2 norm of a set of nx vectors */ +void fvec_norms_L2 (float * __restrict nr, + const float * __restrict x, + size_t d, size_t nx) +{ + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d)); + } +} + +void fvec_norms_L2sqr (float * __restrict nr, + const float * __restrict x, + size_t d, size_t nx) +{ +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) + nr[i] = fvec_norm_L2sqr (x + i * d, d); +} + + + +void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x) +{ +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + float * __restrict xi = x + i * d; + + float nr = fvec_norm_L2sqr (xi, d); + + if (nr > 0) { + size_t j; + const float inv_nr = 1.0 / sqrtf (nr); + for (j = 0; j < d; j++) + xi[j] *= inv_nr; + } + } +} + + + + + + + +/*************************************************************************** + * KNN functions + ***************************************************************************/ + +int parallel_policy_threshold = 65535; + +/* Find the nearest neighbors for nx queries in a set of ny vectors */ +static void knn_inner_product_sse (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = res->k; + size_t thread_max_num = omp_get_max_threads(); + + if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) { + size_t block_x = std::min( + get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), + nx); + + size_t all_heap_size = block_x * k * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + + for (size_t x_from = 0, x_to; x_from < nx; x_from = x_to) { + x_to = std::min(nx, x_from + block_x); + int size = x_to - x_from; + int thread_heap_size = size * k; + + // init heap + for (size_t i = 0; i < all_heap_size; i++) { + value[i] = -1.0 / 0.0; + labels[i] = -1; + } + +#pragma omp parallel for schedule(static) + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)) { + size_t thread_no = omp_get_thread_num(); + const float *y_j = y + j * d; + const float *x_i = x + x_from * d; + for (size_t i = 0; i < size; i++) { + float disij = fvec_inner_product (x_i, y_j, d); + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; + if (disij > val_[0]) { + minheap_swap_top (k, val_, ids_, disij, j); + } + x_i += d; + } + } + } + + // merge heap + for (size_t t = 1; t < thread_max_num; t++) { + for (size_t i = 0; i < size; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] > value_x[0]) { + minheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + // sort + for (size_t i = 0; i < size; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; + minheap_reorder (k, value_x, labels_x); + } + + // copy result + memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float)); + memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t)); + } + delete[] value; + delete[] labels; + + } else { + float * value = res->val; + int64_t * labels = res->ids; + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + const float *x_i = x + i * d; + const float *y_j = y; + + float * __restrict val_ = value + i * k; + int64_t * __restrict ids_ = labels + i * k; + + for (size_t j = 0; j < k; j++) { + val_[j] = -1.0 / 0.0; + ids_[j] = -1; + } + + for (size_t j = 0; j < ny; j++) { + if (!bitset || !bitset->test(j)) { + float disij = fvec_inner_product (x_i, y_j, d); + if (disij > val_[0]) { + minheap_swap_top (k, val_, ids_, disij, j); + } + } + y_j += d; + } + + minheap_reorder (k, val_, ids_); + } + } +} + +static void knn_L2sqr_sse ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = res->k; + size_t thread_max_num = omp_get_max_threads(); + + if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) { + size_t block_x = std::min( + get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), + nx); + + size_t all_heap_size = block_x * k * thread_max_num; + float *value = new float[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + + for (size_t x_from = 0, x_to; x_from < nx; x_from = x_to) { + x_to = std::min(nx, x_from + block_x); + int size = x_to - x_from; + int thread_heap_size = size * k; + + // init heap + for (size_t i = 0; i < all_heap_size; i++) { + value[i] = 1.0 / 0.0; + labels[i] = -1; + } + +#pragma omp parallel for schedule(static) + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)) { + size_t thread_no = omp_get_thread_num(); + const float *y_j = y + j * d; + const float *x_i = x + x_from * d; + for (size_t i = 0; i < size; i++) { + float disij = fvec_L2sqr (x_i, y_j, d); + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; + if (disij < val_[0]) { + maxheap_swap_top (k, val_, ids_, disij, j); + } + x_i += d; + } + } + } + + // merge heap + for (size_t t = 1; t < thread_max_num; t++) { + for (size_t i = 0; i < size; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + float *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + // sort + for (size_t i = 0; i < size; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; + maxheap_reorder (k, value_x, labels_x); + } + + // copy result + memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float)); + memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t)); + } + delete[] value; + delete[] labels; + + } else { + + float * value = res->val; + int64_t * labels = res->ids; + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + const float *x_i = x + i * d; + const float *y_j = y; + + float * __restrict val_ = value + i * k; + int64_t * __restrict ids_ = labels + i * k; + + for (size_t j = 0; j < k; j++) { + val_[j] = 1.0 / 0.0; + ids_[j] = -1; + } + + for (size_t j = 0; j < ny; j++) { + if (!bitset || !bitset->test(j)) { + float disij = fvec_L2sqr (x_i, y_j, d); + if (disij < val_[0]) { + maxheap_swap_top (k, val_, ids_, disij, j); + } + } + y_j += d; + } + + maxheap_reorder (k, val_, ids_); + } + } +} + +/** Find the nearest neighbors for nx queries in a set of ny vectors */ +static void knn_inner_product_blas ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr) +{ + res->heapify (); + + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) return; + + size_t k = res->k; + + /* block sizes */ + const size_t bs_x = 4096, bs_y = 1024; + // const size_t bs_x = 16, bs_y = 16; + float *ip_block = new float[bs_x * bs_y]; + ScopeDeleter del1(ip_block);; + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if(i1 > nx) i1 = nx; + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, + y + j0 * d, &di, + x + i0 * d, &di, &zero, + ip_block, &nyi); + } + + /* collect maxima */ +#pragma omp parallel for + for(size_t i = i0; i < i1; i++){ + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + const float *ip_line = ip_block + (i - i0) * (j1 - j0); + + for(size_t j = j0; j < j1; j++){ + if(!bitset || !bitset->test(j)){ + float dis = *ip_line; + + if(dis > simi[0]){ + minheap_swap_top(k, simi, idxi, dis, j); + } + } + ip_line++; + } + } + } + InterruptCallback::check (); + } + res->reorder (); +} + +// distance correction is an operator that can be applied to transform +// the distances +template +static void knn_L2sqr_blas (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + const DistanceCorrection &corr, + ConcurrentBitsetPtr bitset = nullptr) +{ + res->heapify (); + + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) return; + + size_t k = res->k; + + /* block sizes */ + const size_t bs_x = 4096, bs_y = 1024; + // const size_t bs_x = 16, bs_y = 16; + float *ip_block = new float[bs_x * bs_y]; + float *x_norms = new float[nx]; + float *y_norms = new float[ny]; + ScopeDeleter del1(ip_block), del3(x_norms), del2(y_norms); + + fvec_norms_L2sqr (x_norms, x, d, nx); + fvec_norms_L2sqr (y_norms, y, d, ny); + + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if(i1 > nx) i1 = nx; + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, + y + j0 * d, &di, + x + i0 * d, &di, &zero, + ip_block, &nyi); + } + + /* collect minima */ +#pragma omp parallel for + for (size_t i = i0; i < i1; i++) { + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + const float *ip_line = ip_block + (i - i0) * (j1 - j0); + + for (size_t j = j0; j < j1; j++) { + if(!bitset || !bitset->test(j)){ + float ip = *ip_line; + float dis = x_norms[i] + y_norms[j] - 2 * ip; + + // negative values can occur for identical vectors + // due to roundoff errors + if (dis < 0) dis = 0; + + dis = corr (dis, i, j); + + if (dis < simi[0]) { + maxheap_swap_top (k, simi, idxi, dis, j); + } + } + ip_line++; + } + } + } + InterruptCallback::check (); + } + res->reorder (); + +} + +template +static void knn_jaccard_blas (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + const DistanceCorrection &corr, + ConcurrentBitsetPtr bitset = nullptr) +{ + res->heapify (); + + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) return; + + size_t k = res->k; + + /* block sizes */ + const size_t bs_x = 4096, bs_y = 1024; + // const size_t bs_x = 16, bs_y = 16; + float *ip_block = new float[bs_x * bs_y]; + float *x_norms = new float[nx]; + float *y_norms = new float[ny]; + ScopeDeleter del1(ip_block), del3(x_norms), del2(y_norms); + + fvec_norms_L2sqr (x_norms, x, d, nx); + fvec_norms_L2sqr (y_norms, y, d, ny); + + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if(i1 > nx) i1 = nx; + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, + y + j0 * d, &di, + x + i0 * d, &di, &zero, + ip_block, &nyi); + } + + /* collect minima */ +#pragma omp parallel for + for (size_t i = i0; i < i1; i++) { + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + const float *ip_line = ip_block + (i - i0) * (j1 - j0); + + for (size_t j = j0; j < j1; j++) { + if(!bitset || !bitset->test(j)){ + float ip = *ip_line; + float dis = 1.0 - ip / (x_norms[i] + y_norms[j] - ip); + + // negative values can occur for identical vectors + // due to roundoff errors + if (dis < 0) dis = 0; + + dis = corr (dis, i, j); + + if (dis < simi[0]) { + maxheap_swap_top (k, simi, idxi, dis, j); + } + } + ip_line++; + } + } + } + InterruptCallback::check (); + } + res->reorder (); +} + + + + + + + +/******************************************************* + * KNN driver functions + *******************************************************/ + +int distance_compute_blas_threshold = 20; + +void knn_inner_product (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res, + ConcurrentBitsetPtr bitset) +{ + if (nx < distance_compute_blas_threshold) { + knn_inner_product_sse (x, y, d, nx, ny, res, bitset); + } else { + knn_inner_product_blas (x, y, d, nx, ny, res, bitset); + } +} + + + +struct NopDistanceCorrection { + float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const { + return dis; + } +}; + +void knn_L2sqr (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset) +{ + if (nx < distance_compute_blas_threshold) { + knn_L2sqr_sse (x, y, d, nx, ny, res, bitset); + } else { + NopDistanceCorrection nop; + knn_L2sqr_blas (x, y, d, nx, ny, res, nop, bitset); + } +} + +void knn_jaccard (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset) +{ + if (d % 4 == 0 && nx < distance_compute_blas_threshold) { +// knn_jaccard_sse (x, y, d, nx, ny, res); + printf("jaccard sse not implemented!\n"); + } else { + NopDistanceCorrection nop; + knn_jaccard_blas (x, y, d, nx, ny, res, nop, bitset); + } +} + +struct BaseShiftDistanceCorrection { + const float *base_shift; + float operator()(float dis, size_t /*qno*/, size_t bno) const { + return dis - base_shift[bno]; + } +}; + +void knn_L2sqr_base_shift ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + const float *base_shift) +{ + BaseShiftDistanceCorrection corr = {base_shift}; + knn_L2sqr_blas (x, y, d, nx, ny, res, corr); +} + + + +/*************************************************************************** + * compute a subset of distances + ***************************************************************************/ + +/* compute the inner product between x and a subset y of ny vectors, + whose indices are given by idy. */ +void fvec_inner_products_by_idx (float * __restrict ip, + const float * x, + const float * y, + const int64_t * __restrict ids, /* for y vecs */ + size_t d, size_t nx, size_t ny) +{ +#pragma omp parallel for + for (size_t j = 0; j < nx; j++) { + const int64_t * __restrict idsj = ids + j * ny; + const float * xj = x + j * d; + float * __restrict ipj = ip + j * ny; + for (size_t i = 0; i < ny; i++) { + if (idsj[i] < 0) + continue; + ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d); + } + } +} + + + +/* compute the inner product between x and a subset y of ny vectors, + whose indices are given by idy. */ +void fvec_L2sqr_by_idx (float * __restrict dis, + const float * x, + const float * y, + const int64_t * __restrict ids, /* ids of y vecs */ + size_t d, size_t nx, size_t ny) +{ +#pragma omp parallel for + for (size_t j = 0; j < nx; j++) { + const int64_t * __restrict idsj = ids + j * ny; + const float * xj = x + j * d; + float * __restrict disj = dis + j * ny; + for (size_t i = 0; i < ny; i++) { + if (idsj[i] < 0) + continue; + disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d); + } + } +} + +void pairwise_indexed_L2sqr ( + size_t d, size_t n, + const float * x, const int64_t *ix, + const float * y, const int64_t *iy, + float *dis) +{ +#pragma omp parallel for + for (size_t j = 0; j < n; j++) { + if (ix[j] >= 0 && iy[j] >= 0) { + dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d); + } + } +} + +void pairwise_indexed_inner_product ( + size_t d, size_t n, + const float * x, const int64_t *ix, + const float * y, const int64_t *iy, + float *dis) +{ +#pragma omp parallel for + for (size_t j = 0; j < n; j++) { + if (ix[j] >= 0 && iy[j] >= 0) { + dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d); + } + } +} + + +/* Find the nearest neighbors for nx queries in a set of ny vectors + indexed by ids. May be useful for re-ranking a pre-selected vector list */ +void knn_inner_products_by_idx (const float * x, + const float * y, + const int64_t * ids, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res) +{ + size_t k = res->k; + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + const float * x_ = x + i * d; + const int64_t * idsi = ids + i * ny; + size_t j; + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + minheap_heapify (k, simi, idxi); + + for (j = 0; j < ny; j++) { + if (idsi[j] < 0) break; + float ip = fvec_inner_product (x_, y + d * idsi[j], d); + + if (ip > simi[0]) { + minheap_swap_top (k, simi, idxi, ip, idsi[j]); + } + } + minheap_reorder (k, simi, idxi); + } + +} + +void knn_L2sqr_by_idx (const float * x, + const float * y, + const int64_t * __restrict ids, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res) +{ + size_t k = res->k; + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + const float * x_ = x + i * d; + const int64_t * __restrict idsi = ids + i * ny; + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + maxheap_heapify (res->k, simi, idxi); + for (size_t j = 0; j < ny; j++) { + float disij = fvec_L2sqr (x_, y + d * idsi[j], d); + + if (disij < simi[0]) { + maxheap_swap_top (k, simi, idxi, disij, idsi[j]); + } + } + maxheap_reorder (res->k, simi, idxi); + } + +} + + + + + +/*************************************************************************** + * Range search + ***************************************************************************/ + +/** Find the nearest neighbors for nx queries in a set of ny vectors + * compute_l2 = compute pairwise squared L2 distance rather than inner prod + */ +template +static void range_search_blas ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *result) +{ + + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) return; + + /* block sizes */ + const size_t bs_x = 4096, bs_y = 1024; + // const size_t bs_x = 16, bs_y = 16; + float *ip_block = new float[bs_x * bs_y]; + ScopeDeleter del0(ip_block); + + float *x_norms = nullptr, *y_norms = nullptr; + ScopeDeleter del1, del2; + if (compute_l2) { + x_norms = new float[nx]; + del1.set (x_norms); + fvec_norms_L2sqr (x_norms, x, d, nx); + + y_norms = new float[ny]; + del2.set (y_norms); + fvec_norms_L2sqr (y_norms, y, d, ny); + } + + std::vector partial_results; + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + RangeSearchPartialResult * pres = new RangeSearchPartialResult (result); + partial_results.push_back (pres); + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if(i1 > nx) i1 = nx; + + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one, + y + j0 * d, &di, + x + i0 * d, &di, &zero, + ip_block, &nyi); + } + + + for (size_t i = i0; i < i1; i++) { + const float *ip_line = ip_block + (i - i0) * (j1 - j0); + + RangeQueryResult & qres = pres->new_result (i); + + for (size_t j = j0; j < j1; j++) { + float ip = *ip_line++; + if (compute_l2) { + float dis = x_norms[i] + y_norms[j] - 2 * ip; + if (dis < radius) { + qres.add (dis, j); + } + } else { + if (ip > radius) { + qres.add (ip, j); + } + } + } + } + } + InterruptCallback::check (); + } + + RangeSearchPartialResult::merge (partial_results); +} + + +template +static void range_search_sse (const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *res) +{ + +#pragma omp parallel + { + RangeSearchPartialResult pres (res); + +#pragma omp for + for (size_t i = 0; i < nx; i++) { + const float * x_ = x + i * d; + const float * y_ = y; + size_t j; + + RangeQueryResult & qres = pres.new_result (i); + + for (j = 0; j < ny; j++) { + if (compute_l2) { + float disij = fvec_L2sqr (x_, y_, d); + if (disij < radius) { + qres.add (disij, j); + } + } else { + float ip = fvec_inner_product (x_, y_, d); + if (ip > radius) { + qres.add (ip, j); + } + } + y_ += d; + } + + } + pres.finalize (); + } + + // check just at the end because the use case is typically just + // when the nb of queries is low. + InterruptCallback::check(); +} + + + + + +void range_search_L2sqr ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *res) +{ + + if (nx < distance_compute_blas_threshold) { + range_search_sse (x, y, d, nx, ny, radius, res); + } else { + range_search_blas (x, y, d, nx, ny, radius, res); + } +} + +void range_search_inner_product ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *res) +{ + + if (nx < distance_compute_blas_threshold) { + range_search_sse (x, y, d, nx, ny, radius, res); + } else { + range_search_blas (x, y, d, nx, ny, radius, res); + } +} + + +void pairwise_L2sqr (int64_t d, + int64_t nq, const float *xq, + int64_t nb, const float *xb, + float *dis, + int64_t ldq, int64_t ldb, int64_t ldd) +{ + if (nq == 0 || nb == 0) return; + if (ldq == -1) ldq = d; + if (ldb == -1) ldb = d; + if (ldd == -1) ldd = nb; + + // store in beginning of distance matrix to avoid malloc + float *b_norms = dis; + +#pragma omp parallel for + for (int64_t i = 0; i < nb; i++) + b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d); + +#pragma omp parallel for + for (int64_t i = 1; i < nq; i++) { + float q_norm = fvec_norm_L2sqr (xq + i * ldq, d); + for (int64_t j = 0; j < nb; j++) + dis[i * ldd + j] = q_norm + b_norms [j]; + } + + { + float q_norm = fvec_norm_L2sqr (xq, d); + for (int64_t j = 0; j < nb; j++) + dis[j] += q_norm; + } + + { + FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd; + float one = 1.0, minus_2 = -2.0; + + sgemm_ ("Transposed", "Not transposed", + &nbi, &nqi, &di, + &minus_2, + xb, &ldbi, + xq, &ldqi, + &one, dis, &lddi); + } + +} + +void elkan_L2_sse ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + int64_t *ids, float *val) { + + if (nx == 0 || ny == 0) { + return; + } + + const size_t bs_y = 1024; + float *data = (float *) malloc((bs_y * (bs_y - 1) / 2) * sizeof (float)); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + BuilderSuspend::check_wait(); + + size_t j1 = j0 + bs_y; + if (j1 > ny) j1 = ny; + + auto Y = [&](size_t i, size_t j) -> float& { + assert(i != j); + i -= j0, j -= j0; + return (i > j) ? data[j + i * (i - 1) / 2] : data[i + j * (j - 1) / 2]; + }; + +#pragma omp parallel for + for (size_t i = j0 + 1; i < j1; i++) { + const float *y_i = y + i * d; + for (size_t j = j0; j < i; j++) { + const float *y_j = y + j * d; + Y(i, j) = sqrt(fvec_L2sqr(y_i, y_j, d)); + } + } + +#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + const float *x_i = x + i * d; + + int64_t ids_i = j0; + float val_i = sqrt(fvec_L2sqr(x_i, y + j0 * d, d)); + float val_i_2 = val_i * 2; + for (size_t j = j0 + 1; j < j1; j++) { + if (val_i_2 <= Y(ids_i, j)) { + continue; + } + const float *y_j = y + j * d; + float disij = sqrt(fvec_L2sqr(x_i, y_j, d)); + if (disij < val_i) { + ids_i = j; + val_i = disij; + val_i_2 = val_i * 2; + } + } + + if (j0 == 0 || val[i] > val_i) { + val[i] = val_i; + ids[i] = ids_i; + } + } + } + + free(data); +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances.h b/core/src/index/thirdparty/faiss/utils/distances.h new file mode 100644 index 0000000000..b4311d09c6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances.h @@ -0,0 +1,271 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* All distance functions for L2 and IP distances. + * The actual functions are implemented in distances.cpp and distances_simd.cpp */ + +#pragma once + +#include + +#include +#include + + +namespace faiss { + + /********************************************************* + * Optimized distance/norm/inner prod computations + *********************************************************/ + +#ifdef __SSE__ +float fvec_L2sqr_sse ( + const float * x, + const float * y, + size_t d); + +float fvec_inner_product_sse ( + const float * x, + const float * y, + size_t d); + +float fvec_L1_sse ( + const float * x, + const float * y, + size_t d); + +float fvec_Linf_sse ( + const float * x, + const float * y, + size_t d); +#endif + +float fvec_jaccard ( + const float * x, + const float * y, + size_t d); + +/** Compute pairwise distances between sets of vectors + * + * @param d dimension of the vectors + * @param nq nb of query vectors + * @param nb nb of database vectors + * @param xq query vectors (size nq * d) + * @param xb database vectros (size nb * d) + * @param dis output distances (size nq * nb) + * @param ldq,ldb, ldd strides for the matrices + */ +void pairwise_L2sqr (int64_t d, + int64_t nq, const float *xq, + int64_t nb, const float *xb, + float *dis, + int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1); + +/* compute the inner product between nx vectors x and one y */ +void fvec_inner_products_ny ( + float * ip, /* output inner product */ + const float * x, + const float * y, + size_t d, size_t ny); + +/* compute ny square L2 distance bewteen x and a set of contiguous y vectors */ +void fvec_L2sqr_ny ( + float * dis, + const float * x, + const float * y, + size_t d, size_t ny); + + +/** squared norm of a vector */ +float fvec_norm_L2sqr (const float * x, + size_t d); + +/** compute the L2 norms for a set of vectors + * + * @param ip output norms, size nx + * @param x set of vectors, size nx * d + */ +void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx); + +/// same as fvec_norms_L2, but computes square norms +void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx); + +/* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */ +void fvec_renorm_L2 (size_t d, size_t nx, float * x); + + +/* This function exists because the Torch counterpart is extremly slow + (not multi-threaded + unexpected overhead even in single thread). + It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2 */ +void inner_product_to_L2sqr (float * dis, + const float * nr1, + const float * nr2, + size_t n1, size_t n2); + +/*************************************************************************** + * Compute a subset of distances + ***************************************************************************/ + + /* compute the inner product between x and a subset y of ny vectors, + whose indices are given by idy. */ +void fvec_inner_products_by_idx ( + float * ip, + const float * x, + const float * y, + const int64_t *ids, + size_t d, size_t nx, size_t ny); + +/* same but for a subset in y indexed by idsy (ny vectors in total) */ +void fvec_L2sqr_by_idx ( + float * dis, + const float * x, + const float * y, + const int64_t *ids, /* ids of y vecs */ + size_t d, size_t nx, size_t ny); + + +/** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1 + * + * @param x size (max(ix) + 1, d) + * @param y size (max(iy) + 1, d) + * @param ix size n + * @param iy size n + * @param dis size n + */ +void pairwise_indexed_L2sqr ( + size_t d, size_t n, + const float * x, const int64_t *ix, + const float * y, const int64_t *iy, + float *dis); + +/* same for inner product */ +void pairwise_indexed_inner_product ( + size_t d, size_t n, + const float * x, const int64_t *ix, + const float * y, const int64_t *iy, + float *dis); + +/*************************************************************************** + * KNN functions + ***************************************************************************/ + +// threshold on nx above which we switch to BLAS to compute distances +extern int distance_compute_blas_threshold; + +// threshold on nx above which we switch to compute parallel on ny +extern int parallel_policy_threshold; + +/** Return the k nearest neighors of each of the nx vectors x among the ny + * vector y, w.r.t to max inner product + * + * @param x query vectors, size nx * d + * @param y database vectors, size ny * d + * @param res result array, which also provides k. Sorted on output + */ +void knn_inner_product ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr); + +/** Same as knn_inner_product, for the L2 distance */ +void knn_L2sqr ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr); + +void knn_jaccard ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr); + +/** same as knn_L2sqr, but base_shift[bno] is subtracted to all + * computed distances. + * + * @param base_shift size ny + */ +void knn_L2sqr_base_shift ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res, + const float *base_shift); + +/* Find the nearest neighbors for nx queries in a set of ny vectors + * indexed by ids. May be useful for re-ranking a pre-selected vector list + */ +void knn_inner_products_by_idx ( + const float * x, + const float * y, + const int64_t * ids, + size_t d, size_t nx, size_t ny, + float_minheap_array_t * res); + +void knn_L2sqr_by_idx (const float * x, + const float * y, + const int64_t * ids, + size_t d, size_t nx, size_t ny, + float_maxheap_array_t * res); + +/*************************************************************************** + * Range search + ***************************************************************************/ + + + +/// Forward declaration, see AuxIndexStructures.h +struct RangeSearchResult; + +/** Return the k nearest neighors of each of the nx vectors x among the ny + * vector y, w.r.t to max inner product + * + * @param x query vectors, size nx * d + * @param y database vectors, size ny * d + * @param radius search radius around the x vectors + * @param result result structure + */ +void range_search_L2sqr ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *result); + +/// same as range_search_L2sqr for the inner product similarity +void range_search_inner_product ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + float radius, + RangeSearchResult *result); + + +/*************************************************************************** + * elkan + ***************************************************************************/ + +/** Return the nearest neighors of each of the nx vectors x among the ny + * + * @param x query vectors, size nx * d + * @param y database vectors, size ny * d + * @param ids result array ids + * @param val result array value + */ +void elkan_L2_sse ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + int64_t *ids, float *val); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances_avx.h b/core/src/index/thirdparty/faiss/utils/distances_avx.h new file mode 100644 index 0000000000..734c38ebe7 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances_avx.h @@ -0,0 +1,32 @@ + +// -*- c++ -*- + +/* All distance functions for L2 and IP distances. + * The actual functions are implemented in distances_simd_avx512.cpp */ + +#pragma once + +#include + +namespace faiss { + +/********************************************************* + * Optimized distance/norm/inner prod computations + *********************************************************/ + +/// Squared L2 distance between two vectors +float +fvec_L2sqr_avx(const float* x, const float* y, size_t d); + +/// inner product +float +fvec_inner_product_avx(const float* x, const float* y, size_t d); + +/// L1 distance +float +fvec_L1_avx(const float* x, const float* y, size_t d); + +float +fvec_Linf_avx(const float* x, const float* y, size_t d); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances_avx512.h b/core/src/index/thirdparty/faiss/utils/distances_avx512.h new file mode 100644 index 0000000000..d410f3e821 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances_avx512.h @@ -0,0 +1,32 @@ + +// -*- c++ -*- + +/* All distance functions for L2 and IP distances. + * The actual functions are implemented in distances_simd_avx512.cpp */ + +#pragma once + +#include + +namespace faiss { + +/********************************************************* + * Optimized distance/norm/inner prod computations + *********************************************************/ + +/// Squared L2 distance between two vectors +float +fvec_L2sqr_avx512(const float* x, const float* y, size_t d); + +/// inner product +float +fvec_inner_product_avx512(const float * x, const float * y, size_t d); + +/// L1 distance +float +fvec_L1_avx512(const float* x, const float* y, size_t d); + +float +fvec_Linf_avx512(const float* x, const float* y, size_t d); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances_simd.cpp b/core/src/index/thirdparty/faiss/utils/distances_simd.cpp new file mode 100644 index 0000000000..e33967d5e6 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances_simd.cpp @@ -0,0 +1,624 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include +#include + +#ifdef __SSE__ +#include +#endif + +#ifdef __aarch64__ +#include +#endif + +#include + +namespace faiss { + +/********************************************************* + * Optimized distance computations + *********************************************************/ + + +/* Functions to compute: + - L2 distance between 2 vectors + - inner product between 2 vectors + - L2 norm of a vector + + The functions should probably not be invoked when a large number of + vectors are be processed in batch (in which case Matrix multiply + is faster), but may be useful for comparing vectors isolated in + memory. + + Works with any vectors of any dimension, even unaligned (in which + case they are slower). + +*/ + + +/********************************************************* + * Reference implementations + */ + + +float fvec_L2sqr_ref (const float * x, + const float * y, + size_t d) +{ + size_t i; + float res = 0; + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} + +float fvec_L1_ref (const float * x, + const float * y, + size_t d) +{ + size_t i; + float res = 0; + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += fabs(tmp); + } + return res; +} + +float fvec_Linf_ref (const float * x, + const float * y, + size_t d) +{ + size_t i; + float res = 0; + for (i = 0; i < d; i++) { + res = fmax(res, fabs(x[i] - y[i])); + } + return res; +} + +float fvec_inner_product_ref (const float * x, + const float * y, + size_t d) +{ + size_t i; + float res = 0; + for (i = 0; i < d; i++) + res += x[i] * y[i]; + return res; +} + +float fvec_norm_L2sqr_ref (const float *x, size_t d) +{ + size_t i; + double res = 0; + for (i = 0; i < d; i++) + res += x[i] * x[i]; + return res; +} + + +void fvec_L2sqr_ny_ref (float * dis, + const float * x, + const float * y, + size_t d, size_t ny) +{ + for (size_t i = 0; i < ny; i++) { + dis[i] = fvec_L2sqr (x, y, d); + y += d; + } +} + + + + +/********************************************************* + * SSE and AVX implementations + */ + +#ifdef __SSE__ + +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read (int d, const float *x) +{ + assert (0 <= d && d < 4); + __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + case 2: + buf[1] = x[1]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps (buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +float fvec_norm_L2sqr (const float * x, + size_t d) +{ + __m128 mx; + __m128 msum1 = _mm_setzero_ps(); + + while (d >= 4) { + mx = _mm_loadu_ps (x); x += 4; + msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); + d -= 4; + } + + mx = masked_read (d, x); + msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx)); + + msum1 = _mm_hadd_ps (msum1, msum1); + msum1 = _mm_hadd_ps (msum1, msum1); + return _mm_cvtss_f32 (msum1); +} + +namespace { + +float sqr (float x) { + return x * x; +} + + +void fvec_L2sqr_ny_D1 (float * dis, const float * x, + const float * y, size_t ny) +{ + float x0s = x[0]; + __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s); + + size_t i; + for (i = 0; i + 3 < ny; i += 4) { + __m128 tmp, accu; + tmp = x0 - _mm_loadu_ps (y); y += 4; + accu = tmp * tmp; + dis[i] = _mm_cvtss_f32 (accu); + tmp = _mm_shuffle_ps (accu, accu, 1); + dis[i + 1] = _mm_cvtss_f32 (tmp); + tmp = _mm_shuffle_ps (accu, accu, 2); + dis[i + 2] = _mm_cvtss_f32 (tmp); + tmp = _mm_shuffle_ps (accu, accu, 3); + dis[i + 3] = _mm_cvtss_f32 (tmp); + } + while (i < ny) { // handle non-multiple-of-4 case + dis[i++] = sqr(x0s - *y++); + } +} + + +void fvec_L2sqr_ny_D2 (float * dis, const float * x, + const float * y, size_t ny) +{ + __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]); + + size_t i; + for (i = 0; i + 1 < ny; i += 2) { + __m128 tmp, accu; + tmp = x0 - _mm_loadu_ps (y); y += 4; + accu = tmp * tmp; + accu = _mm_hadd_ps (accu, accu); + dis[i] = _mm_cvtss_f32 (accu); + accu = _mm_shuffle_ps (accu, accu, 3); + dis[i + 1] = _mm_cvtss_f32 (accu); + } + if (i < ny) { // handle odd case + dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]); + } +} + + + +void fvec_L2sqr_ny_D4 (float * dis, const float * x, + const float * y, size_t ny) +{ + __m128 x0 = _mm_loadu_ps(x); + + for (size_t i = 0; i < ny; i++) { + __m128 tmp, accu; + tmp = x0 - _mm_loadu_ps (y); y += 4; + accu = tmp * tmp; + accu = _mm_hadd_ps (accu, accu); + accu = _mm_hadd_ps (accu, accu); + dis[i] = _mm_cvtss_f32 (accu); + } +} + + +void fvec_L2sqr_ny_D8 (float * dis, const float * x, + const float * y, size_t ny) +{ + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + + for (size_t i = 0; i < ny; i++) { + __m128 tmp, accu; + tmp = x0 - _mm_loadu_ps (y); y += 4; + accu = tmp * tmp; + tmp = x1 - _mm_loadu_ps (y); y += 4; + accu += tmp * tmp; + accu = _mm_hadd_ps (accu, accu); + accu = _mm_hadd_ps (accu, accu); + dis[i] = _mm_cvtss_f32 (accu); + } +} + + +void fvec_L2sqr_ny_D12 (float * dis, const float * x, + const float * y, size_t ny) +{ + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + __m128 x2 = _mm_loadu_ps(x + 8); + + for (size_t i = 0; i < ny; i++) { + __m128 tmp, accu; + tmp = x0 - _mm_loadu_ps (y); y += 4; + accu = tmp * tmp; + tmp = x1 - _mm_loadu_ps (y); y += 4; + accu += tmp * tmp; + tmp = x2 - _mm_loadu_ps (y); y += 4; + accu += tmp * tmp; + accu = _mm_hadd_ps (accu, accu); + accu = _mm_hadd_ps (accu, accu); + dis[i] = _mm_cvtss_f32 (accu); + } +} + + +} // anonymous namespace + +void fvec_L2sqr_ny (float * dis, const float * x, + const float * y, size_t d, size_t ny) { + // optimized for a few special cases + switch(d) { + case 1: + fvec_L2sqr_ny_D1 (dis, x, y, ny); + return; + case 2: + fvec_L2sqr_ny_D2 (dis, x, y, ny); + return; + case 4: + fvec_L2sqr_ny_D4 (dis, x, y, ny); + return; + case 8: + fvec_L2sqr_ny_D8 (dis, x, y, ny); + return; + case 12: + fvec_L2sqr_ny_D12 (dis, x, y, ny); + return; + default: + fvec_L2sqr_ny_ref (dis, x, y, d, ny); + return; + } +} + +#endif + +#if defined(__SSE__) // But not AVX + +float fvec_L1_sse (const float * x, const float * y, size_t d) +{ + return fvec_L1_ref (x, y, d); +} + +float fvec_Linf_sse (const float * x, const float * y, size_t d) +{ + return fvec_Linf_ref (x, y, d); +} + + +float fvec_L2sqr_sse (const float * x, + const float * y, + size_t d) +{ + __m128 msum1 = _mm_setzero_ps(); + + while (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b1 = mx - my; + msum1 += a_m_b1 * a_m_b1; + d -= 4; + } + + if (d > 0) { + // add the last 1, 2 or 3 values + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b1 = mx - my; + msum1 += a_m_b1 * a_m_b1; + } + + msum1 = _mm_hadd_ps (msum1, msum1); + msum1 = _mm_hadd_ps (msum1, msum1); + return _mm_cvtss_f32 (msum1); +} + + +float fvec_inner_product_sse (const float * x, + const float * y, + size_t d) +{ + __m128 mx, my; + __m128 msum1 = _mm_setzero_ps(); + + while (d >= 4) { + mx = _mm_loadu_ps (x); x += 4; + my = _mm_loadu_ps (y); y += 4; + msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my)); + d -= 4; + } + + // add the last 1, 2, or 3 values + mx = masked_read (d, x); + my = masked_read (d, y); + __m128 prod = _mm_mul_ps (mx, my); + + msum1 = _mm_add_ps (msum1, prod); + + msum1 = _mm_hadd_ps (msum1, msum1); + msum1 = _mm_hadd_ps (msum1, msum1); + return _mm_cvtss_f32 (msum1); +} + +#endif /* defined(__SSE__) */ + +//#elif defined(__aarch64__) +// +//float fvec_L2sqr (const float * x, +// const float * y, +// size_t d) +//{ +// if (d & 3) return fvec_L2sqr_ref (x, y, d); +// float32x4_t accu = vdupq_n_f32 (0); +// for (size_t i = 0; i < d; i += 4) { +// float32x4_t xi = vld1q_f32 (x + i); +// float32x4_t yi = vld1q_f32 (y + i); +// float32x4_t sq = vsubq_f32 (xi, yi); +// accu = vfmaq_f32 (accu, sq, sq); +// } +// float32x4_t a2 = vpaddq_f32 (accu, accu); +// return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1); +//} +// +//float fvec_inner_product (const float * x, +// const float * y, +// size_t d) +//{ +// if (d & 3) return fvec_inner_product_ref (x, y, d); +// float32x4_t accu = vdupq_n_f32 (0); +// for (size_t i = 0; i < d; i += 4) { +// float32x4_t xi = vld1q_f32 (x + i); +// float32x4_t yi = vld1q_f32 (y + i); +// accu = vfmaq_f32 (accu, xi, yi); +// } +// float32x4_t a2 = vpaddq_f32 (accu, accu); +// return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1); +//} +// +//float fvec_norm_L2sqr (const float *x, size_t d) +//{ +// if (d & 3) return fvec_norm_L2sqr_ref (x, d); +// float32x4_t accu = vdupq_n_f32 (0); +// for (size_t i = 0; i < d; i += 4) { +// float32x4_t xi = vld1q_f32 (x + i); +// accu = vfmaq_f32 (accu, xi, xi); +// } +// float32x4_t a2 = vpaddq_f32 (accu, accu); +// return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1); +//} +// +//// not optimized for ARM +//void fvec_L2sqr_ny (float * dis, const float * x, +// const float * y, size_t d, size_t ny) { +// fvec_L2sqr_ny_ref (dis, x, y, d, ny); +//} +// +//float fvec_L1 (const float * x, const float * y, size_t d) +//{ +// return fvec_L1_ref (x, y, d); +//} +// +//float fvec_Linf (const float * x, const float * y, size_t d) +//{ +// return fvec_Linf_ref (x, y, d); +//} +// +// +//#else +// scalar implementation + +//float fvec_L2sqr (const float * x, +// const float * y, +// size_t d) +//{ +// return fvec_L2sqr_ref (x, y, d); +//} +// +//float fvec_L1 (const float * x, const float * y, size_t d) +//{ +// return fvec_L1_ref (x, y, d); +//} +// +//float fvec_Linf (const float * x, const float * y, size_t d) +//{ +// return fvec_Linf_ref (x, y, d); +//} +// +//float fvec_inner_product (const float * x, +// const float * y, +// size_t d) +//{ +// return fvec_inner_product_ref (x, y, d); +//} +// +//float fvec_norm_L2sqr (const float *x, size_t d) +//{ +// return fvec_norm_L2sqr_ref (x, d); +//} +// +//void fvec_L2sqr_ny (float * dis, const float * x, +// const float * y, size_t d, size_t ny) { +// fvec_L2sqr_ny_ref (dis, x, y, d, ny); +//} +// +//#endif + + + + +/*************************************************************************** + * heavily optimized table computations + ***************************************************************************/ + + +static inline void fvec_madd_ref (size_t n, const float *a, + float bf, const float *b, float *c) { + for (size_t i = 0; i < n; i++) + c[i] = a[i] + bf * b[i]; +} + +#ifdef __SSE__ + +static inline void fvec_madd_sse (size_t n, const float *a, + float bf, const float *b, float *c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1 (bf); + __m128 * a4 = (__m128*)a; + __m128 * b4 = (__m128*)b; + __m128 * c4 = (__m128*)c; + + while (n--) { + *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4)); + b4++; + a4++; + c4++; + } +} + +void fvec_madd (size_t n, const float *a, + float bf, const float *b, float *c) +{ + if ((n & 3) == 0 && + ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) + fvec_madd_sse (n, a, bf, b, c); + else + fvec_madd_ref (n, a, bf, b, c); +} + +#else + +void fvec_madd (size_t n, const float *a, + float bf, const float *b, float *c) +{ + fvec_madd_ref (n, a, bf, b, c); +} + +#endif + +static inline int fvec_madd_and_argmin_ref (size_t n, const float *a, + float bf, const float *b, float *c) { + float vmin = 1e20; + int imin = -1; + + for (size_t i = 0; i < n; i++) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = i; + } + } + return imin; +} + +#ifdef __SSE__ + +static inline int fvec_madd_and_argmin_sse ( + size_t n, const float *a, + float bf, const float *b, float *c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1 (bf); + __m128 vmin4 = _mm_set_ps1 (1e20); + __m128i imin4 = _mm_set1_epi32 (-1); + __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0); + __m128i inc4 = _mm_set1_epi32 (4); + __m128 * a4 = (__m128*)a; + __m128 * b4 = (__m128*)b; + __m128 * c4 = (__m128*)c; + + while (n--) { + __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4)); + *c4 = vc4; + __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4); + // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! + + imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4), + _mm_andnot_si128 (mask, imin4)); + vmin4 = _mm_min_ps (vmin4, vc4); + b4++; + a4++; + c4++; + idx4 = _mm_add_epi32 (idx4, inc4); + } + + // 4 values -> 2 + { + idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2); + __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2); + __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4); + imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4), + _mm_andnot_si128 (mask, imin4)); + vmin4 = _mm_min_ps (vmin4, vc4); + } + // 2 values -> 1 + { + idx4 = _mm_shuffle_epi32 (imin4, 1); + __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1); + __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4); + imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4), + _mm_andnot_si128 (mask, imin4)); + // vmin4 = _mm_min_ps (vmin4, vc4); + } + return _mm_cvtsi128_si32 (imin4); +} + + +int fvec_madd_and_argmin (size_t n, const float *a, + float bf, const float *b, float *c) +{ + if ((n & 3) == 0 && + ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) + return fvec_madd_and_argmin_sse (n, a, bf, b, c); + else + return fvec_madd_and_argmin_ref (n, a, bf, b, c); +} + +#else + +int fvec_madd_and_argmin (size_t n, const float *a, + float bf, const float *b, float *c) +{ + return fvec_madd_and_argmin_ref (n, a, bf, b, c); +} + +#endif + + + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances_simd_avx.cpp b/core/src/index/thirdparty/faiss/utils/distances_simd_avx.cpp new file mode 100644 index 0000000000..4a3c83a89e --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances_simd_avx.cpp @@ -0,0 +1,213 @@ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include +#include + +#include + +namespace faiss { + +#ifdef __SSE__ +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read (int d, const float *x) { + assert (0 <= d && d < 4); + __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + case 2: + buf[1] = x[1]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} +#endif + +#ifdef __AVX__ + +// reads 0 <= d < 8 floats as __m256 +static inline __m256 masked_read_8 (int d, const float* x) { + assert (0 <= d && d < 8); + if (d < 4) { + __m256 res = _mm256_setzero_ps (); + res = _mm256_insertf128_ps (res, masked_read (d, x), 0); + return res; + } else { + __m256 res = _mm256_setzero_ps (); + res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0); + res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1); + return res; + } +} + +float fvec_inner_product_avx (const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float fvec_L2sqr_avx (const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b1 = mx - my; + msum1 += a_m_b1 * a_m_b1; + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float fvec_L1_avx (const float * x, const float * y, size_t d) +{ + __m256 msum1 = _mm256_setzero_ps(); + __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b = mx - my; + msum1 += _mm256_and_ps(signmask, a_m_b); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b = mx - my; + msum2 += _mm_and_ps(signmask2, a_m_b); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b = mx - my; + msum2 += _mm_and_ps(signmask2, a_m_b); + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float fvec_Linf_avx (const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b = mx - my; + msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b = mx - my; + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b = mx - my; + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); + msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1)); + return _mm_cvtss_f32 (msum2); +} + +#else + +float fvec_inner_product_avx(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float fvec_L2sqr_avx(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float fvec_L1_avx(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float fvec_Linf_avx (const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +#endif + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/distances_simd_avx512.cpp b/core/src/index/thirdparty/faiss/utils/distances_simd_avx512.cpp new file mode 100644 index 0000000000..a73d9b7da9 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/distances_simd_avx512.cpp @@ -0,0 +1,250 @@ + +// -*- c++ -*- + +#include +#include + +#include +#include +#include +#include + +#include + +namespace faiss { + +#ifdef __SSE__ +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read (int d, const float *x) { + assert (0 <= d && d < 4); + __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + case 2: + buf[1] = x[1]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} +#endif + +#if (defined(__AVX512F__) && defined(__AVX512DQ__)) + +float +fvec_inner_product_avx512(const float* x, const float* y, size_t d) { + __m512 msum0 = _mm512_setzero_ps(); + + while (d >= 16) { + __m512 mx = _mm512_loadu_ps (x); x += 16; + __m512 my = _mm512_loadu_ps (y); y += 16; + msum0 = _mm512_add_ps (msum0, _mm512_mul_ps (mx, my)); + d -= 16; + } + + __m256 msum1 = _mm512_extractf32x8_ps(msum0, 1); + msum1 += _mm512_extractf32x8_ps(msum0, 0); + + if (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my)); + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float +fvec_L2sqr_avx512(const float* x, const float* y, size_t d) { + __m512 msum0 = _mm512_setzero_ps(); + + while (d >= 16) { + __m512 mx = _mm512_loadu_ps (x); x += 16; + __m512 my = _mm512_loadu_ps (y); y += 16; + const __m512 a_m_b1 = mx - my; + msum0 += a_m_b1 * a_m_b1; + d -= 16; + } + + __m256 msum1 = _mm512_extractf32x8_ps(msum0, 1); + msum1 += _mm512_extractf32x8_ps(msum0, 0); + + if (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b1 = mx - my; + msum1 += a_m_b1 * a_m_b1; + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float +fvec_L1_avx512(const float* x, const float* y, size_t d) { + __m512 msum0 = _mm512_setzero_ps(); + __m512 signmask0 = __m512(_mm512_set1_epi32 (0x7fffffffUL)); + + while (d >= 16) { + __m512 mx = _mm512_loadu_ps (x); x += 16; + __m512 my = _mm512_loadu_ps (y); y += 16; + const __m512 a_m_b = mx - my; + msum0 += _mm512_and_ps(signmask0, a_m_b); + d -= 16; + } + + __m256 msum1 = _mm512_extractf32x8_ps(msum0, 1); + msum1 += _mm512_extractf32x8_ps(msum0, 0); + __m256 signmask1 = __m256(_mm256_set1_epi32 (0x7fffffffUL)); + + if (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b = mx - my; + msum1 += _mm256_and_ps(signmask1, a_m_b); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b = mx - my; + msum2 += _mm_and_ps(signmask2, a_m_b); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b = mx - my; + msum2 += _mm_and_ps(signmask2, a_m_b); + } + + msum2 = _mm_hadd_ps (msum2, msum2); + msum2 = _mm_hadd_ps (msum2, msum2); + return _mm_cvtss_f32 (msum2); +} + +float +fvec_Linf_avx512(const float* x, const float* y, size_t d) { + __m512 msum0 = _mm512_setzero_ps(); + __m512 signmask0 = __m512(_mm512_set1_epi32 (0x7fffffffUL)); + + while (d >= 16) { + __m512 mx = _mm512_loadu_ps (x); x += 16; + __m512 my = _mm512_loadu_ps (y); y += 16; + const __m512 a_m_b = mx - my; + msum0 = _mm512_max_ps(msum0, _mm512_and_ps(signmask0, a_m_b)); + d -= 16; + } + + __m256 msum1 = _mm512_extractf32x8_ps(msum0, 1); + msum1 = _mm256_max_ps (msum1, _mm512_extractf32x8_ps(msum0, 0)); + __m256 signmask1 = __m256(_mm256_set1_epi32 (0x7fffffffUL)); + + if (d >= 8) { + __m256 mx = _mm256_loadu_ps (x); x += 8; + __m256 my = _mm256_loadu_ps (y); y += 8; + const __m256 a_m_b = mx - my; + msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask1, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps (x); x += 4; + __m128 my = _mm_loadu_ps (y); y += 4; + const __m128 a_m_b = mx - my; + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read (d, x); + __m128 my = masked_read (d, y); + __m128 a_m_b = mx - my; + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); + msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1)); + return _mm_cvtss_f32 (msum2); +} + +#else + +float +fvec_inner_product_avx512(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float +fvec_L2sqr_avx512(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float +fvec_L1_avx512(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +float +fvec_Linf_avx512(const float* x, const float* y, size_t d) { + FAISS_ASSERT(false); + return 0.0; +} + +#endif + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/extra_distances.cpp b/core/src/index/thirdparty/faiss/utils/extra_distances.cpp new file mode 100644 index 0000000000..de03b013ac --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/extra_distances.cpp @@ -0,0 +1,374 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace faiss { + +/*************************************************************************** + * Distance functions (other than L2 and IP) + ***************************************************************************/ + +struct VectorDistanceL2 { + size_t d; + + float operator () (const float *x, const float *y) const { + return fvec_L2sqr (x, y, d); + } +}; + +struct VectorDistanceL1 { + size_t d; + + float operator () (const float *x, const float *y) const { + return fvec_L1 (x, y, d); + } +}; + +struct VectorDistanceLinf { + size_t d; + + float operator () (const float *x, const float *y) const { + return fvec_Linf (x, y, d); + /* + float vmax = 0; + for (size_t i = 0; i < d; i++) { + float diff = fabs (x[i] - y[i]); + if (diff > vmax) vmax = diff; + } + return vmax;*/ + } +}; + +struct VectorDistanceLp { + size_t d; + const float p; + + float operator () (const float *x, const float *y) const { + float accu = 0; + for (size_t i = 0; i < d; i++) { + float diff = fabs (x[i] - y[i]); + accu += powf (diff, p); + } + return accu; + } +}; + +struct VectorDistanceCanberra { + size_t d; + + float operator () (const float *x, const float *y) const { + float accu = 0; + for (size_t i = 0; i < d; i++) { + float xi = x[i], yi = y[i]; + accu += fabs (xi - yi) / (fabs(xi) + fabs(yi)); + } + return accu; + } +}; + +struct VectorDistanceBrayCurtis { + size_t d; + + float operator () (const float *x, const float *y) const { + float accu_num = 0, accu_den = 0; + for (size_t i = 0; i < d; i++) { + float xi = x[i], yi = y[i]; + accu_num += fabs (xi - yi); + accu_den += fabs (xi + yi); + } + return accu_num / accu_den; + } +}; + +struct VectorDistanceJensenShannon { + size_t d; + + float operator () (const float *x, const float *y) const { + float accu = 0; + + for (size_t i = 0; i < d; i++) { + float xi = x[i], yi = y[i]; + float mi = 0.5 * (xi + yi); + float kl1 = - xi * log(mi / xi); + float kl2 = - yi * log(mi / yi); + accu += kl1 + kl2; + } + return 0.5 * accu; + } +}; + +struct VectorDistanceJaccard { + size_t d; + + float operator () (const float *x, const float *y) const { + float accu_num = 0, accu_den = 0; + const float EPSILON = 0.000001; + for (size_t i = 0; i < d; i++) { + float xi = x[i], yi = y[i]; + if (fabs (xi - yi) < EPSILON) { + accu_num += xi; + accu_den += xi; + } else { + accu_den += xi; + accu_den += yi; + } + } + return 1 - accu_num / accu_den; + } +}; + +struct VectorDistanceTanimoto { + size_t d; + + float operator () (const float *x, const float *y) const { + float accu_num = 0, accu_den = 0; + for (size_t i = 0; i < d; i++) { + float xi = x[i], yi = y[i]; + accu_num += xi * yi; + accu_den += xi * xi + yi * yi - xi * yi; + } + return -log2(accu_num / accu_den) ; + } +}; + + + +namespace { + +template +void pairwise_extra_distances_template ( + VD vd, + int64_t nq, const float *xq, + int64_t nb, const float *xb, + float *dis, + int64_t ldq, int64_t ldb, int64_t ldd) +{ + +#pragma omp parallel for if(nq > 10) + for (int64_t i = 0; i < nq; i++) { + const float *xqi = xq + i * ldq; + const float *xbj = xb; + float *disi = dis + ldd * i; + + for (int64_t j = 0; j < nb; j++) { + disi[j] = vd (xqi, xbj); + xbj += ldb; + } + } +} + + +template +void knn_extra_metrics_template ( + VD vd, + const float * x, + const float * y, + size_t nx, size_t ny, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = res->k; + size_t d = vd.d; + size_t check_period = InterruptCallback::get_period_hint (ny * d); + check_period *= omp_get_max_threads(); + + for (size_t i0 = 0; i0 < nx; i0 += check_period) { + size_t i1 = std::min(i0 + check_period, nx); + +#pragma omp parallel for + for (size_t i = i0; i < i1; i++) { + const float * x_i = x + i * d; + const float * y_j = y; + size_t j; + float * simi = res->get_val(i); + int64_t * idxi = res->get_ids (i); + + maxheap_heapify (k, simi, idxi); + for (j = 0; j < ny; j++) { + if (!bitset || !bitset->test(j)) { + float disij = vd (x_i, y_j); + if (disij < simi[0]) { + maxheap_pop (k, simi, idxi); + maxheap_push (k, simi, idxi, disij, j); + } + } + y_j += d; + } + maxheap_reorder (k, simi, idxi); + } + InterruptCallback::check (); + } + +} + + +template +struct ExtraDistanceComputer : DistanceComputer { + VD vd; + Index::idx_t nb; + const float *q; + const float *b; + + float operator () (idx_t i) override { + return vd (q, b + i * vd.d); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return vd (b + j * vd.d, b + i * vd.d); + } + + ExtraDistanceComputer(const VD & vd, const float *xb, + size_t nb, const float *q = nullptr) + : vd(vd), nb(nb), q(q), b(xb) {} + + void set_query(const float *x) override { + q = x; + } +}; + +} // anonymous namespace + +void pairwise_extra_distances ( + int64_t d, + int64_t nq, const float *xq, + int64_t nb, const float *xb, + MetricType mt, float metric_arg, + float *dis, + int64_t ldq, int64_t ldb, int64_t ldd) +{ + if (nq == 0 || nb == 0) return; + if (ldq == -1) ldq = d; + if (ldb == -1) ldb = d; + if (ldd == -1) ldd = nb; + + switch(mt) { +#define HANDLE_VAR(kw) \ + case METRIC_ ## kw: { \ + VectorDistance ## kw vd({(size_t)d}); \ + pairwise_extra_distances_template (vd, nq, xq, nb, xb, \ + dis, ldq, ldb, ldd); \ + break; \ + } + HANDLE_VAR(L2); + HANDLE_VAR(L1); + HANDLE_VAR(Linf); + HANDLE_VAR(Canberra); + HANDLE_VAR(BrayCurtis); + HANDLE_VAR(JensenShannon); +#undef HANDLE_VAR + case METRIC_Lp: { + VectorDistanceLp vd({(size_t)d, metric_arg}); + pairwise_extra_distances_template (vd, nq, xq, nb, xb, + dis, ldq, ldb, ldd); + break; + } + case METRIC_Jaccard: { + VectorDistanceJaccard vd({(size_t) d}); + pairwise_extra_distances_template(vd, nq, xq, nb, xb, + dis, ldq, ldb, ldd); + break; + } + case METRIC_Tanimoto: { + VectorDistanceTanimoto vd({(size_t) d}); + pairwise_extra_distances_template(vd, nq, xq, nb, xb, + dis, ldq, ldb, ldd); + break; + } + default: + FAISS_THROW_MSG ("metric type not implemented"); + } + +} + +void knn_extra_metrics ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + MetricType mt, float metric_arg, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset) +{ + + switch(mt) { +#define HANDLE_VAR(kw) \ + case METRIC_ ## kw: { \ + VectorDistance ## kw vd({(size_t)d}); \ + knn_extra_metrics_template (vd, x, y, nx, ny, res, bitset); \ + break; \ + } + HANDLE_VAR(L2); + HANDLE_VAR(L1); + HANDLE_VAR(Linf); + HANDLE_VAR(Canberra); + HANDLE_VAR(BrayCurtis); + HANDLE_VAR(JensenShannon); +#undef HANDLE_VAR + case METRIC_Lp: { + VectorDistanceLp vd({(size_t)d, metric_arg}); + knn_extra_metrics_template (vd, x, y, nx, ny, res, bitset); + break; + } + case METRIC_Jaccard: { + VectorDistanceJaccard vd({(size_t) d}); + knn_extra_metrics_template(vd, x, y, nx, ny, res, bitset); + break; + } + case METRIC_Tanimoto: { + VectorDistanceTanimoto vd({(size_t) d}); + knn_extra_metrics_template(vd, x, y, nx, ny, res, bitset); + break; + } + default: + FAISS_THROW_MSG ("metric type not implemented"); + } + +} + +DistanceComputer *get_extra_distance_computer ( + size_t d, + MetricType mt, float metric_arg, + size_t nb, const float *xb) +{ + + switch(mt) { +#define HANDLE_VAR(kw) \ + case METRIC_ ## kw: { \ + VectorDistance ## kw vd({(size_t)d}); \ + return new ExtraDistanceComputer(vd, xb, nb); \ + } + HANDLE_VAR(L2); + HANDLE_VAR(L1); + HANDLE_VAR(Linf); + HANDLE_VAR(Canberra); + HANDLE_VAR(BrayCurtis); + HANDLE_VAR(JensenShannon); +#undef HANDLE_VAR + case METRIC_Lp: { + VectorDistanceLp vd({(size_t)d, metric_arg}); + return new ExtraDistanceComputer (vd, xb, nb); + break; + } + default: + FAISS_THROW_MSG ("metric type not implemented"); + } + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/extra_distances.h b/core/src/index/thirdparty/faiss/utils/extra_distances.h new file mode 100644 index 0000000000..2ac60aa6ac --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/extra_distances.h @@ -0,0 +1,55 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_distances_h +#define FAISS_distances_h + +/** In this file are the implementations of extra metrics beyond L2 + * and inner product */ + +#include + +#include + +#include + + + +namespace faiss { + + +void pairwise_extra_distances ( + int64_t d, + int64_t nq, const float *xq, + int64_t nb, const float *xb, + MetricType mt, float metric_arg, + float *dis, + int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1); + + +void knn_extra_metrics ( + const float * x, + const float * y, + size_t d, size_t nx, size_t ny, + MetricType mt, float metric_arg, + float_maxheap_array_t * res, + ConcurrentBitsetPtr bitset = nullptr); + + +/** get a DistanceComputer that refers to this type of distance and + * indexes a flat array of size nb */ +DistanceComputer *get_extra_distance_computer ( + size_t d, + MetricType mt, float metric_arg, + size_t nb, const float *xb); + +} + + +#endif diff --git a/core/src/index/thirdparty/faiss/utils/hamming-inl.h b/core/src/index/thirdparty/faiss/utils/hamming-inl.h new file mode 100644 index 0000000000..861e1f4308 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/hamming-inl.h @@ -0,0 +1,472 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +namespace faiss { + + +inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size): + code (code), code_size (code_size), i(0) +{ + bzero (code, code_size); +} + +inline void BitstringWriter::write(uint64_t x, int nbit) { + assert (code_size * 8 >= nbit + i); + // nb of available bits in i / 8 + int na = 8 - (i & 7); + + if (nbit <= na) { + code[i >> 3] |= x << (i & 7); + i += nbit; + return; + } else { + int j = i >> 3; + code[j++] |= x << (i & 7); + i += nbit; + x >>= na; + while (x != 0) { + code[j++] |= x; + x >>= 8; + } + } +} + + +inline BitstringReader::BitstringReader(const uint8_t *code, int code_size): + code (code), code_size (code_size), i(0) +{} + +inline uint64_t BitstringReader::read(int nbit) { + assert (code_size * 8 >= nbit + i); + // nb of available bits in i / 8 + int na = 8 - (i & 7); + // get available bits in current byte + uint64_t res = code[i >> 3] >> (i & 7); + if (nbit <= na) { + res &= (1 << nbit) - 1; + i += nbit; + return res; + } else { + int ofs = na; + int j = (i >> 3) + 1; + i += nbit; + nbit -= na; + while (nbit > 8) { + res |= ((uint64_t)code[j++]) << ofs; + ofs += 8; + nbit -= 8; // TODO remove nbit + } + uint64_t last_byte = code[j]; + last_byte &= (1 << nbit) - 1; + res |= last_byte << ofs; + return res; + } +} + + +/****************************************************************** + * The HammingComputer series of classes compares a single code of + * size 4 to 32 to incoming codes. They are intended for use as a + * template class where it would be inefficient to switch on the code + * size in the inner loop. Hopefully the compiler will inline the + * hamming() functions and put the a0, a1, ... in registers. + ******************************************************************/ + + +struct HammingComputer4 { + uint32_t a0; + + HammingComputer4 () {} + + HammingComputer4 (const uint8_t *a, int code_size) { + set (a, code_size); + } + + void set (const uint8_t *a, int code_size) { + assert (code_size == 4); + a0 = *(uint32_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return popcount64 (*(uint32_t *)b ^ a0); + } + +}; + +struct HammingComputer8 { + uint64_t a0; + + HammingComputer8 () {} + + HammingComputer8 (const uint8_t *a, int code_size) { + set (a, code_size); + } + + void set (const uint8_t *a, int code_size) { + assert (code_size == 8); + a0 = *(uint64_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return popcount64 (*(uint64_t *)b ^ a0); + } + +}; + + +struct HammingComputer16 { + uint64_t a0, a1; + + HammingComputer16 () {} + + HammingComputer16 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1); + } + +}; + +// when applied to an array, 1/2 of the 64-bit accesses are unaligned. +// This incurs a penalty of ~10% wrt. fully aligned accesses. +struct HammingComputer20 { + uint64_t a0, a1; + uint32_t a2; + + HammingComputer20 () {} + + HammingComputer20 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 20); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (*(uint32_t*)(b + 2) ^ a2); + } +}; + +struct HammingComputer32 { + uint64_t a0, a1, a2, a3; + + HammingComputer32 () {} + + HammingComputer32 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3); + } + +}; + +struct HammingComputer64 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7; + + HammingComputer64 () {} + + HammingComputer64 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 64); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3) + + popcount64 (b[4] ^ a4) + popcount64 (b[5] ^ a5) + + popcount64 (b[6] ^ a6) + popcount64 (b[7] ^ a7); + } + +}; + +// very inefficient... +struct HammingComputerDefault { + const uint8_t *a; + int n; + + HammingComputerDefault () {} + + HammingComputerDefault (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + a = a8; + n = code_size; + } + + int hamming (const uint8_t *b8) const { + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b8[i]); + return accu; + } + +}; + +struct HammingComputerM8 { + const uint64_t *a; + int n; + + HammingComputerM8 () {} + + HammingComputerM8 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size % 8 == 0); + a = (uint64_t *)a8; + n = code_size / 8; + } + + int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b[i]); + return accu; + } + +}; + +// even more inefficient! +struct HammingComputerM4 { + const uint32_t *a; + int n; + + HammingComputerM4 () {} + + HammingComputerM4 (const uint8_t *a4, int code_size) { + set (a4, code_size); + } + + void set (const uint8_t *a4, int code_size) { + assert (code_size % 4 == 0); + a = (uint32_t *)a4; + n = code_size / 4; + } + + int hamming (const uint8_t *b8) const { + const uint32_t *b = (uint32_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b[i]); + return accu; + } + +}; + +/*************************************************************************** + * Equivalence with a template class when code size is known at compile time + **************************************************************************/ + +// default template +template +struct HammingComputer: HammingComputerM8 { + HammingComputer (const uint8_t *a, int code_size): + HammingComputerM8(a, code_size) {} +}; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template<> struct HammingComputer: \ + HammingComputer ## CODE_SIZE { \ + HammingComputer (const uint8_t *a): \ + HammingComputer ## CODE_SIZE(a, CODE_SIZE) {} \ + } + +SPECIALIZED_HC(4); +SPECIALIZED_HC(8); +SPECIALIZED_HC(16); +SPECIALIZED_HC(20); +SPECIALIZED_HC(32); +SPECIALIZED_HC(64); + +#undef SPECIALIZED_HC + + +/*************************************************************************** + * generalized Hamming = number of bytes that are different between + * two codes. + ***************************************************************************/ + + +inline int generalized_hamming_64 (uint64_t a) { + a |= a >> 1; + a |= a >> 2; + a |= a >> 4; + a &= 0x0101010101010101UL; + return popcount64 (a); +} + + +struct GenHammingComputer8 { + uint64_t a0; + + GenHammingComputer8 (const uint8_t *a, int code_size) { + assert (code_size == 8); + a0 = *(uint64_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return generalized_hamming_64 (*(uint64_t *)b ^ a0); + } + +}; + + +struct GenHammingComputer16 { + uint64_t a0, a1; + GenHammingComputer16 (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return generalized_hamming_64 (b[0] ^ a0) + + generalized_hamming_64 (b[1] ^ a1); + } + +}; + +struct GenHammingComputer32 { + uint64_t a0, a1, a2, a3; + + GenHammingComputer32 (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return generalized_hamming_64 (b[0] ^ a0) + + generalized_hamming_64 (b[1] ^ a1) + + generalized_hamming_64 (b[2] ^ a2) + + generalized_hamming_64 (b[3] ^ a3); + } + +}; + +struct GenHammingComputerM8 { + const uint64_t *a; + int n; + + GenHammingComputerM8 (const uint8_t *a8, int code_size) { + assert (code_size % 8 == 0); + a = (uint64_t *)a8; + n = code_size / 8; + } + + int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += generalized_hamming_64 (a[i] ^ b[i]); + return accu; + } + +}; + + +/** generalized Hamming distances (= count number of code bytes that + are the same) */ +void generalized_hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t code_size, + int ordered = true); + + + +/** This class maintains a list of best distances seen so far. + * + * Since the distances are in a limited range (0 to nbit), the + * object maintains one list per possible distance, and fills + * in only the n-first lists, such that the sum of sizes of the + * n lists is below k. + */ +template +struct HCounterState { + int *counters; + int64_t *ids_per_dis; + + HammingComputer hc; + int thres; + int count_lt; + int count_eq; + int k; + + HCounterState(int *counters, int64_t *ids_per_dis, + const uint8_t *x, int d, int k) + : counters(counters), + ids_per_dis(ids_per_dis), + hc(x, d / 8), + thres(d + 1), + count_lt(0), + count_eq(0), + k(k) {} + + void update_counter(const uint8_t *y, size_t j) { + int32_t dis = hc.hamming(y); + + if (dis <= thres) { + if (dis < thres) { + ids_per_dis[dis * k + counters[dis]++] = j; + ++count_lt; + while (count_lt == k && thres > 0) { + --thres; + count_eq = counters[thres]; + count_lt -= count_eq; + } + } else if (count_eq < k) { + ids_per_dis[dis * k + count_eq++] = j; + counters[dis] = count_eq; + } + } + } +}; + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/hamming.cpp b/core/src/index/thirdparty/faiss/utils/hamming.cpp new file mode 100644 index 0000000000..2e714c34b4 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/hamming.cpp @@ -0,0 +1,981 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * Implementation of Hamming related functions (distances, smallest distance + * selection with regular heap|radix and probabilistic heap|radix. + * + * IMPLEMENTATION NOTES + * Bitvectors are generally assumed to be multiples of 64 bits. + * + * hamdis_t is used for distances because at this time + * it is not clear how we will need to balance + * - flexibility in vector size (unclear more than 2^16 or even 2^8 bitvectors) + * - memory usage + * - cache-misses when dealing with large volumes of data (lower bits is better) + * + * The hamdis_t should optimally be compatibe with one of the Torch Storage + * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes +*/ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static const size_t BLOCKSIZE_QUERY = 8192; +static const size_t size_1M = 1 * 1024 * 1024; + +namespace faiss { + +size_t hamming_batch_size = 65536; + +static const uint8_t hamdis_tab_ham_bytes[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8 +}; + + +/* Elementary Hamming distance computation: unoptimized */ +template +T hamming (const uint8_t *bs1, + const uint8_t *bs2) +{ + const size_t nbytes = nbits / 8; + size_t i; + T h = 0; + for (i = 0; i < nbytes; i++) + h += (T) hamdis_tab_ham_bytes[bs1[i]^bs2[i]]; + return h; +} + + +/* Hamming distances for multiples of 64 bits */ +template +hamdis_t hamming (const uint64_t * bs1, const uint64_t * bs2) +{ + const size_t nwords = nbits / 64; + size_t i; + hamdis_t h = 0; + for (i = 0; i < nwords; i++) + h += popcount64 (bs1[i] ^ bs2[i]); + return h; +} + + + +/* specialized (optimized) functions */ +template <> +hamdis_t hamming<64> (const uint64_t * pa, const uint64_t * pb) +{ + return popcount64 (pa[0] ^ pb[0]); +} + + +template <> +hamdis_t hamming<128> (const uint64_t *pa, const uint64_t *pb) +{ + return popcount64 (pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]); +} + + +template <> +hamdis_t hamming<256> (const uint64_t * pa, const uint64_t * pb) +{ + return popcount64 (pa[0] ^ pb[0]) + + popcount64 (pa[1] ^ pb[1]) + + popcount64 (pa[2] ^ pb[2]) + + popcount64 (pa[3] ^ pb[3]); +} + + +/* Hamming distances for multiple of 64 bits */ +hamdis_t hamming ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t nwords) +{ + size_t i; + hamdis_t h = 0; + for (i = 0; i < nwords; i++) + h += popcount64 (bs1[i] ^ bs2[i]); + return h; +} + + + +template +void hammings ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, size_t n2, + hamdis_t * dis) + +{ + size_t i, j; + const size_t nwords = nbits / 64; + for (i = 0; i < n1; i++) { + const uint64_t * __restrict bs1_ = bs1 + i * nwords; + hamdis_t * __restrict dis_ = dis + i * n2; + for (j = 0; j < n2; j++) + dis_[j] = hamming(bs1_, bs2 + j * nwords); + } +} + + + +void hammings ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + size_t nwords, + hamdis_t * __restrict dis) +{ + size_t i, j; + n1 *= nwords; + n2 *= nwords; + for (i = 0; i < n1; i+=nwords) { + const uint64_t * bs1_ = bs1+i; + for (j = 0; j < n2; j+=nwords) + dis[j] = hamming (bs1_, bs2+j, nwords); + } +} + + + + +/* Count number of matches given a max threshold */ +template +void hamming_count_thres ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t * nptr) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + const uint64_t * bs2_ = bs2; + + for (i = 0; i < n1; i++) { + bs2 = bs2_; + for (j = 0; j < n2; j++) { + /* collect the match only if this satisfies the threshold */ + if (hamming (bs1, bs2) <= ht) + posm++; + bs2 += nwords; + } + bs1 += nwords; /* next signature */ + } + *nptr = posm; +} + + +template +void crosshamming_count_thres ( + const uint64_t * dbs, + size_t n, + int ht, + size_t * nptr) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + const uint64_t * bs1 = dbs; + for (i = 0; i < n; i++) { + const uint64_t * bs2 = bs1 + 2; + for (j = i + 1; j < n; j++) { + /* collect the match only if this satisfies the threshold */ + if (hamming (bs1, bs2) <= ht) + posm++; + bs2 += nwords; + } + bs1 += nwords; + } + *nptr = posm; +} + + +template +size_t match_hamming_thres ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + int ht, + int64_t * idx, + hamdis_t * hams) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + hamdis_t h; + const uint64_t * bs2_ = bs2; + for (i = 0; i < n1; i++) { + bs2 = bs2_; + for (j = 0; j < n2; j++) { + /* Here perform the real work of computing the distance */ + h = hamming (bs1, bs2); + + /* collect the match only if this satisfies the threshold */ + if (h <= ht) { + /* Enough space to store another match ? */ + *idx = i; idx++; + *idx = j; idx++; + *hams = h; + hams++; + posm++; + } + bs2+=nwords; /* next signature */ + } + bs1+=nwords; + } + return posm; +} + + +/* Return closest neighbors w.r.t Hamming distance, using a heap. */ +template +static +void hammings_knn_hc ( + int bytes_per_code, + int_maxheap_array_t * ha, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true, + ConcurrentBitsetPtr bitset = nullptr) +{ + size_t k = ha->k; + + if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) { + int thread_max_num = omp_get_max_threads(); + // init heap + size_t thread_heap_size = ha->nh * k; + size_t all_heap_size = thread_heap_size * thread_max_num; + hamdis_t *value = new hamdis_t[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + for (int i = 0; i < all_heap_size; i++) { + value[i] = 0x7fffffff; + labels[i] = -1; + } + + HammingComputer *hc = new HammingComputer[ha->nh]; + for (size_t i = 0; i < ha->nh; i++) { + hc[i].set(bs1 + i * bytes_per_code, bytes_per_code); + } + +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + int thread_no = omp_get_thread_num(); + + const uint8_t * bs2_ = bs2 + j * bytes_per_code; + for (size_t i = 0; i < ha->nh; i++) { + hamdis_t dis = hc[i].hamming (bs2_); + + hamdis_t * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; + if (dis < val_[0]) { + faiss::maxheap_swap_top (k, val_, ids_, dis, j); + } + } + } + } + + for (size_t t = 1; t < thread_max_num; t++) { + // merge heap + for (size_t i = 0; i < ha->nh; i++) { + hamdis_t * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; + hamdis_t *value_x_t = value_x + t * thread_heap_size; + int64_t *labels_x_t = labels_x + t * thread_heap_size; + for (size_t j = 0; j < k; j++) { + if (value_x_t[j] < value_x[0]) { + faiss::maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]); + } + } + } + } + + // copy result + memcpy(ha->val, value, thread_heap_size * sizeof(hamdis_t)); + memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t)); + + delete[] hc; + delete[] value; + delete[] labels; + + } else { + if (init_heap) ha->heapify (); + const size_t block_size = hamming_batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); +#pragma omp parallel for + for (size_t i = 0; i < ha->nh; i++) { + HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code); + + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + hamdis_t dis; + hamdis_t * __restrict bh_val_ = ha->val + i * k; + int64_t * __restrict bh_ids_ = ha->ids + i * k; + size_t j; + for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { + if(!bitset || !bitset->test(j)){ + dis = hc.hamming (bs2_); + if (dis < bh_val_[0]) { + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); + } + } + } + } + } + } + if (order) ha->reorder (); +} + +/* Return closest neighbors w.r.t Hamming distance, using max count. */ +template +static +void hammings_knn_mc ( + int bytes_per_code, + const uint8_t *a, + const uint8_t *b, + size_t na, + size_t nb, + size_t k, + int32_t *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset = nullptr) +{ + const int nBuckets = bytes_per_code * 8 + 1; + std::vector all_counters(na * nBuckets, 0); + std::unique_ptr all_ids_per_dis(new int64_t[na * nBuckets * k]); + + std::vector> cs; + for (size_t i = 0; i < na; ++i) { + cs.push_back(HCounterState( + all_counters.data() + i * nBuckets, + all_ids_per_dis.get() + i * nBuckets * k, + a + i * bytes_per_code, + 8 * bytes_per_code, + k + )); + } + + const size_t block_size = hamming_batch_size; + for (size_t j0 = 0; j0 < nb; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, nb); +#pragma omp parallel for + for (size_t i = 0; i < na; ++i) { + for (size_t j = j0; j < j1; ++j) { + if (!bitset || !bitset->test(j)) { + cs[i].update_counter(b + j * bytes_per_code, j); + } + } + } + } + + for (size_t i = 0; i < na; ++i) { + HCounterState& csi = cs[i]; + + int nres = 0; + for (int b = 0; b < nBuckets && nres < k; b++) { + for (int l = 0; l < csi.counters[b] && nres < k; l++) { + labels[i * k + nres] = csi.ids_per_dis[b * k + l]; + distances[i * k + nres] = b; + nres++; + } + } + while (nres < k) { + labels[i * k + nres] = -1; + distances[i * k + nres] = std::numeric_limits::max(); + ++nres; + } + } +} + + + +// works faster than the template version +static +void hammings_knn_hc_1 ( + int_maxheap_array_t * ha, + const uint64_t * bs1, + const uint64_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true, + ConcurrentBitsetPtr bitset = nullptr) +{ + const size_t nwords = 1; + size_t k = ha->k; + + if (init_heap) { + ha->heapify (); + } + + int thread_max_num = omp_get_max_threads(); + if (ha->nh == 1) { + // omp for n2 + int all_heap_size = thread_max_num * k; + hamdis_t *value = new hamdis_t[all_heap_size]; + int64_t *labels = new int64_t[all_heap_size]; + + // init heap + for (int i = 0; i < all_heap_size; i++) { + value[i] = 0x7fffffff; + } + const uint64_t bs1_ = bs1[0]; +#pragma omp parallel for + for (size_t j = 0; j < n2; j++) { + if(!bitset || !bitset->test(j)) { + hamdis_t dis = popcount64 (bs1_ ^ bs2[j]); + + int thread_no = omp_get_thread_num(); + hamdis_t * __restrict val_ = value + thread_no * k; + int64_t * __restrict ids_ = labels + thread_no * k; + if (dis < val_[0]) { + faiss::maxheap_swap_top (k, val_, ids_, dis, j); + } + } + } + // merge heap + hamdis_t * __restrict bh_val_ = ha->val; + int64_t * __restrict bh_ids_ = ha->ids; + for (int i = 0; i < all_heap_size; i++) { + if (value[i] < bh_val_[0]) { + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, value[i], labels[i]); + } + } + + delete[] value; + delete[] labels; + + } else { +#pragma omp parallel for + for (size_t i = 0; i < ha->nh; i++) { + const uint64_t bs1_ = bs1 [i]; + const uint64_t * bs2_ = bs2; + hamdis_t dis; + hamdis_t * bh_val_ = ha->val + i * k; + hamdis_t bh_val_0 = bh_val_[0]; + int64_t * bh_ids_ = ha->ids + i * k; + size_t j; + for (j = 0; j < n2; j++, bs2_+= nwords) { + if(!bitset || !bitset->test(j)){ + dis = popcount64 (bs1_ ^ *bs2_); + if (dis < bh_val_0) { + faiss::maxheap_swap_top (k, bh_val_, bh_ids_, dis, j); + bh_val_0 = bh_val_[0]; + } + } + } + } + } + if (order) { + ha->reorder (); + } +} + + + + +/* Functions to maps vectors to bits. Assume proper allocation done beforehand, + meaning that b should be be able to receive as many bits as x may produce. */ + +/* + * dimension 0 corresponds to the least significant bit of b[0], or + * equivalently to the lsb of the first byte that is stored. + */ +void fvec2bitvec (const float * x, uint8_t * b, size_t d) +{ + for (int i = 0; i < d; i += 8) { + uint8_t w = 0; + uint8_t mask = 1; + int nj = i + 8 <= d ? 8 : d - i; + for (int j = 0; j < nj; j++) { + if (x[i + j] >= 0) + w |= mask; + mask <<= 1; + } + *b = w; + b++; + } +} + + + +/* Same but for n vectors. + Ensure that the ouptut b is byte-aligned (pad with 0s). */ +void fvecs2bitvecs (const float * x, uint8_t * b, size_t d, size_t n) +{ + const int64_t ncodes = ((d + 7) / 8); +#pragma omp parallel for if(n > 100000) + for (size_t i = 0; i < n; i++) + fvec2bitvec (x + i * d, b + i * ncodes, d); +} + + + +void bitvecs2fvecs ( + const uint8_t * b, + float * x, + size_t d, + size_t n) { + + const int64_t ncodes = ((d + 7) / 8); +#pragma omp parallel for if(n > 100000) + for (size_t i = 0; i < n; i++) { + binary_to_real (d, b + i * ncodes, x + i * d); + } +} + + +/* Reverse bit (NOT a optimized function, only used for print purpose) */ +static uint64_t uint64_reverse_bits (uint64_t b) +{ + int i; + uint64_t revb = 0; + for (i = 0; i < 64; i++) { + revb <<= 1; + revb |= b & 1; + b >>= 1; + } + return revb; +} + + +/* print the bit vector */ +void bitvec_print (const uint8_t * b, size_t d) +{ + size_t i, j; + for (i = 0; i < d; ) { + uint64_t brev = uint64_reverse_bits (* (uint64_t *) b); + for (j = 0; j < 64 && i < d; j++, i++) { + printf ("%d", (int) (brev & 1)); + brev >>= 1; + } + b += 8; + printf (" "); + } +} + + +void bitvec_shuffle (size_t n, size_t da, size_t db, + const int *order, + const uint8_t *a, + uint8_t *b) +{ + for(size_t i = 0; i < db; i++) { + FAISS_THROW_IF_NOT (order[i] >= 0 && order[i] < da); + } + size_t lda = (da + 7) / 8; + size_t ldb = (db + 7) / 8; + +#pragma omp parallel for if(n > 10000) + for (size_t i = 0; i < n; i++) { + const uint8_t *ai = a + i * lda; + uint8_t *bi = b + i * ldb; + memset (bi, 0, ldb); + for(size_t i = 0; i < db; i++) { + int o = order[i]; + uint8_t the_bit = (ai[o >> 3] >> (o & 7)) & 1; + bi[i >> 3] |= the_bit << (i & 7); + } + } + +} + + + +/*----------------------------------------*/ +/* Hamming distance computation and k-nn */ + + +#define C64(x) ((uint64_t *)x) + + +/* Compute a set of Hamming distances */ +void hammings ( + const uint8_t * a, + const uint8_t * b, + size_t na, size_t nb, + size_t ncodes, + hamdis_t * __restrict dis) +{ + FAISS_THROW_IF_NOT (ncodes % 8 == 0); + switch (ncodes) { + case 8: + faiss::hammings <64> (C64(a), C64(b), na, nb, dis); return; + case 16: + faiss::hammings <128> (C64(a), C64(b), na, nb, dis); return; + case 32: + faiss::hammings <256> (C64(a), C64(b), na, nb, dis); return; + case 64: + faiss::hammings <512> (C64(a), C64(b), na, nb, dis); return; + default: + faiss::hammings (C64(a), C64(b), na, nb, ncodes * 8, dis); return; + } +} + +void hammings_knn( + int_maxheap_array_t *ha, + const uint8_t *a, + const uint8_t *b, + size_t nb, + size_t ncodes, + int order) +{ + hammings_knn_hc(ha, a, b, nb, ncodes, order); +} + +void hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int order, + ConcurrentBitsetPtr bitset) +{ + switch (ncodes) { + case 4: + hammings_knn_hc + (4, ha, a, b, nb, order, true, bitset); + break; + case 8: + hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true, bitset); + // hammings_knn_hc + // (8, ha, a, b, nb, order, true); + break; + case 16: + hammings_knn_hc + (16, ha, a, b, nb, order, true, bitset); + break; + case 32: + hammings_knn_hc + (32, ha, a, b, nb, order, true, bitset); + break; + default: + if(ncodes % 8 == 0) { + hammings_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); + } else { + hammings_knn_hc + (ncodes, ha, a, b, nb, order, true, bitset); + + } + } +} + +void hammings_knn_mc( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + int32_t *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset) +{ + switch (ncodes) { + case 4: + hammings_knn_mc( + 4, a, b, na, nb, k, distances, labels, bitset + ); + break; + case 8: + // TODO(hoss): Write analog to hammings_knn_hc_1 + // hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true); + hammings_knn_mc( + 8, a, b, na, nb, k, distances, labels, bitset + ); + break; + case 16: + hammings_knn_mc( + 16, a, b, na, nb, k, distances, labels, bitset + ); + break; + case 32: + hammings_knn_mc( + 32, a, b, na, nb, k, distances, labels, bitset + ); + break; + default: + if(ncodes % 8 == 0) { + hammings_knn_mc( + ncodes, a, b, na, nb, k, distances, labels, bitset + ); + } else { + hammings_knn_mc( + ncodes, a, b, na, nb, k, distances, labels, bitset + ); + } + } +} +template +static +void hamming_range_search_template ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult *res) +{ + +#pragma omp parallel + { + RangeSearchPartialResult pres (res); + +#pragma omp for + for (size_t i = 0; i < na; i++) { + HammingComputer hc (a + i * code_size, code_size); + const uint8_t * yi = b; + RangeQueryResult & qres = pres.new_result (i); + + for (size_t j = 0; j < nb; j++) { + int dis = hc.hamming (yi); + if (dis < radius) { + qres.add(dis, j); + } + yi += code_size; + } + } + pres.finalize (); + } +} + +void hamming_range_search ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult *result) +{ + +#define HC(name) hamming_range_search_template (a, b, na, nb, radius, code_size, result) + + switch(code_size) { + case 4: HC(HammingComputer4); break; + case 8: HC(HammingComputer8); break; + case 16: HC(HammingComputer16); break; + case 32: HC(HammingComputer32); break; + default: + if (code_size % 8 == 0) { + HC(HammingComputerM8); + } else { + HC(HammingComputerDefault); + } + } +#undef HC +} + + + +/* Count number of matches given a max threshold */ +void hamming_count_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + size_t * nptr) +{ + switch (ncodes) { + case 8: + faiss::hamming_count_thres <64> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 16: + faiss::hamming_count_thres <128> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 32: + faiss::hamming_count_thres <256> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 64: + faiss::hamming_count_thres <512> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + default: + FAISS_THROW_FMT ("not implemented for %zu bits", ncodes); + } +} + + +/* Count number of cross-matches given a threshold */ +void crosshamming_count_thres ( + const uint8_t * dbs, + size_t n, + hamdis_t ht, + size_t ncodes, + size_t * nptr) +{ + switch (ncodes) { + case 8: + faiss::crosshamming_count_thres <64> (C64(dbs), n, ht, nptr); + return; + case 16: + faiss::crosshamming_count_thres <128> (C64(dbs), n, ht, nptr); + return; + case 32: + faiss::crosshamming_count_thres <256> (C64(dbs), n, ht, nptr); + return; + case 64: + faiss::crosshamming_count_thres <512> (C64(dbs), n, ht, nptr); + return; + default: + FAISS_THROW_FMT ("not implemented for %zu bits", ncodes); + } +} + + +/* Returns all matches given a threshold */ +size_t match_hamming_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + int64_t * idx, + hamdis_t * dis) +{ + switch (ncodes) { + case 8: + return faiss::match_hamming_thres <64> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 16: + return faiss::match_hamming_thres <128> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 32: + return faiss::match_hamming_thres <256> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 64: + return faiss::match_hamming_thres <512> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + default: + FAISS_THROW_FMT ("not implemented for %zu bits", ncodes); + return 0; + } +} + + +#undef C64 + + + +/************************************* + * generalized Hamming distances + ************************************/ + + + +template +static void hamming_dis_inner_loop ( + const uint8_t *ca, + const uint8_t *cb, + size_t nb, + size_t code_size, + int k, + hamdis_t * bh_val_, + int64_t * bh_ids_) +{ + + HammingComputer hc (ca, code_size); + + for (size_t j = 0; j < nb; j++) { + int ndiff = hc.hamming (cb); + cb += code_size; + if (ndiff < bh_val_[0]) { + maxheap_swap_top (k, bh_val_, bh_ids_, ndiff, j); + } + } +} + +void generalized_hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t code_size, + int ordered) +{ + int na = ha->nh; + int k = ha->k; + + if (ordered) + ha->heapify (); + +#pragma omp parallel for + for (int i = 0; i < na; i++) { + const uint8_t *ca = a + i * code_size; + const uint8_t *cb = b; + + hamdis_t * bh_val_ = ha->val + i * k; + int64_t * bh_ids_ = ha->ids + i * k; + + switch (code_size) { + case 8: + hamming_dis_inner_loop + (ca, cb, nb, 8, k, bh_val_, bh_ids_); + break; + case 16: + hamming_dis_inner_loop + (ca, cb, nb, 16, k, bh_val_, bh_ids_); + break; + case 32: + hamming_dis_inner_loop + (ca, cb, nb, 32, k, bh_val_, bh_ids_); + break; + default: + hamming_dis_inner_loop + (ca, cb, nb, code_size, k, bh_val_, bh_ids_); + break; + } + } + + if (ordered) + ha->reorder (); + +} + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/hamming.h b/core/src/index/thirdparty/faiss/utils/hamming.h new file mode 100644 index 0000000000..38bdd651f2 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/hamming.h @@ -0,0 +1,243 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * Hamming distances. The binary vector dimensionality should be a + * multiple of 8, as the elementary operations operate on bytes. If + * you need other sizes, just pad with 0s (this is done by function + * fvecs2bitvecs). + * + * User-defined type hamdis_t is used for distances because at this time + * it is still uncler clear how we will need to balance + * - flexibility in vector size (may need 16- or even 8-bit vectors) + * - memory usage + * - cache-misses when dealing with large volumes of data (fewer bits is better) + * + */ + +#ifndef FAISS_hamming_h +#define FAISS_hamming_h + + +#include + +#include +#include + +/* The Hamming distance type */ +typedef int32_t hamdis_t; + +namespace faiss { + +/************************************************** + * General bit vector functions + **************************************************/ + +struct RangeSearchResult; + +void bitvec_print (const uint8_t * b, size_t d); + + +/* Functions for casting vectors of regular types to compact bits. + They assume proper allocation done beforehand, meaning that b + should be be able to receive as many bits as x may produce. */ + +/* Makes an array of bits from the signs of a float array. The length + of the output array b is rounded up to byte size (allocate + accordingly) */ +void fvecs2bitvecs ( + const float * x, + uint8_t * b, + size_t d, + size_t n); + +void bitvecs2fvecs ( + const uint8_t * b, + float * x, + size_t d, + size_t n); + + +void fvec2bitvec (const float * x, uint8_t * b, size_t d); + +/** Shuffle the bits from b(i, j) := a(i, order[j]) + */ +void bitvec_shuffle (size_t n, size_t da, size_t db, + const int *order, + const uint8_t *a, + uint8_t *b); + + +/*********************************************** + * Generic reader/writer for bit strings + ***********************************************/ + + +struct BitstringWriter { + uint8_t *code; + size_t code_size; + size_t i; // current bit offset + + // code_size in bytes + BitstringWriter(uint8_t *code, int code_size); + + // write the nbit low bits of x + void write(uint64_t x, int nbit); +}; + +struct BitstringReader { + const uint8_t *code; + size_t code_size; + size_t i; + + // code_size in bytes + BitstringReader(const uint8_t *code, int code_size); + + // read nbit bits from the code + uint64_t read(int nbit); +}; + +/************************************************** + * Hamming distance computation functions + **************************************************/ + + + +extern size_t hamming_batch_size; + +inline int popcount64(uint64_t x) { + return __builtin_popcountl(x); +} + + +/** Compute a set of Hamming distances between na and nb binary vectors + * + * @param a size na * nbytespercode + * @param b size nb * nbytespercode + * @param nbytespercode should be multiple of 8 + * @param dis output distances, size na * nb + */ +void hammings ( + const uint8_t * a, + const uint8_t * b, + size_t na, size_t nb, + size_t nbytespercode, + hamdis_t * dis); + + + + +/** Return the k smallest Hamming distances for a set of binary query vectors, + * using a max heap. + * @param a queries, size ha->nh * ncodes + * @param b database, size nb * ncodes + * @param nb number of database vectors + * @param ncodes size of the binary codes (bytes) + * @param ordered if != 0: order the results by decreasing distance + * (may be bottleneck for k/n > 0.01) */ +void hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int ordered, + ConcurrentBitsetPtr bitset = nullptr); + +/* Legacy alias to hammings_knn_hc. */ +void hammings_knn ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int ordered, + ConcurrentBitsetPtr bitset = nullptr); + +/** Return the k smallest Hamming distances for a set of binary query vectors, + * using counting max. + * @param a queries, size na * ncodes + * @param b database, size nb * ncodes + * @param na number of query vectors + * @param nb number of database vectors + * @param k number of vectors/distances to return + * @param ncodes size of the binary codes (bytes) + * @param distances output distances from each query vector to its k nearest + * neighbors + * @param labels output ids of the k nearest neighbors to each query vector + */ +void hammings_knn_mc ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + int32_t *distances, + int64_t *labels, + ConcurrentBitsetPtr bitset = nullptr); + +/** same as hammings_knn except we are doing a range search with radius */ +void hamming_range_search ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t ncodes, + RangeSearchResult *result); + + +/* Counting the number of matches or of cross-matches (without returning them) + For use with function that assume pre-allocated memory */ +void hamming_count_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + size_t * nptr); + +/* Return all Hamming distances/index passing a thres. Pre-allocation of output + is required. Use hamming_count_thres to determine the proper size. */ +size_t match_hamming_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + int64_t * idx, + hamdis_t * dis); + +/* Cross-matching in a set of vectors */ +void crosshamming_count_thres ( + const uint8_t * dbs, + size_t n, + hamdis_t ht, + size_t ncodes, + size_t * nptr); + + +/* compute the Hamming distances between two codewords of nwords*64 bits */ +hamdis_t hamming ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t nwords); + + + +} // namespace faiss + +// inlined definitions of HammingComputerXX and GenHammingComputerXX + +#include + +#endif /* FAISS_hamming_h */ diff --git a/core/src/index/thirdparty/faiss/utils/instruction_set.h b/core/src/index/thirdparty/faiss/utils/instruction_set.h new file mode 100644 index 0000000000..c2fe84ca70 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/instruction_set.h @@ -0,0 +1,355 @@ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +class InstructionSet { + public: + static InstructionSet& + GetInstance() { + static InstructionSet inst; + return inst; + } + + private: + InstructionSet() + : nIds_{0}, + nExIds_{0}, + isIntel_{false}, + isAMD_{false}, + f_1_ECX_{0}, + f_1_EDX_{0}, + f_7_EBX_{0}, + f_7_ECX_{0}, + f_81_ECX_{0}, + f_81_EDX_{0}, + data_{}, + extdata_{} { + std::array cpui; + + // Calling __cpuid with 0x0 as the function_id argument + // gets the number of the highest valid function ID. + __cpuid(0, cpui[0], cpui[1], cpui[2], cpui[3]); + nIds_ = cpui[0]; + + for (int i = 0; i <= nIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + data_.push_back(cpui); + } + + // Capture vendor string + char vendor[0x20]; + memset(vendor, 0, sizeof(vendor)); + *reinterpret_cast(vendor) = data_[0][1]; + *reinterpret_cast(vendor + 4) = data_[0][3]; + *reinterpret_cast(vendor + 8) = data_[0][2]; + vendor_ = vendor; + if (vendor_ == "GenuineIntel") { + isIntel_ = true; + } else if (vendor_ == "AuthenticAMD") { + isAMD_ = true; + } + + // load bitset with flags for function 0x00000001 + if (nIds_ >= 1) { + f_1_ECX_ = data_[1][2]; + f_1_EDX_ = data_[1][3]; + } + + // load bitset with flags for function 0x00000007 + if (nIds_ >= 7) { + f_7_EBX_ = data_[7][1]; + f_7_ECX_ = data_[7][2]; + } + + // Calling __cpuid with 0x80000000 as the function_id argument + // gets the number of the highest valid extended ID. + __cpuid(0x80000000, cpui[0], cpui[1], cpui[2], cpui[3]); + nExIds_ = cpui[0]; + + char brand[0x40]; + memset(brand, 0, sizeof(brand)); + + for (int i = 0x80000000; i <= nExIds_; ++i) { + __cpuid_count(i, 0, cpui[0], cpui[1], cpui[2], cpui[3]); + extdata_.push_back(cpui); + } + + // load bitset with flags for function 0x80000001 + if (nExIds_ >= (int)0x80000001) { + f_81_ECX_ = extdata_[1][2]; + f_81_EDX_ = extdata_[1][3]; + } + + // Interpret CPU brand string if reported + if (nExIds_ >= (int)0x80000004) { + memcpy(brand, extdata_[2].data(), sizeof(cpui)); + memcpy(brand + 16, extdata_[3].data(), sizeof(cpui)); + memcpy(brand + 32, extdata_[4].data(), sizeof(cpui)); + brand_ = brand; + } + }; + + public: + // getters + std::string + Vendor(void) { + return vendor_; + } + std::string + Brand(void) { + return brand_; + } + + bool + SSE3(void) { + return f_1_ECX_[0]; + } + bool + PCLMULQDQ(void) { + return f_1_ECX_[1]; + } + bool + MONITOR(void) { + return f_1_ECX_[3]; + } + bool + SSSE3(void) { + return f_1_ECX_[9]; + } + bool + FMA(void) { + return f_1_ECX_[12]; + } + bool + CMPXCHG16B(void) { + return f_1_ECX_[13]; + } + bool + SSE41(void) { + return f_1_ECX_[19]; + } + bool + SSE42(void) { + return f_1_ECX_[20]; + } + bool + MOVBE(void) { + return f_1_ECX_[22]; + } + bool + POPCNT(void) { + return f_1_ECX_[23]; + } + bool + AES(void) { + return f_1_ECX_[25]; + } + bool + XSAVE(void) { + return f_1_ECX_[26]; + } + bool + OSXSAVE(void) { + return f_1_ECX_[27]; + } + bool + AVX(void) { + return f_1_ECX_[28]; + } + bool + F16C(void) { + return f_1_ECX_[29]; + } + bool + RDRAND(void) { + return f_1_ECX_[30]; + } + + bool + MSR(void) { + return f_1_EDX_[5]; + } + bool + CX8(void) { + return f_1_EDX_[8]; + } + bool + SEP(void) { + return f_1_EDX_[11]; + } + bool + CMOV(void) { + return f_1_EDX_[15]; + } + bool + CLFSH(void) { + return f_1_EDX_[19]; + } + bool + MMX(void) { + return f_1_EDX_[23]; + } + bool + FXSR(void) { + return f_1_EDX_[24]; + } + bool + SSE(void) { + return f_1_EDX_[25]; + } + bool + SSE2(void) { + return f_1_EDX_[26]; + } + + bool + FSGSBASE(void) { + return f_7_EBX_[0]; + } + bool + BMI1(void) { + return f_7_EBX_[3]; + } + bool + HLE(void) { + return isIntel_ && f_7_EBX_[4]; + } + bool + AVX2(void) { + return f_7_EBX_[5]; + } + bool + BMI2(void) { + return f_7_EBX_[8]; + } + bool + ERMS(void) { + return f_7_EBX_[9]; + } + bool + INVPCID(void) { + return f_7_EBX_[10]; + } + bool + RTM(void) { + return isIntel_ && f_7_EBX_[11]; + } + bool + AVX512F(void) { + return f_7_EBX_[16]; + } + bool + AVX512DQ(void) { + return f_7_EBX_[17]; + } + bool + RDSEED(void) { + return f_7_EBX_[18]; + } + bool + ADX(void) { + return f_7_EBX_[19]; + } + bool + AVX512PF(void) { + return f_7_EBX_[26]; + } + bool + AVX512ER(void) { + return f_7_EBX_[27]; + } + bool + AVX512CD(void) { + return f_7_EBX_[28]; + } + bool + SHA(void) { + return f_7_EBX_[29]; + } + bool + AVX512BW(void) { + return f_7_EBX_[30]; + } + bool + AVX512VL(void) { + return f_7_EBX_[31]; + } + + bool + PREFETCHWT1(void) { + return f_7_ECX_[0]; + } + + bool + LAHF(void) { + return f_81_ECX_[0]; + } + bool + LZCNT(void) { + return isIntel_ && f_81_ECX_[5]; + } + bool + ABM(void) { + return isAMD_ && f_81_ECX_[5]; + } + bool + SSE4a(void) { + return isAMD_ && f_81_ECX_[6]; + } + bool + XOP(void) { + return isAMD_ && f_81_ECX_[11]; + } + bool + TBM(void) { + return isAMD_ && f_81_ECX_[21]; + } + + bool + SYSCALL(void) { + return isIntel_ && f_81_EDX_[11]; + } + bool + MMXEXT(void) { + return isAMD_ && f_81_EDX_[22]; + } + bool + RDTSCP(void) { + return isIntel_ && f_81_EDX_[27]; + } + bool + _3DNOWEXT(void) { + return isAMD_ && f_81_EDX_[30]; + } + bool + _3DNOW(void) { + return isAMD_ && f_81_EDX_[31]; + } + + private: + int nIds_; + int nExIds_; + std::string vendor_; + std::string brand_; + bool isIntel_; + bool isAMD_; + std::bitset<32> f_1_ECX_; + std::bitset<32> f_1_EDX_; + std::bitset<32> f_7_EBX_; + std::bitset<32> f_7_ECX_; + std::bitset<32> f_81_ECX_; + std::bitset<32> f_81_EDX_; + std::vector> data_; + std::vector> extdata_; +}; + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/jaccard-inl.h b/core/src/index/thirdparty/faiss/utils/jaccard-inl.h new file mode 100644 index 0000000000..9aa7fbd924 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/jaccard-inl.h @@ -0,0 +1,389 @@ +namespace faiss { + + struct JaccardComputer8 { + uint64_t a0; + + JaccardComputer8 () {} + + JaccardComputer8 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 8); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu_num = popcount64 (b[0] & a0); + int accu_den = popcount64 (b[0] | a0); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputer16 { + uint64_t a0, a1; + + JaccardComputer16 () {} + + JaccardComputer16 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputer32 { + uint64_t a0, a1, a2, a3; + + JaccardComputer32 () {} + + JaccardComputer32 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + + popcount64 (b[2] & a2) + popcount64 (b[3] & a3); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) + + popcount64 (b[2] | a2) + popcount64 (b[3] | a3); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputer64 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7; + + JaccardComputer64 () {} + + JaccardComputer64 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 64); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + + popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + + popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + + popcount64 (b[6] & a6) + popcount64 (b[7] & a7); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) + + popcount64 (b[2] | a2) + popcount64 (b[3] | a3) + + popcount64 (b[4] | a4) + popcount64 (b[5] | a5) + + popcount64 (b[6] | a6) + popcount64 (b[7] | a7); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputer128 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15; + + JaccardComputer128 () {} + + JaccardComputer128 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 128); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + } + + inline float compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + + popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + + popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + + popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + + popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + + popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + + popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + + popcount64 (b[14] & a14) + popcount64 (b[15] & a15); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) + + popcount64 (b[2] | a2) + popcount64 (b[3] | a3) + + popcount64 (b[4] | a4) + popcount64 (b[5] | a5) + + popcount64 (b[6] | a6) + popcount64 (b[7] | a7) + + popcount64 (b[8] | a8) + popcount64 (b[9] | a9) + + popcount64 (b[10] | a10) + popcount64 (b[11] | a11) + + popcount64 (b[12] | a12) + popcount64 (b[13] | a13) + + popcount64 (b[14] | a14) + popcount64 (b[15] | a15); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + +struct JaccardComputer256 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31; + + JaccardComputer256 () {} + + JaccardComputer256 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 256); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + } + + inline float compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + + popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + + popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + + popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + + popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + + popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + + popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + + popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + + popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + + popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + + popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + + popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + + popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + + popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + + popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + + popcount64 (b[30] & a30) + popcount64 (b[31] & a31); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) + + popcount64 (b[2] | a2) + popcount64 (b[3] | a3) + + popcount64 (b[4] | a4) + popcount64 (b[5] | a5) + + popcount64 (b[6] | a6) + popcount64 (b[7] | a7) + + popcount64 (b[8] | a8) + popcount64 (b[9] | a9) + + popcount64 (b[10] | a10) + popcount64 (b[11] | a11) + + popcount64 (b[12] | a12) + popcount64 (b[13] | a13) + + popcount64 (b[14] | a14) + popcount64 (b[15] | a15) + + popcount64 (b[16] | a16) + popcount64 (b[17] | a17) + + popcount64 (b[18] | a18) + popcount64 (b[19] | a19) + + popcount64 (b[20] | a20) + popcount64 (b[21] | a21) + + popcount64 (b[22] | a22) + popcount64 (b[23] | a23) + + popcount64 (b[24] | a24) + popcount64 (b[25] | a25) + + popcount64 (b[26] | a26) + popcount64 (b[27] | a27) + + popcount64 (b[28] | a28) + popcount64 (b[29] | a29) + + popcount64 (b[30] | a30) + popcount64 (b[31] | a31); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputer512 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31, + a32,a33,a34,a35,a36,a37,a38,a39, + a40,a41,a42,a43,a44,a45,a46,a47, + a48,a49,a50,a51,a52,a53,a54,a55, + a56,a57,a58,a59,a60,a61,a62,a63; + + JaccardComputer512 () {} + + JaccardComputer512 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 512); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35]; + a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39]; + a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43]; + a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47]; + a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51]; + a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55]; + a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59]; + a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63]; + } + + inline float compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + int accu_num = popcount64 (b[0] & a0) + popcount64 (b[1] & a1) + + popcount64 (b[2] & a2) + popcount64 (b[3] & a3) + + popcount64 (b[4] & a4) + popcount64 (b[5] & a5) + + popcount64 (b[6] & a6) + popcount64 (b[7] & a7) + + popcount64 (b[8] & a8) + popcount64 (b[9] & a9) + + popcount64 (b[10] & a10) + popcount64 (b[11] & a11) + + popcount64 (b[12] & a12) + popcount64 (b[13] & a13) + + popcount64 (b[14] & a14) + popcount64 (b[15] & a15) + + popcount64 (b[16] & a16) + popcount64 (b[17] & a17) + + popcount64 (b[18] & a18) + popcount64 (b[19] & a19) + + popcount64 (b[20] & a20) + popcount64 (b[21] & a21) + + popcount64 (b[22] & a22) + popcount64 (b[23] & a23) + + popcount64 (b[24] & a24) + popcount64 (b[25] & a25) + + popcount64 (b[26] & a26) + popcount64 (b[27] & a27) + + popcount64 (b[28] & a28) + popcount64 (b[29] & a29) + + popcount64 (b[30] & a30) + popcount64 (b[31] & a31) + + popcount64 (b[32] & a32) + popcount64 (b[33] & a33) + + popcount64 (b[34] & a34) + popcount64 (b[35] & a35) + + popcount64 (b[36] & a36) + popcount64 (b[37] & a37) + + popcount64 (b[38] & a38) + popcount64 (b[39] & a39) + + popcount64 (b[40] & a40) + popcount64 (b[41] & a41) + + popcount64 (b[42] & a42) + popcount64 (b[43] & a43) + + popcount64 (b[44] & a44) + popcount64 (b[45] & a45) + + popcount64 (b[46] & a46) + popcount64 (b[47] & a47) + + popcount64 (b[48] & a48) + popcount64 (b[49] & a49) + + popcount64 (b[50] & a50) + popcount64 (b[51] & a51) + + popcount64 (b[52] & a52) + popcount64 (b[53] & a53) + + popcount64 (b[54] & a54) + popcount64 (b[55] & a55) + + popcount64 (b[56] & a56) + popcount64 (b[57] & a57) + + popcount64 (b[58] & a58) + popcount64 (b[59] & a59) + + popcount64 (b[60] & a60) + popcount64 (b[61] & a61) + + popcount64 (b[62] & a62) + popcount64 (b[63] & a63); + int accu_den = popcount64 (b[0] | a0) + popcount64 (b[1] | a1) + + popcount64 (b[2] | a2) + popcount64 (b[3] | a3) + + popcount64 (b[4] | a4) + popcount64 (b[5] | a5) + + popcount64 (b[6] | a6) + popcount64 (b[7] | a7) + + popcount64 (b[8] | a8) + popcount64 (b[9] | a9) + + popcount64 (b[10] | a10) + popcount64 (b[11] | a11) + + popcount64 (b[12] | a12) + popcount64 (b[13] | a13) + + popcount64 (b[14] | a14) + popcount64 (b[15] | a15) + + popcount64 (b[16] | a16) + popcount64 (b[17] | a17) + + popcount64 (b[18] | a18) + popcount64 (b[19] | a19) + + popcount64 (b[20] | a20) + popcount64 (b[21] | a21) + + popcount64 (b[22] | a22) + popcount64 (b[23] | a23) + + popcount64 (b[24] | a24) + popcount64 (b[25] | a25) + + popcount64 (b[26] | a26) + popcount64 (b[27] | a27) + + popcount64 (b[28] | a28) + popcount64 (b[29] | a29) + + popcount64 (b[30] | a30) + popcount64 (b[31] | a31) + + popcount64 (b[32] | a32) + popcount64 (b[33] | a33) + + popcount64 (b[34] | a34) + popcount64 (b[35] | a35) + + popcount64 (b[36] | a36) + popcount64 (b[37] | a37) + + popcount64 (b[38] | a38) + popcount64 (b[39] | a39) + + popcount64 (b[40] | a40) + popcount64 (b[41] | a41) + + popcount64 (b[42] | a42) + popcount64 (b[43] | a43) + + popcount64 (b[44] | a44) + popcount64 (b[45] | a45) + + popcount64 (b[46] | a46) + popcount64 (b[47] | a47) + + popcount64 (b[48] | a48) + popcount64 (b[49] | a49) + + popcount64 (b[50] | a50) + popcount64 (b[51] | a51) + + popcount64 (b[52] | a52) + popcount64 (b[53] | a53) + + popcount64 (b[54] | a54) + popcount64 (b[55] | a55) + + popcount64 (b[56] | a56) + popcount64 (b[57] | a57) + + popcount64 (b[58] | a58) + popcount64 (b[59] | a59) + + popcount64 (b[60] | a60) + popcount64 (b[61] | a61) + + popcount64 (b[62] | a62) + popcount64 (b[63] | a63); + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + + struct JaccardComputerDefault { + const uint8_t *a; + int n; + + JaccardComputerDefault () {} + + JaccardComputerDefault (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + a = a8; + n = code_size; + } + + float compute (const uint8_t *b8) const { + int accu_num = 0; + int accu_den = 0; + for (int i = 0; i < n; i++) { + accu_num += popcount64(a[i] & b8[i]); + accu_den += popcount64(a[i] | b8[i]); + } + if (accu_num == 0) + return 1.0; + return 1.0 - (float)(accu_num) / (float)(accu_den); + } + + }; + +// default template + template + struct JaccardComputer: JaccardComputerDefault { + JaccardComputer (const uint8_t *a, int code_size): + JaccardComputerDefault(a, code_size) {} + }; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template<> struct JaccardComputer: \ + JaccardComputer ## CODE_SIZE { \ + JaccardComputer (const uint8_t *a): \ + JaccardComputer ## CODE_SIZE(a, CODE_SIZE) {} \ + } + + SPECIALIZED_HC(8); + SPECIALIZED_HC(16); + SPECIALIZED_HC(32); + SPECIALIZED_HC(64); + SPECIALIZED_HC(128); + SPECIALIZED_HC(256); + SPECIALIZED_HC(512); + +#undef SPECIALIZED_HC + +} diff --git a/core/src/index/thirdparty/faiss/utils/random.cpp b/core/src/index/thirdparty/faiss/utils/random.cpp new file mode 100644 index 0000000000..7f50e0eb1c --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/random.cpp @@ -0,0 +1,192 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +namespace faiss { + +/************************************************** + * Random data generation functions + **************************************************/ + +RandomGenerator::RandomGenerator (int64_t seed) + : mt((unsigned int)seed) {} + +int RandomGenerator::rand_int () +{ + return mt() & 0x7fffffff; +} + +int64_t RandomGenerator::rand_int64 () +{ + return int64_t(rand_int()) | int64_t(rand_int()) << 31; +} + +int RandomGenerator::rand_int (int max) +{ + return mt() % max; +} + +float RandomGenerator::rand_float () +{ + return mt() / float(mt.max()); +} + +double RandomGenerator::rand_double () +{ + return mt() / double(mt.max()); +} + + +/*********************************************************************** + * Random functions in this C file only exist because Torch + * counterparts are slow and not multi-threaded. Typical use is for + * more than 1-100 billion values. */ + + +/* Generate a set of random floating point values such that x[i] in [0,1] + multi-threading. For this reason, we rely on re-entreant functions. */ +void float_rand (float * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_float (); + } +} + + +void float_randn (float * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + RandomGenerator rng (a0 + j * b0); + + double a = 0, b = 0, s = 0; + int state = 0; /* generate two number per "do-while" loop */ + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + for (size_t i = istart; i < iend; i++) { + /* Marsaglia's method (see Knuth) */ + if (state == 0) { + do { + a = 2.0 * rng.rand_double () - 1; + b = 2.0 * rng.rand_double () - 1; + s = a * a + b * b; + } while (s >= 1.0); + x[i] = a * sqrt(-2.0 * log(s) / s); + } + else + x[i] = b * sqrt(-2.0 * log(s) / s); + state = 1 - state; + } + } +} + + +/* Integer versions */ +void int64_rand (int64_t * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_int64 (); + } +} + +void int64_rand_max (int64_t * x, size_t n, uint64_t max, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_int64 () % max; + } +} + + +void rand_perm (int *perm, size_t n, int64_t seed) +{ + for (size_t i = 0; i < n; i++) perm[i] = i; + + RandomGenerator rng (seed); + + for (size_t i = 0; i + 1 < n; i++) { + int i2 = i + rng.rand_int (n - i); + std::swap(perm[i], perm[i2]); + } +} + + + + +void byte_rand (uint8_t * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + size_t i; + for (i = istart; i < iend; i++) + x[i] = rng.rand_int64 (); + } +} + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/random.h b/core/src/index/thirdparty/faiss/utils/random.h new file mode 100644 index 0000000000..e94ac068cf --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/random.h @@ -0,0 +1,60 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* Random generators. Implemented here for speed and to make + * sequences reproducible. + */ + +#pragma once + +#include +#include + + +namespace faiss { + +/************************************************** + * Random data generation functions + **************************************************/ + +/// random generator that can be used in multithreaded contexts +struct RandomGenerator { + + std::mt19937 mt; + + /// random positive integer + int rand_int (); + + /// random int64_t + int64_t rand_int64 (); + + /// generate random integer between 0 and max-1 + int rand_int (int max); + + /// between 0 and 1 + float rand_float (); + + double rand_double (); + + explicit RandomGenerator (int64_t seed = 1234); +}; + +/* Generate an array of uniform random floats / multi-threaded implementation */ +void float_rand (float * x, size_t n, int64_t seed); +void float_randn (float * x, size_t n, int64_t seed); +void int64_rand (int64_t * x, size_t n, int64_t seed); +void byte_rand (uint8_t * x, size_t n, int64_t seed); +// max is actually the maximum value + 1 +void int64_rand_max (int64_t * x, size_t n, uint64_t max, int64_t seed); + +/* random permutation */ +void rand_perm (int * perm, size_t n, int64_t seed); + + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/substructure-inl.h b/core/src/index/thirdparty/faiss/utils/substructure-inl.h new file mode 100644 index 0000000000..aa57a5a646 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/substructure-inl.h @@ -0,0 +1,302 @@ +namespace faiss { + + struct SubstructureComputer8 { + uint64_t a0; + + SubstructureComputer8 () {} + + SubstructureComputer8 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 8); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == a0; + } + + }; + + struct SubstructureComputer16 { + uint64_t a0, a1; + + SubstructureComputer16 () {} + + SubstructureComputer16 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1; + } + + }; + + struct SubstructureComputer32 { + uint64_t a0, a1, a2, a3; + + SubstructureComputer32 () {} + + SubstructureComputer32 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3; + } + + }; + + struct SubstructureComputer64 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7; + + SubstructureComputer64 () {} + + SubstructureComputer64 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 64); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7; + } + + }; + + struct SubstructureComputer128 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15; + + SubstructureComputer128 () {} + + SubstructureComputer128 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 128); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + } + + inline bool compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15; + } + + }; + + struct SubstructureComputer256 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31; + + SubstructureComputer256 () {} + + SubstructureComputer256 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 256); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + } + + inline bool compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15 && + (a16 & b[16]) == a16 && (a17 & b[17]) == a17 && + (a18 & b[18]) == a18 && (a19 & b[19]) == a19 && + (a20 & b[20]) == a20 && (a21 & b[21]) == a21 && + (a22 & b[22]) == a22 && (a23 & b[23]) == a23 && + (a24 & b[24]) == a24 && (a25 & b[25]) == a25 && + (a26 & b[26]) == a26 && (a27 & b[27]) == a27 && + (a28 & b[28]) == a28 && (a29 & b[29]) == a29 && + (a30 & b[30]) == a30 && (a31 & b[31]) == a31; + } + + }; + + struct SubstructureComputer512 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31, + a32,a33,a34,a35,a36,a37,a38,a39, + a40,a41,a42,a43,a44,a45,a46,a47, + a48,a49,a50,a51,a52,a53,a54,a55, + a56,a57,a58,a59,a60,a61,a62,a63; + + SubstructureComputer512 () {} + + SubstructureComputer512 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 512); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35]; + a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39]; + a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43]; + a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47]; + a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51]; + a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55]; + a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59]; + a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63]; + } + + inline bool compute (const uint8_t *b16) const { + const uint64_t *b = (uint64_t *)b16; + return (a0 & b[0]) == a0 && (a1 & b[1]) == a1 && + (a2 & b[2]) == a2 && (a3 & b[3]) == a3 && + (a4 & b[4]) == a4 && (a5 & b[5]) == a5 && + (a6 & b[6]) == a6 && (a7 & b[7]) == a7 && + (a8 & b[8]) == a8 && (a9 & b[9]) == a9 && + (a10 & b[10]) == a10 && (a11 & b[11]) == a11 && + (a12 & b[12]) == a12 && (a13 & b[13]) == a13 && + (a14 & b[14]) == a14 && (a15 & b[15]) == a15 && + (a16 & b[16]) == a16 && (a17 & b[17]) == a17 && + (a18 & b[18]) == a18 && (a19 & b[19]) == a19 && + (a20 & b[20]) == a20 && (a21 & b[21]) == a21 && + (a22 & b[22]) == a22 && (a23 & b[23]) == a23 && + (a24 & b[24]) == a24 && (a25 & b[25]) == a25 && + (a26 & b[26]) == a26 && (a27 & b[27]) == a27 && + (a28 & b[28]) == a28 && (a29 & b[29]) == a29 && + (a30 & b[30]) == a30 && (a31 & b[31]) == a31 && + (a32 & b[32]) == a32 && (a33 & b[33]) == a33 && + (a34 & b[34]) == a34 && (a35 & b[35]) == a35 && + (a36 & b[36]) == a36 && (a37 & b[37]) == a37 && + (a38 & b[38]) == a38 && (a39 & b[39]) == a39 && + (a40 & b[40]) == a40 && (a41 & b[41]) == a41 && + (a42 & b[42]) == a42 && (a43 & b[43]) == a43 && + (a44 & b[44]) == a44 && (a45 & b[45]) == a45 && + (a46 & b[46]) == a46 && (a47 & b[47]) == a47 && + (a48 & b[48]) == a48 && (a49 & b[49]) == a49 && + (a50 & b[50]) == a50 && (a51 & b[51]) == a51 && + (a52 & b[52]) == a52 && (a53 & b[53]) == a53 && + (a54 & b[54]) == a54 && (a55 & b[55]) == a55 && + (a56 & b[56]) == a56 && (a57 & b[57]) == a57 && + (a58 & b[58]) == a58 && (a59 & b[59]) == a59 && + (a60 & b[60]) == a60 && (a61 & b[61]) == a61 && + (a62 & b[62]) == a62 && (a63 & b[63]) == a63; + } + + }; + + struct SubstructureComputerDefault { + const uint8_t *a; + int n; + + SubstructureComputerDefault () {} + + SubstructureComputerDefault (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + a = a8; + n = code_size; + } + + bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + for (int i = 0; i < n; i++) { + if ((a[i] & b[i]) != a[i]) { + return false; + } + } + return true; + } + + }; + +// default template + template + struct SubstructureComputer: SubstructureComputerDefault { + SubstructureComputer (const uint8_t *a, int code_size): + SubstructureComputerDefault(a, code_size) {} + }; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template<> struct SubstructureComputer: \ + SubstructureComputer ## CODE_SIZE { \ + SubstructureComputer (const uint8_t *a): \ + SubstructureComputer ## CODE_SIZE(a, CODE_SIZE) {} \ + } + + SPECIALIZED_HC(8); + SPECIALIZED_HC(16); + SPECIALIZED_HC(32); + SPECIALIZED_HC(64); + SPECIALIZED_HC(128); + SPECIALIZED_HC(256); + SPECIALIZED_HC(512); + +#undef SPECIALIZED_HC + +} diff --git a/core/src/index/thirdparty/faiss/utils/superstructure-inl.h b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h new file mode 100644 index 0000000000..e8b384e75f --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/superstructure-inl.h @@ -0,0 +1,302 @@ +namespace faiss { + + struct SuperstructureComputer8 { + uint64_t a0; + + SuperstructureComputer8 () {} + + SuperstructureComputer8 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 8); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0]; + } + + }; + + struct SuperstructureComputer16 { + uint64_t a0, a1; + + SuperstructureComputer16 () {} + + SuperstructureComputer16 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1]; + } + + }; + + struct SuperstructureComputer32 { + uint64_t a0, a1, a2, a3; + + SuperstructureComputer32 () {} + + SuperstructureComputer32 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3]; + } + + }; + + struct SuperstructureComputer64 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7; + + SuperstructureComputer64 () {} + + SuperstructureComputer64 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 64); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7]; + } + + }; + + struct SuperstructureComputer128 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15; + + SuperstructureComputer128 () {} + + SuperstructureComputer128 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 128); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15]; + } + + }; + + struct SuperstructureComputer256 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31; + + SuperstructureComputer256 () {} + + SuperstructureComputer256 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 256); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + } + + inline float compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] && + (a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] && + (a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] && + (a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] && + (a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] && + (a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] && + (a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] && + (a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] && + (a30 & b[30]) == b[30] && (a31 & b[31]) == b[31]; + } + + }; + + struct SuperstructureComputer512 { + uint64_t a0,a1,a2,a3,a4,a5,a6,a7, + a8,a9,a10,a11,a12,a13,a14,a15, + a16,a17,a18,a19,a20,a21,a22,a23, + a24,a25,a26,a27,a28,a29,a30,a31, + a32,a33,a34,a35,a36,a37,a38,a39, + a40,a41,a42,a43,a44,a45,a46,a47, + a48,a49,a50,a51,a52,a53,a54,a55, + a56,a57,a58,a59,a60,a61,a62,a63; + + SuperstructureComputer512 () {} + + SuperstructureComputer512 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *au8, int code_size) { + assert (code_size == 512); + const uint64_t *a = (uint64_t *)au8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + a8 = a[8]; a9 = a[9]; a10 = a[10]; a11 = a[11]; + a12 = a[12]; a13 = a[13]; a14 = a[14]; a15 = a[15]; + a16 = a[16]; a17 = a[17]; a18 = a[18]; a19 = a[19]; + a20 = a[20]; a21 = a[21]; a22 = a[22]; a23 = a[23]; + a24 = a[24]; a25 = a[25]; a26 = a[26]; a27 = a[27]; + a28 = a[28]; a29 = a[29]; a30 = a[30]; a31 = a[31]; + a32 = a[32]; a33 = a[33]; a34 = a[34]; a35 = a[35]; + a36 = a[36]; a37 = a[37]; a38 = a[38]; a39 = a[39]; + a40 = a[40]; a41 = a[41]; a42 = a[42]; a43 = a[43]; + a44 = a[44]; a45 = a[45]; a46 = a[46]; a47 = a[47]; + a48 = a[48]; a49 = a[49]; a50 = a[50]; a51 = a[51]; + a52 = a[52]; a53 = a[53]; a54 = a[54]; a55 = a[55]; + a56 = a[56]; a57 = a[57]; a58 = a[58]; a59 = a[59]; + a60 = a[60]; a61 = a[61]; a62 = a[62]; a63 = a[63]; + } + + inline bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return (a0 & b[0]) == b[0] && (a1 & b[1]) == b[1] && + (a2 & b[2]) == b[2] && (a3 & b[3]) == b[3] && + (a4 & b[4]) == b[4] && (a5 & b[5]) == b[5] && + (a6 & b[6]) == b[6] && (a7 & b[7]) == b[7] && + (a8 & b[8]) == b[8] && (a9 & b[9]) == b[9] && + (a10 & b[10]) == b[10] && (a11 & b[11]) == b[11] && + (a12 & b[12]) == b[12] && (a13 & b[13]) == b[13] && + (a14 & b[14]) == b[14] && (a15 & b[15]) == b[15] && + (a16 & b[16]) == b[16] && (a17 & b[17]) == b[17] && + (a18 & b[18]) == b[18] && (a19 & b[19]) == b[19] && + (a20 & b[20]) == b[20] && (a21 & b[21]) == b[21] && + (a22 & b[22]) == b[22] && (a23 & b[23]) == b[23] && + (a24 & b[24]) == b[24] && (a25 & b[25]) == b[25] && + (a26 & b[26]) == b[26] && (a27 & b[27]) == b[27] && + (a28 & b[28]) == b[28] && (a29 & b[29]) == b[29] && + (a30 & b[30]) == b[30] && (a31 & b[31]) == b[31] && + (a32 & b[32]) == b[32] && (a33 & b[33]) == b[33] && + (a34 & b[34]) == b[34] && (a35 & b[35]) == b[35] && + (a36 & b[36]) == b[36] && (a37 & b[37]) == b[37] && + (a38 & b[38]) == b[38] && (a39 & b[39]) == b[39] && + (a40 & b[40]) == b[40] && (a41 & b[41]) == b[41] && + (a42 & b[42]) == b[42] && (a43 & b[43]) == b[43] && + (a44 & b[44]) == b[44] && (a45 & b[45]) == b[45] && + (a46 & b[46]) == b[46] && (a47 & b[47]) == b[47] && + (a48 & b[48]) == b[48] && (a49 & b[49]) == b[49] && + (a50 & b[50]) == b[50] && (a51 & b[51]) == b[51] && + (a52 & b[52]) == b[52] && (a53 & b[53]) == b[53] && + (a54 & b[54]) == b[54] && (a55 & b[55]) == b[55] && + (a56 & b[56]) == b[56] && (a57 & b[57]) == b[57] && + (a58 & b[58]) == b[58] && (a59 & b[59]) == b[59] && + (a60 & b[60]) == b[60] && (a61 & b[61]) == b[61] && + (a62 & b[62]) == b[62] && (a63 & b[63]) == b[63]; + } + + }; + + struct SuperstructureComputerDefault { + const uint8_t *a; + int n; + + SuperstructureComputerDefault () {} + + SuperstructureComputerDefault (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + a = a8; + n = code_size; + } + + bool compute (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + for (int i = 0; i < n; i++) { + if ((a[i] & b[i]) != b[i]) { + return false; + } + } + return true; + } + + }; + +// default template + template + struct SuperstructureComputer: SuperstructureComputerDefault { + SuperstructureComputer (const uint8_t *a, int code_size): + SuperstructureComputerDefault(a, code_size) {} + }; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template<> struct SuperstructureComputer: \ + SuperstructureComputer ## CODE_SIZE { \ + SuperstructureComputer (const uint8_t *a): \ + SuperstructureComputer ## CODE_SIZE(a, CODE_SIZE) {} \ + } + + SPECIALIZED_HC(8); + SPECIALIZED_HC(16); + SPECIALIZED_HC(32); + SPECIALIZED_HC(64); + SPECIALIZED_HC(128); + SPECIALIZED_HC(256); + SPECIALIZED_HC(512); + +#undef SPECIALIZED_HC + +} diff --git a/core/src/index/thirdparty/faiss/utils/utils.cpp b/core/src/index/thirdparty/faiss/utils/utils.cpp new file mode 100644 index 0000000000..20b2e29553 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/utils.cpp @@ -0,0 +1,723 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + + + +#ifndef FINTEGER +#define FINTEGER long +#endif + + +extern "C" { + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, FINTEGER * + ldb, float *beta, float *c, FINTEGER *ldc); + +/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */ + +int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda, + float *tau, float *work, FINTEGER *lwork, FINTEGER *info); + +int sorgqr_(FINTEGER *m, FINTEGER *n, FINTEGER *k, float *a, + FINTEGER *lda, float *tau, float *work, + FINTEGER *lwork, FINTEGER *info); + +int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha, + const float *a, FINTEGER *lda, const float *x, FINTEGER *incx, + float *beta, float *y, FINTEGER *incy); + +} + + +/************************************************** + * Get some stats about the system + **************************************************/ + +namespace faiss { + +double getmillisecs () { + struct timeval tv; + gettimeofday (&tv, nullptr); + return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3; +} + +uint64_t get_cycles () { +#ifdef __x86_64__ + uint32_t high, low; + asm volatile("rdtsc \n\t" + : "=a" (low), + "=d" (high)); + return ((uint64_t)high << 32) | (low); +#else + return 0; +#endif +} + + +#ifdef __linux__ + +size_t get_mem_usage_kb () +{ + int pid = getpid (); + char fname[256]; + snprintf (fname, 256, "/proc/%d/status", pid); + FILE * f = fopen (fname, "r"); + FAISS_THROW_IF_NOT_MSG (f, "cannot open proc status file"); + size_t sz = 0; + for (;;) { + char buf [256]; + if (!fgets (buf, 256, f)) break; + if (sscanf (buf, "VmRSS: %ld kB", &sz) == 1) break; + } + fclose (f); + return sz; +} + +#elif __APPLE__ + +size_t get_mem_usage_kb () +{ + fprintf(stderr, "WARN: get_mem_usage_kb not implemented on the mac\n"); + return 0; +} + +#endif + + + + + +void reflection (const float * __restrict u, + float * __restrict x, + size_t n, size_t d, size_t nu) +{ + size_t i, j, l; + for (i = 0; i < n; i++) { + const float * up = u; + for (l = 0; l < nu; l++) { + float ip1 = 0, ip2 = 0; + + for (j = 0; j < d; j+=2) { + ip1 += up[j] * x[j]; + ip2 += up[j+1] * x[j+1]; + } + float ip = 2 * (ip1 + ip2); + + for (j = 0; j < d; j++) + x[j] -= ip * up[j]; + up += d; + } + x += d; + } +} + + +/* Reference implementation (slower) */ +void reflection_ref (const float * u, float * x, size_t n, size_t d, size_t nu) +{ + size_t i, j, l; + for (i = 0; i < n; i++) { + const float * up = u; + for (l = 0; l < nu; l++) { + double ip = 0; + + for (j = 0; j < d; j++) + ip += up[j] * x[j]; + ip *= 2; + + for (j = 0; j < d; j++) + x[j] -= ip * up[j]; + + up += d; + } + x += d; + } +} + + + + + + +/*************************************************************************** + * Some matrix manipulation functions + ***************************************************************************/ + + +/* This function exists because the Torch counterpart is extremly slow + (not multi-threaded + unexpected overhead even in single thread). + It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2 */ +void inner_product_to_L2sqr (float * __restrict dis, + const float * nr1, + const float * nr2, + size_t n1, size_t n2) +{ + +#pragma omp parallel for + for (size_t j = 0 ; j < n1 ; j++) { + float * disj = dis + j * n2; + for (size_t i = 0 ; i < n2 ; i++) + disj[i] = nr1[j] + nr2[i] - 2 * disj[i]; + } +} + + +void matrix_qr (int m, int n, float *a) +{ + FAISS_THROW_IF_NOT (m >= n); + FINTEGER mi = m, ni = n, ki = mi < ni ? mi : ni; + std::vector tau (ki); + FINTEGER lwork = -1, info; + float work_size; + + sgeqrf_ (&mi, &ni, a, &mi, tau.data(), + &work_size, &lwork, &info); + lwork = size_t(work_size); + std::vector work (lwork); + + sgeqrf_ (&mi, &ni, a, &mi, + tau.data(), work.data(), &lwork, &info); + + sorgqr_ (&mi, &ni, &ki, a, &mi, tau.data(), + work.data(), &lwork, &info); + +} + + + + +/*************************************************************************** + * Result list routines + ***************************************************************************/ + + +void ranklist_handle_ties (int k, int64_t *idx, const float *dis) +{ + float prev_dis = -1e38; + int prev_i = -1; + for (int i = 0; i < k; i++) { + if (dis[i] != prev_dis) { + if (i > prev_i + 1) { + // sort between prev_i and i - 1 + std::sort (idx + prev_i, idx + i); + } + prev_i = i; + prev_dis = dis[i]; + } + } +} + +size_t merge_result_table_with (size_t n, size_t k, + int64_t *I0, float *D0, + const int64_t *I1, const float *D1, + bool keep_min, + int64_t translation) +{ + size_t n1 = 0; + +#pragma omp parallel reduction(+:n1) + { + std::vector tmpI (k); + std::vector tmpD (k); + +#pragma omp for + for (size_t i = 0; i < n; i++) { + int64_t *lI0 = I0 + i * k; + float *lD0 = D0 + i * k; + const int64_t *lI1 = I1 + i * k; + const float *lD1 = D1 + i * k; + size_t r0 = 0; + size_t r1 = 0; + + if (keep_min) { + for (size_t j = 0; j < k; j++) { + + if (lI0[r0] >= 0 && lD0[r0] < lD1[r1]) { + tmpD[j] = lD0[r0]; + tmpI[j] = lI0[r0]; + r0++; + } else if (lD1[r1] >= 0) { + tmpD[j] = lD1[r1]; + tmpI[j] = lI1[r1] + translation; + r1++; + } else { // both are NaNs + tmpD[j] = NAN; + tmpI[j] = -1; + } + } + } else { + for (size_t j = 0; j < k; j++) { + if (lI0[r0] >= 0 && lD0[r0] > lD1[r1]) { + tmpD[j] = lD0[r0]; + tmpI[j] = lI0[r0]; + r0++; + } else if (lD1[r1] >= 0) { + tmpD[j] = lD1[r1]; + tmpI[j] = lI1[r1] + translation; + r1++; + } else { // both are NaNs + tmpD[j] = NAN; + tmpI[j] = -1; + } + } + } + n1 += r1; + memcpy (lD0, tmpD.data(), sizeof (lD0[0]) * k); + memcpy (lI0, tmpI.data(), sizeof (lI0[0]) * k); + } + } + + return n1; +} + + + +size_t ranklist_intersection_size (size_t k1, const int64_t *v1, + size_t k2, const int64_t *v2_in) +{ + if (k2 > k1) return ranklist_intersection_size (k2, v2_in, k1, v1); + int64_t *v2 = new int64_t [k2]; + memcpy (v2, v2_in, sizeof (int64_t) * k2); + std::sort (v2, v2 + k2); + { // de-dup v2 + int64_t prev = -1; + size_t wp = 0; + for (size_t i = 0; i < k2; i++) { + if (v2 [i] != prev) { + v2[wp++] = prev = v2 [i]; + } + } + k2 = wp; + } + const int64_t seen_flag = 1L << 60; + size_t count = 0; + for (size_t i = 0; i < k1; i++) { + int64_t q = v1 [i]; + size_t i0 = 0, i1 = k2; + while (i0 + 1 < i1) { + size_t imed = (i1 + i0) / 2; + int64_t piv = v2 [imed] & ~seen_flag; + if (piv <= q) i0 = imed; + else i1 = imed; + } + if (v2 [i0] == q) { + count++; + v2 [i0] |= seen_flag; + } + } + delete [] v2; + + return count; +} + +double imbalance_factor (int k, const int *hist) { + double tot = 0, uf = 0; + + for (int i = 0 ; i < k ; i++) { + tot += hist[i]; + uf += hist[i] * (double) hist[i]; + } + uf = uf * k / (tot * tot); + + return uf; +} + + +double imbalance_factor (int n, int k, const int64_t *assign) { + std::vector hist(k, 0); + for (int i = 0; i < n; i++) { + hist[assign[i]]++; + } + + return imbalance_factor (k, hist.data()); +} + + + +int ivec_hist (size_t n, const int * v, int vmax, int *hist) { + memset (hist, 0, sizeof(hist[0]) * vmax); + int nout = 0; + while (n--) { + if (v[n] < 0 || v[n] >= vmax) nout++; + else hist[v[n]]++; + } + return nout; +} + + +void bincode_hist(size_t n, size_t nbits, const uint8_t *codes, int *hist) +{ + FAISS_THROW_IF_NOT (nbits % 8 == 0); + size_t d = nbits / 8; + std::vector accu(d * 256); + const uint8_t *c = codes; + for (size_t i = 0; i < n; i++) + for(int j = 0; j < d; j++) + accu[j * 256 + *c++]++; + memset (hist, 0, sizeof(*hist) * nbits); + for (int i = 0; i < d; i++) { + const int *ai = accu.data() + i * 256; + int * hi = hist + i * 8; + for (int j = 0; j < 256; j++) + for (int k = 0; k < 8; k++) + if ((j >> k) & 1) + hi[k] += ai[j]; + } + +} + + + +size_t ivec_checksum (size_t n, const int *a) +{ + size_t cs = 112909; + while (n--) cs = cs * 65713 + a[n] * 1686049; + return cs; +} + + +namespace { + struct ArgsortComparator { + const float *vals; + bool operator() (const size_t a, const size_t b) const { + return vals[a] < vals[b]; + } + }; + + struct SegmentS { + size_t i0; // begin pointer in the permutation array + size_t i1; // end + size_t len() const { + return i1 - i0; + } + }; + + // see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge + // extended to > 1 merge thread + + // merges 2 ranges that should be consecutive on the source into + // the union of the two on the destination + template + void parallel_merge (const T *src, T *dst, + SegmentS &s1, SegmentS & s2, int nt, + const ArgsortComparator & comp) { + if (s2.len() > s1.len()) { // make sure that s1 larger than s2 + std::swap(s1, s2); + } + + // compute sub-ranges for each thread + SegmentS s1s[nt], s2s[nt], sws[nt]; + s2s[0].i0 = s2.i0; + s2s[nt - 1].i1 = s2.i1; + + // not sure parallel actually helps here +#pragma omp parallel for num_threads(nt) + for (int t = 0; t < nt; t++) { + s1s[t].i0 = s1.i0 + s1.len() * t / nt; + s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt; + + if (t + 1 < nt) { + T pivot = src[s1s[t].i1]; + size_t i0 = s2.i0, i1 = s2.i1; + while (i0 + 1 < i1) { + size_t imed = (i1 + i0) / 2; + if (comp (pivot, src[imed])) {i1 = imed; } + else {i0 = imed; } + } + s2s[t].i1 = s2s[t + 1].i0 = i1; + } + } + s1.i0 = std::min(s1.i0, s2.i0); + s1.i1 = std::max(s1.i1, s2.i1); + s2 = s1; + sws[0].i0 = s1.i0; + for (int t = 0; t < nt; t++) { + sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len(); + if (t + 1 < nt) { + sws[t + 1].i0 = sws[t].i1; + } + } + assert(sws[nt - 1].i1 == s1.i1); + + // do the actual merging +#pragma omp parallel for num_threads(nt) + for (int t = 0; t < nt; t++) { + SegmentS sw = sws[t]; + SegmentS s1t = s1s[t]; + SegmentS s2t = s2s[t]; + if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) { + for (;;) { + // assert (sw.len() == s1t.len() + s2t.len()); + if (comp(src[s1t.i0], src[s2t.i0])) { + dst[sw.i0++] = src[s1t.i0++]; + if (s1t.i0 == s1t.i1) break; + } else { + dst[sw.i0++] = src[s2t.i0++]; + if (s2t.i0 == s2t.i1) break; + } + } + } + if (s1t.len() > 0) { + assert(s1t.len() == sw.len()); + memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0])); + } else if (s2t.len() > 0) { + assert(s2t.len() == sw.len()); + memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0])); + } + } + } + +}; + +void fvec_argsort (size_t n, const float *vals, + size_t *perm) +{ + for (size_t i = 0; i < n; i++) perm[i] = i; + ArgsortComparator comp = {vals}; + std::sort (perm, perm + n, comp); +} + +void fvec_argsort_parallel (size_t n, const float *vals, + size_t *perm) +{ + size_t * perm2 = new size_t[n]; + // 2 result tables, during merging, flip between them + size_t *permB = perm2, *permA = perm; + + int nt = omp_get_max_threads(); + { // prepare correct permutation so that the result ends in perm + // at final iteration + int nseg = nt; + while (nseg > 1) { + nseg = (nseg + 1) / 2; + std::swap (permA, permB); + } + } + +#pragma omp parallel + for (size_t i = 0; i < n; i++) permA[i] = i; + + ArgsortComparator comp = {vals}; + + SegmentS segs[nt]; + + // independent sorts +#pragma omp parallel for + for (int t = 0; t < nt; t++) { + size_t i0 = t * n / nt; + size_t i1 = (t + 1) * n / nt; + SegmentS seg = {i0, i1}; + std::sort (permA + seg.i0, permA + seg.i1, comp); + segs[t] = seg; + } + int prev_nested = omp_get_nested(); + omp_set_nested(1); + + int nseg = nt; + while (nseg > 1) { + int nseg1 = (nseg + 1) / 2; + int sub_nt = nseg % 2 == 0 ? nt : nt - 1; + int sub_nseg1 = nseg / 2; + +#pragma omp parallel for num_threads(nseg1) + for (int s = 0; s < nseg; s += 2) { + if (s + 1 == nseg) { // otherwise isolated segment + memcpy(permB + segs[s].i0, permA + segs[s].i0, + segs[s].len() * sizeof(size_t)); + } else { + int t0 = s * sub_nt / sub_nseg1; + int t1 = (s + 1) * sub_nt / sub_nseg1; + printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0); + parallel_merge(permA, permB, segs[s], segs[s + 1], + t1 - t0, comp); + } + } + for (int s = 0; s < nseg; s += 2) + segs[s / 2] = segs[s]; + nseg = nseg1; + std::swap (permA, permB); + } + assert (permA == perm); + omp_set_nested(prev_nested); + delete [] perm2; +} + + + + + + + + + + + + + + + + + + +const float *fvecs_maybe_subsample ( + size_t d, size_t *n, size_t nmax, const float *x, + bool verbose, int64_t seed) +{ + + if (*n <= nmax) return x; // nothing to do + + size_t n2 = nmax; + if (verbose) { + printf (" Input training set too big (max size is %ld), sampling " + "%ld / %ld vectors\n", nmax, n2, *n); + } + std::vector subset (*n); + rand_perm (subset.data (), *n, seed); + float *x_subset = new float[n2 * d]; + for (int64_t i = 0; i < n2; i++) + memcpy (&x_subset[i * d], + &x[subset[i] * size_t(d)], + sizeof (x[0]) * d); + *n = n2; + return x_subset; +} + + +void binary_to_real(size_t d, const uint8_t *x_in, float *x_out) { + for (size_t i = 0; i < d; ++i) { + x_out[i] = 2 * ((x_in[i >> 3] >> (i & 7)) & 1) - 1; + } +} + +void real_to_binary(size_t d, const float *x_in, uint8_t *x_out) { + for (size_t i = 0; i < d / 8; ++i) { + uint8_t b = 0; + for (int j = 0; j < 8; ++j) { + if (x_in[8 * i + j] > 0) { + b |= (1 << j); + } + } + x_out[i] = b; + } +} + + +// from Python's stringobject.c +uint64_t hash_bytes (const uint8_t *bytes, int64_t n) { + const uint8_t *p = bytes; + uint64_t x = (uint64_t)(*p) << 7; + int64_t len = n; + while (--len >= 0) { + x = (1000003*x) ^ *p++; + } + x ^= n; + return x; +} + + +bool check_openmp() { + omp_set_num_threads(10); + + if (omp_get_max_threads() != 10) { + return false; + } + + std::vector nt_per_thread(10); + size_t sum = 0; + bool in_parallel = true; +#pragma omp parallel reduction(+: sum) + { + if (!omp_in_parallel()) { + in_parallel = false; + } + + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + nt_per_thread[rank] = nt; +#pragma omp for + for(int i = 0; i < 1000 * 1000 * 10; i++) { + sum += i; + } + } + + if (!in_parallel) { + return false; + } + if (nt_per_thread[0] != 10) { + return false; + } + if (sum == 0) { + return false; + } + + return true; +} + +int64_t get_L3_Size() { + static int64_t l3_size = -1; + constexpr int64_t KB = 1024; + if (l3_size == -1) { + + FILE* file = fopen("/sys/devices/system/cpu/cpu0/cache/index3/size","r"); + int64_t result = 0; + constexpr int64_t line_length = 128; + char line[line_length]; + if (file){ + char* ret = fgets(line, sizeof(line) - 1, file); + + sscanf(line, "%luK", &result); + l3_size = result * KB; + + fclose(file); + } else { + l3_size = 12 * KB * KB; // 12M + } + + } + return l3_size; +} + +void (*LOG_TRACE_)(const std::string&); + +void (*LOG_DEBUG_)(const std::string&); + +void (*LOG_INFO_)(const std::string&); + +void (*LOG_WARNING_)(const std::string&); + +void (*LOG_FATAL_)(const std::string&); + +void (*LOG_ERROR_)(const std::string&); + +} // namespace faiss diff --git a/core/src/index/thirdparty/faiss/utils/utils.h b/core/src/index/thirdparty/faiss/utils/utils.h new file mode 100644 index 0000000000..9be65b10a0 --- /dev/null +++ b/core/src/index/thirdparty/faiss/utils/utils.h @@ -0,0 +1,182 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * A few utilitary functions for similarity search: + * - optimized exhaustive distance and knn search functions + * - some functions reimplemented from torch for speed + */ + +#ifndef FAISS_utils_h +#define FAISS_utils_h + +#include +#include + +#include + + +namespace faiss { + + +/************************************************** + * Get some stats about the system +**************************************************/ + + +/// ms elapsed since some arbitrary epoch +double getmillisecs (); + +/// get current RSS usage in kB +size_t get_mem_usage_kb (); + + +uint64_t get_cycles (); + +/*************************************************************************** + * Misc matrix and vector manipulation functions + ***************************************************************************/ + + +/** compute c := a + bf * b for a, b and c tables + * + * @param n size of the tables + * @param a size n + * @param b size n + * @param c restult table, size n + */ +void fvec_madd (size_t n, const float *a, + float bf, const float *b, float *c); + + +/** same as fvec_madd, also return index of the min of the result table + * @return index of the min of table c + */ +int fvec_madd_and_argmin (size_t n, const float *a, + float bf, const float *b, float *c); + + +/* perform a reflection (not an efficient implementation, just for test ) */ +void reflection (const float * u, float * x, size_t n, size_t d, size_t nu); + + +/** compute the Q of the QR decomposition for m > n + * @param a size n * m: input matrix and output Q + */ +void matrix_qr (int m, int n, float *a); + +/** distances are supposed to be sorted. Sorts indices with same distance*/ +void ranklist_handle_ties (int k, int64_t *idx, const float *dis); + +/** count the number of comon elements between v1 and v2 + * algorithm = sorting + bissection to avoid double-counting duplicates + */ +size_t ranklist_intersection_size (size_t k1, const int64_t *v1, + size_t k2, const int64_t *v2); + +/** merge a result table into another one + * + * @param I0, D0 first result table, size (n, k) + * @param I1, D1 second result table, size (n, k) + * @param keep_min if true, keep min values, otherwise keep max + * @param translation add this value to all I1's indexes + * @return nb of values that were taken from the second table + */ +size_t merge_result_table_with (size_t n, size_t k, + int64_t *I0, float *D0, + const int64_t *I1, const float *D1, + bool keep_min = true, + int64_t translation = 0); + + +/// a balanced assignment has a IF of 1 +double imbalance_factor (int n, int k, const int64_t *assign); + +/// same, takes a histogram as input +double imbalance_factor (int k, const int *hist); + + +void fvec_argsort (size_t n, const float *vals, + size_t *perm); + +void fvec_argsort_parallel (size_t n, const float *vals, + size_t *perm); + + +/// compute histogram on v +int ivec_hist (size_t n, const int * v, int vmax, int *hist); + +/** Compute histogram of bits on a code array + * + * @param codes size(n, nbits / 8) + * @param hist size(nbits): nb of 1s in the array of codes + */ +void bincode_hist(size_t n, size_t nbits, const uint8_t *codes, int *hist); + + +/// compute a checksum on a table. +size_t ivec_checksum (size_t n, const int *a); + + +/** random subsamples a set of vectors if there are too many of them + * + * @param d dimension of the vectors + * @param n on input: nb of input vectors, output: nb of output vectors + * @param nmax max nb of vectors to keep + * @param x input array, size *n-by-d + * @param seed random seed to use for sampling + * @return x or an array allocated with new [] with *n vectors + */ +const float *fvecs_maybe_subsample ( + size_t d, size_t *n, size_t nmax, const float *x, + bool verbose = false, int64_t seed = 1234); + +/** Convert binary vector to +1/-1 valued float vector. + * + * @param d dimension of the vector (multiple of 8) + * @param x_in input binary vector (uint8_t table of size d / 8) + * @param x_out output float vector (float table of size d) + */ +void binary_to_real(size_t d, const uint8_t *x_in, float *x_out); + +/** Convert float vector to binary vector. Components > 0 are converted to 1, + * others to 0. + * + * @param d dimension of the vector (multiple of 8) + * @param x_in input float vector (float table of size d) + * @param x_out output binary vector (uint8_t table of size d / 8) + */ +void real_to_binary(size_t d, const float *x_in, uint8_t *x_out); + + +/** A reasonable hashing function */ +uint64_t hash_bytes (const uint8_t *bytes, int64_t n); + +/** Whether OpenMP annotations were respected. */ +bool check_openmp(); + +/** get the size of L3 cache */ +int64_t get_L3_Size(); + +extern void (*LOG_TRACE_)(const std::string&); + +extern void (*LOG_DEBUG_)(const std::string&); + +extern void (*LOG_INFO_)(const std::string&); + +extern void (*LOG_WARNING_)(const std::string&); + +extern void (*LOG_FATAL_)(const std::string&); + +extern void (*LOG_ERROR_)(const std::string&); + +} // namspace faiss + + +#endif /* FAISS_utils_h */ diff --git a/core/src/index/thirdparty/hnswlib/bruteforce.h b/core/src/index/thirdparty/hnswlib/bruteforce.h new file mode 100644 index 0000000000..ae2fa6a8f6 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/bruteforce.h @@ -0,0 +1,164 @@ +#pragma once +#include +#include +#include +#include + +namespace hnswlib { + +template +class BruteforceSearch : public AlgorithmInterface { + public: + BruteforceSearch(SpaceInterface *s) { + + } + BruteforceSearch(SpaceInterface *s, const std::string &location) { + loadIndex(location, s); + } + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + ~BruteforceSearch() { + free(data_); + } + + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + void addPoint(const void *datapoint, labeltype label) { + + size_t idx; + { + std::unique_lock lock(index_lock); + + + + auto search=dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx=search->second; + } + else{ + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); + } + idx=cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; + } + } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + + + + + }; + + void removePoint(labeltype cur_external) { + size_t cur_c=dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label]=cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + } + + std::priority_queue> + searchKnn(const void *query_data, size_t k) const { + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (size_t i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + } + dist_t lastdist = topResults.top().first; + for (size_t i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + + data_size_)))); + if (topResults.size() > k) + topResults.pop(); + lastdist = topResults.top().first; + } + } + return topResults; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + + output.write(data_, maxelements_ * size_per_element_); + + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + std::streampos position; + + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + + input.read(data_, maxelements_ * size_per_element_); + + input.close(); + } +}; +} diff --git a/core/src/index/thirdparty/hnswlib/hnswalg.h b/core/src/index/thirdparty/hnswlib/hnswalg.h new file mode 100644 index 0000000000..61158a61d8 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswalg.h @@ -0,0 +1,1156 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib.h" +#include +#include +#include +#include + +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace hnswlib { + +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; + +template +class HierarchicalNSW : public AlgorithmInterface { + public: + HierarchicalNSW(SpaceInterface *s) { + } + + HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + loadIndex(location, s, max_elements); + } + + HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : + link_list_locks_(max_elements), element_levels_(max_elements) { + // linxj + space = s; + if (auto x = dynamic_cast(s)) { + metric_type_ = 0; + } else if (auto x = dynamic_cast(s)) { + metric_type_ = 1; + } else { + metric_type_ = 100; + } + + max_elements_ = max_elements; + + has_deletions_=false; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction,M_); + ef_ = 10; + + level_generator_.seed(random_seed); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + + //initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + struct CompareByFirst { + constexpr bool operator()(std::pair const &a, + std::pair const &b) const noexcept { + return a.first < b.first; + } + }; + + ~HierarchicalNSW() { + + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + + // linxj: delete + delete space; + } + + // linxj: use for free resource + SpaceInterface *space; + size_t metric_type_; // 0:l2, 1:ip + + size_t max_elements_; + size_t cur_element_count; + size_t size_data_per_element_; + size_t size_links_per_element_; + + size_t M_; + size_t maxM_; + size_t maxM0_; + size_t ef_construction_; + + double mult_, revSize_; + int maxlevel_; + + + VisitedListPool *visited_list_pool_; + std::mutex cur_element_count_guard_; + + std::vector link_list_locks_; + tableint enterpoint_node_; + + + size_t size_links_level0_; + size_t offsetData_, offsetLevel0_; + + + char *data_level0_memory_; + char **linkLists_; + std::vector element_levels_; + + size_t data_size_; + + bool has_deletions_; + + + size_t label_offset_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } + + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); + // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, faiss::ConcurrentBitsetPtr bitset) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; +// if (!has_deletions || !isMarkedDeleted(ep_id)) { + if (!has_deletions || !bitset->test((faiss::ConcurrentBitset::id_type_t)getExternalLabel(ep_id))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); + // bool cur_node_deleted = isMarkedDeleted(current_node_id); + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0);//////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + + visited_array[candidate_id] = visited_array_tag; + + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_,/////////// + _MM_HINT_T0);//////////////////////// +#endif + +// if (!has_deletions || !isMarkedDeleted(candidate_id)) + if (!has_deletions || (!bitset->test((faiss::ConcurrentBitset::id_type_t)getExternalLabel(candidate_id)))) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_);; + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + } + + + } + + for (std::pair curent_pair : return_list) { + + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + }; + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + }; + + void mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> top_candidates, + int level) { + + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur,selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + + + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx]) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + + } + } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + + } + } + + std::mutex global; + size_t ef_; + + void setEf(size_t ef) { + ef_ = ef; + } + + + std::priority_queue> searchKnnInternal(void *query_data, int k) { + std::priority_queue> top_candidates; + if (cur_element_count == 0) return top_candidates; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (size_t level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + int *data; + data = (int *) get_linklist(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + if (has_deletions_) { + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_); + top_candidates.swap(top_candidates1); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + return top_candidates; + }; + + void resizeIndex(size_t new_max_elements){ + if (new_max_elements(new_max_elements).swap(link_list_locks_); + + + // Reallocate base layer + char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); + free(data_level0_memory_); + data_level0_memory_=data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); + free(linkLists_); + linkLists_=linkLists_new; + + max_elements_=new_max_elements; + + } + + void saveIndex(milvus::knowhere::MemoryIOWriter& output) { + // write l2/ip calculator + writeBinaryPOD(output, metric_type_); + writeBinaryPOD(output, data_size_); + writeBinaryPOD(output, *((size_t *) dist_func_param_)); + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + // output.close(); + } + + void loadIndex(milvus::knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + // linxj: init with metrictype + size_t dim = 100; + readBinaryPOD(input, metric_type_); + readBinaryPOD(input, data_size_); + readBinaryPOD(input, dim); + if (metric_type_ == 0) { + space = new hnswlib::L2Space(dim); + } else if (metric_type_ == 1) { + space = new hnswlib::InnerProductSpace(dim); + } else { + // throw exception + } + fstdistfunc_ = space->get_dist_func(); + dist_func_param_ = space->get_dist_func_param(); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + + // data_size_ = s->get_data_size(); + // fstdistfunc_ = s->get_dist_func(); + // dist_func_param_ = s->get_dist_func_param(); + + // auto pos= input.rp; + + + // /// Optional - check if index is ok: + // + // input.seekg(cur_element_count * size_data_per_element_,input.cur); + // for (size_t i = 0; i < cur_element_count; i++) { + // if(input.tellg() < 0 || input.tellg()>=total_filesize){ + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // } + // + // unsigned int linkListSize; + // readBinaryPOD(input, linkListSize); + // if (linkListSize != 0) { + // input.seekg(linkListSize,input.cur); + // } + // } + // + // // throw exception if it either corrupted or old index + // if(input.tellg()!=total_filesize) + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // + // input.clear(); + // + // /// Optional check end + // + // input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + return; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0,input.end); + std::streampos total_filesize=input.tellg(); + input.seekg(0,input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos=input.tellg(); + + /// Optional - check if index is ok: + + input.seekg(cur_element_count * size_data_per_element_,input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if(input.tellg() < 0 || input.tellg()>=total_filesize){ + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize,input.cur); + } + } + + // throw exception if it either corrupted or old index + if(input.tellg()!=total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + + /// Optional check end + + input.seekg(pos,input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + input.close(); + return; + } + + template + std::vector getDataByLabel(labeltype label) { + tableint label_c; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + label_c = search->second; + + char* data_ptrv = getDataByInternalId(label_c); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + static const unsigned char DELETE_MARK = 0x01; + // static const unsigned char REUSE_MARK = 0x10; + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + * @param label + */ + void markDelete(labeltype label) + { + has_deletions_=true; + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + markDeletedInternal(search->second); + } + + /** + * Uses the first 8 bits of the memory for the linked list to store the mark, + * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. + * @param internalId + */ + void markDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + } + + /** + * Remove the deleted mark of the node. + * @param internalId + */ + void unmarkDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + } + + /** + * Checks the first 8 bits of the memory to see if the element is marked deleted. + * @param internalId + * @return + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; + return *ll_cur & DELETE_MARK; + } + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + void addPoint(const void *data_point, labeltype label) { + addPoint(data_point, label,-1); + } + + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + std::unique_lock lock(cur_element_count_guard_); + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + }; + + cur_c = cur_element_count; + cur_element_count++; + + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + std::unique_lock lock_el(link_list_locks_[search->second]); + has_deletions_ = true; + markDeletedInternal(search->second); + } + label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + + if (curlevel < maxlevelcopy) { + + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj,level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); + + currObj = top_candidates.top().second; + } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + }; + + std::priority_queue> + searchKnn(const void *query_data, size_t k, faiss::ConcurrentBitsetPtr bitset) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (bitset != nullptr) { + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset); + top_candidates.swap(top_candidates1); + } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + }; + + template + std::vector> + searchKnn(const void* query_data, size_t k, Comp comp, faiss::ConcurrentBitsetPtr bitset) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn(query_data, k, bitset); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += sizeof(*space); + ret += visited_list_pool_->GetSize(); + ret += link_list_locks_.size() * sizeof(std::mutex); + ret += element_levels_.size() * sizeof(int); + ret += max_elements_ * size_data_per_element_; + ret += max_elements_ * sizeof(void*); + for (size_t i = 0; i < max_elements_; ++ i) { + ret += linkLists_[i] ? size_links_per_element_ * element_levels_[i] : 0; + } + return ret; + } + + }; + +} + + + + + diff --git a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h new file mode 100644 index 0000000000..a39563d2b2 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h @@ -0,0 +1,1227 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib_nm.h" +#include +#include +#include +#include + +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "faiss/impl/ScalarQuantizer.h" +#include "faiss/impl/ScalarQuantizerCodec.h" + +namespace hnswlib_nm { + + typedef unsigned int tableint; + typedef unsigned int linklistsizeint; + + using QuantizerClass = faiss::QuantizerTemplate; + using DCClassIP = faiss::DCTemplate, 1>; + using DCClassL2 = faiss::DCTemplate, 1>; + + template + class HierarchicalNSW_NM : public AlgorithmInterface { + public: + HierarchicalNSW_NM(SpaceInterface *s) { + } + + HierarchicalNSW_NM(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + loadIndex(location, s, max_elements); + } + + HierarchicalNSW_NM(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : + link_list_locks_(max_elements), element_levels_(max_elements) { + // linxj + space = s; + if (auto x = dynamic_cast(s)) { + metric_type_ = 0; + } else if (auto x = dynamic_cast(s)) { + metric_type_ = 1; + } else { + metric_type_ = 100; + } + + max_elements_ = max_elements; + + has_deletions_=false; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction,M_); + ef_ = 10; + + is_sq8_ = false; + sq_ = nullptr; + + level_generator_.seed(random_seed); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_; // + sizeof(labeltype); + data_size_;; +// label_offset_ = size_links_level0_; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); + + + + //initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW_NM failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + struct CompareByFirst { + constexpr bool operator()(std::pair const &a, + std::pair const &b) const noexcept { + return a.first < b.first; + } + }; + + ~HierarchicalNSW_NM() { + + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + + if (sq_) delete sq_; + + // linxj: delete + delete space; + } + + // linxj: use for free resource + SpaceInterface *space; + size_t metric_type_; // 0:l2, 1:ip + + size_t max_elements_; + size_t cur_element_count; + size_t size_data_per_element_; + size_t size_links_per_element_; + + size_t M_; + size_t maxM_; + size_t maxM0_; + size_t ef_construction_; + + bool is_sq8_ = false; + faiss::ScalarQuantizer *sq_ = nullptr; + + double mult_, revSize_; + int maxlevel_; + + + VisitedListPool *visited_list_pool_; + std::mutex cur_element_count_guard_; + + std::vector link_list_locks_; + tableint enterpoint_node_; + + + size_t size_links_level0_; + + + char *data_level0_memory_; + char **linkLists_; + std::vector element_levels_; + + size_t data_size_; + + bool has_deletions_; + + + DISTFUNC fstdistfunc_; + void *dist_func_param_; + + std::default_random_engine level_generator_; + + inline char *getDataByInternalId(void *pdata, tableint offset) const { + return ((char*)pdata + offset * data_size_); + } + + void SetSq8(const float *trained) { + if (!trained) + throw std::runtime_error("trained sq8 data cannot be null in SetSq8!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code + sq_->trained.resize((sq_->d) << 1); + memcpy(sq_->trained.data(), trained, sq_->trained.size() * sizeof(float)); + } + + void sq_train(size_t nb, const float *xb, uint8_t *p_codes) { + if (!p_codes) + throw std::runtime_error("p_codes cannot be null in sq_train!"); + if (!xb) + throw std::runtime_error("base vector cannot be null in sq_train!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code + sq_->train(nb, xb); + sq_->compute_codes(xb, p_codes, nb); + memcpy(p_codes + *(size_t*)dist_func_param_ * nb, sq_->trained.data(), *(size_t*)dist_func_param_ * sizeof(float) * 2); + } + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer, void *pdata) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); + // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(pdata, candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(pdata, candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, faiss::ConcurrentBitsetPtr bitset, void *pdata) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + faiss::SQDistanceComputer *sqdc = nullptr; + if (is_sq8_) { + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->codes = (uint8_t*)pdata; + sqdc->set_query((const float*)data_point); + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; +// if (!has_deletions || !isMarkedDeleted(ep_id)) { + if (!has_deletions || !bitset->test((faiss::ConcurrentBitset::id_type_t)(ep_id))) { + dist_t dist; + if (is_sq8_) { + dist = (*sqdc)(ep_id); + } else { + dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); + } + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); + // bool cur_node_deleted = isMarkedDeleted(current_node_id); + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); +// _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(data + j + 1)), + _MM_HINT_T0);//////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + + visited_array[candidate_id] = visited_array_tag; + + dist_t dist; + if (is_sq8_) { + dist = (*sqdc)(candidate_id); + } else { + char *currObj1 = (getDataByInternalId(pdata, candidate_id)); + dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + } + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_,/////////// + _MM_HINT_T0);//////////////////////// +#endif + +// if (!has_deletions || !isMarkedDeleted(candidate_id)) + if (!has_deletions || (!bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id)))) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + if (is_sq8_) delete sqdc; + return top_candidates; + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M, tableint *ret, size_t &ret_len, void *pdata) { + if (top_candidates.size() < M) { + while (top_candidates.size() > 0) { + ret[ret_len ++] = top_candidates.top().second; + top_candidates.pop(); + } + return; + } + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(pdata, second_pair.second), + getDataByInternalId(pdata, curent_pair.second), + dist_func_param_);; + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + ret[ret_len ++] = curent_pair.second; + } + + + } + +// for (std::pair curent_pair : return_list) { +// +// top_candidates.emplace(-curent_pair.first, curent_pair.second); +// } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); + }; + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); + }; + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + }; + + void mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, void *pdata) { + + size_t Mcurmax = level ? maxM_ : maxM0_; +// std::vector selectedNeighbors; +// selectedNeighbors.reserve(M_); + tableint *selectedNeighbors = (tableint*)malloc(sizeof(tableint) * M_); + size_t selectedNeighbors_size = 0; + getNeighborsByHeuristic2(top_candidates, M_, selectedNeighbors, selectedNeighbors_size, pdata); + if (selectedNeighbors_size > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + +// while (top_candidates.size() > 0) { +// selectedNeighbors.push_back(top_candidates.top().second); +// top_candidates.pop(); +// } + + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur,(unsigned short)selectedNeighbors_size); + tableint *data = (tableint *) (ll_cur + 1); + + + for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { + if (data[idx]) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + + } + } + for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { + + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(pdata, cur_c), getDataByInternalId(pdata, selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(pdata, data[j]), getDataByInternalId(pdata, selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + size_t indx = 0; + getNeighborsByHeuristic2(candidates, Mcurmax, data, indx, pdata); + +// while (candidates.size() > 0) { +// data[indx] = candidates.top().second; +// candidates.pop(); +// indx++; +// } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + + } + } + + std::mutex global; + size_t ef_; + + void setEf(size_t ef) { + ef_ = ef; + } + + + std::priority_queue> searchKnnInternal(void *query_data, int k, dist_t *pdata) { + std::priority_queue> top_candidates; + if (cur_element_count == 0) return top_candidates; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); + + for (size_t level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + int *data; + data = (int *) get_linklist(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + if (has_deletions_) { + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_, pdata); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_, pdata); + top_candidates.swap(top_candidates1); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + return top_candidates; + }; + + void resizeIndex(size_t new_max_elements){ + if (new_max_elements(new_max_elements).swap(link_list_locks_); + + // Reallocate base layer + char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); + free(data_level0_memory_); + data_level0_memory_=data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); + free(linkLists_); + linkLists_=linkLists_new; + + max_elements_=new_max_elements; + + } + + void saveIndex(milvus::knowhere::MemoryIOWriter& output) { + // write l2/ip calculator + writeBinaryPOD(output, metric_type_); + writeBinaryPOD(output, data_size_); + writeBinaryPOD(output, *((size_t *) dist_func_param_)); + +// writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); +// writeBinaryPOD(output, label_offset_); +// writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + // output.close(); + } + + void loadIndex(milvus::knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + // linxj: init with metrictype + size_t dim = 100; + readBinaryPOD(input, metric_type_); + readBinaryPOD(input, data_size_); + readBinaryPOD(input, dim); + if (metric_type_ == 0) { + space = new L2Space(dim); + } else if (metric_type_ == 1) { + space = new InnerProductSpace(dim); + } else { + // throw exception + } + fstdistfunc_ = space->get_dist_func(); + dist_func_param_ = space->get_dist_func_param(); + +// readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); +// readBinaryPOD(input, label_offset_); +// readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + + // data_size_ = s->get_data_size(); + // fstdistfunc_ = s->get_dist_func(); + // dist_func_param_ = s->get_dist_func_param(); + + // auto pos= input.rp; + + + // /// Optional - check if index is ok: + // + // input.seekg(cur_element_count * size_data_per_element_,input.cur); + // for (size_t i = 0; i < cur_element_count; i++) { + // if(input.tellg() < 0 || input.tellg()>=total_filesize){ + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // } + // + // unsigned int linkListSize; + // readBinaryPOD(input, linkListSize); + // if (linkListSize != 0) { + // input.seekg(linkListSize,input.cur); + // } + // } + // + // // throw exception if it either corrupted or old index + // if(input.tellg()!=total_filesize) + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // + // input.clear(); + // + // /// Optional check end + // + // input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { +// label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + return; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + +// writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); +// writeBinaryPOD(output, label_offset_); +// writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0,input.end); + std::streampos total_filesize=input.tellg(); + input.seekg(0,input.beg); + +// readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); +// readBinaryPOD(input, label_offset_); +// readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos=input.tellg(); + + /// Optional - check if index is ok: + + input.seekg(cur_element_count * size_data_per_element_,input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if(input.tellg() < 0 || input.tellg()>=total_filesize){ + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize,input.cur); + } + } + + // throw exception if it either corrupted or old index + if(input.tellg()!=total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + + /// Optional check end + + input.seekg(pos,input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { +// label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + input.close(); + return; + } + + /* + template + std::vector getDataByLabel(tableint internal_id, dist_t *pdata) { + // tableint label_c; + // auto search = label_lookup_.find(label); + // if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + // throw std::runtime_error("Label not found"); + // } + // label_c = search->second; + + char* data_ptrv = getDataByInternalId(pdata, internal_id); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + */ + + static const unsigned char DELETE_MARK = 0x01; + // static const unsigned char REUSE_MARK = 0x10; + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + * @param label + */ + void markDelete(labeltype label) + { + has_deletions_=true; +// auto search = label_lookup_.find(label); +// if (search == label_lookup_.end()) { +// throw std::runtime_error("Label not found"); +// } +// markDeletedInternal(search->second); + markDeletedInternal(label); + } + + /** + * Uses the first 8 bits of the memory for the linked list to store the mark, + * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. + * @param internalId + */ + void markDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + } + + /** + * Remove the deleted mark of the node. + * @param internalId + */ + void unmarkDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + } + + /** + * Checks the first 8 bits of the memory to see if the element is marked deleted. + * @param internalId + * @return + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; + return *ll_cur & DELETE_MARK; + } + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + void addPoint(void *data_point, labeltype label, size_t base, size_t offset) { + addPoint(data_point, label,-1, base, offset); + } + + tableint addPoint(void *data_point, labeltype label, int level, size_t base, size_t offset) { + tableint cur_c = 0; + { + std::unique_lock lock(cur_element_count_guard_); + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + }; + +// cur_c = cur_element_count; + cur_c = tableint(base + offset); + cur_element_count++; + +// auto search = label_lookup_.find(label); +// if (search != label_lookup_.end()) { +// std::unique_lock lock_el(link_list_locks_[search->second]); +// has_deletions_ = true; +// markDeletedInternal(search->second); +// } +// label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + // prepose non-concurrent operation + memset(data_level0_memory_ + cur_c * size_data_per_element_, 0, size_data_per_element_); +// setExternalLabel(cur_c, label); +// memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + if ((signed)currObj != -1) { + + if (curlevel < maxlevelcopy) { + + dist_t curdist = fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj,level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(getDataByInternalId(data_point, tableint(offset)), getDataByInternalId(data_point, cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, getDataByInternalId(data_point, (tableint)offset), level, data_point); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + currObj = top_candidates.top().second; + + mutuallyConnectNewElement(getDataByInternalId(data_point, (tableint)offset), cur_c, top_candidates, level, data_point); + } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + }; + + std::priority_queue> + searchKnn_NM(const void *query_data, size_t k, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist; + faiss::SQDistanceComputer *sqdc = nullptr; + if (is_sq8_) { + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->set_query((const float*)query_data); + sqdc->codes = (uint8_t*)pdata; + curdist = (*sqdc)(currObj); + } else { + curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); + } + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d; + if (is_sq8_) { + d = (*sqdc)(cand); + } else { + d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); + } + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (bitset != nullptr) { + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); + top_candidates.swap(top_candidates1); + } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); +// result.push(std::pair(rez.first, getExternalLabel(rez.second))); + result.push(std::pair(rez.first, rez.second)); + top_candidates.pop(); + } + if (is_sq8_) delete sqdc; + return result; + }; + + template + std::vector> + searchKnn_NM(const void* query_data, size_t k, Comp comp, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn_NM(query_data, k, bitset, pdata); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += sizeof(*space); + ret += visited_list_pool_->GetSize(); + ret += link_list_locks_.size() * sizeof(std::mutex); + ret += element_levels_.size() * sizeof(int); + ret += max_elements_ * size_data_per_element_; + ret += max_elements_ * sizeof(void*); + for (auto i = 0; i < max_elements_; ++ i) { + ret += linkLists_[i] ? size_links_per_element_ * element_levels_[i] : 0; + } + return ret; + } + }; + +} diff --git a/core/src/index/thirdparty/hnswlib/hnswlib.h b/core/src/index/thirdparty/hnswlib/hnswlib.h new file mode 100644 index 0000000000..84c92436c6 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswlib.h @@ -0,0 +1,101 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#ifdef __SSE__ +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +#else +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif +#endif + +#include +#include +#include + +#include +#include + +namespace hnswlib { + typedef int64_t labeltype; + + template + class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } + }; + + template + static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + static void writeBinaryPOD(W &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(R &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + + + template + class SpaceInterface { + public: + //virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + SpaceInterface() =default; + + virtual ~SpaceInterface() =default; + }; + + template + class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label)=0; + virtual std::priority_queue> searchKnn(const void *, size_t, faiss::ConcurrentBitsetPtr bitset) const = 0; + template + std::vector> searchKnn(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset) { + } + virtual void saveIndex(const std::string &location)=0; + virtual ~AlgorithmInterface(){ + } + }; +} + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg.h" + diff --git a/core/src/index/thirdparty/hnswlib/hnswlib_nm.h b/core/src/index/thirdparty/hnswlib/hnswlib_nm.h new file mode 100644 index 0000000000..31142568ee --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswlib_nm.h @@ -0,0 +1,98 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#ifdef __SSE__ +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +#else +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif +#endif + +#include +#include +#include + +#include +#include + +namespace hnswlib_nm { + typedef int64_t labeltype; + + template + class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } + }; + + template + static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + static void writeBinaryPOD(W &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(R &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + + + template + class SpaceInterface { + public: + //virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} + }; + + template + class AlgorithmInterface { + public: + virtual void addPoint(void *datapoint, labeltype label, size_t base, size_t offset)=0; + virtual std::priority_queue> searchKnn_NM(const void *, size_t, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const = 0; + template + std::vector> searchKnn_NM(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) { + } + virtual void saveIndex(const std::string &location)=0; + virtual ~AlgorithmInterface(){ + } + }; +} + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg_nm.h" \ No newline at end of file diff --git a/core/src/index/thirdparty/hnswlib/space_ip.h b/core/src/index/thirdparty/hnswlib/space_ip.h new file mode 100644 index 0000000000..fc25485e7d --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/space_ip.h @@ -0,0 +1,253 @@ +#pragma once +#include "hnswlib.h" +#include + +namespace hnswlib { + +static float +InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { +#if 0 /* use FAISS distance calculation algorithm instead */ + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; + } + return (1.0f - res); +#else + return (1.0f - faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr))); +#endif +} + +#if 0 /* use FAISS distance calculation algorithm instead */ +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; + return 1.0f - sum; +} + +#elif defined(USE_SSE) + +static float +InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; +} + +#endif + +#if defined(USE_AVX) + +static float +InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return 1.0f - sum; +} + +#elif defined(USE_SSE) + +static float +InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return 1.0f - sum; +} + +#endif +#endif + +class InnerProductSpace : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProduct; +#if 0 /* use FAISS distance calculation algorithm instead */ +#if defined(USE_AVX) || defined(USE_SSE) + if (dim % 4 == 0) + fstdistfunc_ = InnerProductSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = InnerProductSIMD16Ext; +#endif +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~InnerProductSpace() {} +}; +} diff --git a/core/src/index/thirdparty/hnswlib/space_l2.h b/core/src/index/thirdparty/hnswlib/space_l2.h new file mode 100644 index 0000000000..3fd0d2da2c --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/space_l2.h @@ -0,0 +1,245 @@ +#pragma once +#include "hnswlib.h" +#include + +namespace hnswlib { + +static float +L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { +#if 0 /* use FAISS distance calculation algorithm instead */ + //return *((float *)pVect2); + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; + res += t * t; + } + return (res); +#else + return faiss::fvec_L2sqr((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr)); +#endif +} + +#if 0 /* use FAISS distance calculation algorithm instead */ +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return (res); +} + +#elif defined(USE_SSE) + +static float +L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + // const float* pEnd2 = pVect1 + (qty4 << 2); + // const float* pEnd3 = pVect1 + qty; + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); +} +#endif + + +#ifdef USE_SSE +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + + // size_t qty4 = qty >> 2; + size_t qty16 = qty >> 2; + + const float *pEnd1 = pVect1 + (qty16 << 2); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return (res); +} +#endif +#endif + +class L2Space : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if 0 /* use FAISS distance calculation algorithm instead */ +#if defined(USE_SSE) || defined(USE_AVX) + if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + /*else{ + throw runtime_error("Data type not supported!"); + }*/ +#endif +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2Space() {} +}; + +static int +L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + /*for (int i = 0; i < qty; i++) { + int t = int((a)[i]) - int((b)[i]); + res += t*t; + }*/ + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + + return (res); +} + +class L2SpaceI : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + L2SpaceI(size_t dim) { + fstdistfunc_ = L2SqrI; + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2SpaceI() {} +}; + +} diff --git a/core/src/index/thirdparty/hnswlib/visited_list_pool.h b/core/src/index/thirdparty/hnswlib/visited_list_pool.h new file mode 100644 index 0000000000..c4f7d4bf6f --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/visited_list_pool.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include + +namespace hnswlib { +typedef unsigned short int vl_type; + +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + }; + + + ~VisitedList() { delete[] mass; } +}; + +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + }; + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + }; + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + }; + + int64_t GetSize() { + auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type); + auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size); + return pool_size + sizeof(*this); + } +}; +} + diff --git a/core/src/index/thirdparty/versions.txt b/core/src/index/thirdparty/versions.txt new file mode 100644 index 0000000000..c5cea80234 --- /dev/null +++ b/core/src/index/thirdparty/versions.txt @@ -0,0 +1,6 @@ +ARROW_VERSION=apache-arrow-0.15.1 +BOOST_VERSION=1.70.0 +GTEST_VERSION=1.8.1 +LAPACK_VERSION=v3.8.0 +OPENBLAS_VERSION=0.3.9 +MKL_VERSION=2019.5.281 diff --git a/core/src/index/unittest/CMakeLists.txt b/core/src/index/unittest/CMakeLists.txt new file mode 100644 index 0000000000..bbdd3183a6 --- /dev/null +++ b/core/src/index/unittest/CMakeLists.txt @@ -0,0 +1,269 @@ +include_directories(${INDEX_SOURCE_DIR}/thirdparty) +include_directories(${INDEX_SOURCE_DIR}/thirdparty/SPTAG/AnnService) +include_directories(${INDEX_SOURCE_DIR}/knowhere) +include_directories(${INDEX_SOURCE_DIR}) + +set(depend_libs + gtest gmock gtest_main gmock_main + faiss + ) +if (FAISS_WITH_MKL) + set(depend_libs ${depend_libs} + "-Wl,--start-group \ + ${MKL_LIB_PATH}/libmkl_intel_ilp64.a \ + ${MKL_LIB_PATH}/libmkl_gnu_thread.a \ + ${MKL_LIB_PATH}/libmkl_core.a \ + -Wl,--end-group -lgomp -lpthread -lm -ldl" + ) +else () + set(depend_libs ${depend_libs} + ${OpenBLAS_LIBRARIES} + ${LAPACK_LIBRARIES} + ) +endif () + +set(basic_libs + gomp gfortran pthread + ) + +set(util_srcs + ${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Log.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp + ${INDEX_SOURCE_DIR}/unittest/utils.cpp + ) + +if (MILVUS_GPU_VERSION) + include_directories(${CUDA_INCLUDE_DIRS}) + link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") + set(cuda_lib + cudart + cublas + ) + set(basic_libs ${basic_libs} + ${cuda_lib} + ) + set(util_srcs ${util_srcs} + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp + ) +endif () + +set(faiss_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVF.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp + ) +if (MILVUS_GPU_VERSION) +set(faiss_srcs ${faiss_srcs} + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp + ) +endif () + +################################################################################ +# +if (NOT TARGET test_instructionset) + add_executable(test_instructionset test_instructionset.cpp) +endif () +target_link_libraries(test_instructionset ${depend_libs} ${unittest_libs}) +install(TARGETS test_instructionset DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_knowhere_common) + add_executable(test_knowhere_common test_common.cpp ${util_srcs}) +endif () +target_link_libraries(test_knowhere_common ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_knowhere_common DESTINATION unittest) + +if (MILVUS_GPU_VERSION) +################################################################################ +# +add_executable(test_gpuresource test_gpuresource.cpp ${util_srcs} ${faiss_srcs}) +target_link_libraries(test_gpuresource ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_gpuresource DESTINATION unittest) + +################################################################################ +# +add_executable(test_customized_index test_customized_index.cpp ${util_srcs} ${faiss_srcs}) +target_link_libraries(test_customized_index ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_customized_index DESTINATION unittest) +endif () + +################################################################################ +# +if (NOT TARGET test_idmap) + add_executable(test_idmap test_idmap.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_idmap ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_idmap DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivf) + add_executable(test_ivf test_ivf.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivf ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivf DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivf_cpu_nm) + add_executable(test_ivf_cpu_nm test_ivf_cpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivf_cpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivf_cpu_nm DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivf_gpu_nm) + add_executable(test_ivf_gpu_nm test_ivf_gpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivf_gpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivf_gpu_nm DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_binaryidmap) + add_executable(test_binaryidmap test_binaryidmap.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_binaryidmap ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_binaryidmap DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_binaryivf) + add_executable(test_binaryivf test_binaryivf.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_binaryivf ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_binaryivf DESTINATION unittest) + + +################################################################################ +# +add_definitions(-std=c++11 -O3 -march=native -Werror -DINFO) + +find_package(OpenMP REQUIRED) +if (OpenMP_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +else () + message(FATAL_ERROR "no OpenMP supprot") +endif () + +include_directories(${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/impl/nsg) +aux_source_directory(${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/impl/nsg nsg_src) +set(interface_src + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp + ) +if (NOT TARGET test_nsg) + add_executable(test_nsg test_nsg.cpp ${interface_src} ${nsg_src} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_nsg ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_nsg DESTINATION unittest) + +################################################################################ +# +set(hnsw_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexHNSW.cpp + ) +if (NOT TARGET test_hnsw) + add_executable(test_hnsw test_hnsw.cpp ${hnsw_srcs} ${util_srcs}) +endif () +target_link_libraries(test_hnsw ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_hnsw DESTINATION unittest) + +################################################################################ +# +set(rhnsw_flat_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWFlat.cpp + ) +if (NOT TARGET test_rhnsw_flat) + add_executable(test_rhnsw_flat test_rhnsw_flat.cpp ${rhnsw_flat_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_flat ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_flat DESTINATION unittest) + +################################################################################ +# +set(rhnsw_pq_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWPQ.cpp + ) +if (NOT TARGET test_rhnsw_pq) + add_executable(test_rhnsw_pq test_rhnsw_pq.cpp ${rhnsw_pq_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_pq ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_pq DESTINATION unittest) + +################################################################################ +# +set(rhnsw_sq8_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexRHNSWSQ.cpp + ) +if (NOT TARGET test_rhnsw_sq8) + add_executable(test_rhnsw_sq8 test_rhnsw_sq8.cpp ${rhnsw_sq8_srcs} ${util_srcs} ${faiss_srcs}) +endif () +target_link_libraries(test_rhnsw_sq8 ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_rhnsw_sq8 DESTINATION unittest) + +################################################################################ +# +if (MILVUS_SUPPORT_SPTAG) + set(sptag_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp + ) + if (NOT TARGET test_sptag) + add_executable(test_sptag test_sptag.cpp ${sptag_srcs} ${util_srcs}) + endif () + target_link_libraries(test_sptag + SPTAGLibStatic + ${depend_libs} ${unittest_libs} ${basic_libs}) + install(TARGETS test_sptag DESTINATION unittest) +endif () + +################################################################################ +# +set(annoy_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp + ) +if (NOT TARGET test_annoy) + add_executable(test_annoy test_annoy.cpp ${annoy_srcs} ${util_srcs}) +endif () +target_link_libraries(test_annoy ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_annoy DESTINATION unittest) + +################################################################################ +# +set(structured_index_sort_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/structured_index/StructuredIndexSort-inl.h + ) +if (NOT TARGET test_structured_index_sort) + add_executable(test_structured_index_sort test_structured_index_sort.cpp ${structured_index_sort_srcs} ${util_srcs}) +endif () +target_link_libraries(test_structured_index_sort ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_structured_index_sort DESTINATION unittest) + +#add_subdirectory(faiss_benchmark) +#add_subdirectory(metric_alg_benchmark) diff --git a/core/src/index/unittest/Helper.h b/core/src/index/unittest/Helper.h new file mode 100644 index 0000000000..cd4808c9c5 --- /dev/null +++ b/core/src/index/unittest/Helper.h @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include + +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +int DEVICEID = 0; +constexpr int64_t DIM = 128; +constexpr int64_t NB = 10000; +constexpr int64_t NQ = 10; +constexpr int64_t K = 10; +constexpr int64_t PINMEM = 1024 * 1024 * 200; +constexpr int64_t TEMPMEM = 1024 * 1024 * 300; +constexpr int64_t RESNUM = 2; + +milvus::knowhere::IVFPtr +IndexFactory(const milvus::knowhere::IndexType& type, const milvus::knowhere::IndexMode mode) { + if (mode == milvus::knowhere::IndexMode::MODE_CPU) { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + return std::make_shared(); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + return std::make_shared(); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { + return std::make_shared(); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { + std::cout << "IVFSQ8H does not support MODE_CPU" << std::endl; + } else { + std::cout << "Invalid IndexType " << type << std::endl; + } +#ifdef MILVUS_GPU_VERSION + } else { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + return std::make_shared(DEVICEID); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + return std::make_shared(DEVICEID); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { + return std::make_shared(DEVICEID); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { + return std::make_shared(DEVICEID); + } else { + std::cout << "Invalid IndexType " << type << std::endl; + } +#endif + } + return nullptr; +} + +milvus::knowhere::IVFNMPtr +IndexFactoryNM(const milvus::knowhere::IndexType& type, const milvus::knowhere::IndexMode mode) { + if (mode == milvus::knowhere::IndexMode::MODE_CPU) { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + return std::make_shared(); + } else { + std::cout << "Invalid IndexType " << type << std::endl; + } + } + return nullptr; +} + +class ParamGenerator { + public: + static ParamGenerator& + GetInstance() { + static ParamGenerator instance; + return instance; + } + + milvus::knowhere::Config + Gen(const milvus::knowhere::IndexType& type) { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, DEVICEID}, + }; + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, DEVICEID}, + }; + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 || + type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { + return milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, K}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, DEVICEID}, + }; + } else { + std::cout << "Invalid index type " << type << std::endl; + } + return milvus::knowhere::Config(); + } +}; + +#include + +class TestGpuIndexBase : public ::testing::Test { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } +}; diff --git a/core/src/index/unittest/SPTAG.cpp b/core/src/index/unittest/SPTAG.cpp new file mode 100644 index 0000000000..c12a8060b7 --- /dev/null +++ b/core/src/index/unittest/SPTAG.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. +// +//#include +//#include +//#include +//#include +//#include +// +// int +// main(int argc, char* argv[]) { +// using namespace SPTAG; +// const int d = 128; +// const int n = 100; +// +// auto p_data = new float[n * d]; +// +// auto index = VectorIndex::CreateInstance(IndexAlgoType::KDT, VectorValueType::Float); +// +// std::random_device rd; +// std::mt19937 mt(rd()); +// std::uniform_real_distribution dist(1.0, 2.0); +// +// for (auto i = 0; i < n; i++) { +// for (auto j = 0; j < d; j++) { +// p_data[i * d + j] = dist(mt) - 1; +// } +// } +// std::cout << "generate random n * d finished."; +// ByteArray data((uint8_t*)p_data, n * d * sizeof(float), true); +// +// auto vectorset = std::make_shared(data, VectorValueType::Float, d, n); +// index->BuildIndex(vectorset, nullptr); +// +// std::cout << index->GetFeatureDim(); +//} diff --git a/core/src/index/unittest/faiss_benchmark/CMakeLists.txt b/core/src/index/unittest/faiss_benchmark/CMakeLists.txt new file mode 100644 index 0000000000..c0a53fbd94 --- /dev/null +++ b/core/src/index/unittest/faiss_benchmark/CMakeLists.txt @@ -0,0 +1,53 @@ +if (MILVUS_GPU_VERSION) + + include_directories(${INDEX_SOURCE_DIR}/thirdparty) + include_directories(${INDEX_SOURCE_DIR}/include) + include_directories(/usr/local/cuda/include) + include_directories(/usr/local/hdf5/include) + + link_directories(/usr/local/cuda/lib64) + link_directories(/usr/local/hdf5/lib) + + set(unittest_libs + gtest gmock gtest_main gmock_main) + + set(depend_libs + faiss hdf5 + ) + if (FAISS_WITH_MKL) + set(depend_libs ${depend_libs} + "-Wl,--start-group \ + ${MKL_LIB_PATH}/libmkl_intel_ilp64.a \ + ${MKL_LIB_PATH}/libmkl_gnu_thread.a \ + ${MKL_LIB_PATH}/libmkl_core.a \ + -Wl,--end-group -lgomp -lpthread -lm -ldl" + ) + else () + set(depend_libs ${depend_libs} + ${OpenBLAS_LIBRARIES} + ${LAPACK_LIBRARIES} + ) + endif () + + set(basic_libs + gomp gfortran pthread + ) + + include_directories(${CUDA_INCLUDE_DIRS}) + link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") + set(cuda_lib + cudart + cublas + ) + set(basic_libs ${basic_libs} + ${cuda_lib} + ) + + add_executable(test_faiss_benchmark faiss_benchmark_test.cpp) + target_link_libraries(test_faiss_benchmark ${depend_libs} ${unittest_libs} ${basic_libs}) + install(TARGETS test_faiss_benchmark DESTINATION unittest) + + add_executable(test_faiss_bitset faiss_bitset_test.cpp) + target_link_libraries(test_faiss_bitset ${depend_libs} ${unittest_libs} ${basic_libs}) + install(TARGETS test_faiss_bitset DESTINATION unittest) +endif () diff --git a/core/src/index/unittest/faiss_benchmark/README.md b/core/src/index/unittest/faiss_benchmark/README.md new file mode 100644 index 0000000000..c451ac13b0 --- /dev/null +++ b/core/src/index/unittest/faiss_benchmark/README.md @@ -0,0 +1,25 @@ +### To run this FAISS benchmark, please follow these steps: + +#### Step 1: +Download the HDF5 source from: + https://support.hdfgroup.org/ftp/HDF5/releases/ +and build/install to "/usr/local/hdf5". + +#### Step 2: +Download HDF5 data files from: + https://github.com/erikbern/ann-benchmarks + +#### Step 3: +Update 'milvus/core/src/index/unittest/CMakeLists.txt', +uncomment "#add_subdirectory(faiss_benchmark)". + +#### Step 4: +Build Milvus with unittest enabled: "./build.sh -t Release -u", +binary 'test_faiss_benchmark' will be generated. + +#### Step 5: +Put HDF5 data files into the same directory with binary 'test_faiss_benchmark'. + +#### Step 6: +Run test binary 'test_faiss_benchmark'. + diff --git a/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp b/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp new file mode 100644 index 0000000000..c9be654fec --- /dev/null +++ b/core/src/index/unittest/faiss_benchmark/faiss_benchmark_test.cpp @@ -0,0 +1,572 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/***************************************************** + * To run this test, please download the HDF5 from + * https://support.hdfgroup.org/ftp/HDF5/releases/ + * and install it to /usr/local/hdf5 . + *****************************************************/ +#define DEBUG_VERBOSE 0 + +const char HDF5_POSTFIX[] = ".hdf5"; +const char HDF5_DATASET_TRAIN[] = "train"; +const char HDF5_DATASET_TEST[] = "test"; +const char HDF5_DATASET_NEIGHBORS[] = "neighbors"; +const char HDF5_DATASET_DISTANCES[] = "distances"; + +const int32_t GPU_DEVICE_IDX = 0; + +enum QueryMode { MODE_CPU = 0, MODE_MIX, MODE_GPU }; + +double +elapsed() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + +void +normalize(float* arr, int32_t nq, int32_t dim) { + for (int32_t i = 0; i < nq; i++) { + double vecLen = 0.0, inv_vecLen = 0.0; + for (int32_t j = 0; j < dim; j++) { + double val = arr[i * dim + j]; + vecLen += val * val; + } + inv_vecLen = 1.0 / std::sqrt(vecLen); + for (int32_t j = 0; j < dim; j++) { + arr[i * dim + j] = (float)(arr[i * dim + j] * inv_vecLen); + } + } +} + +void* +hdf5_read(const std::string& file_name, const std::string& dataset_name, H5T_class_t dataset_class, int32_t& d_out, + int32_t& n_out) { + hid_t file, dataset, datatype, dataspace, memspace; + H5T_class_t t_class; /* data type class */ + hsize_t dimsm[3]; /* memory space dimensions */ + hsize_t dims_out[2]; /* dataset dimensions */ + hsize_t count[2]; /* size of the hyperslab in the file */ + hsize_t offset[2]; /* hyperslab offset in the file */ + hsize_t count_out[3]; /* size of the hyperslab in memory */ + hsize_t offset_out[3]; /* hyperslab offset in memory */ + void* data_out = nullptr; /* output buffer */ + + /* Open the file and the dataset. */ + file = H5Fopen(file_name.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + dataset = H5Dopen2(file, dataset_name.c_str(), H5P_DEFAULT); + + /* Get datatype and dataspace handles and then query + * dataset class, order, size, rank and dimensions. */ + datatype = H5Dget_type(dataset); /* datatype handle */ + t_class = H5Tget_class(datatype); + assert(t_class == dataset_class || !"Illegal dataset class type"); + + dataspace = H5Dget_space(dataset); /* dataspace handle */ + H5Sget_simple_extent_dims(dataspace, dims_out, nullptr); + n_out = dims_out[0]; + d_out = dims_out[1]; + + /* Define hyperslab in the dataset. */ + offset[0] = offset[1] = 0; + count[0] = dims_out[0]; + count[1] = dims_out[1]; + H5Sselect_hyperslab(dataspace, H5S_SELECT_SET, offset, nullptr, count, nullptr); + + /* Define the memory dataspace. */ + dimsm[0] = dims_out[0]; + dimsm[1] = dims_out[1]; + dimsm[2] = 1; + memspace = H5Screate_simple(3, dimsm, nullptr); + + /* Define memory hyperslab. */ + offset_out[0] = offset_out[1] = offset_out[2] = 0; + count_out[0] = dims_out[0]; + count_out[1] = dims_out[1]; + count_out[2] = 1; + H5Sselect_hyperslab(memspace, H5S_SELECT_SET, offset_out, nullptr, count_out, nullptr); + + /* Read data from hyperslab in the file into the hyperslab in memory and display. */ + switch (t_class) { + case H5T_INTEGER: + data_out = new int[dims_out[0] * dims_out[1]]; + H5Dread(dataset, H5T_NATIVE_INT, memspace, dataspace, H5P_DEFAULT, data_out); + break; + case H5T_FLOAT: + data_out = new float[dims_out[0] * dims_out[1]]; + H5Dread(dataset, H5T_NATIVE_FLOAT, memspace, dataspace, H5P_DEFAULT, data_out); + break; + default: + printf("Illegal dataset class type\n"); + break; + } + + /* Close/release resources. */ + H5Tclose(datatype); + H5Dclose(dataset); + H5Sclose(dataspace); + H5Sclose(memspace); + H5Fclose(file); + + return data_out; +} + +std::string +get_index_file_name(const std::string& ann_test_name, const std::string& index_key, int32_t data_loops) { + size_t pos = index_key.find_first_of(',', 0); + std::string file_name = ann_test_name; + file_name = file_name + "_" + index_key.substr(0, pos) + "_" + index_key.substr(pos + 1); + file_name = file_name + "_" + std::to_string(data_loops) + ".index"; + return file_name; +} + +bool +parse_ann_test_name(const std::string& ann_test_name, int32_t& dim, faiss::MetricType& metric_type) { + size_t pos1, pos2; + + if (ann_test_name.empty()) + return false; + + pos1 = ann_test_name.find_first_of('-', 0); + if (pos1 == std::string::npos) + return false; + pos2 = ann_test_name.find_first_of('-', pos1 + 1); + if (pos2 == std::string::npos) + return false; + + dim = std::stoi(ann_test_name.substr(pos1 + 1, pos2 - pos1 - 1)); + std::string metric_str = ann_test_name.substr(pos2 + 1); + if (metric_str == "angular") { + metric_type = faiss::METRIC_INNER_PRODUCT; + } else if (metric_str == "euclidean") { + metric_type = faiss::METRIC_L2; + } else { + return false; + } + + return true; +} + +int32_t +GetResultHitCount(const faiss::Index::idx_t* ground_index, const faiss::Index::idx_t* index, int32_t ground_k, + int32_t k, int32_t nq, int32_t index_add_loops) { + int32_t min_k = std::min(ground_k, k); + int hit = 0; + for (int32_t i = 0; i < nq; i++) { + std::set ground(ground_index + i * ground_k, + ground_index + i * ground_k + min_k / index_add_loops); + for (int32_t j = 0; j < min_k; j++) { + faiss::Index::idx_t id = index[i * k + j]; + if (ground.count(id) > 0) { + hit++; + } + } + } + return hit; +} + +#if DEBUG_VERBOSE +void +print_array(const char* header, bool is_integer, const void* arr, int32_t nq, int32_t k) { + const int ROW = 10; + const int COL = 10; + assert(ROW <= nq); + assert(COL <= k); + printf("%s\n", header); + printf("==============================================\n"); + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (is_integer) { + printf("%7ld ", ((int64_t*)arr)[i * k + j]); + } else { + printf("%.6f ", ((float*)arr)[i * k + j]); + } + } + printf("\n"); + } + printf("\n"); +} +#endif + +void +load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std::string& index_key, + faiss::gpu::StandardGpuResources& res, const faiss::MetricType metric_type, const int32_t dim, + int32_t index_add_loops, QueryMode mode = MODE_CPU) { + double t0 = elapsed(); + + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + faiss::Index *cpu_index = nullptr, *gpu_index = nullptr; + faiss::distance_compute_blas_threshold = 800; + + std::string index_file_name = get_index_file_name(ann_test_name, index_key, index_add_loops); + + try { + printf("[%.3f s] Reading index file: %s\n", elapsed() - t0, index_file_name.c_str()); + cpu_index = faiss::read_index(index_file_name.c_str()); + } catch (...) { + int32_t nb, d; + printf("[%.3f s] Loading HDF5 file: %s\n", elapsed() - t0, ann_file_name.c_str()); + float* xb = (float*)hdf5_read(ann_file_name, HDF5_DATASET_TRAIN, H5T_FLOAT, d, nb); + assert(d == dim || !"dataset does not have correct dimension"); + + if (metric_type == faiss::METRIC_INNER_PRODUCT) { + printf("[%.3f s] Normalizing base data set \n", elapsed() - t0); + normalize(xb, nb, d); + } + + printf("[%.3f s] Creating CPU index \"%s\" d=%d\n", elapsed() - t0, index_key.c_str(), d); + cpu_index = faiss::index_factory(d, index_key.c_str(), metric_type); + + printf("[%.3f s] Cloning CPU index to GPU\n", elapsed() - t0); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index); + delete cpu_index; + + printf("[%.3f s] Training on %d vectors\n", elapsed() - t0, nb); + gpu_index->train(nb, xb); + + // add index multiple times to get ~1G data set + for (int i = 0; i < index_add_loops; i++) { + printf("[%.3f s] No.%d Indexing database, size %d*%d\n", elapsed() - t0, i, nb, d); + std::vector xids(nb); + for (int32_t t = 0; t < nb; t++) { + xids[t] = i * nb + t; + } + gpu_index->add_with_ids(nb, xb, xids.data()); + } + + printf("[%.3f s] Coping GPU index to CPU\n", elapsed() - t0); + + cpu_index = faiss::gpu::index_gpu_to_cpu(gpu_index); + delete gpu_index; + + faiss::IndexIVF* cpu_ivf_index = dynamic_cast(cpu_index); + if (cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + printf("[%.3f s] Writing index file: %s\n", elapsed() - t0, index_file_name.c_str()); + faiss::write_index(cpu_index, index_file_name.c_str()); + + delete[] xb; + } + + index = cpu_index; +} + +void +load_query_data(faiss::Index::distance_t*& xq, int32_t& nq, const std::string& ann_test_name, + const faiss::MetricType metric_type, const int32_t dim) { + double t0 = elapsed(); + int32_t d; + + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + xq = (float*)hdf5_read(ann_file_name, HDF5_DATASET_TEST, H5T_FLOAT, d, nq); + assert(d == dim || !"query does not have same dimension as train set"); + + if (metric_type == faiss::METRIC_INNER_PRODUCT) { + printf("[%.3f s] Normalizing query data \n", elapsed() - t0); + normalize(xq, nq, d); + } +} + +void +load_ground_truth(faiss::Index::idx_t*& gt, int32_t& k, const std::string& ann_test_name, const int32_t nq) { + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + // load ground-truth and convert int to long + int32_t nq2; + int* gt_int = (int*)hdf5_read(ann_file_name, HDF5_DATASET_NEIGHBORS, H5T_INTEGER, k, nq2); + assert(nq2 == nq || !"incorrect nb of ground truth index"); + + gt = new faiss::Index::idx_t[k * nq]; + for (int32_t i = 0; i < k * nq; i++) { + gt[i] = gt_int[i]; + } + delete[] gt_int; + +#if DEBUG_VERBOSE + faiss::Index::distance_t* gt_dist; // nq * k matrix of ground-truth nearest-neighbors distances + gt_dist = (float*)hdf5_read(ann_file_name, HDF5_DATASET_DISTANCES, H5T_FLOAT, k, nq2); + assert(nq2 == nq || !"incorrect nb of ground truth distance"); + + std::string str; + str = ann_test_name + " ground truth index"; + print_array(str.c_str(), true, gt, nq, k); + str = ann_test_name + " ground truth distance"; + print_array(str.c_str(), false, gt_dist, nq, k); + + delete gt_dist; +#endif +} + +void +test_with_nprobes(const std::string& ann_test_name, const std::string& index_key, faiss::Index* cpu_index, + faiss::gpu::StandardGpuResources& res, const QueryMode query_mode, const faiss::Index::distance_t* xq, + const faiss::Index::idx_t* gt, const std::vector& nprobes, const int32_t index_add_loops, + const int32_t search_loops) { + double t0 = elapsed(); + + const std::vector NQ = {10, 100}; + const std::vector K = {10, 100, 1000}; + const int32_t GK = 100; // topk of ground truth + + std::unordered_map mode_str_map = { + {MODE_CPU, "MODE_CPU"}, {MODE_MIX, "MODE_MIX"}, {MODE_GPU, "MODE_GPU"}}; + + faiss::Index *gpu_index = nullptr, *index = nullptr; + if (query_mode != MODE_CPU) { + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + double copy_time = 0.0; + + faiss::IndexComposition index_composition; + index_composition.index = cpu_index; + index_composition.quantizer = nullptr; + switch (query_mode) { + case MODE_MIX: { + index_composition.mode = 1; // 0: all data, 1: copy quantizer, 2: copy data + + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + copy_time = elapsed() - copy_time; + printf("[%.3f s] Copy quantizer completed, cost %f s\n", elapsed() - t0, copy_time); + + auto ivf_index = dynamic_cast(cpu_index); + auto is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if (is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition.quantizer; + } + index = cpu_index; + break; + } + case MODE_GPU: +#if 1 + index_composition.mode = 0; // 0: all data, 1: copy quantizer, 2: copy data + + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); +#else + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index, &option); +#endif + copy_time = elapsed() - copy_time; + printf("[%.3f s] Copy data completed, cost %f s\n", elapsed() - t0, copy_time); + + delete cpu_index; + index = gpu_index; + break; + default: + break; + } + } else { + index = cpu_index; + } + + for (auto nprobe : nprobes) { + // brute-force need not set nprobe + if (index_key.find("IDMap") == std::string::npos) { + switch (query_mode) { + case MODE_CPU: + case MODE_MIX: { + faiss::ParameterSpace params; + std::string nprobe_str = "nprobe=" + std::to_string(nprobe); + params.set_index_parameters(index, nprobe_str.c_str()); + break; + } + case MODE_GPU: { + faiss::gpu::GpuIndexIVF* gpu_index_ivf = dynamic_cast(index); + gpu_index_ivf->setNumProbes(nprobe); + } + } + } + + // output buffers + faiss::Index::idx_t* I = new faiss::Index::idx_t[NQ.back() * K.back()]; + faiss::Index::distance_t* D = new faiss::Index::distance_t[NQ.back() * K.back()]; + + printf("\n%s | %s - %s | nprobe=%d\n", ann_test_name.c_str(), index_key.c_str(), + mode_str_map[query_mode].c_str(), nprobe); + printf("======================================================================================\n"); + for (size_t j = 0; j < K.size(); j++) { + int32_t t_k = K[j]; + for (size_t i = 0; i < NQ.size(); i++) { + int32_t t_nq = NQ[i]; + faiss::indexIVF_stats.quantization_time = 0.0; + faiss::indexIVF_stats.search_time = 0.0; + + double t_start = elapsed(), t_end; + for (int s = 0; s < search_loops; s++) { + index->search(t_nq, xq, t_k, D, I); + } + t_end = elapsed(); + +#if DEBUG_VERBOSE + std::string str; + str = "I (" + index_key + ", nq=" + std::to_string(t_nq) + ", k=" + std::to_string(t_k) + ")"; + print_array(str.c_str(), true, I, t_nq, t_k); + str = "D (" + index_key + ", nq=" + std::to_string(t_nq) + ", k=" + std::to_string(t_k) + ")"; + print_array(str.c_str(), false, D, t_nq, t_k); +#endif + + // k = 100 for ground truth + int32_t hit = GetResultHitCount(gt, I, GK, t_k, t_nq, index_add_loops); + + printf("nq = %4d, k = %4d, elapse = %.4fs (quant = %.4fs, search = %.4fs), R@ = %.4f\n", t_nq, t_k, + (t_end - t_start) / search_loops, faiss::indexIVF_stats.quantization_time / 1000 / search_loops, + faiss::indexIVF_stats.search_time / 1000 / search_loops, + (hit / float(t_nq * std::min(GK, t_k) / index_add_loops))); + } + } + printf("======================================================================================\n"); + + delete[] I; + delete[] D; + } + + delete index; +} + +void +test_ann_hdf5(const std::string& ann_test_name, const std::string& cluster_type, const std::string& index_type, + const QueryMode query_mode, int32_t index_add_loops, const std::vector& nprobes, + int32_t search_loops) { + double t0 = elapsed(); + + faiss::gpu::StandardGpuResources res; + + faiss::MetricType metric_type; + int32_t dim; + + if (query_mode == MODE_MIX && index_type != "SQ8Hybrid") { + assert(index_type == "SQ8Hybrid" || !"Only SQ8Hybrid support MODE_MIX"); + return; + } + + std::string index_key = cluster_type + "," + index_type; + + if (!parse_ann_test_name(ann_test_name, dim, metric_type)) { + printf("Invalid ann test name: %s\n", ann_test_name.c_str()); + return; + } + + int32_t nq, k; + faiss::Index* index; + faiss::Index::distance_t* xq; + faiss::Index::idx_t* gt; // ground-truth index + + printf("[%.3f s] Loading base data\n", elapsed() - t0); + load_base_data(index, ann_test_name, index_key, res, metric_type, dim, index_add_loops, query_mode); + + printf("[%.3f s] Loading queries\n", elapsed() - t0); + load_query_data(xq, nq, ann_test_name, metric_type, dim); + + printf("[%.3f s] Loading ground truth for %d queries\n", elapsed() - t0, nq); + load_ground_truth(gt, k, ann_test_name, nq); + + test_with_nprobes(ann_test_name, index_key, index, res, query_mode, xq, gt, nprobes, index_add_loops, search_loops); + printf("[%.3f s] Search test done\n\n", elapsed() - t0); + + delete[] xq; + delete[] gt; +} + +/************************************************************************************ + * https://github.com/erikbern/ann-benchmarks + * + * Dataset Dimensions Train_size Test_size Neighbors Distance Download + * Fashion- + MNIST 784 60,000 10,000 100 Euclidean HDF5 (217MB) + * GIST 960 1,000,000 1,000 100 Euclidean HDF5 (3.6GB) + * GloVe 100 1,183,514 10,000 100 Angular HDF5 (463MB) + * GloVe 200 1,183,514 10,000 100 Angular HDF5 (918MB) + * MNIST 784 60,000 10,000 100 Euclidean HDF5 (217MB) + * NYTimes 256 290,000 10,000 100 Angular HDF5 (301MB) + * SIFT 128 1,000,000 10,000 100 Euclidean HDF5 (501MB) + *************************************************************************************/ + +TEST(FAISSTEST, BENCHMARK) { + std::vector param_nprobes = {8, 128}; + const int32_t SEARCH_LOOPS = 5; + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + const int32_t SIFT_INSERT_LOOPS = 2; // insert twice to get ~1G data set + + test_ann_hdf5("sift-128-euclidean", "IDMap", "Flat", MODE_CPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + test_ann_hdf5("sift-128-euclidean", "IDMap", "Flat", MODE_GPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + + test_ann_hdf5("sift-128-euclidean", "IVF16384", "Flat", MODE_CPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + test_ann_hdf5("sift-128-euclidean", "IVF16384", "Flat", MODE_GPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + + test_ann_hdf5("sift-128-euclidean", "IVF16384", "SQ8", MODE_CPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + test_ann_hdf5("sift-128-euclidean", "IVF16384", "SQ8", MODE_GPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + + test_ann_hdf5("sift-128-euclidean", "IVF16384", "SQ8Hybrid", MODE_CPU, SIFT_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); + test_ann_hdf5("sift-128-euclidean", "IVF16384", "SQ8Hybrid", MODE_MIX, SIFT_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); + test_ann_hdf5("sift-128-euclidean", "IVF16384", "SQ8Hybrid", MODE_GPU, SIFT_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); + + /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + const int32_t GLOVE_INSERT_LOOPS = 1; + + test_ann_hdf5("glove-200-angular", "IVF16384", "Flat", MODE_CPU, GLOVE_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + test_ann_hdf5("glove-200-angular", "IVF16384", "Flat", MODE_GPU, GLOVE_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + + test_ann_hdf5("glove-200-angular", "IVF16384", "SQ8", MODE_CPU, GLOVE_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + test_ann_hdf5("glove-200-angular", "IVF16384", "SQ8", MODE_GPU, GLOVE_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); + + test_ann_hdf5("glove-200-angular", "IVF16384", "SQ8Hybrid", MODE_CPU, GLOVE_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); + test_ann_hdf5("glove-200-angular", "IVF16384", "SQ8Hybrid", MODE_MIX, GLOVE_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); + test_ann_hdf5("glove-200-angular", "IVF16384", "SQ8Hybrid", MODE_GPU, GLOVE_INSERT_LOOPS, param_nprobes, + SEARCH_LOOPS); +} diff --git a/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp b/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp new file mode 100644 index 0000000000..1d87f29f6c --- /dev/null +++ b/core/src/index/unittest/faiss_benchmark/faiss_bitset_test.cpp @@ -0,0 +1,575 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/***************************************************** + * To run this test, please download the HDF5 from + * https://support.hdfgroup.org/ftp/HDF5/releases/ + * and install it to /usr/local/hdf5 . + *****************************************************/ +#define DEBUG_VERBOSE 0 + +const char HDF5_POSTFIX[] = ".hdf5"; +const char HDF5_DATASET_TRAIN[] = "train"; +const char HDF5_DATASET_TEST[] = "test"; +const char HDF5_DATASET_NEIGHBORS[] = "neighbors"; +const char HDF5_DATASET_DISTANCES[] = "distances"; + +const int32_t GPU_DEVICE_IDX = 0; + +enum QueryMode { MODE_CPU = 0, MODE_MIX, MODE_GPU }; + +double +elapsed() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + +void +normalize(float* arr, int32_t nq, int32_t dim) { + for (int32_t i = 0; i < nq; i++) { + double vecLen = 0.0, inv_vecLen = 0.0; + for (int32_t j = 0; j < dim; j++) { + double val = arr[i * dim + j]; + vecLen += val * val; + } + inv_vecLen = 1.0 / std::sqrt(vecLen); + for (int32_t j = 0; j < dim; j++) { + arr[i * dim + j] = (float)(arr[i * dim + j] * inv_vecLen); + } + } +} + +void* +hdf5_read(const std::string& file_name, const std::string& dataset_name, H5T_class_t dataset_class, int32_t& d_out, + int32_t& n_out) { + hid_t file, dataset, datatype, dataspace, memspace; + H5T_class_t t_class; /* data type class */ + hsize_t dimsm[3]; /* memory space dimensions */ + hsize_t dims_out[2]; /* dataset dimensions */ + hsize_t count[2]; /* size of the hyperslab in the file */ + hsize_t offset[2]; /* hyperslab offset in the file */ + hsize_t count_out[3]; /* size of the hyperslab in memory */ + hsize_t offset_out[3]; /* hyperslab offset in memory */ + void* data_out = nullptr; /* output buffer */ + + /* Open the file and the dataset. */ + file = H5Fopen(file_name.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + dataset = H5Dopen2(file, dataset_name.c_str(), H5P_DEFAULT); + + /* Get datatype and dataspace handles and then query + * dataset class, order, size, rank and dimensions. */ + datatype = H5Dget_type(dataset); /* datatype handle */ + t_class = H5Tget_class(datatype); + assert(t_class == dataset_class || !"Illegal dataset class type"); + + dataspace = H5Dget_space(dataset); /* dataspace handle */ + H5Sget_simple_extent_dims(dataspace, dims_out, nullptr); + n_out = dims_out[0]; + d_out = dims_out[1]; + + /* Define hyperslab in the dataset. */ + offset[0] = offset[1] = 0; + count[0] = dims_out[0]; + count[1] = dims_out[1]; + H5Sselect_hyperslab(dataspace, H5S_SELECT_SET, offset, nullptr, count, nullptr); + + /* Define the memory dataspace. */ + dimsm[0] = dims_out[0]; + dimsm[1] = dims_out[1]; + dimsm[2] = 1; + memspace = H5Screate_simple(3, dimsm, nullptr); + + /* Define memory hyperslab. */ + offset_out[0] = offset_out[1] = offset_out[2] = 0; + count_out[0] = dims_out[0]; + count_out[1] = dims_out[1]; + count_out[2] = 1; + H5Sselect_hyperslab(memspace, H5S_SELECT_SET, offset_out, nullptr, count_out, nullptr); + + /* Read data from hyperslab in the file into the hyperslab in memory and display. */ + switch (t_class) { + case H5T_INTEGER: + data_out = new int[dims_out[0] * dims_out[1]]; + H5Dread(dataset, H5T_NATIVE_INT, memspace, dataspace, H5P_DEFAULT, data_out); + break; + case H5T_FLOAT: + data_out = new float[dims_out[0] * dims_out[1]]; + H5Dread(dataset, H5T_NATIVE_FLOAT, memspace, dataspace, H5P_DEFAULT, data_out); + break; + default: + printf("Illegal dataset class type\n"); + break; + } + + /* Close/release resources. */ + H5Tclose(datatype); + H5Dclose(dataset); + H5Sclose(dataspace); + H5Sclose(memspace); + H5Fclose(file); + + return data_out; +} + +std::string +get_index_file_name(const std::string& ann_test_name, const std::string& index_key, int32_t data_loops) { + size_t pos = index_key.find_first_of(',', 0); + std::string file_name = ann_test_name; + file_name = file_name + "_" + index_key.substr(0, pos) + "_" + index_key.substr(pos + 1); + file_name = file_name + "_" + std::to_string(data_loops) + ".index"; + return file_name; +} + +bool +parse_ann_test_name(const std::string& ann_test_name, int32_t& dim, faiss::MetricType& metric_type) { + size_t pos1, pos2; + + if (ann_test_name.empty()) + return false; + + pos1 = ann_test_name.find_first_of('-', 0); + if (pos1 == std::string::npos) + return false; + pos2 = ann_test_name.find_first_of('-', pos1 + 1); + if (pos2 == std::string::npos) + return false; + + dim = std::stoi(ann_test_name.substr(pos1 + 1, pos2 - pos1 - 1)); + std::string metric_str = ann_test_name.substr(pos2 + 1); + if (metric_str == "angular") { + metric_type = faiss::METRIC_INNER_PRODUCT; + } else if (metric_str == "euclidean") { + metric_type = faiss::METRIC_L2; + } else { + return false; + } + + return true; +} + +int32_t +GetResultHitCount(const faiss::Index::idx_t* ground_index, const faiss::Index::idx_t* index, int32_t ground_k, + int32_t k, int32_t nq, int32_t index_add_loops) { + int32_t min_k = std::min(ground_k, k); + int hit = 0; + for (int32_t i = 0; i < nq; i++) { + std::set ground(ground_index + i * ground_k, + ground_index + i * ground_k + min_k / index_add_loops); + for (int32_t j = 0; j < min_k; j++) { + faiss::Index::idx_t id = index[i * k + j]; + if (ground.count(id) > 0) { + hit++; + } + } + } + return hit; +} + +#if DEBUG_VERBOSE +void +print_array(const char* header, bool is_integer, const void* arr, int32_t nq, int32_t k) { + const int ROW = 10; + const int COL = 10; + assert(ROW <= nq); + assert(COL <= k); + printf("%s\n", header); + printf("==============================================\n"); + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + if (is_integer) { + printf("%7ld ", ((int64_t*)arr)[i * k + j]); + } else { + printf("%.6f ", ((float*)arr)[i * k + j]); + } + } + printf("\n"); + } + printf("\n"); +} +#endif + +void +load_base_data(faiss::Index*& index, const std::string& ann_test_name, const std::string& index_key, + faiss::gpu::StandardGpuResources& res, const faiss::MetricType metric_type, const int32_t dim, + int32_t index_add_loops, QueryMode mode = MODE_CPU) { + double t0 = elapsed(); + + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + faiss::Index *cpu_index = nullptr, *gpu_index = nullptr; + faiss::distance_compute_blas_threshold = 800; + + std::string index_file_name = get_index_file_name(ann_test_name, index_key, index_add_loops); + + try { + printf("[%.3f s] Reading index file: %s\n", elapsed() - t0, index_file_name.c_str()); + cpu_index = faiss::read_index(index_file_name.c_str()); + } catch (...) { + int32_t nb, d; + printf("[%.3f s] Loading HDF5 file: %s\n", elapsed() - t0, ann_file_name.c_str()); + float* xb = (float*)hdf5_read(ann_file_name, HDF5_DATASET_TRAIN, H5T_FLOAT, d, nb); + assert(d == dim || !"dataset does not have correct dimension"); + + if (metric_type == faiss::METRIC_INNER_PRODUCT) { + printf("[%.3f s] Normalizing base data set \n", elapsed() - t0); + normalize(xb, nb, d); + } + + printf("[%.3f s] Creating CPU index \"%s\" d=%d\n", elapsed() - t0, index_key.c_str(), d); + cpu_index = faiss::index_factory(d, index_key.c_str(), metric_type); + + printf("[%.3f s] Cloning CPU index to GPU\n", elapsed() - t0); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index); + delete cpu_index; + + printf("[%.3f s] Training on %d vectors\n", elapsed() - t0, nb); + gpu_index->train(nb, xb); + + // add index multiple times to get ~1G data set + for (int i = 0; i < index_add_loops; i++) { + printf("[%.3f s] No.%d Indexing database, size %d*%d\n", elapsed() - t0, i, nb, d); + std::vector xids(nb); + for (int32_t t = 0; t < nb; t++) { + xids[t] = i * nb + t; + } + gpu_index->add_with_ids(nb, xb, xids.data()); + } + + printf("[%.3f s] Coping GPU index to CPU\n", elapsed() - t0); + + cpu_index = faiss::gpu::index_gpu_to_cpu(gpu_index); + delete gpu_index; + + faiss::IndexIVF* cpu_ivf_index = dynamic_cast(cpu_index); + if (cpu_ivf_index != nullptr) { + cpu_ivf_index->to_readonly(); + } + + printf("[%.3f s] Writing index file: %s\n", elapsed() - t0, index_file_name.c_str()); + faiss::write_index(cpu_index, index_file_name.c_str()); + + delete[] xb; + } + + index = cpu_index; +} + +void +load_query_data(faiss::Index::distance_t*& xq, int32_t& nq, const std::string& ann_test_name, + const faiss::MetricType metric_type, const int32_t dim) { + double t0 = elapsed(); + int32_t d; + + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + xq = (float*)hdf5_read(ann_file_name, HDF5_DATASET_TEST, H5T_FLOAT, d, nq); + assert(d == dim || !"query does not have same dimension as train set"); + + if (metric_type == faiss::METRIC_INNER_PRODUCT) { + printf("[%.3f s] Normalizing query data \n", elapsed() - t0); + normalize(xq, nq, d); + } +} + +void +load_ground_truth(faiss::Index::idx_t*& gt, int32_t& k, const std::string& ann_test_name, const int32_t nq) { + const std::string ann_file_name = ann_test_name + HDF5_POSTFIX; + + // load ground-truth and convert int to long + int32_t nq2; + int* gt_int = (int*)hdf5_read(ann_file_name, HDF5_DATASET_NEIGHBORS, H5T_INTEGER, k, nq2); + assert(nq2 == nq || !"incorrect nb of ground truth index"); + + gt = new faiss::Index::idx_t[k * nq]; + for (int32_t i = 0; i < k * nq; i++) { + gt[i] = gt_int[i]; + } + delete[] gt_int; + +#if DEBUG_VERBOSE + faiss::Index::distance_t* gt_dist; // nq * k matrix of ground-truth nearest-neighbors distances + gt_dist = (float*)hdf5_read(ann_file_name, HDF5_DATASET_DISTANCES, H5T_FLOAT, k, nq2); + assert(nq2 == nq || !"incorrect nb of ground truth distance"); + + std::string str; + str = ann_test_name + " ground truth index"; + print_array(str.c_str(), true, gt, nq, k); + str = ann_test_name + " ground truth distance"; + print_array(str.c_str(), false, gt_dist, nq, k); + + delete gt_dist; +#endif +} + +faiss::ConcurrentBitsetPtr +CreateBitset(int32_t size, int32_t percentage) { + if (percentage < 0 || percentage > 100) { + assert(false); + } + + faiss::ConcurrentBitsetPtr bitset_ptr = std::make_shared(size); + if (percentage != 0) { + int32_t step = 100 / percentage; + for (int32_t i = 0; i < size; i += step) { + bitset_ptr->set(i); + } + } + return bitset_ptr; +} + +void +test_with_nprobes(const std::string& ann_test_name, const std::string& index_key, faiss::Index* cpu_index, + faiss::gpu::StandardGpuResources& res, const QueryMode query_mode, const faiss::Index::distance_t* xq, + const faiss::Index::idx_t* gt, const std::vector& nprobes, const int32_t index_add_loops, + const int32_t search_loops) { + double t0 = elapsed(); + + const std::vector NQ = {100}; + const std::vector K = {100}; + const int32_t GK = 100; // topk of ground truth + + std::unordered_map mode_str_map = { + {MODE_CPU, "MODE_CPU"}, {MODE_MIX, "MODE_MIX"}, {MODE_GPU, "MODE_GPU"}}; + + faiss::Index *gpu_index = nullptr, *index = nullptr; + if (query_mode != MODE_CPU) { + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + double copy_time = 0.0; + + faiss::IndexComposition index_composition; + index_composition.index = cpu_index; + index_composition.quantizer = nullptr; + switch (query_mode) { + case MODE_MIX: { + index_composition.mode = 1; // 0: all data, 1: copy quantizer, 2: copy data + + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + copy_time = elapsed() - copy_time; + printf("[%.3f s] Copy quantizer completed, cost %f s\n", elapsed() - t0, copy_time); + + auto ivf_index = dynamic_cast(cpu_index); + auto is_gpu_flat_index = dynamic_cast(ivf_index->quantizer); + if (is_gpu_flat_index == nullptr) { + delete ivf_index->quantizer; + ivf_index->quantizer = index_composition.quantizer; + } + index = cpu_index; + break; + } + case MODE_GPU: +#if 1 + index_composition.mode = 0; // 0: all data, 1: copy quantizer, 2: copy data + + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, &index_composition, &option); +#else + // warm up the transmission + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index, &option); + delete gpu_index; + + copy_time = elapsed(); + gpu_index = faiss::gpu::index_cpu_to_gpu(&res, GPU_DEVICE_IDX, cpu_index, &option); +#endif + copy_time = elapsed() - copy_time; + printf("[%.3f s] Copy data completed, cost %f s\n", elapsed() - t0, copy_time); + + delete cpu_index; + index = gpu_index; + break; + default: + break; + } + } else { + index = cpu_index; + } + + std::vector> bitset_array; + bitset_array.push_back(std::make_pair("nil", nullptr)); + bitset_array.push_back(std::make_pair("0", CreateBitset(index->ntotal, 0))); + bitset_array.push_back(std::make_pair("5", CreateBitset(index->ntotal, 5))); + bitset_array.push_back(std::make_pair("50", CreateBitset(index->ntotal, 50))); + bitset_array.push_back(std::make_pair("100", CreateBitset(index->ntotal, 100))); + + for (auto nprobe : nprobes) { + // brute-force need not set nprobe + if (index_key.find("IDMap") == std::string::npos) { + switch (query_mode) { + case MODE_CPU: + case MODE_MIX: { + faiss::ParameterSpace params; + std::string nprobe_str = "nprobe=" + std::to_string(nprobe); + params.set_index_parameters(index, nprobe_str.c_str()); + break; + } + case MODE_GPU: { + faiss::gpu::GpuIndexIVF* gpu_index_ivf = dynamic_cast(index); + gpu_index_ivf->setNumProbes(nprobe); + } + } + } + + // output buffers + faiss::Index::idx_t* I = new faiss::Index::idx_t[NQ.back() * K.back()]; + faiss::Index::distance_t* D = new faiss::Index::distance_t[NQ.back() * K.back()]; + + for (size_t j = 0; j < K.size(); j++) { + int32_t t_k = K[j]; + for (size_t i = 0; i < NQ.size(); i++) { + int32_t t_nq = NQ[i]; + + printf("\n%s | %s - %s | nq = %4d, k = %4d, nprobe=%d\n", ann_test_name.c_str(), index_key.c_str(), + mode_str_map[query_mode].c_str(), t_nq, t_k, nprobe); + printf("================================================================================\n"); + for (size_t s = 0; s < bitset_array.size(); s++) { + faiss::indexIVF_stats.quantization_time = 0.0; + faiss::indexIVF_stats.search_time = 0.0; + + double t_start = elapsed(), t_end; + for (int loop = 0; loop < search_loops; loop++) { + index->search(t_nq, xq, t_k, D, I, bitset_array[s].second); + } + t_end = elapsed(); + +#if DEBUG_VERBOSE + std::string str; + str = "I (" + index_key + ", nq=" + std::to_string(t_nq) + ", k=" + std::to_string(t_k) + ")"; + print_array(str.c_str(), true, I, t_nq, t_k); + str = "D (" + index_key + ", nq=" + std::to_string(t_nq) + ", k=" + std::to_string(t_k) + ")"; + print_array(str.c_str(), false, D, t_nq, t_k); +#endif + + // k = 100 for ground truth + int32_t hit = GetResultHitCount(gt, I, GK, t_k, t_nq, index_add_loops); + + printf("bitset = %3s%%, elapse = %.4fs (quant = %.4fs, search = %.4fs), R@ = %.4f\n", + bitset_array[s].first.c_str(), (t_end - t_start) / search_loops, + faiss::indexIVF_stats.quantization_time / 1000 / search_loops, + faiss::indexIVF_stats.search_time / 1000 / search_loops, + (hit / float(t_nq * std::min(GK, t_k) / index_add_loops))); + } + printf("================================================================================\n"); + } + } + + delete[] I; + delete[] D; + } + + delete index; +} + +void +test_ann_hdf5(const std::string& ann_test_name, const std::string& cluster_type, const std::string& index_type, + const QueryMode query_mode, int32_t index_add_loops, const std::vector& nprobes, + int32_t search_loops) { + double t0 = elapsed(); + + faiss::gpu::StandardGpuResources res; + + faiss::MetricType metric_type; + int32_t dim; + + if (query_mode == MODE_MIX && index_type != "SQ8Hybrid") { + assert(index_type == "SQ8Hybrid" || !"Only SQ8Hybrid support MODE_MIX"); + return; + } + + std::string index_key = cluster_type + "," + index_type; + + if (!parse_ann_test_name(ann_test_name, dim, metric_type)) { + printf("Invalid ann test name: %s\n", ann_test_name.c_str()); + return; + } + + int32_t nq, k; + faiss::Index* index; + faiss::Index::distance_t* xq; + faiss::Index::idx_t* gt; // ground-truth index + + printf("[%.3f s] Loading base data\n", elapsed() - t0); + load_base_data(index, ann_test_name, index_key, res, metric_type, dim, index_add_loops, query_mode); + + printf("[%.3f s] Loading queries\n", elapsed() - t0); + load_query_data(xq, nq, ann_test_name, metric_type, dim); + + printf("[%.3f s] Loading ground truth for %d queries\n", elapsed() - t0, nq); + load_ground_truth(gt, k, ann_test_name, nq); + + test_with_nprobes(ann_test_name, index_key, index, res, query_mode, xq, gt, nprobes, index_add_loops, search_loops); + printf("[%.3f s] Search test done\n\n", elapsed() - t0); + + delete[] xq; + delete[] gt; +} + +/************************************************************************************ + * https://github.com/erikbern/ann-benchmarks + * + * Dataset Dimensions Train_size Test_size Neighbors Distance Download + * Fashion- + MNIST 784 60,000 10,000 100 Euclidean HDF5 (217MB) + * GIST 960 1,000,000 1,000 100 Euclidean HDF5 (3.6GB) + * GloVe 100 1,183,514 10,000 100 Angular HDF5 (463MB) + * GloVe 200 1,183,514 10,000 100 Angular HDF5 (918MB) + * MNIST 784 60,000 10,000 100 Euclidean HDF5 (217MB) + * NYTimes 256 290,000 10,000 100 Angular HDF5 (301MB) + * SIFT 128 1,000,000 10,000 100 Euclidean HDF5 (501MB) + *************************************************************************************/ + +TEST(FAISSTEST, BENCHMARK) { + std::vector param_nprobes = {32, 128}; + const int32_t SEARCH_LOOPS = 1; + const int32_t SIFT_INSERT_LOOPS = 1; + + test_ann_hdf5("sift-128-euclidean", "IVF128", "Flat", MODE_CPU, SIFT_INSERT_LOOPS, param_nprobes, SEARCH_LOOPS); +} diff --git a/core/src/index/unittest/kdtree.cpp b/core/src/index/unittest/kdtree.cpp new file mode 100644 index 0000000000..9da3023309 --- /dev/null +++ b/core/src/index/unittest/kdtree.cpp @@ -0,0 +1,138 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. +// +//#include +//#include +//#include "knowhere/adapter/sptag.h" +//#include "knowhere/adapter/structure.h" +//#include "knowhere/index/vector_index/cpu_kdt_rng.h" +//#include "knowhere/index/vector_index/definitions.h" +// +// +// knowhere::DatasetPtr +// generate_dataset(int64_t n, int64_t d, int64_t base) { +// auto elems = n * d; +// auto p_data = (float*)malloc(elems * sizeof(float)); +// auto p_id = (int64_t*)malloc(elems * sizeof(int64_t)); +// assert(p_data != nullptr && p_id != nullptr); +// +// for (auto i = 0; i < n; ++i) { +// for (auto j = 0; j < d; ++j) { +// p_data[i * d + j] = float(base + i); +// } +// p_id[i] = i; +// } +// +// std::vector shape{n, d}; +// auto tensor = ConstructFloatTensorSmart((uint8_t*)p_data, elems * sizeof(float), shape); +// std::vector tensors{tensor}; +// std::vector tensor_fields{ConstructFloatField("data")}; +// auto tensor_schema = std::make_shared(tensor_fields); +// +// auto id_array = ConstructInt64ArraySmart((uint8_t*)p_id, n * sizeof(int64_t)); +// std::vector arrays{id_array}; +// std::vector array_fields{ConstructInt64Field("id")}; +// auto array_schema = std::make_shared(tensor_fields); +// +// auto dataset = std::make_shared(std::move(arrays), array_schema, std::move(tensors), tensor_schema); +// +// return dataset; +//} +// +// knowhere::DatasetPtr +// generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) { +// size_t size = sizeof(float) * n * d; +// auto v = (float*)malloc(size); +// // TODO(lxj): check malloc +// for (auto i = 0; i < n; ++i) { +// for (auto j = 0; j < d; ++j) { +// v[i * d + j] = float(base + i); +// } +// } +// +// std::vector data; +// auto buffer = MakeMutableBufferSmart((uint8_t*)v, size); +// std::vector shape{n, d}; +// auto float_type = std::make_shared(); +// auto tensor = std::make_shared(float_type, buffer, shape); +// data.push_back(tensor); +// +// Config meta; +// meta[META_ROWS] = int64_t(n); +// meta[META_DIM] = int64_t(d); +// meta[META_K] = int64_t(k); +// +// auto type = std::make_shared(); +// auto field = std::make_shared("data", type); +// std::vector fields{field}; +// auto schema = std::make_shared(fields); +// +// return std::make_shared(data, schema); +//} +// +// int +// main(int argc, char* argv[]) { +// auto kdt_index = std::make_shared(); +// +// const auto d = 10; +// const auto k = 3; +// const auto nquery = 10; +// +// // ID [0, 99] +// auto train = generate_dataset(100, d, 0); +// // ID [100] +// auto base = generate_dataset(1, d, 0); +// auto queries = generate_queries(nquery, d, k, 0); +// +// // Build Preprocessor +// auto preprocessor = kdt_index->BuildPreprocessor(train, Config()); +// +// // Set Preprocessor +// kdt_index->set_preprocessor(preprocessor); +// +// Config train_config; +// train_config["TPTNumber"] = "64"; +// // Train +// kdt_index->Train(train, train_config); +// +// // Add +// kdt_index->Add(base, Config()); +// +// auto binary = kdt_index->Serialize(); +// auto new_index = std::make_shared(); +// new_index->Load(binary); +// // auto new_index = kdt_index; +// +// Config search_config; +// search_config[META_K] = int64_t(k); +// +// // Search +// auto result = new_index->Search(queries, search_config); +// +// // Print Result +// { +// auto ids = result->array()[0]; +// auto dists = result->array()[1]; +// +// std::stringstream ss_id; +// std::stringstream ss_dist; +// for (auto i = 0; i < nquery; i++) { +// for (auto j = 0; j < k; ++j) { +// ss_id << *ids->data()->GetValues(1, i * k + j) << " "; +// ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; +// } +// ss_id << std::endl; +// ss_dist << std::endl; +// } +// std::cout << "id\n" << ss_id.str() << std::endl; +// std::cout << "dist\n" << ss_dist.str() << std::endl; +// } +//} diff --git a/core/src/index/unittest/metric_alg_benchmark/CMakeLists.txt b/core/src/index/unittest/metric_alg_benchmark/CMakeLists.txt new file mode 100644 index 0000000000..702e72935f --- /dev/null +++ b/core/src/index/unittest/metric_alg_benchmark/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. + +set(unittest_libs + gtest gmock gtest_main gmock_main) + +add_executable(test_metric_benchmark metric_benchmark_test.cpp) +target_link_libraries(test_metric_benchmark ${unittest_libs}) +install(TARGETS test_metric_benchmark DESTINATION unittest) diff --git a/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp b/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp new file mode 100644 index 0000000000..ea9e22943c --- /dev/null +++ b/core/src/index/unittest/metric_alg_benchmark/metric_benchmark_test.cpp @@ -0,0 +1,415 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include + +typedef float (*metric_func_ptr)(const float*, const float*, size_t); + +constexpr int64_t DIM = 512; +constexpr int64_t NB = 10000; +constexpr int64_t NQ = 5; +constexpr int64_t LOOP = 5; + +void +GenerateData(const int64_t dim, const int64_t n, float* x) { + for (int64_t i = 0; i < n; ++i) { + for (int64_t j = 0; j < dim; ++j) { + x[i * dim + j] = drand48(); + } + } +} + +void +TestMetricAlg(std::unordered_map& func_map, const std::string& key, int64_t loop, + float* distance, const int64_t nb, const float* xb, const int64_t nq, const float* xq, + const int64_t dim) { + int64_t diff = 0; + for (int64_t i = 0; i < loop; i++) { + auto t0 = std::chrono::system_clock::now(); + for (int64_t i = 0; i < nb; i++) { + for (int64_t j = 0; j < nq; j++) { + distance[i * NQ + j] = func_map[key](xb + i * dim, xq + j * dim, dim); + } + } + auto t1 = std::chrono::system_clock::now(); + diff += std::chrono::duration_cast(t1 - t0).count(); + } + std::cout << key << " takes average " << diff / loop << "ms" << std::endl; +} + +void +CheckResult(const float* result1, const float* result2, const size_t size) { + for (size_t i = 0; i < size; i++) { + ASSERT_FLOAT_EQ(result1[i], result2[i]); + } +} + +/////////////////////////////////////////////////////////////////////////////// +/* from faiss/utils/distances_simd.cpp */ +namespace FAISS { +// reads 0 <= d < 4 floats as __m128 +static inline __m128 +masked_read(int d, const float* x) { + assert(0 <= d && d < 4); + __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + case 2: + buf[1] = x[1]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +static inline __m256 +masked_read_8(int d, const float* x) { + assert(0 <= d && d < 8); + if (d < 4) { + __m256 res = _mm256_setzero_ps(); + res = _mm256_insertf128_ps(res, masked_read(d, x), 0); + return res; + } else { + __m256 res = _mm256_setzero_ps(); + res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0); + res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1); + return res; + } +} + +float +fvec_inner_product_avx(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my)); + } + + msum2 = _mm_hadd_ps(msum2, msum2); + msum2 = _mm_hadd_ps(msum2, msum2); + return _mm_cvtss_f32(msum2); +} + +float +fvec_L2sqr_avx(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + const __m256 a_m_b1 = mx - my; + msum1 += a_m_b1 * a_m_b1; + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 += _mm256_extractf128_ps(msum1, 0); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b1 = mx - my; + msum2 += a_m_b1 * a_m_b1; + } + + msum2 = _mm_hadd_ps(msum2, msum2); + msum2 = _mm_hadd_ps(msum2, msum2); + return _mm_cvtss_f32(msum2); +} +} // namespace FAISS + +/////////////////////////////////////////////////////////////////////////////// +/* from knowhere/index/vector_index/impl/nsg/Distance.cpp */ +namespace NSG { +float +DistanceL2_Compare(const float* a, const float* b, size_t size) { + float result = 0; + +#define AVX_L2SQR(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_sub_ps(tmp1, tmp2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp1); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_L2SQR(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_L2SQR(l, r, sum, l0, r0); + AVX_L2SQR(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + + return result; +} + +float +DistanceIP_Compare(const float* a, const float* b, size_t size) { + float result = 0; + +#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \ + tmp1 = _mm256_loadu_ps(addr1); \ + tmp2 = _mm256_loadu_ps(addr2); \ + tmp1 = _mm256_mul_ps(tmp1, tmp2); \ + dest = _mm256_add_ps(dest, tmp1); + + __m256 sum; + __m256 l0, l1; + __m256 r0, r1; + unsigned D = (size + 7) & ~7U; + unsigned DR = D % 16; + unsigned DD = D - DR; + const float* l = a; + const float* r = b; + const float* e_l = l + DD; + const float* e_r = r + DD; + float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; + + sum = _mm256_loadu_ps(unpack); + if (DR) { + AVX_DOT(e_l, e_r, sum, l0, r0); + } + + for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) { + AVX_DOT(l, r, sum, l0, r0); + AVX_DOT(l + 8, r + 8, sum, l1, r1); + } + _mm256_storeu_ps(unpack, sum); + result = unpack[0] + unpack[1] + unpack[2] + unpack[3] + unpack[4] + unpack[5] + unpack[6] + unpack[7]; + + return result; +} +} // namespace NSG + +/////////////////////////////////////////////////////////////////////////////// +/* from index/thirdparty/annoy/src/annoylib.h */ +namespace ANNOY { +inline float +hsum256_ps_avx(__m256 v) { + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + return _mm_cvtss_f32(x32); +} + +inline float +euclidean_distance(const float* x, const float* y, size_t f) { + float result = 0; + if (f > 7) { + __m256 d = _mm256_setzero_ps(); + for (; f > 7; f -= 8) { + const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)); + d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX... + x += 8; + y += 8; + } + // Sum all floats in dot register. + result = hsum256_ps_avx(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + float tmp = *x - *y; + result += tmp * tmp; + x++; + y++; + } + return result; +} + +inline float +dot(const float* x, const float* y, size_t f) { + float result = 0; + if (f > 7) { + __m256 d = _mm256_setzero_ps(); + for (; f > 7; f -= 8) { + d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y))); + x += 8; + y += 8; + } + // Sum all floats in dot register. + result += hsum256_ps_avx(d); + } + // Don't forget the remaining values. + for (; f > 0; f--) { + result += *x * *y; + x++; + y++; + } + return result; +} +} // namespace ANNOY + +namespace HNSW { +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) + +static float +L2SqrSIMD16Ext(const float* pVect1v, const float* pVect2v, size_t qty) { + float* pVect1 = (float*)pVect1v; + float* pVect2 = (float*)pVect2v; + // size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float* pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return (res); +} + +static float +InnerProductSIMD16Ext(const float* pVect1v, const float* pVect2v, size_t qty) { + float PORTABLE_ALIGN32 TmpRes[8]; + float* pVect1 = (float*)pVect1v; + float* pVect2 = (float*)pVect2v; + // size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float* pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return sum; +} +} // namespace HNSW + +TEST(METRICTEST, BENCHMARK) { + std::unordered_map func_map; + func_map["FAISS::L2"] = FAISS::fvec_L2sqr_avx; + func_map["NSG::L2"] = NSG::DistanceL2_Compare; + func_map["HNSW::L2"] = HNSW::L2SqrSIMD16Ext; + func_map["ANNOY::L2"] = ANNOY::euclidean_distance; + + func_map["FAISS::IP"] = FAISS::fvec_inner_product_avx; + func_map["NSG::IP"] = NSG::DistanceIP_Compare; + func_map["HNSW::IP"] = HNSW::InnerProductSIMD16Ext; + func_map["ANNOY::IP"] = ANNOY::dot; + + std::vector xb(NB * DIM); + std::vector xq(NQ * DIM); + GenerateData(DIM, NB, xb.data()); + GenerateData(DIM, NQ, xq.data()); + + std::vector distance_faiss(NB * NQ); + // std::vector distance_nsg(NB * NQ); + std::vector distance_annoy(NB * NQ); + // std::vector distance_hnsw(NB * NQ); + + std::cout << "==========" << std::endl; + TestMetricAlg(func_map, "FAISS::L2", LOOP, distance_faiss.data(), NB, xb.data(), NQ, xq.data(), DIM); + + TestMetricAlg(func_map, "ANNOY::L2", LOOP, distance_annoy.data(), NB, xb.data(), NQ, xq.data(), DIM); + CheckResult(distance_faiss.data(), distance_annoy.data(), NB * NQ); + + std::cout << "==========" << std::endl; + TestMetricAlg(func_map, "FAISS::IP", LOOP, distance_faiss.data(), NB, xb.data(), NQ, xq.data(), DIM); + + TestMetricAlg(func_map, "ANNOY::IP", LOOP, distance_annoy.data(), NB, xb.data(), NQ, xq.data(), DIM); + CheckResult(distance_faiss.data(), distance_annoy.data(), NB * NQ); +} diff --git a/core/src/index/unittest/sift.50NN.graph b/core/src/index/unittest/sift.50NN.graph new file mode 100644 index 0000000000..bf7ba7555d Binary files /dev/null and b/core/src/index/unittest/sift.50NN.graph differ diff --git a/core/src/index/unittest/siftsmall_base.fvecs b/core/src/index/unittest/siftsmall_base.fvecs new file mode 100644 index 0000000000..e3b90ae1ee Binary files /dev/null and b/core/src/index/unittest/siftsmall_base.fvecs differ diff --git a/core/src/index/unittest/test_annoy.cpp b/core/src/index/unittest/test_annoy.cpp new file mode 100644 index 0000000000..69d7809747 --- /dev/null +++ b/core/src/index/unittest/test_annoy.cpp @@ -0,0 +1,315 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexAnnoy.h" + +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class AnnoyTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + Generate(128, 10000, 10); + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::n_trees, 4}, + {milvus::knowhere::IndexParams::search_k, 100}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy")); + +TEST_P(AnnoyTest, annoy_basic) { + assert(!xb.empty()); + + // null faiss index + { + ASSERT_ANY_THROW(index_->Train(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Add(base_dataset, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + + index_->BuildAll(base_dataset, conf); // Train + Add + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + + /* + * output result to check by eyes + { + auto ids = result->Get(milvus::knowhere::meta::IDS); + auto dist = result->Get(milvus::knowhere::meta::DISTANCE); + + std::stringstream ss_id; + std::stringstream ss_dist; + for (auto i = 0; i < nq; i++) { + for (auto j = 0; j < k; ++j) { + // ss_id << *ids->data()->GetValues(1, i * k + j) << " "; + // ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; + ss_id << *((int64_t*)(ids) + i * k + j) << " "; + ss_dist << *((float*)(dist) + i * k + j) << " "; + } + ss_id << std::endl; + ss_dist << std::endl; + } + std::cout << "id\n" << ss_id.str() << std::endl; + std::cout << "dist\n" << ss_dist.str() << std::endl; + } + */ +} + +TEST_P(AnnoyTest, annoy_delete) { + assert(!xb.empty()); + + index_->BuildAll(base_dataset, conf); // Train + Add + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << " ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ + /* + * output result to check by eyes + { + auto ids = result->Get(milvus::knowhere::meta::IDS); + auto dist = result->Get(milvus::knowhere::meta::DISTANCE); + + std::stringstream ss_id; + std::stringstream ss_dist; + for (auto i = 0; i < nq; i++) { + for (auto j = 0; j < k; ++j) { + // ss_id << *ids->data()->GetValues(1, i * k + j) << " "; + // ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; + ss_id << *((int64_t*)(ids) + i * k + j) << " "; + ss_dist << *((float*)(dist) + i * k + j) << " "; + } + ss_id << std::endl; + ss_dist << std::endl; + } + std::cout << "id\n" << ss_id.str() << std::endl; + std::cout << "dist\n" << ss_dist.str() << std::endl; + } + */ +} + +TEST_P(AnnoyTest, annoy_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + // write and flush + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + // serialize index + index_->BuildAll(base_dataset, conf); + auto binaryset = index_->Serialize(milvus::knowhere::Config()); + + auto bin_data = binaryset.GetByName("annoy_index_data"); + std::string filename1 = "/tmp/annoy_test_data_serialize.bin"; + auto load_data1 = new uint8_t[bin_data->size]; + serialize(filename1, bin_data, load_data1); + + auto bin_metric_type = binaryset.GetByName("annoy_metric_type"); + std::string filename2 = "/tmp/annoy_test_metric_type_serialize.bin"; + auto load_data2 = new uint8_t[bin_metric_type->size]; + serialize(filename2, bin_metric_type, load_data2); + + auto bin_dim = binaryset.GetByName("annoy_dim"); + std::string filename3 = "/tmp/annoy_test_dim_serialize.bin"; + auto load_data3 = new uint8_t[bin_dim->size]; + serialize(filename3, bin_dim, load_data3); + + binaryset.clear(); + std::shared_ptr index_data(load_data1); + binaryset.Append("annoy_index_data", index_data, bin_data->size); + + std::shared_ptr metric_data(load_data2); + binaryset.Append("annoy_metric_type", metric_data, bin_metric_type->size); + + std::shared_ptr dim_data(load_data3); + binaryset.Append("annoy_dim", dim_data, bin_dim->size); + + index_->Load(binaryset); + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} + +/* + * faiss style test + * keep it +int +main() { + int64_t d = 64; // dimension + int64_t nb = 10000; // database size + int64_t nq = 10; // 10000; // nb of queries + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + + int64_t* ids = new int64_t[nb]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) xb[d * i + j] = (float)drand48(); + xb[d * i] += i / 1000.; + ids[i] = i; + } + printf("gen xb and ids done! \n"); + + // srand((unsigned)time(nullptr)); + auto random_seed = (unsigned)time(nullptr); + printf("delete ids: \n"); + for (int i = 0; i < nq; i++) { + auto tmp = rand_r(&random_seed) % nb; + printf("%d\n", tmp); + // std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl; + bitset->set(tmp); + // std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl; + for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j]; + // xq[d * i] += i / 1000.; + } + printf("\n"); + + int k = 4; + int n_trees = 5; + int search_k = 100; + milvus::knowhere::IndexAnnoy index; + milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids); + + milvus::knowhere::Config base_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::n_trees, n_trees}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + milvus::knowhere::DatasetPtr query_dataset = generate_query_dataset(nq, d, (const void*)xq); + milvus::knowhere::Config query_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::search_k, search_k}, + }; + + index.BuildAll(base_dataset, base_conf); + + printf("------------sanity check----------------\n"); + { // sanity check + auto res = index.Query(query_dataset, query_conf); + printf("Query done!\n"); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); + float* D = res->Get(milvus::knowhere::meta::DISTANCE); + + printf("I=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + + printf("D=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]); + printf("\n"); + } + } + + printf("---------------search xq-------------\n"); + { // search xq + auto res = index.Query(query_dataset, query_conf); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + printf("----------------search xq with delete------------\n"); + { // search xq with delete + index.SetBlacklist(bitset); + auto res = index.Query(query_dataset, query_conf); + auto I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + delete[] xb; + delete[] xq; + delete[] ids; + + return 0; +} +*/ diff --git a/core/src/index/unittest/test_binaryidmap.cpp b/core/src/index/unittest/test_binaryidmap.cpp new file mode 100644 index 0000000000..42c1eeb321 --- /dev/null +++ b/core/src/index/unittest/test_binaryidmap.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexBinaryIDMAP.h" + +#include "Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class BinaryIDMAPTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + Init_with_default(true); + index_ = std::make_shared(); + } + + void + TearDown() override{}; + + protected: + milvus::knowhere::BinaryIDMAPPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIDMAPTest, + Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING"))); + +TEST_P(BinaryIDMAPTest, binaryidmap_basic) { + ASSERT_TRUE(!xb_bin.empty()); + + std::string MetricType = GetParam(); + milvus::knowhere::Config conf{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::Metric::TYPE, MetricType}, + }; + + // null faiss index + { + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + } + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + ASSERT_TRUE(index_->GetRawVectors() != nullptr); + ASSERT_TRUE(index_->GetRawIds() != nullptr); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + + auto binaryset = index_->Serialize(conf); + auto new_index = std::make_shared(); + new_index->Load(binaryset); + auto result2 = new_index->Query(query_dataset, conf); + AssertAnns(result2, nq, k); + // PrintResult(re_result, nq, k); + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + + // auto result4 = index_->SearchById(id_dataset, conf); + // AssertAneq(result4, nq, k); +} + +TEST_P(BinaryIDMAPTest, binaryidmap_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + std::string MetricType = GetParam(); + milvus::knowhere::Config conf{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::Metric::TYPE, MetricType}, + }; + + { + // serialize index + index_->Train(base_dataset, conf); + index_->AddWithoutIds(base_dataset, milvus::knowhere::Config()); + auto re_result = index_->Query(query_dataset, conf); + AssertAnns(re_result, nq, k); + // PrintResult(re_result, nq, k); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto binaryset = index_->Serialize(conf); + auto bin = binaryset.GetByName("BinaryIVF"); + + std::string filename = "/tmp/bianryidmap_test_serialize.bin"; + auto load_data = new uint8_t[bin->size]; + serialize(filename, bin, load_data); + + binaryset.clear(); + std::shared_ptr data(load_data); + binaryset.Append("BinaryIVF", data, bin->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + } +} diff --git a/core/src/index/unittest/test_binaryivf.cpp b/core/src/index/unittest/test_binaryivf.cpp new file mode 100644 index 0000000000..ac07e3ad06 --- /dev/null +++ b/core/src/index/unittest/test_binaryivf.cpp @@ -0,0 +1,152 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/IndexBinaryIVF.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class BinaryIVFTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + std::string MetricType = GetParam(); + Init_with_default(true); + // nb = 1000000; + // nq = 1000; + // k = 1000; + // Generate(DIM, NB, NQ); + index_ = std::make_shared(); + + milvus::knowhere::Config temp_conf{ + {milvus::knowhere::meta::DIM, dim}, {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::nlist, 100}, {milvus::knowhere::IndexParams::nprobe, 10}, + {milvus::knowhere::Metric::TYPE, MetricType}, + }; + conf = temp_conf; + } + + void + TearDown() override { + } + + protected: + std::string index_type; + milvus::knowhere::Config conf; + milvus::knowhere::BinaryIVFIndexPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(METRICParameters, BinaryIVFTest, + Values(std::string("JACCARD"), std::string("TANIMOTO"), std::string("HAMMING"))); + +TEST_P(BinaryIVFTest, binaryivf_basic) { + assert(!xb_bin.empty()); + + // null faiss index + { + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + } + + index_->BuildAll(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + +#if 0 + auto result3 = index_->QueryById(id_dataset, conf); + AssertAnns(result3, nq, k, CheckMode::CHECK_NOT_EQUAL); + + auto result4 = index_->GetVectorById(xid_dataset, conf); + AssertBinVeceq(result4, base_dataset, xid_dataset, nq, dim/8); +#endif +} + +TEST_P(BinaryIVFTest, binaryivf_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + // { + // // serialize index-model + // auto model = index_->Train(base_dataset, conf); + // auto binaryset = model->Serialize(); + // auto bin = binaryset.GetByName("BinaryIVF"); + // + // std::string filename = "/tmp/binaryivf_test_model_serialize.bin"; + // auto load_data = new uint8_t[bin->size]; + // serialize(filename, bin, load_data); + // + // binaryset.clear(); + // auto data = std::make_shared(); + // data.reset(load_data); + // binaryset.Append("BinaryIVF", data, bin->size); + // + // model->Load(binaryset); + // + // index_->set_index_model(model); + // index_->Add(base_dataset, conf); + // auto result = index_->Query(query_dataset, conf); + // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // } + + { + // serialize index + index_->BuildAll(base_dataset, conf); + // index_->set_index_model(model); + // index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(conf); + auto bin = binaryset.GetByName("BinaryIVF"); + + std::string filename = "/tmp/binaryivf_test_serialize.bin"; + auto load_data = new uint8_t[bin->size]; + serialize(filename, bin, load_data); + + binaryset.clear(); + std::shared_ptr data(load_data); + binaryset.Append("BinaryIVF", data, bin->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + } +} diff --git a/core/src/index/unittest/test_common.cpp b/core/src/index/unittest/test_common.cpp new file mode 100644 index 0000000000..bdebac1eb2 --- /dev/null +++ b/core/src/index/unittest/test_common.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include "knowhere/common/Dataset.h" +#include "knowhere/common/Timer.h" +#include "knowhere/knowhere/common/Exception.h" +#include "unittest/utils.h" + +/*Some unittest for knowhere/common, mainly for improve code coverage.*/ + +TEST(COMMON_TEST, dataset_test) { + milvus::knowhere::Dataset set; + int64_t v1 = 111; + + set.Set("key1", v1); + auto get_v1 = set.Get("key1"); + ASSERT_EQ(get_v1, v1); + + ASSERT_ANY_THROW(set.Get("key1")); + ASSERT_ANY_THROW(set.Get("dummy")); +} + +TEST(COMMON_TEST, knowhere_exception) { + const std::string msg = "test"; + milvus::knowhere::KnowhereException ex(msg); + ASSERT_EQ(ex.what(), msg); +} + +TEST(COMMON_TEST, time_recoder) { + InitLog(); + + milvus::knowhere::TimeRecorder recoder("COMMTEST", 0); + sleep(1); + double span = recoder.ElapseFromBegin("get time"); + ASSERT_GE(span, 1.0); +} diff --git a/core/src/index/unittest/test_customized_index.cpp b/core/src/index/unittest/test_customized_index.cpp new file mode 100644 index 0000000000..21c61c880f --- /dev/null +++ b/core/src/index/unittest/test_customized_index.cpp @@ -0,0 +1,232 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "unittest/Helper.h" +#include "unittest/utils.h" + +class SingleIndexTest : public DataGen, public TestGpuIndexBase { + protected: + void + SetUp() override { + TestGpuIndexBase::SetUp(); + nb = 100000; + nq = 1000; + dim = DIM; + Generate(dim, nb, nq); + k = 1000; + } + + void + TearDown() override { + TestGpuIndexBase::TearDown(); + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::IVFPtr index_ = nullptr; +}; + +#ifdef MILVUS_GPU_VERSION +TEST_F(SingleIndexTest, IVFSQHybrid) { + assert(!xb.empty()); + + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + index_ = IndexFactory(index_type_, index_mode_); + + auto conf = ParamGenerator::GetInstance().Gen(index_type_); + + fiu_init(0); + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto binaryset = index_->Serialize(conf); + { + // copy cpu to gpu + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + { + for (int i = 0; i < 3; ++i) { + auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf); + auto result = gpu_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + } + } + } + + { + // quantization already in gpu, only copy data + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + ASSERT_ANY_THROW(cpu_idx->CopyCpuToGpuWithQuantizer(-1, conf)); + auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); + auto gpu_idx = pair.first; + + auto result = gpu_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + + milvus::json quantizer_conf{{milvus::knowhere::meta::DEVICEID, DEVICEID}, {"mode", 2}}; + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = std::make_shared(DEVICEID); + hybrid_idx->Load(binaryset); + auto quantization = hybrid_idx->LoadQuantizer(quantizer_conf); + auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf); + auto result = new_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + } + } + + { + // quantization already in gpu, only set quantization + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); + auto quantization = pair.second; + + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = std::make_shared(DEVICEID); + hybrid_idx->Load(binaryset); + + hybrid_idx->SetQuantizer(quantization); + auto result = hybrid_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + hybrid_idx->UnsetQuantizer(); + } + } +} + +// TEST_F(SingleIndexTest, thread_safe) { +// assert(!xb.empty()); +// +// index_type = "IVFSQHybrid"; +// index_ = IndexFactory(index_type); +// auto base = ParamGenerator::GetInstance().Gen(ParameterType::ivfsq); +// auto conf = std::dynamic_pointer_cast(base); +// conf->nlist = 16384; +// conf->k = k; +// conf->nprobe = 10; +// conf->d = dim; +// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); +// index_->set_preprocessor(preprocessor); +// +// auto model = index_->Train(base_dataset, conf); +// index_->set_index_model(model); +// index_->Add(base_dataset, conf); +// EXPECT_EQ(index_->Count(), nb); +// EXPECT_EQ(index_->Dimension(), dim); +// +// auto binaryset = index_->Serialize(); +// +// +// +// auto cpu_idx = std::make_shared(DEVICEID); +// cpu_idx->Load(binaryset); +// auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); +// auto quantizer = pair.second; +// +// auto quantizer_conf = std::make_shared(); +// quantizer_conf->mode = 2; // only copy data +// quantizer_conf->gpu_id = DEVICEID; +// +// auto CopyAllToGpu = [&](int64_t search_count, bool do_search = false) { +// for (int i = 0; i < search_count; ++i) { +// auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf); +// if (do_search) { +// auto result = gpu_idx->Search(query_dataset, conf); +// AssertAnns(result, nq, conf->k); +// } +// } +// }; +// +// auto hybrid_qt_idx = std::make_shared(DEVICEID); +// hybrid_qt_idx->Load(binaryset); +// auto SetQuantizerDoSearch = [&](int64_t search_count) { +// for (int i = 0; i < search_count; ++i) { +// hybrid_qt_idx->SetQuantizer(quantizer); +// auto result = hybrid_qt_idx->Search(query_dataset, conf); +// AssertAnns(result, nq, conf->k); +// // PrintResult(result, nq, k); +// hybrid_qt_idx->UnsetQuantizer(); +// } +// }; +// +// auto hybrid_data_idx = std::make_shared(DEVICEID); +// hybrid_data_idx->Load(binaryset); +// auto LoadDataDoSearch = [&](int64_t search_count, bool do_search = false) { +// for (int i = 0; i < search_count; ++i) { +// auto hybrid_idx = hybrid_data_idx->LoadData(quantizer, quantizer_conf); +// if (do_search) { +// auto result = hybrid_idx->Search(query_dataset, conf); +//// AssertAnns(result, nq, conf->k); +// } +// } +// }; +// +// milvus::knowhere::TimeRecorder tc(""); +// CopyAllToGpu(2000/2, false); +// tc.RecordSection("CopyAllToGpu witout search"); +// CopyAllToGpu(400/2, true); +// tc.RecordSection("CopyAllToGpu with search"); +// SetQuantizerDoSearch(6); +// tc.RecordSection("SetQuantizer with search"); +// LoadDataDoSearch(2000/2, false); +// tc.RecordSection("LoadData without search"); +// LoadDataDoSearch(400/2, true); +// tc.RecordSection("LoadData with search"); +// +// { +// std::thread t1(CopyAllToGpu, 2000, false); +// std::thread t2(CopyAllToGpu, 400, true); +// t1.join(); +// t2.join(); +// } +// +// { +// std::thread t1(SetQuantizerDoSearch, 12); +// std::thread t2(CopyAllToGpu, 400, true); +// t1.join(); +// t2.join(); +// } +// +// { +// std::thread t1(SetQuantizerDoSearch, 12); +// std::thread t2(LoadDataDoSearch, 400, true); +// t1.join(); +// t2.join(); +// } +// +// { +// std::thread t1(LoadDataDoSearch, 2000, false); +// std::thread t2(LoadDataDoSearch, 400, true); +// t1.join(); +// t2.join(); +// } +// +//} + +#endif diff --git a/core/src/index/unittest/test_gpuresource.cpp b/core/src/index/unittest/test_gpuresource.cpp new file mode 100644 index 0000000000..2c7a30a565 --- /dev/null +++ b/core/src/index/unittest/test_gpuresource.cpp @@ -0,0 +1,302 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +class GPURESTEST : public DataGen, public TestGpuIndexBase { + protected: + void + SetUp() override { + TestGpuIndexBase::SetUp(); + Generate(DIM, NB, NQ); + + k = K; + elems = nq * k; + ids = (int64_t*)malloc(sizeof(int64_t) * elems); + dis = (float*)malloc(sizeof(float) * elems); + } + + void + TearDown() override { + delete ids; + delete dis; + TestGpuIndexBase::TearDown(); + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::IVFPtr index_ = nullptr; + + int64_t* ids = nullptr; + float* dis = nullptr; + int64_t elems = 0; +}; + +TEST_F(GPURESTEST, copyandsearch) { + // search and copy at the same time + printf("==================\n"); + + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + index_ = IndexFactory(index_type_, index_mode_); + + auto conf = ParamGenerator::GetInstance().Gen(index_type_); + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + + index_->SetIndexSize(nb * dim * sizeof(float)); + auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + milvus::knowhere::IVFPtr ivf_idx = std::dynamic_pointer_cast(cpu_idx); + ivf_idx->Seal(); + auto search_idx = milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + + constexpr int64_t search_count = 50; + constexpr int64_t load_count = 15; + auto search_func = [&] { + // TimeRecorder tc("search&load"); + for (int i = 0; i < search_count; ++i) { + search_idx->Query(query_dataset, conf); + // if (i > search_count - 6 || i == 0) + // tc.RecordSection("search once"); + } + // tc.ElapseFromBegin("search finish"); + }; + auto load_func = [&] { + // TimeRecorder tc("search&load"); + for (int i = 0; i < load_count; ++i) { + milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + // if (i > load_count -5 || i < 5) + // tc.RecordSection("Copy to gpu"); + } + // tc.ElapseFromBegin("load finish"); + }; + + milvus::knowhere::TimeRecorder tc("Basic"); + milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + tc.RecordSection("Copy to gpu once"); + search_idx->Query(query_dataset, conf); + tc.RecordSection("Search once"); + search_func(); + tc.RecordSection("Search total cost"); + load_func(); + tc.RecordSection("Copy total cost"); + + std::thread search_thread(search_func); + std::thread load_thread(load_func); + search_thread.join(); + load_thread.join(); + tc.RecordSection("Copy&Search total"); +} + +TEST_F(GPURESTEST, trainandsearch) { + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + index_ = IndexFactory(index_type_, index_mode_); + + auto conf = ParamGenerator::GetInstance().Gen(index_type_); + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + index_->SetIndexSize(nb * dim * sizeof(float)); + auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + milvus::knowhere::IVFPtr ivf_idx = std::dynamic_pointer_cast(cpu_idx); + ivf_idx->Seal(); + auto search_idx = milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + + constexpr int train_count = 5; + constexpr int search_count = 200; + auto train_stage = [&] { + for (int i = 0; i < train_count; ++i) { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + } + }; + auto search_stage = [&](milvus::knowhere::VecIndexPtr& search_idx) { + for (int i = 0; i < search_count; ++i) { + auto result = search_idx->Query(query_dataset, conf); + AssertAnns(result, nq, k); + } + }; + + // TimeRecorder tc("record"); + // train_stage(); + // tc.RecordSection("train cost"); + // search_stage(search_idx); + // tc.RecordSection("search cost"); + + { + // search and build parallel + std::thread search_thread(search_stage, std::ref(search_idx)); + std::thread train_thread(train_stage); + train_thread.join(); + search_thread.join(); + } + { + // build parallel + std::thread train_1(train_stage); + std::thread train_2(train_stage); + train_1.join(); + train_2.join(); + } + { + // search parallel + auto search_idx_2 = milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + std::thread search_1(search_stage, std::ref(search_idx)); + std::thread search_2(search_stage, std::ref(search_idx_2)); + search_1.join(); + search_2.join(); + } +} + +#ifdef CompareToOriFaiss +TEST_F(GPURESTEST, gpu_ivf_resource_test) { + assert(!xb.empty()); + + { + index_ = std::make_shared(-1); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), -1); + std::dynamic_pointer_cast(index_)->SetGpuDevice(DEVICEID); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), DEVICEID); + + auto conf = ParamGenerator::GetInstance().Gen(ParameterType::ivfsq); + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dimension(), dim); + + // milvus::knowhere::TimeRecorder tc("knowere GPUIVF"); + for (int i = 0; i < search_count; ++i) { + index_->Search(query_dataset, conf); + if (i > search_count - 6 || i < 5) + // tc.RecordSection("search once"); + } + // tc.ElapseFromBegin("search all"); + } + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); + + // { + // // ori faiss IVF-Search + // faiss::gpu::StandardGpuResources res; + // faiss::gpu::GpuIndexIVFFlatConfig idx_config; + // idx_config.device = DEVICEID; + // faiss::gpu::GpuIndexIVFFlat device_index(&res, dim, 1638, faiss::METRIC_L2, idx_config); + // device_index.train(nb, xb.data()); + // device_index.add(nb, xb.data()); + // + // milvus::knowhere::TimeRecorder tc("ori IVF"); + // for (int i = 0; i < search_count; ++i) { + // device_index.search(nq, xq.data(), k, dis, ids); + // if (i > search_count - 6 || i < 5) + // tc.RecordSection("search once"); + // } + // tc.ElapseFromBegin("search all"); + // } +} + +TEST_F(GPURESTEST, gpuivfsq) { + { + // knowhere gpu ivfsq + index_type = "GPUIVFSQ"; + index_ = IndexFactory(index_type); + + auto conf = std::make_shared(); + conf->nlist = 1638; + conf->d = dim; + conf->gpu_id = DEVICEID; + conf->metric_type = milvus::knowhere::METRICTYPE::L2; + conf->k = k; + conf->nbits = 8; + conf->nprobe = 1; + + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + // auto result = index_->Search(query_dataset, conf); + // AssertAnns(result, nq, k); + + auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + cpu_idx->Seal(); + + milvus::knowhere::TimeRecorder tc("knowhere GPUSQ8"); + auto search_idx = milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + tc.RecordSection("Copy to gpu"); + for (int i = 0; i < search_count; ++i) { + search_idx->Search(query_dataset, conf); + if (i > search_count - 6 || i < 5) + tc.RecordSection("search once"); + } + tc.ElapseFromBegin("search all"); + } + + { + // Ori gpuivfsq Test + const char* index_description = "IVF1638,SQ8"; + faiss::Index* ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2); + + faiss::gpu::StandardGpuResources res; + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, DEVICEID, ori_index); + device_index->train(nb, xb.data()); + device_index->add(nb, xb.data()); + + auto cpu_index = faiss::gpu::index_gpu_to_cpu(device_index); + auto idx = dynamic_cast(cpu_index); + if (idx != nullptr) { + idx->to_readonly(); + } + delete device_index; + delete ori_index; + + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + milvus::knowhere::TimeRecorder tc("ori GPUSQ8"); + faiss::Index* search_idx = faiss::gpu::index_cpu_to_gpu(&res, DEVICEID, cpu_index, &option); + tc.RecordSection("Copy to gpu"); + for (int i = 0; i < search_count; ++i) { + search_idx->search(nq, xq.data(), k, dis, ids); + if (i > search_count - 6 || i < 5) + tc.RecordSection("search once"); + } + tc.ElapseFromBegin("search all"); + delete cpu_index; + delete search_idx; + } +} +#endif diff --git a/core/src/index/unittest/test_hnsw.cpp b/core/src/index/unittest/test_hnsw.cpp new file mode 100644 index 0000000000..a07bb61baa --- /dev/null +++ b/core/src/index/unittest/test_hnsw.cpp @@ -0,0 +1,292 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class HNSWTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW")); + +TEST_P(HNSWTest, HNSW_basic) { + assert(!xb.empty()); + + // null faiss index + /* + { + ASSERT_ANY_THROW(index_->Serialize()); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + */ + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(conf); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); +} + +TEST_P(HNSWTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(conf); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +/* +TEST_P(HNSWTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto bin = binaryset.GetByName("HNSW"); + + std::string filename = "/tmp/HNSW_test_serialize.bin"; + auto load_data = new uint8_t[bin->size]; + serialize(filename, bin, load_data); + + binaryset.clear(); + std::shared_ptr data(load_data); + binaryset.Append("HNSW", data, bin->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +}*/ + +/* + * faiss style test + * keep it +int +main() { + int64_t d = 64; // dimension + int64_t nb = 10000; // database size + int64_t nq = 10; // 10000; // nb of queries + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + + int64_t* ids = new int64_t[nb]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; + // int64_t *ids = (int64_t*)malloc(nb * sizeof(int64_t)); + // float* xb = (float*)malloc(d * nb * sizeof(float)); + // float* xq = (float*)malloc(d * nq * sizeof(float)); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + ids[i] = i; + } +// printf("gen xb and ids done! \n"); + + // srand((unsigned)time(nullptr)); + auto random_seed = (unsigned)time(nullptr); +// printf("delete ids: \n"); + for (int i = 0; i < nq; i++) { + auto tmp = rand_r(&random_seed) % nb; +// printf("%ld\n", tmp); + // std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl; + bitset->set(tmp); + // std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl; + for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j]; + // xq[d * i] += i / 1000.; + } +// printf("\n"); + + int k = 4; + int m = 16; + int ef = 200; + milvus::knowhere::IndexHNSW_NM index; + milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids); +// base_dataset->Set(milvus::knowhere::meta::ROWS, nb); +// base_dataset->Set(milvus::knowhere::meta::DIM, d); +// base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb); +// base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids); + + milvus::knowhere::Config base_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::efConstruction, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + milvus::knowhere::DatasetPtr query_dataset = generate_query_dataset(nq, d, (const void*)xq); + milvus::knowhere::Config query_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::ef, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + + index.Train(base_dataset, base_conf); + index.Add(base_dataset, base_conf); + +// printf("------------sanity check----------------\n"); + { // sanity check + auto res = index.Query(query_dataset, query_conf); +// printf("Query done!\n"); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); +// float* D = res->Get(milvus::knowhere::meta::DISTANCE); + +// printf("I=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); +// printf("\n"); +// } + +// printf("D=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]); +// printf("\n"); +// } + } + +// printf("---------------search xq-------------\n"); + { // search xq + auto res = index.Query(query_dataset, query_conf); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + printf("----------------search xq with delete------------\n"); + { // search xq with delete + index.SetBlacklist(bitset); + auto res = index.Query(query_dataset, query_conf); + auto I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + delete[] xb; + delete[] xq; + delete[] ids; + + return 0; +} +*/ diff --git a/core/src/index/unittest/test_idmap.cpp b/core/src/index/unittest/test_idmap.cpp new file mode 100644 index 0000000000..eebb488c0a --- /dev/null +++ b/core/src/index/unittest/test_idmap.cpp @@ -0,0 +1,246 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#ifdef MILVUS_GPU_VERSION +#include +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif +#include "Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IDMAPTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + index_mode_ = GetParam(); + Init_with_default(); + index_ = std::make_shared(); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IDMAPPtr index_ = nullptr; + milvus::knowhere::IndexMode index_mode_; +}; + +INSTANTIATE_TEST_CASE_P(IDMAPParameters, IDMAPTest, + Values( +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::IndexMode::MODE_GPU, +#endif + milvus::knowhere::IndexMode::MODE_CPU)); + +TEST_P(IDMAPTest, idmap_basic) { + ASSERT_TRUE(!xb.empty()); + + milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}}; + + // null faiss index + { + ASSERT_ANY_THROW(index_->Serialize(conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + } + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + ASSERT_TRUE(index_->GetRawVectors() != nullptr); + ASSERT_TRUE(index_->GetRawIds() != nullptr); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + + if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) { +#ifdef MILVUS_GPU_VERSION + // cpu to gpu + index_ = std::dynamic_pointer_cast(index_->CopyCpuToGpu(DEVICEID, conf)); +#endif + } + + auto binaryset = index_->Serialize(conf); + auto new_index = std::make_shared(); + new_index->Load(binaryset); + auto result2 = new_index->Query(query_dataset, conf); + AssertAnns(result2, nq, k); + // PrintResult(re_result, nq, k); + +#if 0 + auto result3 = new_index->QueryById(id_dataset, conf); + AssertAnns(result3, nq, k); + + auto result4 = new_index->GetVectorById(xid_dataset, conf); + AssertVec(result4, base_dataset, xid_dataset, 1, dim); +#endif + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + +#if 0 + auto result_bs_2 = index_->QueryById(id_dataset, conf); + AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + auto result_bs_3 = index_->GetVectorById(xid_dataset, conf); + AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL); +#endif +} + +TEST_P(IDMAPTest, idmap_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}}; + + { + // serialize index + index_->Train(base_dataset, conf); + index_->Add(base_dataset, milvus::knowhere::Config()); + + if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) { +#ifdef MILVUS_GPU_VERSION + // cpu to gpu + index_ = std::dynamic_pointer_cast(index_->CopyCpuToGpu(DEVICEID, conf)); +#endif + } + + auto re_result = index_->Query(query_dataset, conf); + AssertAnns(re_result, nq, k); + // PrintResult(re_result, nq, k); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto binaryset = index_->Serialize(conf); + auto bin = binaryset.GetByName("IVF"); + + std::string filename = "/tmp/idmap_test_serialize.bin"; + auto load_data = new uint8_t[bin->size]; + serialize(filename, bin, load_data); + + binaryset.clear(); + std::shared_ptr data(load_data); + binaryset.Append("IVF", data, bin->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + } +} + +#ifdef MILVUS_GPU_VERSION +TEST_P(IDMAPTest, idmap_copy) { + ASSERT_TRUE(!xb.empty()); + + milvus::knowhere::Config conf{{milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}}; + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + ASSERT_TRUE(index_->GetRawVectors() != nullptr); + ASSERT_TRUE(index_->GetRawIds() != nullptr); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + + { + // clone + // auto clone_index = index_->Clone(); + // auto clone_result = clone_index->Search(query_dataset, conf); + // AssertAnns(clone_result, nq, k); + } + + { + // cpu to gpu + ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, conf)); + auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); + auto clone_result = clone_index->Query(query_dataset, conf); + AssertAnns(clone_result, nq, k); + ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawVectors(); }, + milvus::knowhere::KnowhereException); + ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawIds(); }, + milvus::knowhere::KnowhereException); + + fiu_init(0); + fiu_enable("GPUIDMP.SerializeImpl.throw_exception", 1, nullptr, 0); + ASSERT_ANY_THROW(clone_index->Serialize(conf)); + fiu_disable("GPUIDMP.SerializeImpl.throw_exception"); + + auto binary = clone_index->Serialize(conf); + clone_index->Load(binary); + auto new_result = clone_index->Query(query_dataset, conf); + AssertAnns(new_result, nq, k); + + // auto clone_gpu_idx = clone_index->Clone(); + // auto clone_gpu_res = clone_gpu_idx->Search(query_dataset, conf); + // AssertAnns(clone_gpu_res, nq, k); + + // gpu to cpu + auto host_index = milvus::knowhere::cloner::CopyGpuToCpu(clone_index, conf); + auto host_result = host_index->Query(query_dataset, conf); + AssertAnns(host_result, nq, k); + ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawVectors() != nullptr); + ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawIds() != nullptr); + + // gpu to gpu + auto device_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); + auto new_device_index = + std::static_pointer_cast(device_index)->CopyGpuToGpu(DEVICEID, conf); + auto device_result = new_device_index->Query(query_dataset, conf); + AssertAnns(device_result, nq, k); + } +} +#endif diff --git a/core/src/index/unittest/test_instructionset.cpp b/core/src/index/unittest/test_instructionset.cpp new file mode 100644 index 0000000000..f1eb5746ae --- /dev/null +++ b/core/src/index/unittest/test_instructionset.cpp @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "faiss/utils/instruction_set.h" + +#include +#include + +void +ShowInstructionSet() { + auto& outstream = std::cout; + + auto support_message = [&outstream](const std::string& isa_feature, bool is_supported) { + outstream << isa_feature << (is_supported ? " supported" : " not supported") << std::endl; + }; + + faiss::InstructionSet& instruction_set_inst = faiss::InstructionSet::GetInstance(); + + std::cout << instruction_set_inst.Vendor() << std::endl; + std::cout << instruction_set_inst.Brand() << std::endl; + + support_message("3DNOW", instruction_set_inst._3DNOW()); + support_message("3DNOWEXT", instruction_set_inst._3DNOWEXT()); + support_message("ABM", instruction_set_inst.ABM()); + support_message("ADX", instruction_set_inst.ADX()); + support_message("AES", instruction_set_inst.AES()); + support_message("AVX", instruction_set_inst.AVX()); + support_message("AVX2", instruction_set_inst.AVX2()); + support_message("AVX512BW", instruction_set_inst.AVX512BW()); + support_message("AVX512CD", instruction_set_inst.AVX512CD()); + support_message("AVX512DQ", instruction_set_inst.AVX512DQ()); + support_message("AVX512ER", instruction_set_inst.AVX512ER()); + support_message("AVX512F", instruction_set_inst.AVX512F()); + support_message("AVX512PF", instruction_set_inst.AVX512PF()); + support_message("AVX512VL", instruction_set_inst.AVX512VL()); + support_message("BMI1", instruction_set_inst.BMI1()); + support_message("BMI2", instruction_set_inst.BMI2()); + support_message("CLFSH", instruction_set_inst.CLFSH()); + support_message("CMOV", instruction_set_inst.CMOV()); + support_message("CMPXCHG16B", instruction_set_inst.CMPXCHG16B()); + support_message("CX8", instruction_set_inst.CX8()); + support_message("ERMS", instruction_set_inst.ERMS()); + support_message("F16C", instruction_set_inst.F16C()); + support_message("FMA", instruction_set_inst.FMA()); + support_message("FSGSBASE", instruction_set_inst.FSGSBASE()); + support_message("FXSR", instruction_set_inst.FXSR()); + support_message("HLE", instruction_set_inst.HLE()); + support_message("INVPCID", instruction_set_inst.INVPCID()); + support_message("LAHF", instruction_set_inst.LAHF()); + support_message("LZCNT", instruction_set_inst.LZCNT()); + support_message("MMX", instruction_set_inst.MMX()); + support_message("MMXEXT", instruction_set_inst.MMXEXT()); + support_message("MONITOR", instruction_set_inst.MONITOR()); + support_message("MOVBE", instruction_set_inst.MOVBE()); + support_message("MSR", instruction_set_inst.MSR()); + support_message("OSXSAVE", instruction_set_inst.OSXSAVE()); + support_message("PCLMULQDQ", instruction_set_inst.PCLMULQDQ()); + support_message("POPCNT", instruction_set_inst.POPCNT()); + support_message("PREFETCHWT1", instruction_set_inst.PREFETCHWT1()); + support_message("RDRAND", instruction_set_inst.RDRAND()); + support_message("RDSEED", instruction_set_inst.RDSEED()); + support_message("RDTSCP", instruction_set_inst.RDTSCP()); + support_message("RTM", instruction_set_inst.RTM()); + support_message("SEP", instruction_set_inst.SEP()); + support_message("SHA", instruction_set_inst.SHA()); + support_message("SSE", instruction_set_inst.SSE()); + support_message("SSE2", instruction_set_inst.SSE2()); + support_message("SSE3", instruction_set_inst.SSE3()); + support_message("SSE4.1", instruction_set_inst.SSE41()); + support_message("SSE4.2", instruction_set_inst.SSE42()); + support_message("SSE4a", instruction_set_inst.SSE4a()); + support_message("SSSE3", instruction_set_inst.SSSE3()); + support_message("SYSCALL", instruction_set_inst.SYSCALL()); + support_message("TBM", instruction_set_inst.TBM()); + support_message("XOP", instruction_set_inst.XOP()); + support_message("XSAVE", instruction_set_inst.XSAVE()); +} + +TEST(InstructionSetTest, INSTRUCTION_SET_TEST) { + ASSERT_NO_FATAL_FAILURE(ShowInstructionSet()); +} diff --git a/core/src/index/unittest/test_ivf.cpp b/core/src/index/unittest/test_ivf.cpp new file mode 100644 index 0000000000..1017af8f86 --- /dev/null +++ b/core/src/index/unittest/test_ivf.cpp @@ -0,0 +1,383 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::tie(index_type_, index_mode_) = GetParam(); + // Init_with_default(); + // nb = 1000000; + // nq = 1000; + // k = 1000; + Generate(DIM, NB, NQ); + index_ = IndexFactory(index_type_, index_mode_); + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + // conf_->Dump(); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P( + IVFParameters, IVFTest, + Values( +#ifdef MILVUS_GPU_VERSION + std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_GPU), + std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_GPU), + std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, milvus::knowhere::IndexMode::MODE_GPU), +#endif + std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_CPU))); + +TEST_P(IVFTest, ivf_basic_cpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_CPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->Train(base_dataset, conf_); + index_->AddWithoutIds(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + + if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { +#if 0 + auto result2 = index_->QueryById(id_dataset, conf_); + AssertAnns(result2, nq, k); + + if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { + auto result3 = index_->GetVectorById(xid_dataset, conf_); + AssertVec(result3, base_dataset, xid_dataset, 1, dim); + } else { + auto result3 = index_->GetVectorById(xid_dataset, conf_); + /* for SQ8, sometimes the mean diff can bigger than 20% */ + // AssertVec(result3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_APPROXIMATE_EQUAL); + } +#endif + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + // PrintResult(result, nq, k); + +#if 0 + auto result_bs_2 = index_->QueryById(id_dataset, conf_); + AssertAnns(result_bs_2, nq, k, CheckMode::CHECK_NOT_EQUAL); + // PrintResult(result, nq, k); + + auto result_bs_3 = index_->GetVectorById(xid_dataset, conf_); + AssertVec(result_bs_3, base_dataset, xid_dataset, 1, dim, CheckMode::CHECK_NOT_EQUAL); +#endif + } + +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +#endif +} + +TEST_P(IVFTest, ivf_basic_gpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->BuildAll(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + // PrintResult(result, nq, k); + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + // PrintResult(result, nq, k); + +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +#endif +} + +TEST_P(IVFTest, ivf_serialize) { + fiu_init(0); + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + // serialize index + index_->Train(base_dataset, conf_); + index_->Add(base_dataset, conf_); + auto binaryset = index_->Serialize(conf_); + auto bin = binaryset.GetByName("IVF"); + + std::string filename = "/tmp/ivf_test_serialize.bin"; + auto load_data = new uint8_t[bin->size]; + serialize(filename, bin, load_data); + + binaryset.clear(); + std::shared_ptr data(load_data); + binaryset.Append("IVF", data, bin->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); + } +} + +// TODO(linxj): deprecated +#ifdef MILVUS_GPU_VERSION +TEST_P(IVFTest, clone_test) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf_); + index_->Add(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + /* set peseodo index size, avoid throw exception */ + index_->SetIndexSize(nq * dim * sizeof(float)); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); + // PrintResult(result, nq, k); + + auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { + auto ids_p1 = p1->Get(milvus::knowhere::meta::IDS); + auto ids_p2 = p2->Get(milvus::knowhere::meta::IDS); + + for (int i = 0; i < nq * k; ++i) { + EXPECT_EQ(*((int64_t*)(ids_p2) + i), *((int64_t*)(ids_p1) + i)); + // EXPECT_EQ(*(ids_p2->data()->GetValues(1, i)), *(ids_p1->data()->GetValues(1, + // i))); + } + }; + + { + // copy from gpu to cpu + if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; + }); + } else { + EXPECT_THROW( + { + std::cout << "clone G <=> C [" << index_type_ << "] failed" << std::endl; + auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + }, + milvus::knowhere::KnowhereException); + } + } + + { + // copy to gpu + if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, milvus::knowhere::Config()); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; + }); + EXPECT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, milvus::knowhere::Config())); + } + } +} +#endif + +#ifdef MILVUS_GPU_VERSION +TEST_P(IVFTest, gpu_seal_test) { + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + assert(!xb.empty()); + + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + ASSERT_ANY_THROW(index_->Seal()); + + index_->Train(base_dataset, conf_); + index_->Add(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + /* set peseodo index size, avoid throw exception */ + index_->SetIndexSize(nq * dim * sizeof(float)); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]); + + fiu_init(0); + fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + fiu_disable("IVF.Search.throw_std_exception"); + fiu_enable("IVF.Search.throw_faiss_exception", 1, nullptr, 0); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf_)); + fiu_disable("IVF.Search.throw_faiss_exception"); + + auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + milvus::knowhere::IVFPtr ivf_idx = std::dynamic_pointer_cast(cpu_idx); + + milvus::knowhere::TimeRecorder tc("CopyToGpu"); + milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + auto without_seal = tc.RecordSection("Without seal"); + ivf_idx->Seal(); + tc.RecordSection("seal cost"); + milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config()); + auto with_seal = tc.RecordSection("With seal"); + ASSERT_GE(without_seal, with_seal); + + // copy to GPU with invalid device id + ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, -1, milvus::knowhere::Config())); +} + +TEST_P(IVFTest, invalid_gpu_source) { + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + + auto invalid_conf = ParamGenerator::GetInstance().Gen(index_type_); + invalid_conf[milvus::knowhere::meta::DEVICEID] = -1; + + // if (index_type_ == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + // null faiss index + // index_->SetIndexSize(0); + // milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + // } + + index_->Train(base_dataset, conf_); + + fiu_init(0); + fiu_enable("GPUIVF.SerializeImpl.throw_exception", 1, nullptr, 0); + ASSERT_ANY_THROW(index_->Serialize(conf_)); + fiu_disable("GPUIVF.SerializeImpl.throw_exception"); + + fiu_enable("GPUIVF.search_impl.invald_index", 1, nullptr, 0); + ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf)); + fiu_disable("GPUIVF.search_impl.invald_index"); + + auto ivf_index = std::dynamic_pointer_cast(index_); + if (ivf_index) { + auto gpu_index = std::dynamic_pointer_cast(ivf_index); + gpu_index->SetGpuDevice(-1); + ASSERT_EQ(gpu_index->GetGpuDevice(), -1); + } + + // ASSERT_ANY_THROW(index_->Load(binaryset)); + ASSERT_ANY_THROW(index_->Train(base_dataset, invalid_conf)); +} + +TEST_P(IVFTest, IVFSQHybrid_test) { + if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { + return; + } + fiu_init(0); + + index_->SetIndexSize(0); + milvus::knowhere::cloner::CopyGpuToCpu(index_, conf_); + ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, conf_)); + + fiu_enable("FaissGpuResourceMgr.GetRes.ret_null", 1, nullptr, 0); + ASSERT_ANY_THROW(index_->Train(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->CopyCpuToGpu(DEVICEID, conf_)); + fiu_disable("FaissGpuResourceMgr.GetRes.ret_null"); + + index_->Train(base_dataset, conf_); + auto index = std::dynamic_pointer_cast(index_); + ASSERT_TRUE(index != nullptr); + ASSERT_ANY_THROW(index->UnsetQuantizer()); + + ASSERT_ANY_THROW(index->SetQuantizer(nullptr)); +} +#endif diff --git a/core/src/index/unittest/test_ivf_cpu_nm.cpp b/core/src/index/unittest/test_ivf_cpu_nm.cpp new file mode 100644 index 0000000000..cc9095c5d7 --- /dev/null +++ b/core/src/index/unittest/test_ivf_cpu_nm.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFNMCPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::tie(index_type_, index_mode_) = GetParam(); + Generate(DIM, NB, NQ); + index_ = IndexFactoryNM(index_type_, index_mode_); + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFNMPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(IVFParameters, IVFNMCPUTest, + Values(std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + milvus::knowhere::IndexMode::MODE_CPU))); + +TEST_P(IVFNMCPUTest, ivf_basic_cpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_CPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->Train(base_dataset, conf_); + index_->AddWithoutIds(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + milvus::knowhere::BinarySet bs = index_->Serialize(conf_); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + +#ifdef MILVUS_GPU_VERSION + // copy from cpu to gpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertAnns(clone_result, nq, k); + std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; + }); + EXPECT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, milvus::knowhere::Config())); + } +#endif + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +#endif +} diff --git a/core/src/index/unittest/test_ivf_gpu_nm.cpp b/core/src/index/unittest/test_ivf_gpu_nm.cpp new file mode 100644 index 0000000000..68c5fb7946 --- /dev/null +++ b/core/src/index/unittest/test_ivf_gpu_nm.cpp @@ -0,0 +1,138 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +#define SERIALIZE_AND_LOAD(index_) \ + milvus::knowhere::BinarySet bs = index_->Serialize(conf_); \ + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); \ + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); \ + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); \ + milvus::knowhere::BinaryPtr bptr = std::make_shared(); \ + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); \ + bptr->size = dim * rows * sizeof(float); \ + bs.Append(RAW_DATA, bptr); \ + index_->Load(bs); + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFNMGPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + Generate(DIM, NB, NQ); +#ifdef MILVUS_GPU_VERSION + index_ = std::make_shared(DEVICEID); +#endif + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFPtr index_ = nullptr; +}; + +#ifdef MILVUS_GPU_VERSION +TEST_F(IVFNMGPUTest, ivf_basic_gpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->BuildAll(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + SERIALIZE_AND_LOAD(index_); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + + auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { + auto ids_p1 = p1->Get(milvus::knowhere::meta::IDS); + auto ids_p2 = p2->Get(milvus::knowhere::meta::IDS); + + for (int i = 0; i < nq * k; ++i) { + EXPECT_EQ(*((int64_t*)(ids_p2) + i), *((int64_t*)(ids_p1) + i)); + } + }; + + // copy from gpu to cpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, conf_); + SERIALIZE_AND_LOAD(clone_index); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; + }); + } + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +} +#endif diff --git a/core/src/index/unittest/test_knowhere.cpp b/core/src/index/unittest/test_knowhere.cpp new file mode 100644 index 0000000000..bf6e3a03d5 --- /dev/null +++ b/core/src/index/unittest/test_knowhere.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "config/ServerConfig.h" +#include "wrapper/KnowhereResource.h" +#include "wrapper/utils.h" + +#include +#include +#include + +TEST_F(KnowhereTest, KNOWHERE_RESOURCE_TEST) { + std::string config_path(CONFIG_PATH); + config_path += CONFIG_FILE; + milvus::server::Config& config = milvus::server::Config::GetInstance(); + milvus::Status s = config.LoadConfigFile(config_path); + ASSERT_TRUE(s.ok()); + + milvus::engine::KnowhereResource::Initialize(); + milvus::engine::KnowhereResource::Finalize(); + +#ifdef MILVUS_GPU_VERSION + fiu_init(0); + fiu_enable("check_config_gpu_resource_enable_fail", 1, nullptr, 0); + s = milvus::engine::KnowhereResource::Initialize(); + ASSERT_FALSE(s.ok()); + fiu_disable("check_config_gpu_resource_enable_fail"); + + fiu_enable("KnowhereResource.Initialize.disable_gpu", 1, nullptr, 0); + s = milvus::engine::KnowhereResource::Initialize(); + ASSERT_TRUE(s.ok()); + fiu_disable("KnowhereResource.Initialize.disable_gpu"); + + fiu_enable("check_gpu_resource_config_build_index_fail", 1, nullptr, 0); + s = milvus::engine::KnowhereResource::Initialize(); + ASSERT_FALSE(s.ok()); + fiu_disable("check_gpu_resource_config_build_index_fail"); + + fiu_enable("check_gpu_resource_config_search_fail", 1, nullptr, 0); + s = milvus::engine::KnowhereResource::Initialize(); + ASSERT_FALSE(s.ok()); + fiu_disable("check_gpu_resource_config_search_fail"); +#endif +} diff --git a/core/src/index/unittest/test_nsg.cpp b/core/src/index/unittest/test_nsg.cpp new file mode 100644 index 0000000000..4de272458d --- /dev/null +++ b/core/src/index/unittest/test_nsg.cpp @@ -0,0 +1,202 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +#include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/impl/nsg/NSGIO.h" + +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr int64_t DEVICE_GPU0 = 0; + +class NSGInterfaceTest : public DataGen, public ::testing::Test { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + int64_t MB = 1024 * 1024; + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_GPU0, MB * 200, MB * 600, 1); +#endif + int nsg_dim = 256; + Generate(nsg_dim, 20000, nq); + index_ = std::make_shared(); + + train_conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, 256}, + {milvus::knowhere::IndexParams::nlist, 163}, + {milvus::knowhere::IndexParams::nprobe, 8}, + {milvus::knowhere::IndexParams::knng, 20}, + {milvus::knowhere::IndexParams::search_length, 40}, + {milvus::knowhere::IndexParams::out_degree, 30}, + {milvus::knowhere::IndexParams::candidate, 100}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}}; + + search_conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::search_length, 30}, + }; + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + std::shared_ptr index_; + milvus::knowhere::Config train_conf; + milvus::knowhere::Config search_conf; +}; + +TEST_F(NSGInterfaceTest, basic_test) { + assert(!xb.empty()); + fiu_init(0); + // untrained index + { + ASSERT_ANY_THROW(index_->Serialize(search_conf)); + ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf)); + ASSERT_ANY_THROW(index_->Add(base_dataset, search_conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, search_conf)); + } + + train_conf[milvus::knowhere::meta::DEVICEID] = -1; + index_->BuildAll(base_dataset, train_conf); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(search_conf); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result = index_->Query(query_dataset, search_conf); + AssertAnns(result, nq, k); + + /* test NSG GPU train */ + auto new_index_1 = std::make_shared(DEVICE_GPU0); + train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0; + new_index_1->BuildAll(base_dataset, train_conf); + + // Serialize and Load before Query + bs = new_index_1->Serialize(search_conf); + + dim = base_dataset->Get(milvus::knowhere::meta::DIM); + rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + new_index_1->Load(bs); + + auto new_result_1 = new_index_1->Query(query_dataset, search_conf); + AssertAnns(new_result_1, nq, k); + + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); +} + +TEST_F(NSGInterfaceTest, compare_test) { + milvus::knowhere::impl::DistanceL2 distanceL2; + milvus::knowhere::impl::DistanceIP distanceIP; + + milvus::knowhere::TimeRecorder tc("Compare"); + for (int i = 0; i < 1000; ++i) { + distanceL2.Compare(xb.data(), xq.data(), 256); + } + tc.RecordSection("L2"); + for (int i = 0; i < 1000; ++i) { + distanceIP.Compare(xb.data(), xq.data(), 256); + } + tc.RecordSection("IP"); +} + +TEST_F(NSGInterfaceTest, delete_test) { + assert(!xb.empty()); + + train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0; + index_->Train(base_dataset, train_conf); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(search_conf); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result = index_->Query(query_dataset, search_conf); + AssertAnns(result, nq, k); + + ASSERT_EQ(index_->Count(), nb); + ASSERT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (int i = 0; i < nq; i++) { + bitset->set(i); + } + + auto I_before = result->Get(milvus::knowhere::meta::IDS); + + // search xq with delete + index_->SetBlacklist(bitset); + + // Serialize and Load before Query + bs = index_->Serialize(search_conf); + + dim = base_dataset->Get(milvus::knowhere::meta::DIM); + rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + auto result_after = index_->Query(query_dataset, search_conf); + AssertAnns(result_after, nq, k, CheckMode::CHECK_NOT_EQUAL); + auto I_after = result_after->Get(milvus::knowhere::meta::IDS); + + // First vector deleted + for (int i = 0; i < nq; i++) { + ASSERT_NE(I_before[i * k], I_after[i * k]); + } +} diff --git a/core/src/index/unittest/test_rhnsw_flat.cpp b/core/src/index/unittest/test_rhnsw_flat.cpp new file mode 100644 index 0000000000..c7644da746 --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_flat.cpp @@ -0,0 +1,158 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWFlatTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWFlatTest, Values("RHNSWFlat")); + +TEST_P(RHNSWFlatTest, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + auto result1 = index_->Query(query_dataset, conf); +// AssertAnns(result1, nq, k); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(conf); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); +// AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWFlatTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); +// AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); +// AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWFlatTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(conf); + std::string index_type = index_->index_type(); + std::string idx_name = index_type + "_Index"; + std::string dat_name = index_type + "_Data"; + if (binaryset.binary_map_.find(idx_name) == binaryset.binary_map_.end()) { + std::cout << "no idx!" << std::endl; + } + if (binaryset.binary_map_.find(dat_name) == binaryset.binary_map_.end()) { + std::cout << "no dat!" << std::endl; + } + auto bin_idx = binaryset.GetByName(idx_name); + auto bin_dat = binaryset.GetByName(dat_name); + + std::string filename_idx = "/tmp/RHNSWFlat_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWFlat_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(new_idx->index_type() + "_Data", dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); +// AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/index/unittest/test_rhnsw_pq.cpp b/core/src/index/unittest/test_rhnsw_pq.cpp new file mode 100644 index 0000000000..54af490042 --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_pq.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWPQTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::IndexParams::PQM, 8}}; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWPQTest, Values("RHNSWPQ")); + +TEST_P(RHNSWPQTest, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(conf); + auto result1 = index_->Query(query_dataset, conf); + // AssertAnns(result1, nq, k); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); + // AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWPQTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + // AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + // AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWPQTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(conf); + auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index"); + auto bin_dat = binaryset.GetByName(QUANTIZATION_DATA); + + std::string filename_idx = "/tmp/RHNSWPQ_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWPQ_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(QUANTIZATION_DATA, dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); + // AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/index/unittest/test_rhnsw_sq8.cpp b/core/src/index/unittest/test_rhnsw_sq8.cpp new file mode 100644 index 0000000000..7e523ad2c1 --- /dev/null +++ b/core/src/index/unittest/test_rhnsw_sq8.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class RHNSWSQ8Test : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + // Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, RHNSWSQ8Test, Values("RHNSWSQ8")); + +TEST_P(RHNSWSQ8Test, HNSW_basic) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(conf); + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + auto tmp_index = std::make_shared(); + + tmp_index->Load(bs); + + auto result2 = tmp_index->Query(query_dataset, conf); + AssertAnns(result2, nq, k); +} + +TEST_P(RHNSWSQ8Test, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(RHNSWSQ8Test, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(conf); + auto bin_idx = binaryset.GetByName(index_->index_type() + "_Index"); + auto bin_dat = binaryset.GetByName(QUANTIZATION_DATA); + + std::string filename_idx = "/tmp/RHNSWSQ_test_serialize_idx.bin"; + std::string filename_dat = "/tmp/RHNSWSQ_test_serialize_dat.bin"; + auto load_idx = new uint8_t[bin_idx->size]; + auto load_dat = new uint8_t[bin_dat->size]; + serialize(filename_idx, bin_idx, load_idx); + serialize(filename_dat, bin_dat, load_dat); + + binaryset.clear(); + auto new_idx = std::make_shared(); + std::shared_ptr dat(load_dat); + std::shared_ptr idx(load_idx); + binaryset.Append(new_idx->index_type() + "_Index", idx, bin_idx->size); + binaryset.Append(QUANTIZATION_DATA, dat, bin_dat->size); + + new_idx->Load(binaryset); + EXPECT_EQ(new_idx->Count(), nb); + EXPECT_EQ(new_idx->Dim(), dim); + auto result = new_idx->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} diff --git a/core/src/index/unittest/test_sptag.cpp b/core/src/index/unittest/test_sptag.cpp new file mode 100644 index 0000000000..65349e5e8e --- /dev/null +++ b/core/src/index/unittest/test_sptag.cpp @@ -0,0 +1,143 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/IndexSPTAG.h" +#include "knowhere/index/vector_index/adapter/SptagAdapter.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class SPTAGTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + Generate(128, 100, 5); + index_ = std::make_shared(IndexType); + if (IndexType == "KDT") { + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } else { + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, dim}, + {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + } + + Init_with_default(); + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(SPTAGParameters, SPTAGTest, Values("KDT", "BKT")); + +// TODO(lxj): add test about count() and dimension() +TEST_P(SPTAGTest, sptag_basic) { + assert(!xb.empty()); + + // null faiss index + { + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + } + + index_->BuildAll(base_dataset, conf); + // index_->Add(base_dataset, conf); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); + + { + auto ids = result->Get(milvus::knowhere::meta::IDS); + auto dist = result->Get(milvus::knowhere::meta::DISTANCE); + + std::stringstream ss_id; + std::stringstream ss_dist; + for (auto i = 0; i < nq; i++) { + for (auto j = 0; j < k; ++j) { + // ss_id << *ids->data()->GetValues(1, i * k + j) << " "; + // ss_dist << *dists->data()->GetValues(1, i * k + j) << " "; + ss_id << *((int64_t*)(ids) + i * k + j) << " "; + ss_dist << *((float*)(dist) + i * k + j) << " "; + } + ss_id << std::endl; + ss_dist << std::endl; + } + std::cout << "id\n" << ss_id.str() << std::endl; + std::cout << "dist\n" << ss_dist.str() << std::endl; + } +} + +TEST_P(SPTAGTest, sptag_serialize) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + // index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto new_index = std::make_shared(IndexType); + new_index->Load(binaryset); + auto result = new_index->Query(query_dataset, conf); + AssertAnns(result, nq, k); + PrintResult(result, nq, k); + ASSERT_EQ(new_index->Count(), nb); + ASSERT_EQ(new_index->Dim(), dim); + // ASSERT_THROW({ new_index->Clone(); }, milvus::knowhere::KnowhereException); + // ASSERT_NO_THROW({ new_index->Seal(); }); + + { + int fileno = 0; + const std::string& base_name = "/tmp/sptag_serialize_test_bin_"; + std::vector filename_list; + std::vector> meta_list; + for (auto& iter : binaryset.binary_map_) { + const std::string& filename = base_name + std::to_string(fileno); + FileIOWriter writer(filename); + writer(iter.second->data.get(), iter.second->size); + + meta_list.emplace_back(std::make_pair(iter.first, iter.second->size)); + filename_list.push_back(filename); + ++fileno; + } + + milvus::knowhere::BinarySet load_data_list; + for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) { + auto bin_size = meta_list[i].second; + FileIOReader reader(filename_list[i]); + + auto load_data = new uint8_t[bin_size]; + reader(load_data, bin_size); + std::shared_ptr data(load_data); + load_data_list.Append(meta_list[i].first, data, bin_size); + } + + auto new_index = std::make_shared(IndexType); + new_index->Load(load_data_list); + auto result = new_index->Query(query_dataset, conf); + AssertAnns(result, nq, k); + PrintResult(result, nq, k); + } +} diff --git a/core/src/index/unittest/test_structured_index_sort.cpp b/core/src/index/unittest/test_structured_index_sort.cpp new file mode 100644 index 0000000000..896ab8888e --- /dev/null +++ b/core/src/index/unittest/test_structured_index_sort.cpp @@ -0,0 +1,270 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "knowhere/index/structured_index/StructuredIndexSort.h" + +#include "unittest/utils.h" + +void +gen_rand_data(int range, int n, int*& p) { + srand((unsigned int)time(nullptr)); + p = (int*)malloc(n * sizeof(int)); + int* q = p; + for (auto i = 0; i < n; ++i) { + *q++ = (int)random() % range; + } +} + +void +gen_rand_int64_data(int64_t range, int64_t n, int64_t*& p) { + srand((int64_t)time(nullptr)); + p = (int64_t*)malloc(n * sizeof(int64_t)); + int64_t* q = p; + for (auto i = 0; i < n; ++i) { + *q++ = (int64_t)random() % range; + } +} + +void +gen_rand_double_data(double range, int64_t n, double*& p) { + std::uniform_real_distribution unif(0, range); + std::default_random_engine re; + p = (double*)malloc(n * sizeof(double)); + double* q = p; + for (auto i = 0; i < n; ++i) { + *q++ = unif(re); + } +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_build) { + int range = 100, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + std::sort(p, p + n); + const std::vector> index_data = structuredIndexSort.GetData(); + for (auto i = 0; i < n; ++i) { + ASSERT_EQ(*(p + i), index_data[i].a_); + } + free(p); +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_serialize_and_load) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + // write and flush + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + int range = 100, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + auto binaryset = structuredIndexSort.Serialize(); + + auto bin_data = binaryset.GetByName("index_data"); + std::string data_file = "/tmp/sort_test_data_serialize.bin"; + auto load_data = new uint8_t[bin_data->size]; + serialize(data_file, bin_data, load_data); + + auto bin_length = binaryset.GetByName("index_length"); + std::string length_file = "/tmp/sort_test_length_serialize.bin"; + auto load_length = new uint8_t[bin_length->size]; + serialize(length_file, bin_length, load_length); + + binaryset.clear(); + std::shared_ptr index_data(load_data); + binaryset.Append("index_data", index_data, bin_data->size); + + std::shared_ptr length_data(load_length); + binaryset.Append("index_length", length_data, bin_length->size); + + structuredIndexSort.Load(binaryset); + EXPECT_EQ(n, (int)structuredIndexSort.Size()); + EXPECT_EQ(true, structuredIndexSort.IsBuilt()); + std::sort(p, p + n); + const std::vector> const_index_data = structuredIndexSort.GetData(); + for (auto i = 0; i < n; ++i) { + ASSERT_EQ(*(p + i), const_index_data[i].a_); + } + + free(p); +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_in) { + int range = 1000, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + + int test_times = 10; + std::vector test_vals, test_off; + test_vals.reserve(test_times); + test_off.reserve(test_times); + // std::cout << "STRUCTUREDINDEXSORT_TEST test_in" << std::endl; + for (auto i = 0; i < test_times; ++i) { + auto off = random() % n; + test_vals.emplace_back(*(p + off)); + test_off.emplace_back(off); + // std::cout << "val: " << *(p + off) << ", off: " << off << std::endl; + } + auto res = structuredIndexSort.In(test_times, test_vals.data()); + for (auto i = 0; i < test_times; ++i) { + // std::cout << test_off[i] << " "; + ASSERT_EQ(true, res->test(test_off[i])); + } + + free(p); +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_not_in) { + int range = 10000, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + + int test_times = 10; + std::vector test_vals, test_off; + test_vals.reserve(test_times); + test_off.reserve(test_times); + // std::cout << "STRUCTUREDINDEXSORT_TEST test_notin" << std::endl; + for (auto i = 0; i < test_times; ++i) { + auto off = random() % n; + test_vals.emplace_back(*(p + off)); + test_off.emplace_back(off); + // std::cout << off << " "; + } + // std::cout << std::endl; + auto res = structuredIndexSort.NotIn(test_times, test_vals.data()); + // std::cout << "assert values: " << std::endl; + for (auto i = 0; i < test_times; ++i) { + // std::cout << test_off[i] << " "; + ASSERT_EQ(false, res->test(test_off[i])); + } + // std::cout << std::endl; + + free(p); +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_single_border_range) { + int range = 100, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + + srand((unsigned int)time(nullptr)); + int val; + // test LT + val = (int)random() % 100; + auto lt_res = structuredIndexSort.Range(val, milvus::knowhere::OperatorType::LT); + for (auto i = 0; i < n; ++i) { + if (*(p + i) < val) + ASSERT_EQ(true, lt_res->test(i)); + else + ASSERT_EQ(false, lt_res->test(i)); + } + // test LE + val = (int)random() % 100; + auto le_res = structuredIndexSort.Range(val, milvus::knowhere::OperatorType::LE); + for (auto i = 0; i < n; ++i) { + if (*(p + i) <= val) + ASSERT_EQ(true, le_res->test(i)); + else + ASSERT_EQ(false, le_res->test(i)); + } + // test GE + val = (int)random() % 100; + auto ge_res = structuredIndexSort.Range(val, milvus::knowhere::OperatorType::GE); + for (auto i = 0; i < n; ++i) { + if (*(p + i) >= val) + ASSERT_EQ(true, ge_res->test(i)); + else + ASSERT_EQ(false, ge_res->test(i)); + } + // test GT + val = (int)random() % 100; + auto gt_res = structuredIndexSort.Range(val, milvus::knowhere::OperatorType::GT); + for (auto i = 0; i < n; ++i) { + if (*(p + i) > val) + ASSERT_EQ(true, gt_res->test(i)); + else + ASSERT_EQ(false, gt_res->test(i)); + } + + free(p); +} + +TEST(STRUCTUREDINDEXSORT_TEST, test_double_border_range) { + int range = 100, n = 1000, *p = nullptr; + gen_rand_data(range, n, p); + milvus::knowhere::StructuredIndexSort structuredIndexSort((size_t)n, p); // Build default + + srand((unsigned int)time(nullptr)); + int lb, ub; + // [] + lb = (int)random() % 100; + ub = (int)random() % 100; + if (lb > ub) + std::swap(lb, ub); + auto res1 = structuredIndexSort.Range(lb, true, ub, true); + for (auto i = 0; i < n; ++i) { + if (*(p + i) >= lb && *(p + i) <= ub) + ASSERT_EQ(true, res1->test(i)); + else + ASSERT_EQ(false, res1->test(i)); + } + // [) + lb = (int)random() % 100; + ub = (int)random() % 100; + if (lb > ub) + std::swap(lb, ub); + auto res2 = structuredIndexSort.Range(lb, true, ub, false); + for (auto i = 0; i < n; ++i) { + if (*(p + i) >= lb && *(p + i) < ub) + ASSERT_EQ(true, res2->test(i)); + else + ASSERT_EQ(false, res2->test(i)); + } + // (] + lb = (int)random() % 100; + ub = (int)random() % 100; + if (lb > ub) + std::swap(lb, ub); + auto res3 = structuredIndexSort.Range(lb, false, ub, true); + for (auto i = 0; i < n; ++i) { + if (*(p + i) > lb && *(p + i) <= ub) + ASSERT_EQ(true, res3->test(i)); + else + ASSERT_EQ(false, res3->test(i)); + } + // () + lb = (int)random() % 100; + ub = (int)random() % 100; + if (lb > ub) + std::swap(lb, ub); + auto res4 = structuredIndexSort.Range(lb, false, ub, false); + for (auto i = 0; i < n; ++i) { + if (*(p + i) > lb && *(p + i) < ub) + ASSERT_EQ(true, res4->test(i)); + else + ASSERT_EQ(false, res4->test(i)); + } + free(p); +} diff --git a/core/src/index/unittest/test_vecindex.cpp b/core/src/index/unittest/test_vecindex.cpp new file mode 100644 index 0000000000..713e9d7988 --- /dev/null +++ b/core/src/index/unittest/test_vecindex.cpp @@ -0,0 +1,119 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/VecIndexFactory.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class VecIndexTest : public DataGen, public Tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::tie(index_type_, index_mode_, parameter_type_) = GetParam(); + Generate(DIM, NB, NQ); + index_ = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type_, index_mode_); + conf = ParamGenerator::GetInstance().Gen(parameter_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + ParameterType parameter_type_; + milvus::knowhere::Config conf; + milvus::knowhere::VecIndexPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P( + IVFParameters, IVFTest, + Values( +#ifdef MILVUS_GPU_VERSION + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFFLAT, milvus::knowhere::IndexMode::MODE_GPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_GPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_GPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFSQ8H, milvus::knowhere::IndexMode::MODE_GPU), +#endif + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFFLAT, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_NSG, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_HNSW, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_SPTAG_KDT_RNT, milvus::knowhere::IndexMode::MODE_CPU), + std::make_tuple(milvus::knowhere::IndexType::INDEX_SPTAG_BKT_RNT, milvus::knowhere::IndexMode::MODE_CPU))); + +TEST_P(VecIndexTest, basic) { + assert(!xb.empty()); + KNOWHERE_LOG_DEBUG << "conf: " << conf->dump(); + + index_->BuildAll(base_dataset, conf); + EXPECT_EQ(index_->Dim(), dim); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->index_type(), index_type_); + EXPECT_EQ(index_->index_mode(), index_mode_); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + PrintResult(result, nq, k); +} + +TEST_P(VecIndexTest, serialize) { + index_->BuildAll(base_dataset, conf); + EXPECT_EQ(index_->Dim(), dim); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->index_type(), index_type_); + EXPECT_EQ(index_->index_mode(), index_mode_); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + + auto binaryset = index_->Serialize(); + auto new_index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type_, index_mode_); + new_index->Load(binaryset); + EXPECT_EQ(index_->Dim(), new_index->Dim()); + EXPECT_EQ(index_->Count(), new_index->Count()); + EXPECT_EQ(index_->index_type(), new_index->index_type()); + EXPECT_EQ(index_->index_mode(), new_index->index_mode()); + auto new_result = new_index_->Query(query_dataset, conf); + AssertAnns(new_result, nq, conf[milvus::knowhere::meta::TOPK]); +} + +// todo +#ifdef MILVUS_GPU_VERSION +TEST_P(VecIndexTest, copytogpu) { + // todo +} + +TEST_P(VecIndexTest, copytocpu) { + // todo +} +#endif diff --git a/core/src/index/unittest/test_wrapper.cpp b/core/src/index/unittest/test_wrapper.cpp new file mode 100644 index 0000000000..022fbaea31 --- /dev/null +++ b/core/src/index/unittest/test_wrapper.cpp @@ -0,0 +1,447 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "easyloggingpp/easylogging++.h" + +#ifdef MILVUS_GPU_VERSION + +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "wrapper/WrapperException.h" + +#endif + +#include +#include +#include + +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "wrapper/VecIndex.h" +#include "wrapper/utils.h" + +INITIALIZE_EASYLOGGINGPP + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class KnowhereWrapperTest + : public DataGenBase, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::string generator_type; + std::tie(index_type, generator_type, dim, nb, nq, k) = GetParam(); + GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); + + knowhere::Config tempconf{{knowhere::Metric::TYPE, knowhere::Metric::L2}, + {knowhere::meta::ROWS, nb}, + {knowhere::meta::DIM, dim}, + {knowhere::meta::TOPK, k}, + {knowhere::meta::DEVICEID, DEVICEID}}; + + index_ = GetVecIndexFactory(index_type); + conf = ParamGenerator::GetInstance().GenBuild(index_type, tempconf); + searchconf = ParamGenerator::GetInstance().GenSearchConf(index_type, tempconf); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::engine::IndexType index_type; + milvus::engine::VecIndexPtr index_ = nullptr; + knowhere::Config conf; + knowhere::Config searchconf; +}; + +INSTANTIATE_TEST_CASE_P( + WrapperParam, KnowhereWrapperTest, + Values( +//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] +#ifdef MILVUS_GPU_VERSION + std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_GPU, "Default", DIM, NB, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_MIX, "Default", 64, 1000, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_GPU, "Default", DIM, NB, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_MIX, "Default", DIM, NB, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFPQ_MIX, "Default", 64, 1000, 10, 10), +// std::make_tuple(milvus::engine::IndexType::NSG_MIX, "Default", 128, 250000, 10, 10), +#endif + // std::make_tuple(milvus::engine::IndexType::SPTAG_KDT_RNT_CPU, "Default", 128, 100, 10, 10), + // std::make_tuple(milvus::engine::IndexType::SPTAG_BKT_RNT_CPU, "Default", 126, 100, 10, 10), + std::make_tuple(milvus::engine::IndexType::HNSW, "Default", 64, 10000, 5, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IDMAP, "Default", 64, 1000, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFFLAT_CPU, "Default", 64, 1000, 10, 10), + std::make_tuple(milvus::engine::IndexType::FAISS_IVFSQ8_CPU, "Default", DIM, NB, 10, 10))); + +#ifdef MILVUS_GPU_VERSION +TEST_P(KnowhereWrapperTest, WRAPPER_EXCEPTION_TEST) { + std::string err_msg = "failed"; + milvus::engine::WrapperException ex(err_msg); + + std::string msg = ex.what(); + EXPECT_EQ(msg, err_msg); +} + +#endif + +TEST_P(KnowhereWrapperTest, BASE_TEST) { + EXPECT_EQ(index_->GetType(), index_type); + + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); + + index_->BuildAll(nb, xb.data(), ids.data(), conf); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + AssertResult(res_ids, res_dis); + + { + index_->GetDeviceId(); + + fiu_init(0); + fiu_enable("VecIndexImpl.BuildAll.throw_knowhere_exception", 1, nullptr, 0); + fiu_enable("BFIndex.BuildAll.throw_knowhere_exception", 1, nullptr, 0); + fiu_enable("IVFMixIndex.BuildAll.throw_knowhere_exception", 1, nullptr, 0); + index_->BuildAll(nb, xb.data(), ids.data(), conf); + fiu_disable("IVFMixIndex.BuildAll.throw_knowhere_exception"); + fiu_disable("BFIndex.BuildAll.throw_knowhere_exception"); + fiu_disable("VecIndexImpl.BuildAll.throw_knowhere_exception"); + + fiu_enable("VecIndexImpl.BuildAll.throw_std_exception", 1, nullptr, 0); + fiu_enable("BFIndex.BuildAll.throw_std_exception", 1, nullptr, 0); + fiu_enable("IVFMixIndex.BuildAll.throw_std_exception", 1, nullptr, 0); + index_->BuildAll(nb, xb.data(), ids.data(), conf); + fiu_disable("IVFMixIndex.BuildAll.throw_std_exception"); + fiu_disable("BFIndex.BuildAll.throw_std_exception"); + fiu_disable("VecIndexImpl.BuildAll.throw_std_exception"); + + fiu_enable("VecIndexImpl.Add.throw_knowhere_exception", 1, nullptr, 0); + index_->Add(nb, xb.data(), ids.data()); + fiu_disable("VecIndexImpl.Add.throw_knowhere_exception"); + + fiu_enable("VecIndexImpl.Add.throw_std_exception", 1, nullptr, 0); + index_->Add(nb, xb.data(), ids.data()); + fiu_disable("VecIndexImpl.Add.throw_std_exception"); + + fiu_enable("VecIndexImpl.Search.throw_knowhere_exception", 1, nullptr, 0); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + fiu_disable("VecIndexImpl.Search.throw_knowhere_exception"); + + fiu_enable("VecIndexImpl.Search.throw_std_exception", 1, nullptr, 0); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + fiu_disable("VecIndexImpl.Search.throw_std_exception"); + } +} + +#ifdef MILVUS_GPU_VERSION +TEST_P(KnowhereWrapperTest, TO_GPU_TEST) { + if (index_type == milvus::engine::IndexType::HNSW) { + return; + } + EXPECT_EQ(index_->GetType(), index_type); + + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); + + index_->BuildAll(nb, xb.data(), ids.data(), conf); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + AssertResult(res_ids, res_dis); + + { + auto dev_idx = index_->CopyToGpu(DEVICEID); + for (int i = 0; i < 10; ++i) { + dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + } + AssertResult(res_ids, res_dis); + } + + { + std::string file_location = "/tmp/knowhere_gpu_file"; + write_index(index_, file_location); + auto new_index = milvus::engine::read_index(file_location); + + auto dev_idx = new_index->CopyToGpu(DEVICEID); + for (int i = 0; i < 10; ++i) { + dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + } + AssertResult(res_ids, res_dis); + } +} + +#endif + +TEST_P(KnowhereWrapperTest, SERIALIZE_TEST) { + std::cout << "type: " << static_cast(index_type) << std::endl; + EXPECT_EQ(index_->GetType(), index_type); + + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); + index_->BuildAll(nb, xb.data(), ids.data(), conf); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + AssertResult(res_ids, res_dis); + + { + auto binary = index_->Serialize(); + auto type = index_->GetType(); + auto new_index = GetVecIndexFactory(type); + new_index->Load(binary); + EXPECT_EQ(new_index->Dimension(), index_->Dimension()); + EXPECT_EQ(new_index->Count(), index_->Count()); + + std::vector res_ids(elems); + std::vector res_dis(elems); + new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + AssertResult(res_ids, res_dis); + } + + { + std::string file_location = "/tmp/knowhere"; + write_index(index_, file_location); + auto new_index = milvus::engine::read_index(file_location); + EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type)); + EXPECT_EQ(new_index->Dimension(), index_->Dimension()); + EXPECT_EQ(new_index->Count(), index_->Count()); + + std::vector res_ids(elems); + std::vector res_dis(elems); + new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), searchconf); + AssertResult(res_ids, res_dis); + } + + { + std::string file_location = "/tmp/knowhere_gpu_file"; + fiu_init(0); + fiu_enable("VecIndex.write_index.throw_knowhere_exception", 1, nullptr, 0); + auto s = write_index(index_, file_location); + ASSERT_FALSE(s.ok()); + fiu_disable("VecIndex.write_index.throw_knowhere_exception"); + + fiu_enable("VecIndex.write_index.throw_std_exception", 1, nullptr, 0); + s = write_index(index_, file_location); + ASSERT_FALSE(s.ok()); + fiu_disable("VecIndex.write_index.throw_std_exception"); + + fiu_enable("VecIndex.write_index.throw_no_space_exception", 1, nullptr, 0); + s = write_index(index_, file_location); + ASSERT_FALSE(s.ok()); + fiu_disable("VecIndex.write_index.throw_no_space_exception"); + } +} + +// #include "wrapper/ConfAdapter.h" + +// TEST(whatever, test_config) { +// milvus::engine::TempMetaConf conf; +// conf.nprobe = 16; +// conf.dim = 128; +// auto nsg_conf = std::make_shared(); +// nsg_conf->Match(conf); +// nsg_conf->MatchSearch(conf, milvus::engine::IndexType::NSG_MIX); + +// auto pq_conf = std::make_shared(); +// pq_conf->Match(conf); +// pq_conf->MatchSearch(conf, milvus::engine::IndexType::FAISS_IVFPQ_MIX); + +// auto kdt_conf = std::make_shared(); +// kdt_conf->Match(conf); +// kdt_conf->MatchSearch(conf, milvus::engine::IndexType::SPTAG_KDT_RNT_CPU); + +// auto bkt_conf = std::make_shared(); +// bkt_conf->Match(conf); +// bkt_conf->MatchSearch(conf, milvus::engine::IndexType::SPTAG_BKT_RNT_CPU); + +// auto config_mgr = milvus::engine::AdapterMgr::GetInstance(); +// try { +// config_mgr.GetAdapter(milvus::engine::IndexType::INVALID); +// } catch (std::exception& e) { +// std::cout << "catch an expected exception" << std::endl; +// } + +// conf.size = 1000000.0; +// conf.nlist = 10; +// auto ivf_conf = std::make_shared(); +// ivf_conf->Match(conf); +// conf.nprobe = -1; +// ivf_conf->MatchSearch(conf, milvus::engine::IndexType::FAISS_IVFFLAT_GPU); +// conf.nprobe = 4096; +// ivf_conf->MatchSearch(conf, milvus::engine::IndexType::FAISS_IVFPQ_GPU); + +// auto ivf_pq_conf = std::make_shared(); +// conf.metric_type = knowhere::METRICTYPE::IP; +// try { +// ivf_pq_conf->Match(conf); +// } catch (std::exception& e) { +// std::cout << "catch an expected exception" << std::endl; +// } + +// conf.metric_type = knowhere::METRICTYPE::L2; +// fiu_init(0); +// fiu_enable("IVFPQConfAdapter.Match.empty_resset", 1, nullptr, 0); +// try { +// ivf_pq_conf->Match(conf); +// } catch (std::exception& e) { +// std::cout << "catch an expected exception" << std::endl; +// } +// fiu_disable("IVFPQConfAdapter.Match.empty_resset"); + +// conf.nprobe = -1; +// try { +// ivf_pq_conf->MatchSearch(conf, milvus::engine::IndexType::FAISS_IVFPQ_GPU); +// } catch (std::exception& e) { +// std::cout << "catch an expected exception" << std::endl; +// } +// } + +#include "wrapper/VecImpl.h" + +TEST(BFIndex, test_bf_index_fail) { + auto bf_ptr = std::make_shared(nullptr); + auto float_vec = bf_ptr->GetRawVectors(); + ASSERT_EQ(float_vec, nullptr); + milvus::engine::Config config; + + fiu_init(0); + fiu_enable("BFIndex.Build.throw_knowhere_exception", 1, nullptr, 0); + auto err_code = bf_ptr->Build(config); + ASSERT_EQ(err_code, milvus::KNOWHERE_UNEXPECTED_ERROR); + fiu_disable("BFIndex.Build.throw_knowhere_exception"); + + fiu_enable("BFIndex.Build.throw_std_exception", 1, nullptr, 0); + err_code = bf_ptr->Build(config); + ASSERT_EQ(err_code, milvus::KNOWHERE_ERROR); + fiu_disable("BFIndex.Build.throw_std_exception"); +} + +// #include "knowhere/index/vector_index/IndexIDMAP.h" +// #include "src/wrapper/VecImpl.h" +// #include "src/index/unittest/utils.h" +// The two case below prove NSG is concern with data distribution +// Further work: 1. Use right basedata and pass it by milvus +// a. batch size is 100000 [Pass] +// b. transfer all at once [Pass] +// 2. Use SIFT1M in test and check time cost [] +// TEST_P(KnowhereWrapperTest, nsgwithidmap) { +// auto idmap = GetVecIndexFactory(milvus::engine::IndexType::FAISS_IDMAP); +// auto ori_xb = xb; +// auto ori_ids = ids; +// std::vector temp_xb; +// std::vector temp_ids; +// nb = 50000; +// for (int i = 0; i < 20; ++i) { +// GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); +// assert(xb.size() == nb*dim); +// //#define IDMAP +// #ifdef IDMAP +// temp_xb.insert(temp_xb.end(), xb.data(), xb.data() + nb*dim); +// temp_ids.insert(temp_ids.end(), ori_ids.data()+nb*i, ori_ids.data() + nb*(i+1)); +// if (i == 0) { +// idmap->BuildAll(nb, temp_xb.data(), temp_ids.data(), conf); +// } else { +// idmap->Add(nb, temp_xb.data(), temp_ids.data()); +// } +// temp_xb.clear(); +// temp_ids.clear(); +// #else +// temp_xb.insert(temp_xb.end(), xb.data(), xb.data() + nb*dim); +// temp_ids.insert(temp_ids.end(), ori_ids.data()+nb*i, ori_ids.data() + nb*(i+1)); +// #endif +// } + +// #ifdef IDMAP +// auto idmap_idx = std::dynamic_pointer_cast(idmap); +// auto x = idmap_idx->Count(); +// index_->BuildAll(idmap_idx->Count(), idmap_idx->GetRawVectors(), idmap_idx->GetRawIds(), conf); +// #else +// assert(temp_xb.size() == 1000000*128); +// index_->BuildAll(1000000, temp_xb.data(), ori_ids.data(), conf); +// #endif +// } + +// TEST_P(KnowhereWrapperTest, nsgwithsidmap) { +// auto idmap = GetVecIndexFactory(milvus::engine::IndexType::FAISS_IDMAP); +// auto ori_xb = xb; +// std::vector temp_xb; +// std::vector temp_ids; +// nb = 50000; +// for (int i = 0; i < 20; ++i) { +// #define IDMAP +// #ifdef IDMAP +// temp_xb.insert(temp_xb.end(), ori_xb.data()+nb*dim*i, ori_xb.data() + nb*dim*(i+1)); +// temp_ids.insert(temp_ids.end(), ids.data()+nb*i, ids.data() + nb*(i+1)); +// if (i == 0) { +// idmap->BuildAll(nb, temp_xb.data(), temp_ids.data(), conf); +// } else { +// idmap->Add(nb, temp_xb.data(), temp_ids.data()); +// } +// temp_xb.clear(); +// temp_ids.clear(); +// #else +// temp_xb.insert(temp_xb.end(), ori_xb.data()+nb*dim*i, ori_xb.data() + nb*dim*(i+1)); +// temp_ids.insert(temp_ids.end(), ids.data()+nb*i, ids.data() + nb*(i+1)); +// #endif +// } + +// #ifdef IDMAP +// auto idmap_idx = std::dynamic_pointer_cast(idmap); +// auto x = idmap_idx->Count(); +// index_->BuildAll(idmap_idx->Count(), idmap_idx->GetRawVectors(), idmap_idx->GetRawIds(), conf); +// #else +// index_->BuildAll(1000000, temp_xb.data(), temp_ids.data(), conf); +// #endif + +// // The code use to store raw base data +// FileIOWriter writer("/tmp/newraw"); +// ori_xb.shrink_to_fit(); +// std::cout << "size" << ori_xb.size(); +// writer(static_cast(ori_xb.data()), ori_xb.size()* sizeof(float)); +// std::cout << "Finish!" << std::endl; +// } + +// void load_data(char* filename, float*& data, unsigned& num, +// unsigned& dim) { // load data with sift10K pattern +// std::ifstream in(filename, std::ios::binary); +// if (!in.is_open()) { +// std::cout << "open file error" << std::endl; +// exit(-1); +// } +// in.read((char*)&dim, 4); +// in.seekg(0, std::ios::end); +// std::ios::pos_type ss = in.tellg(); +// size_t fsize = (size_t)ss; +// num = (unsigned)(fsize / (dim + 1) / 4); +// data = new float[(size_t)num * (size_t)dim]; + +// in.seekg(0, std::ios::beg); +// for (size_t i = 0; i < num; i++) { +// in.seekg(4, std::ios::cur); +// in.read((char*)(data + i * dim), dim * 4); +// } +// in.close(); +// } + +// TEST_P(KnowhereWrapperTest, Sift1M) { +// float* data = nullptr; +// unsigned points_num, dim; +// load_data("/mnt/112d53a6-5592-4360-a33b-7fd789456fce/workspace/Data/sift/sift_base.fvecs", data, points_num, +// dim); std::cout << points_num << " " << dim << std::endl; + +// index_->BuildAll(points_num, data, ids.data(), conf); +// } diff --git a/core/src/index/unittest/utils.cpp b/core/src/index/unittest/utils.cpp new file mode 100644 index 0000000000..df9c2a0821 --- /dev/null +++ b/core/src/index/unittest/utils.cpp @@ -0,0 +1,303 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "unittest/utils.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" + +#include +#include +#include +#include +#include + +INITIALIZE_EASYLOGGINGPP + +void +InitLog() { + el::Configurations defaultConf; + defaultConf.setToDefault(); + defaultConf.set(el::Level::Debug, el::ConfigurationType::Format, "[%thread-%datetime-%level]: %msg (%fbase:%line)"); + el::Loggers::reconfigureLogger("default", defaultConf); +} + +void +DataGen::Init_with_default(const bool is_binary) { + Generate(dim, nb, nq, is_binary); +} + +void +DataGen::Generate(const int dim, const int nb, const int nq, const bool is_binary) { + this->dim = dim; + this->nb = nb; + this->nq = nq; + + if (!is_binary) { + GenAll(dim, nb, xb, ids, xids, nq, xq); + assert(xb.size() == (size_t)dim * nb); + assert(xq.size() == (size_t)dim * nq); + + base_dataset = milvus::knowhere::GenDatasetWithIds(nb, dim, xb.data(), ids.data()); + query_dataset = milvus::knowhere::GenDataset(nq, dim, xq.data()); + } else { + int64_t dim_x = dim / 8; + GenAll(dim_x, nb, xb_bin, ids, xids, nq, xq_bin); + assert(xb_bin.size() == (size_t)dim_x * nb); + assert(xq_bin.size() == (size_t)dim_x * nq); + + base_dataset = milvus::knowhere::GenDatasetWithIds(nb, dim, xb_bin.data(), ids.data()); + query_dataset = milvus::knowhere::GenDataset(nq, dim, xq_bin.data()); + } + + id_dataset = milvus::knowhere::GenDatasetWithIds(nq, dim, nullptr, ids.data()); + xid_dataset = milvus::knowhere::GenDatasetWithIds(nq, dim, nullptr, xids.data()); +} + +void +GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, + std::vector& xids, const int64_t nq, std::vector& xq) { + xb.resize(nb * dim); + xq.resize(nq * dim); + ids.resize(nb); + xids.resize(1); + GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), xids.data(), false); +} + +void +GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, + std::vector& xids, const int64_t nq, std::vector& xq) { + xb.resize(nb * dim); + xq.resize(nq * dim); + ids.resize(nb); + xids.resize(1); + GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), xids.data(), true); +} + +void +GenBase(const int64_t dim, const int64_t nb, const void* xb, int64_t* ids, const int64_t nq, const void* xq, + int64_t* xids, bool is_binary) { + if (!is_binary) { + float* xb_f = (float*)xb; + float* xq_f = (float*)xq; + for (auto i = 0; i < nb; ++i) { + for (auto j = 0; j < dim; ++j) { + xb_f[i * dim + j] = drand48(); + } + xb_f[dim * i] += i / 1000.; + ids[i] = i; + } + for (int64_t i = 0; i < nq * dim; ++i) { + xq_f[i] = xb_f[i]; + } + } else { + uint8_t* xb_u = (uint8_t*)xb; + uint8_t* xq_u = (uint8_t*)xq; + for (auto i = 0; i < nb; ++i) { + for (auto j = 0; j < dim; ++j) { + xb_u[i * dim + j] = (uint8_t)lrand48(); + } + xb_u[dim * i] += i / 1000.; + ids[i] = i; + } + for (int64_t i = 0; i < nq * dim; ++i) { + xq_u[i] = xb_u[i]; + } + } + xids[0] = 3; // pseudo random +} + +FileIOReader::FileIOReader(const std::string& fname) { + name = fname; + fs = std::fstream(name, std::ios::in | std::ios::binary); +} + +FileIOReader::~FileIOReader() { + fs.close(); +} + +size_t +FileIOReader::operator()(void* ptr, size_t size) { + fs.read(reinterpret_cast(ptr), size); + return size; +} + +FileIOWriter::FileIOWriter(const std::string& fname) { + name = fname; + fs = std::fstream(name, std::ios::out | std::ios::binary); +} + +FileIOWriter::~FileIOWriter() { + fs.close(); +} + +size_t +FileIOWriter::operator()(void* ptr, size_t size) { + fs.write(reinterpret_cast(ptr), size); + return size; +} + +void +AssertAnns(const milvus::knowhere::DatasetPtr& result, const int nq, const int k, const CheckMode check_mode) { + auto ids = result->Get(milvus::knowhere::meta::IDS); + for (auto i = 0; i < nq; i++) { + switch (check_mode) { + case CheckMode::CHECK_EQUAL: + ASSERT_EQ(i, *((int64_t*)(ids) + i * k)); + break; + case CheckMode::CHECK_NOT_EQUAL: + ASSERT_NE(i, *((int64_t*)(ids) + i * k)); + break; + default: + ASSERT_TRUE(false); + break; + } + } +} + +#if 0 +void +AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) { + float* base = (float*)base_dataset->Get(milvus::knowhere::meta::TENSOR); + auto ids = id_dataset->Get(milvus::knowhere::meta::IDS); + auto x = result->Get(milvus::knowhere::meta::TENSOR); + for (auto i = 0; i < n; i++) { + auto id = ids[i]; + for (auto j = 0; j < dim; j++) { + switch (check_mode) { + case CheckMode::CHECK_EQUAL: { + ASSERT_EQ(*(base + id * dim + j), *(x + i * dim + j)); + break; + } + case CheckMode::CHECK_NOT_EQUAL: { + ASSERT_NE(*(base + id * dim + j), *(x + i * dim + j)); + break; + } + case CheckMode::CHECK_APPROXIMATE_EQUAL: { + float a = *(base + id * dim + j); + float b = *(x + i * dim + j); + ASSERT_TRUE((std::fabs(a - b) / std::fabs(a)) < 0.1); + break; + } + default: + ASSERT_TRUE(false); + break; + } + } + } +} + +void +AssertBinVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) { + auto base = (uint8_t*)base_dataset->Get(milvus::knowhere::meta::TENSOR); + auto ids = id_dataset->Get(milvus::knowhere::meta::IDS); + auto x = result->Get(milvus::knowhere::meta::TENSOR); + for (auto i = 0; i < 1; i++) { + auto id = ids[i]; + for (auto j = 0; j < dim; j++) { + ASSERT_EQ(*(base + id * dim + j), *(x + i * dim + j)); + } + } +} +#endif + +void +PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k) { + auto ids = result->Get(milvus::knowhere::meta::IDS); + auto dist = result->Get(milvus::knowhere::meta::DISTANCE); + + std::stringstream ss_id; + std::stringstream ss_dist; + for (auto i = 0; i < nq; i++) { + for (auto j = 0; j < k; ++j) { + // ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; + // ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; + ss_id << *((int64_t*)(ids) + i * k + j) << " "; + ss_dist << *((float*)(dist) + i * k + j) << " "; + } + ss_id << std::endl; + ss_dist << std::endl; + } + std::cout << "id\n" << ss_id.str() << std::endl; + std::cout << "dist\n" << ss_dist.str() << std::endl; +} + +// not used +#if 0 +void +Load_nns_graph(std::vector>& final_graph, const char* filename) { + std::vector> knng; + + std::ifstream in(filename, std::ios::binary); + unsigned k; + in.read((char*)&k, sizeof(unsigned)); + in.seekg(0, std::ios::end); + std::ios::pos_type ss = in.tellg(); + size_t fsize = (size_t)ss; + size_t num = (size_t)(fsize / (k + 1) / 4); + in.seekg(0, std::ios::beg); + + knng.resize(num); + knng.reserve(num); + int64_t kk = (k + 3) / 4 * 4; + for (size_t i = 0; i < num; i++) { + in.seekg(4, std::ios::cur); + knng[i].resize(k); + knng[i].reserve(kk); + in.read((char*)knng[i].data(), k * sizeof(unsigned)); + } + in.close(); + + final_graph.resize(knng.size()); + for (int i = 0; i < knng.size(); ++i) { + final_graph[i].resize(knng[i].size()); + for (int j = 0; j < knng[i].size(); ++j) { + final_graph[i][j] = knng[i][j]; + } + } +} + +float* +fvecs_read(const char* fname, size_t* d_out, size_t* n_out) { + FILE* f = fopen(fname, "r"); + if (!f) { + fprintf(stderr, "could not open %s\n", fname); + perror(""); + abort(); + } + int d; + fread(&d, 1, sizeof(int), f); + assert((d > 0 && d < 1000000) || !"unreasonable dimension"); + fseek(f, 0, SEEK_SET); + struct stat st; + fstat(fileno(f), &st); + size_t sz = st.st_size; + assert(sz % ((d + 1) * 4) == 0 || !"weird file size"); + size_t n = sz / ((d + 1) * 4); + + *d_out = d; + *n_out = n; + float* x = new float[n * (d + 1)]; + size_t nr = fread(x, sizeof(float), n * (d + 1), f); + assert(nr == n * (d + 1) || !"could not read whole file"); + + // shift array to remove row headers + for (size_t i = 0; i < n; i++) memmove(x + i * d, x + 1 + i * (d + 1), d * sizeof(*x)); + + fclose(f); + return x; +} + +int* // not very clean, but works as long as sizeof(int) == sizeof(float) +ivecs_read(const char* fname, size_t* d_out, size_t* n_out) { + return (int*)fvecs_read(fname, d_out, n_out); +} +#endif diff --git a/core/src/index/unittest/utils.h b/core/src/index/unittest/utils.h new file mode 100644 index 0000000000..2d10762550 --- /dev/null +++ b/core/src/index/unittest/utils.h @@ -0,0 +1,113 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "knowhere/common/Dataset.h" +#include "knowhere/common/Log.h" + +class DataGen { + protected: + void + Init_with_default(const bool is_binary = false); + + void + Generate(const int dim, const int nb, const int nq, const bool is_binary = false); + + protected: + int nb = 10000; + int nq = 10; + int dim = 64; + int k = 10; + std::vector xb; + std::vector xq; + std::vector xb_bin; + std::vector xq_bin; + std::vector ids; + std::vector xids; + milvus::knowhere::DatasetPtr base_dataset = nullptr; + milvus::knowhere::DatasetPtr query_dataset = nullptr; + milvus::knowhere::DatasetPtr id_dataset = nullptr; + milvus::knowhere::DatasetPtr xid_dataset = nullptr; +}; + +extern void +GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, + std::vector& xids, const int64_t nq, std::vector& xq); + +extern void +GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, + std::vector& xids, const int64_t nq, std::vector& xq); + +extern void +GenBase(const int64_t dim, const int64_t nb, const void* xb, int64_t* ids, const int64_t nq, const void* xq, + int64_t* xids, const bool is_binary); + +extern void +InitLog(); + +enum class CheckMode { + CHECK_EQUAL = 0, + CHECK_NOT_EQUAL = 1, + CHECK_APPROXIMATE_EQUAL = 2, +}; + +void +AssertAnns(const milvus::knowhere::DatasetPtr& result, const int nq, const int k, + const CheckMode check_mode = CheckMode::CHECK_EQUAL); + +void +AssertVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, + const CheckMode check_mode = CheckMode::CHECK_EQUAL); + +void +AssertBinVec(const milvus::knowhere::DatasetPtr& result, const milvus::knowhere::DatasetPtr& base_dataset, + const milvus::knowhere::DatasetPtr& id_dataset, const int n, const int dim, + const CheckMode check_mode = CheckMode::CHECK_EQUAL); + +void +PrintResult(const milvus::knowhere::DatasetPtr& result, const int& nq, const int& k); + +struct FileIOWriter { + std::fstream fs; + std::string name; + + explicit FileIOWriter(const std::string& fname); + ~FileIOWriter(); + size_t + operator()(void* ptr, size_t size); +}; + +struct FileIOReader { + std::fstream fs; + std::string name; + + explicit FileIOReader(const std::string& fname); + ~FileIOReader(); + size_t + operator()(void* ptr, size_t size); +}; + +void +Load_nns_graph(std::vector>& final_graph_, const char* filename); + +float* +fvecs_read(const char* fname, size_t* d_out, size_t* n_out); + +int* +ivecs_read(const char* fname, size_t* d_out, size_t* n_out); diff --git a/core/src/log/CMakeLists.txt b/core/src/log/CMakeLists.txt new file mode 100644 index 0000000000..83ef94aaef --- /dev/null +++ b/core/src/log/CMakeLists.txt @@ -0,0 +1,21 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- +set(LOG_FILES ${MILVUS_ENGINE_SRC}/log/Log.cpp + ${MILVUS_ENGINE_SRC}/log/Log.h + ${MILVUS_ENGINE_SRC}/log/LogMgr.cpp + ${MILVUS_ENGINE_SRC}/log/LogMgr.h + ${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.cc + ${MILVUS_THIRDPARTY_SRC}/easyloggingpp/easylogging++.h + ) + +add_library(log STATIC ${LOG_FILES}) diff --git a/core/src/log/Log.cpp b/core/src/log/Log.cpp new file mode 100644 index 0000000000..f6b47ba4fb --- /dev/null +++ b/core/src/log/Log.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "log/Log.h" +INITIALIZE_EASYLOGGINGPP + +#include +#include +#include +#include + +namespace milvus { + +std::string +LogOut(const char* pattern, ...) { + size_t len = strnlen(pattern, 1024) + 256; + auto str_p = std::make_unique(len); + memset(str_p.get(), 0, len); + + va_list vl; + va_start(vl, pattern); + vsnprintf(str_p.get(), len, pattern, vl); // NOLINT + va_end(vl); + + return std::string(str_p.get()); +} + +void +SetThreadName(const std::string& name) { + // Note: the name cannot exceed 16 bytes + pthread_setname_np(pthread_self(), name.c_str()); +} + +std::string +GetThreadName() { + std::string thread_name = "unamed"; + char name[16]; + size_t len = 16; + auto err = pthread_getname_np(pthread_self(), name, len); + if (not err) { + thread_name = name; + } + + return thread_name; +} + +} // namespace milvus diff --git a/core/src/log/Log.h b/core/src/log/Log.h new file mode 100644 index 0000000000..a1ab5d945b --- /dev/null +++ b/core/src/log/Log.h @@ -0,0 +1,135 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "easyloggingpp/easylogging++.h" + +namespace milvus { + +/* + * Please use LOG_MODULE_LEVEL_C macro in member function of class + * and LOG_MODULE_LEVEL_ macro in other functions. + */ + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define SERVER_MODULE_NAME "SERVER" +#define SERVER_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", SERVER_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define SERVER_MODULE_FUNCTION LogOut("[%s][%s][%s] ", SERVER_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_SERVER_TRACE_C LOG(TRACE) << SERVER_MODULE_CLASS_FUNCTION +#define LOG_SERVER_DEBUG_C LOG(DEBUG) << SERVER_MODULE_CLASS_FUNCTION +#define LOG_SERVER_INFO_C LOG(INFO) << SERVER_MODULE_CLASS_FUNCTION +#define LOG_SERVER_WARNING_C LOG(WARNING) << SERVER_MODULE_CLASS_FUNCTION +#define LOG_SERVER_ERROR_C LOG(ERROR) << SERVER_MODULE_CLASS_FUNCTION +#define LOG_SERVER_FATAL_C LOG(FATAL) << SERVER_MODULE_CLASS_FUNCTION + +#define LOG_SERVER_TRACE_ LOG(TRACE) << SERVER_MODULE_FUNCTION +#define LOG_SERVER_DEBUG_ LOG(DEBUG) << SERVER_MODULE_FUNCTION +#define LOG_SERVER_INFO_ LOG(INFO) << SERVER_MODULE_FUNCTION +#define LOG_SERVER_WARNING_ LOG(WARNING) << SERVER_MODULE_FUNCTION +#define LOG_SERVER_ERROR_ LOG(ERROR) << SERVER_MODULE_FUNCTION +#define LOG_SERVER_FATAL_ LOG(FATAL) << SERVER_MODULE_FUNCTION + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define ENGINE_MODULE_NAME "ENGINE" +#define ENGINE_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", ENGINE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define ENGINE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", ENGINE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_ENGINE_TRACE_C LOG(TRACE) << ENGINE_MODULE_CLASS_FUNCTION +#define LOG_ENGINE_DEBUG_C LOG(DEBUG) << ENGINE_MODULE_CLASS_FUNCTION +#define LOG_ENGINE_INFO_C LOG(INFO) << ENGINE_MODULE_CLASS_FUNCTION +#define LOG_ENGINE_WARNING_C LOG(WARNING) << ENGINE_MODULE_CLASS_FUNCTION +#define LOG_ENGINE_ERROR_C LOG(ERROR) << ENGINE_MODULE_CLASS_FUNCTION +#define LOG_ENGINE_FATAL_C LOG(FATAL) << ENGINE_MODULE_CLASS_FUNCTION + +#define LOG_ENGINE_TRACE_ LOG(TRACE) << ENGINE_MODULE_FUNCTION +#define LOG_ENGINE_DEBUG_ LOG(DEBUG) << ENGINE_MODULE_FUNCTION +#define LOG_ENGINE_INFO_ LOG(INFO) << ENGINE_MODULE_FUNCTION +#define LOG_ENGINE_WARNING_ LOG(WARNING) << ENGINE_MODULE_FUNCTION +#define LOG_ENGINE_ERROR_ LOG(ERROR) << ENGINE_MODULE_FUNCTION +#define LOG_ENGINE_FATAL_ LOG(FATAL) << ENGINE_MODULE_FUNCTION + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define WRAPPER_MODULE_NAME "WRAPPER" +#define WRAPPER_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", WRAPPER_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define WRAPPER_MODULE_FUNCTION LogOut("[%s][%s][%s] ", WRAPPER_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_WRAPPER_TRACE_C LOG(TRACE) << WRAPPER_MODULE_CLASS_FUNCTION +#define LOG_WRAPPER_DEBUG_C LOG(DEBUG) << WRAPPER_MODULE_CLASS_FUNCTION +#define LOG_WRAPPER_INFO_C LOG(INFO) << WRAPPER_MODULE_CLASS_FUNCTION +#define LOG_WRAPPER_WARNING_C LOG(WARNING) << WRAPPER_MODULE_CLASS_FUNCTION +#define LOG_WRAPPER_ERROR_C LOG(ERROR) << WRAPPER_MODULE_CLASS_FUNCTION +#define LOG_WRAPPER_FATAL_C LOG(FATAL) << WRAPPER_MODULE_CLASS_FUNCTION + +#define LOG_WRAPPER_TRACE_ LOG(TRACE) << WRAPPER_MODULE_FUNCTION +#define LOG_WRAPPER_DEBUG_ LOG(DEBUG) << WRAPPER_MODULE_FUNCTION +#define LOG_WRAPPER_INFO_ LOG(INFO) << WRAPPER_MODULE_FUNCTION +#define LOG_WRAPPER_WARNING_ LOG(WARNING) << WRAPPER_MODULE_FUNCTION +#define LOG_WRAPPER_ERROR_ LOG(ERROR) << WRAPPER_MODULE_FUNCTION +#define LOG_WRAPPER_FATAL_ LOG(FATAL) << WRAPPER_MODULE_FUNCTION + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define STORAGE_MODULE_NAME "STORAGE" +#define STORAGE_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", STORAGE_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define STORAGE_MODULE_FUNCTION LogOut("[%s][%s][%s] ", STORAGE_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_STORAGE_TRACE_C LOG(TRACE) << STORAGE_MODULE_CLASS_FUNCTION +#define LOG_STORAGE_DEBUG_C LOG(DEBUG) << STORAGE_MODULE_CLASS_FUNCTION +#define LOG_STORAGE_INFO_C LOG(INFO) << STORAGE_MODULE_CLASS_FUNCTION +#define LOG_STORAGE_WARNING_C LOG(WARNING) << STORAGE_MODULE_CLASS_FUNCTION +#define LOG_STORAGE_ERROR_C LOG(ERROR) << STORAGE_MODULE_CLASS_FUNCTION +#define LOG_STORAGE_FATAL_C LOG(FATAL) << STORAGE_MODULE_CLASS_FUNCTION + +#define LOG_STORAGE_TRACE_ LOG(TRACE) << STORAGE_MODULE_FUNCTION +#define LOG_STORAGE_DEBUG_ LOG(DEBUG) << STORAGE_MODULE_FUNCTION +#define LOG_STORAGE_INFO_ LOG(INFO) << STORAGE_MODULE_FUNCTION +#define LOG_STORAGE_WARNING_ LOG(WARNING) << STORAGE_MODULE_FUNCTION +#define LOG_STORAGE_ERROR_ LOG(ERROR) << STORAGE_MODULE_FUNCTION +#define LOG_STORAGE_FATAL_ LOG(FATAL) << STORAGE_MODULE_FUNCTION + +///////////////////////////////////////////////////////////////////////////////////////////////// +#define WAL_MODULE_NAME "WAL" +#define WAL_MODULE_CLASS_FUNCTION \ + LogOut("[%s][%s::%s][%s] ", WAL_MODULE_NAME, (typeid(*this).name()), __FUNCTION__, GetThreadName().c_str()) +#define WAL_MODULE_FUNCTION LogOut("[%s][%s][%s] ", WAL_MODULE_NAME, __FUNCTION__, GetThreadName().c_str()) + +#define LOG_WAL_TRACE_C LOG(TRACE) << WAL_MODULE_CLASS_FUNCTION +#define LOG_WAL_DEBUG_C LOG(DEBUG) << WAL_MODULE_CLASS_FUNCTION +#define LOG_WAL_INFO_C LOG(INFO) << WAL_MODULE_CLASS_FUNCTION +#define LOG_WAL_WARNING_C LOG(WARNING) << WAL_MODULE_CLASS_FUNCTION +#define LOG_WAL_ERROR_C LOG(ERROR) << WAL_MODULE_CLASS_FUNCTION +#define LOG_WAL_FATAL_C LOG(FATAL) << WAL_MODULE_CLASS_FUNCTION + +#define LOG_WAL_TRACE_ LOG(TRACE) << WAL_MODULE_FUNCTION +#define LOG_WAL_DEBUG_ LOG(DEBUG) << WAL_MODULE_FUNCTION +#define LOG_WAL_INFO_ LOG(INFO) << WAL_MODULE_FUNCTION +#define LOG_WAL_WARNING_ LOG(WARNING) << WAL_MODULE_FUNCTION +#define LOG_WAL_ERROR_ LOG(ERROR) << WAL_MODULE_FUNCTION +#define LOG_WAL_FATAL_ LOG(FATAL) << WAL_MODULE_FUNCTION + +///////////////////////////////////////////////////////////////////////////////////////////////////// +std::string +LogOut(const char* pattern, ...); + +void +SetThreadName(const std::string& name); + +std::string +GetThreadName(); + +} // namespace milvus diff --git a/core/src/log/LogMgr.cpp b/core/src/log/LogMgr.cpp new file mode 100644 index 0000000000..d3bb2809f0 --- /dev/null +++ b/core/src/log/LogMgr.cpp @@ -0,0 +1,245 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +#include + +#include "config/ServerConfig.h" +#include "log/LogMgr.h" +#include "utils/Status.h" + +namespace milvus { + +namespace { +static int global_idx = 0; +static int debug_idx = 0; +static int warning_idx = 0; +static int trace_idx = 0; +static int error_idx = 0; +static int fatal_idx = 0; +static int64_t logs_delete_exceeds = 1; +static bool enable_log_delete = false; + +/* module constant */ +const int64_t CONFIG_LOGS_MAX_LOG_FILE_SIZE_MIN = 536870912; /* 512 MB */ +const int64_t CONFIG_LOGS_MAX_LOG_FILE_SIZE_MAX = 4294967296; /* 4 GB */ +const int64_t CONFIG_LOGS_LOG_ROTATE_NUM_MIN = 0; +const int64_t CONFIG_LOGS_LOG_ROTATE_NUM_MAX = 1024; +} // namespace + +// TODO(yzb) : change the easylogging library to get the log level from parameter rather than filename +void +RolloutHandler(const char* filename, std::size_t size, el::Level level) { + char* dirc = strdup(filename); + char* basec = strdup(filename); + char* dir = dirname(dirc); + char* base = basename(basec); + + std::string s(base); + std::string list[] = {"\\", " ", "\'", "\"", "*", "\?", "{", "}", ";", "<", + ">", "|", "^", "&", "$", "#", "!", "`", "~"}; + std::string::size_type position; + for (auto substr : list) { + position = 0; + while ((position = s.find_first_of(substr, position)) != std::string::npos) { + s.insert(position, "\\"); + position += 2; + } + } + std::string m(std::string(dir) + "/" + s); + s = m; + try { + switch (level) { + case el::Level::Debug: { + s.append("." + std::to_string(++debug_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && debug_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(debug_idx - logs_delete_exceeds); + // std::cout << "remote " << to_delete << std::endl; + boost::filesystem::remove(to_delete); + } + break; + } + case el::Level::Warning: { + s.append("." + std::to_string(++warning_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && warning_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(warning_idx - logs_delete_exceeds); + boost::filesystem::remove(to_delete); + } + break; + } + case el::Level::Trace: { + s.append("." + std::to_string(++trace_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && trace_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(trace_idx - logs_delete_exceeds); + boost::filesystem::remove(to_delete); + } + break; + } + case el::Level::Error: { + s.append("." + std::to_string(++error_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && error_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(error_idx - logs_delete_exceeds); + boost::filesystem::remove(to_delete); + } + break; + } + case el::Level::Fatal: { + s.append("." + std::to_string(++fatal_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && fatal_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(fatal_idx - logs_delete_exceeds); + boost::filesystem::remove(to_delete); + } + break; + } + default: { + s.append("." + std::to_string(++global_idx)); + rename(m.c_str(), s.c_str()); + if (enable_log_delete && global_idx - logs_delete_exceeds > 0) { + std::string to_delete = m + "." + std::to_string(global_idx - logs_delete_exceeds); + boost::filesystem::remove(to_delete); + } + break; + } + } + } catch (const std::exception& exc) { + std::cerr << exc.what() << ". Exception throws from RolloutHandler." << std::endl; + } +} + +Status +LogMgr::InitLog(bool trace_enable, const std::string& level, const std::string& logs_path, int64_t max_log_file_size, + int64_t delete_exceeds) { + std::unordered_map level_to_int{ + {"debug", 5}, {"info", 4}, {"warning", 3}, {"error", 2}, {"fatal", 1}, + }; + + bool debug_enable = false; + bool info_enable = false; + bool warning_enable = false; + bool error_enable = false; + bool fatal_enable = false; + + switch (level_to_int[level]) { + case 5: + debug_enable = true; + case 4: + info_enable = true; + case 3: + warning_enable = true; + case 2: + error_enable = true; + case 1: + fatal_enable = true; + break; + default: + return Status(SERVER_UNEXPECTED_ERROR, "invalid log level"); + } + + el::Configurations defaultConf; + defaultConf.setToDefault(); + defaultConf.setGlobally(el::ConfigurationType::Format, "[%datetime][%level]%msg"); + defaultConf.setGlobally(el::ConfigurationType::ToFile, "true"); + defaultConf.setGlobally(el::ConfigurationType::ToStandardOutput, "false"); + defaultConf.setGlobally(el::ConfigurationType::SubsecondPrecision, "3"); + defaultConf.setGlobally(el::ConfigurationType::PerformanceTracking, "false"); + + std::string logs_reg_path = logs_path.rfind('/') == logs_path.length() - 1 ? logs_path : logs_path + "/"; + std::string global_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-global.log"; + defaultConf.set(el::Level::Global, el::ConfigurationType::Filename, global_log_path.c_str()); + defaultConf.set(el::Level::Global, el::ConfigurationType::Enabled, "true"); + + std::string info_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-info.log"; + defaultConf.set(el::Level::Info, el::ConfigurationType::Filename, info_log_path.c_str()); + if (info_enable) { + defaultConf.set(el::Level::Info, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Info, el::ConfigurationType::Enabled, "false"); + } + + std::string debug_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-debug.log"; + defaultConf.set(el::Level::Debug, el::ConfigurationType::Filename, debug_log_path.c_str()); + if (debug_enable) { + defaultConf.set(el::Level::Debug, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Debug, el::ConfigurationType::Enabled, "false"); + } + + std::string warning_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-warning.log"; + defaultConf.set(el::Level::Warning, el::ConfigurationType::Filename, warning_log_path.c_str()); + if (warning_enable) { + defaultConf.set(el::Level::Warning, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Warning, el::ConfigurationType::Enabled, "false"); + } + + std::string trace_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-trace.log"; + defaultConf.set(el::Level::Trace, el::ConfigurationType::Filename, trace_log_path.c_str()); + if (trace_enable) { + defaultConf.set(el::Level::Trace, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Trace, el::ConfigurationType::Enabled, "false"); + } + + std::string error_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-error.log"; + defaultConf.set(el::Level::Error, el::ConfigurationType::Filename, error_log_path.c_str()); + if (error_enable) { + defaultConf.set(el::Level::Error, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Error, el::ConfigurationType::Enabled, "false"); + } + + std::string fatal_log_path = logs_reg_path + "milvus-%datetime{%y-%M-%d-%H:%m}-fatal.log"; + defaultConf.set(el::Level::Fatal, el::ConfigurationType::Filename, fatal_log_path.c_str()); + if (fatal_enable) { + defaultConf.set(el::Level::Fatal, el::ConfigurationType::Enabled, "true"); + } else { + defaultConf.set(el::Level::Fatal, el::ConfigurationType::Enabled, "false"); + } + + if (max_log_file_size < CONFIG_LOGS_MAX_LOG_FILE_SIZE_MIN || + max_log_file_size > CONFIG_LOGS_MAX_LOG_FILE_SIZE_MAX) { + return Status(SERVER_UNEXPECTED_ERROR, "max_log_file_size must in range[" + + std::to_string(CONFIG_LOGS_MAX_LOG_FILE_SIZE_MIN) + ", " + + std::to_string(CONFIG_LOGS_MAX_LOG_FILE_SIZE_MAX) + "], now is " + + std::to_string(max_log_file_size)); + } + defaultConf.setGlobally(el::ConfigurationType::MaxLogFileSize, std::to_string(max_log_file_size)); + el::Loggers::addFlag(el::LoggingFlag::StrictLogFileSizeCheck); + el::Helpers::installPreRollOutCallback(RolloutHandler); + el::Loggers::addFlag(el::LoggingFlag::DisableApplicationAbortOnFatalLog); + + // set delete_exceeds = 0 means disable throw away log file even they reach certain limit. + if (delete_exceeds != 0) { + if (delete_exceeds < CONFIG_LOGS_LOG_ROTATE_NUM_MIN || delete_exceeds > CONFIG_LOGS_LOG_ROTATE_NUM_MAX) { + return Status(SERVER_UNEXPECTED_ERROR, "delete_exceeds must in range[" + + std::to_string(CONFIG_LOGS_LOG_ROTATE_NUM_MIN) + ", " + + std::to_string(CONFIG_LOGS_LOG_ROTATE_NUM_MAX) + "], now is " + + std::to_string(delete_exceeds)); + } + enable_log_delete = true; + logs_delete_exceeds = delete_exceeds; + } + + el::Loggers::reconfigureLogger("default", defaultConf); + + return Status::OK(); +} + +} // namespace milvus diff --git a/core/src/log/LogMgr.h b/core/src/log/LogMgr.h new file mode 100644 index 0000000000..acd283ce03 --- /dev/null +++ b/core/src/log/LogMgr.h @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "easyloggingpp/easylogging++.h" +#include "utils/Status.h" + +#include +#include + +namespace milvus { + +class LogMgr { + public: + static Status + InitLog(bool trace_enable, const std::string& level, const std::string& logs_path, int64_t max_log_file_size, + int64_t delete_exceeds); +}; + +} // namespace milvus diff --git a/core/src/main.cpp b/core/src/main.cpp new file mode 100644 index 0000000000..a66794cc87 --- /dev/null +++ b/core/src/main.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include + +#include "config/ConfigMgr.h" +#include "easyloggingpp/easylogging++.h" +#include "utils/SignalHandler.h" +#include "utils/Status.h" + +INITIALIZE_EASYLOGGINGPP + + +void +print_help(const std::string& app_name) { + std::cout << std::endl << "Usage: " << app_name << " [OPTIONS]" << std::endl; + std::cout << R"( + Options: + -h --help Print this help. + -c --conf_file filename Read configuration from the file. + -d --daemon Daemonize this application. + -p --pid_file filename PID file used by daemonized app. +)" << std::endl; +} + +void +print_banner() { + std::cout << std::endl; + std::cout << " __ _________ _ ____ ______ " << std::endl; + std::cout << " / |/ / _/ /| | / / / / / __/ " << std::endl; + std::cout << " / /|_/ // // /_| |/ / /_/ /\\ \\ " << std::endl; + std::cout << " /_/ /_/___/____/___/\\____/___/ " << std::endl; + std::cout << std::endl; +} + +int +main(int argc, char* argv[]) { + print_banner(); +} diff --git a/core/src/utils/BlockingQueue.h b/core/src/utils/BlockingQueue.h new file mode 100644 index 0000000000..1b489f53e6 --- /dev/null +++ b/core/src/utils/BlockingQueue.h @@ -0,0 +1,95 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace milvus { + +template +class BlockingQueue { + public: + BlockingQueue() : mtx(), full_(), empty_() { + } + + virtual ~BlockingQueue() { + } + + BlockingQueue(const BlockingQueue& rhs) = delete; + + BlockingQueue& + operator=(const BlockingQueue& rhs) = delete; + + void + Put(const T& task) { + std::unique_lock lock(mtx); + full_.wait(lock, [this] { return (queue_.size() < capacity_); }); + queue_.push(task); + empty_.notify_all(); + } + + T + Take() { + std::unique_lock lock(mtx); + empty_.wait(lock, [this] { return !queue_.empty(); }); + T front(queue_.front()); + queue_.pop(); + full_.notify_all(); + return front; + } + + T + Front() { + std::unique_lock lock(mtx); + empty_.wait(lock, [this] { return !queue_.empty(); }); + T front(queue_.front()); + return front; + } + + T + Back() { + std::unique_lock lock(mtx); + empty_.wait(lock, [this] { return !queue_.empty(); }); + T back(queue_.back()); + return back; + } + + size_t + Size() const { + std::lock_guard lock(mtx); + return queue_.size(); + } + + bool + Empty() const { + std::unique_lock lock(mtx); + return queue_.empty(); + } + + void + SetCapacity(const size_t capacity) { + capacity_ = (capacity > 0 ? capacity : capacity_); + } + + protected: + mutable std::mutex mtx; + std::condition_variable full_; + std::condition_variable empty_; + std::queue queue_; + size_t capacity_ = 32; +}; + +} // namespace milvus diff --git a/core/src/utils/CMakeLists.txt b/core/src/utils/CMakeLists.txt index c932d67b1f..b99f5a59bb 100644 --- a/core/src/utils/CMakeLists.txt +++ b/core/src/utils/CMakeLists.txt @@ -1,5 +1,23 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- -set(UTILS_FILES Status.cpp) -add_library(milvus_utils - ${UTILS_FILES} -) \ No newline at end of file +# aux_source_directory( ${MILVUS_ENGINE_SRC}/utils UTILS_FILES ) +set(UTILS_FILES +Status.cpp +) + +add_library( utils STATIC ${UTILS_FILES} ) + +target_link_libraries(utils + libboost_filesystem.a + libboost_system.a) diff --git a/core/src/utils/CommonUtil.cpp b/core/src/utils/CommonUtil.cpp new file mode 100644 index 0000000000..296a2d3078 --- /dev/null +++ b/core/src/utils/CommonUtil.cpp @@ -0,0 +1,216 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "utils/CommonUtil.h" +#include "utils/Log.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace milvus { + +namespace fs = boost::filesystem; + +bool +CommonUtil::IsDirectoryExist(const std::string& path) { + DIR* dp = nullptr; + if ((dp = opendir(path.c_str())) == nullptr) { + return false; + } + + closedir(dp); + return true; +} + +Status +CommonUtil::CreateDirectory(const std::string& path) { + if (path.empty()) { + return Status::OK(); + } + + struct stat directory_stat; + int status = stat(path.c_str(), &directory_stat); + if (status == 0) { + return Status::OK(); // already exist + } + + fs::path fs_path(path); + fs::path parent_path = fs_path.parent_path(); + Status err_status = CreateDirectory(parent_path.string()); + if (!err_status.ok()) { + return err_status; + } + + status = stat(path.c_str(), &directory_stat); + if (status == 0) { + return Status::OK(); // already exist + } + + int makeOK = mkdir(path.c_str(), S_IRWXU | S_IRGRP | S_IROTH); + if (makeOK != 0) { + return Status(SERVER_UNEXPECTED_ERROR, "failed to create directory: " + path); + } + + return Status::OK(); +} + +namespace { +void +RemoveDirectory(const std::string& path) { + DIR* dir = nullptr; + const int32_t buf_size = 256; + char file_name[buf_size]; + + std::string folder_name = path + "/%s"; + if ((dir = opendir(path.c_str())) != nullptr) { + struct dirent* dmsg; + while ((dmsg = readdir(dir)) != nullptr) { + if (strcmp(dmsg->d_name, ".") != 0 && strcmp(dmsg->d_name, "..") != 0) { + snprintf(file_name, buf_size, folder_name.c_str(), dmsg->d_name); + std::string tmp = file_name; + if (tmp.find(".") == std::string::npos) { + RemoveDirectory(file_name); + } + remove(file_name); + } + } + } + + if (dir != nullptr) { + closedir(dir); + } + remove(path.c_str()); +} +} // namespace + +Status +CommonUtil::DeleteDirectory(const std::string& path) { + if (path.empty()) { + return Status::OK(); + } + + struct stat directory_stat; + int statOK = stat(path.c_str(), &directory_stat); + if (statOK != 0) { + return Status::OK(); + } + + RemoveDirectory(path); + return Status::OK(); +} + +bool +CommonUtil::IsFileExist(const std::string& path) { + return (access(path.c_str(), F_OK) == 0); +} + +uint64_t +CommonUtil::GetFileSize(const std::string& path) { + struct stat file_info; + if (stat(path.c_str(), &file_info) < 0) { + return 0; + } + + return static_cast(file_info.st_size); +} + +std::string +CommonUtil::GetFileName(std::string filename) { + int pos = filename.find_last_of('/'); + return filename.substr(pos + 1); +} + +std::string +CommonUtil::GetExePath() { + const int64_t buf_len = 1024; + char buf[buf_len]; + int64_t cnt = readlink("/proc/self/exe", buf, buf_len); + if (cnt < 0 || cnt >= buf_len) { + return ""; + } + + buf[cnt] = '\0'; + + std::string exe_path = buf; + if (exe_path.rfind('/') != exe_path.length() - 1) { + std::string sub_str = exe_path.substr(0, exe_path.rfind('/')); + return sub_str + "/"; + } + return exe_path; +} + +bool +CommonUtil::TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, + const std::string& format) { + time_integer = 0; + memset(&time_struct, 0, sizeof(tm)); + + int ret = sscanf(time_str.c_str(), format.c_str(), &(time_struct.tm_year), &(time_struct.tm_mon), + &(time_struct.tm_mday), &(time_struct.tm_hour), &(time_struct.tm_min), &(time_struct.tm_sec)); + if (ret <= 0) { + return false; + } + + time_struct.tm_year -= 1900; + time_struct.tm_mon--; + time_integer = mktime(&time_struct); + + return true; +} + +void +CommonUtil::ConvertTime(time_t time_integer, tm& time_struct) { + localtime_r(&time_integer, &time_struct); +} + +void +CommonUtil::ConvertTime(tm time_struct, time_t& time_integer) { + time_integer = mktime(&time_struct); +} + +uint64_t +CommonUtil::RandomUINT64(){ + std::random_device rd; //Get a random seed from the OS entropy device, or whatever + std::mt19937_64 eng(rd()); //Use the 64-bit Mersenne Twister 19937 generator + //and seed it with entropy. + //Define the distribution, by default it goes from 0 to MAX(unsigned long long) + //or what have you. + std::uniform_int_distribution distr; + return distr(eng); + +} + +#ifdef ENABLE_CPU_PROFILING +std::string +CommonUtil::GetCurrentTimeStr() { + time_t tt; + time(&tt); + tt = tt + 8 * 60; + tm t; + gmtime_r(&tt, &t); + + std::string str = std::to_string(t.tm_year + 1900) + "_" + std::to_string(t.tm_mon + 1) + "_" + + std::to_string(t.tm_mday) + "_" + std::to_string(t.tm_hour) + "_" + std::to_string(t.tm_min) + + "_" + std::to_string(t.tm_sec); + return str; +} +#endif + +} // namespace milvus diff --git a/core/src/utils/CommonUtil.h b/core/src/utils/CommonUtil.h new file mode 100644 index 0000000000..ded802379d --- /dev/null +++ b/core/src/utils/CommonUtil.h @@ -0,0 +1,57 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "utils/Status.h" + +#include +#include + +namespace milvus { + +class CommonUtil { + public: + static bool + IsFileExist(const std::string& path); + static uint64_t + GetFileSize(const std::string& path); + static bool + IsDirectoryExist(const std::string& path); + static Status + CreateDirectory(const std::string& path); + static Status + DeleteDirectory(const std::string& path); + + static std::string + GetFileName(std::string filename); + static std::string + GetExePath(); + + static bool + TimeStrToTime(const std::string& time_str, time_t& time_integer, tm& time_struct, + const std::string& format = "%d-%d-%d %d:%d:%d"); + + static void + ConvertTime(time_t time_integer, tm& time_struct); + static void + ConvertTime(tm time_struct, time_t& time_integer); + + static uint64_t + RandomUINT64(); + +#ifdef ENABLE_CPU_PROFILING + static std::string + GetCurrentTimeStr(); +#endif +}; + +} // namespace milvus diff --git a/core/src/utils/ConfigUtils.cpp b/core/src/utils/ConfigUtils.cpp new file mode 100644 index 0000000000..da566b7971 --- /dev/null +++ b/core/src/utils/ConfigUtils.cpp @@ -0,0 +1,309 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "utils/ConfigUtils.h" +#include "utils/Log.h" +#include "utils/StringHelpFunctions.h" + +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#if defined(__x86_64__) +#define THREAD_MULTIPLY_CPU 1 +#elif defined(__powerpc64__) +#define THREAD_MULTIPLY_CPU 4 +#else +#define THREAD_MULTIPLY_CPU 1 +#endif + +namespace milvus { +namespace server { + +std::unordered_map BYTE_UNITS = { + {"b", 1}, + {"k", 1024}, + {"m", 1024 * 1024}, + {"g", 1024 * 1024 * 1024}, +}; + +bool +is_number(const std::string& s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isdigit(c); }) == s.end(); +} + +bool +is_alpha(const std::string& s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { return !std::isalpha(c); }) == s.end(); +} + +std::string +str_tolower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + return s; +} + +int64_t +parse_bytes(const std::string& str, std::string& err) { + try { + std::string s = str; + if (is_number(s)) + return std::stoll(s); + if (s.length() == 0) + return 0; + + auto last_two = s.substr(s.length() - 2, 2); + auto last_one = s.substr(s.length() - 1); + if (is_alpha(last_two) && is_alpha(last_one)) + if (last_one == "b" or last_one == "B") + s = s.substr(0, s.length() - 1); + auto& units = BYTE_UNITS; + auto suffix = str_tolower(s.substr(s.length() - 1)); + + std::string digits_part; + if (is_number(suffix)) { + digits_part = s; + suffix = 'b'; + } else { + digits_part = s.substr(0, s.length() - 1); + } + + if (units.find(suffix) != units.end() or is_number(suffix)) { + auto digits = std::stoll(digits_part); + return digits * units[suffix]; + } else { + std::stringstream ss; + ss << "The specified value for memory (" << str << ") should specify the units." + << "The postfix should be one of the `b` `k` `m` `g` characters"; + err = ss.str(); + } + } catch (...) { + err = "Unknown error happened on parse bytes."; + } + return 0; +} + +bool +GetSystemMemInfo(int64_t& total_mem, int64_t& free_mem) { + struct sysinfo info; + int ret = sysinfo(&info); + total_mem = info.totalram; + free_mem = info.freeram; + + return ret == 0; // succeed 0, failed -1 +} + +bool +GetSystemAvailableThreads(int64_t& thread_count) { + // threadCnt = std::thread::hardware_concurrency(); + thread_count = sysconf(_SC_NPROCESSORS_CONF); + thread_count *= THREAD_MULTIPLY_CPU; + + if (thread_count == 0) { + thread_count = 8; + } + + return true; +} + +Status +ValidateGpuIndex(int32_t gpu_index) { +#ifdef MILVUS_GPU_VERSION + int num_devices = 0; + auto cuda_err = cudaGetDeviceCount(&num_devices); + + if (cuda_err != cudaSuccess) { + std::string msg = "Failed to get gpu card number, cuda error:" + std::to_string(cuda_err); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_UNEXPECTED_ERROR, msg); + } + + if (gpu_index >= num_devices) { + std::string msg = "Invalid gpu index: " + std::to_string(gpu_index); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } +#endif + + return Status::OK(); +} + +#ifdef MILVUS_GPU_VERSION +Status +GetGpuMemory(int32_t gpu_index, int64_t& memory) { + + cudaDeviceProp deviceProp; + auto cuda_err = cudaGetDeviceProperties(&deviceProp, gpu_index); + if (cuda_err) { + std::string msg = "Failed to get gpu properties for gpu" + std::to_string(gpu_index) + + " , cuda error:" + std::to_string(cuda_err); + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_UNEXPECTED_ERROR, msg); + } + + memory = deviceProp.totalGlobalMem; + return Status::OK(); +} +#endif + +Status +ValidateIpAddress(const std::string& ip_address) { + struct in_addr address; + + int result = inet_pton(AF_INET, ip_address.c_str(), &address); + + switch (result) { + case 1: + return Status::OK(); + case 0: { + std::string msg = "Invalid IP address: " + ip_address; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + default: { + std::string msg = "IP address conversion error: " + ip_address; + LOG_SERVER_ERROR_ << msg; + return Status(SERVER_UNEXPECTED_ERROR, msg); + } + } +} + +Status +ValidateStringIsNumber(const std::string& str) { + if (str.empty() || !std::all_of(str.begin(), str.end(), ::isdigit)) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid number"); + } + try { + int64_t value = std::stol(str); + if (value < 0) { + return Status(SERVER_INVALID_ARGUMENT, "Negative number"); + } + } catch (...) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid number"); + } + return Status::OK(); +} + +Status +ValidateStringIsBool(const std::string& str) { + std::string s = str; + std::transform(s.begin(), s.end(), s.begin(), ::tolower); + if (s == "true" || s == "on" || s == "yes" || s == "1" || s == "false" || s == "off" || s == "no" || s == "0" || + s.empty()) { + return Status::OK(); + } + return Status(SERVER_INVALID_ARGUMENT, "Invalid boolean: " + str); +} + +Status +ValidateStringIsFloat(const std::string& str) { + try { + float val = std::stof(str); + if (val < 0.0) { + return Status(SERVER_INVALID_ARGUMENT, "Negative float: " + str); + } + } catch (...) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid float: " + str); + } + return Status::OK(); +} + +Status +ValidateDbURI(const std::string& uri) { + std::string dialectRegex = "(.*)"; + std::string usernameRegex = "(.*)"; + std::string passwordRegex = "(.*)"; + std::string hostRegex = "(.*)"; + std::string portRegex = "(.*)"; + std::string dbNameRegex = "(.*)"; + std::string uriRegexStr = dialectRegex + "\\:\\/\\/" + usernameRegex + "\\:" + passwordRegex + "\\@" + hostRegex + + "\\:" + portRegex + "\\/" + dbNameRegex; + std::regex uriRegex(uriRegexStr); + std::smatch pieces_match; + + bool okay = true; + + if (std::regex_match(uri, pieces_match, uriRegex)) { + std::string dialect = pieces_match[1].str(); + std::transform(dialect.begin(), dialect.end(), dialect.begin(), ::tolower); + if (dialect.find("mysql") == std::string::npos && dialect.find("sqlite") == std::string::npos && + dialect.find("mock") == std::string::npos) { + LOG_SERVER_ERROR_ << "Invalid dialect in URI: dialect = " << dialect; + okay = false; + } + + /* + * Could be DNS, skip checking + * + std::string host = pieces_match[4].str(); + if (!host.empty() && host != "localhost") { + if (ValidateIpAddress(host) != SERVER_SUCCESS) { + LOG_SERVER_ERROR_ << "Invalid host ip address in uri = " << host; + okay = false; + } + } + */ + + std::string port = pieces_match[5].str(); + if (!port.empty()) { + auto status = ValidateStringIsNumber(port); + if (!status.ok()) { + LOG_SERVER_ERROR_ << "Invalid port in uri = " << port; + okay = false; + } + } + } else { + LOG_SERVER_ERROR_ << "Wrong URI format: URI = " << uri; + okay = false; + } + + return (okay ? Status::OK() : Status(SERVER_INVALID_ARGUMENT, "Invalid db backend uri")); +} + +Status +ValidateStoragePath(const std::string& path) { + // Validate storage path if is valid, only correct absolute path will be validated pass + // Invalid path only contain character[a-zA-Z], number[0-9], '-', and '_', + // and path must start with '/'. + // examples below are invalid + // '/a//a', '/a--/a', '/-a/a', '/a@#/a', 'aaa/sfs' + std::string path_pattern = "^\\/(\\w+-?\\/?)+$"; + std::regex regex(path_pattern); + + return std::regex_match(path, regex) ? Status::OK() : Status(SERVER_INVALID_ARGUMENT, "Invalid file path"); +} + +Status +ValidateLogLevel(const std::string& level) { + std::set supported_level{"debug", "info", "warning", "error", "fatal"}; + + return supported_level.find(level) != supported_level.end() + ? Status::OK() + : Status(SERVER_INVALID_ARGUMENT, "Log level must be one of debug, info, warning, error and fatal."); +} + +bool +IsNumber(const std::string& s) { + return !s.empty() && std::all_of(s.begin(), s.end(), ::isdigit); +} + +} // namespace server +} // namespace milvus diff --git a/core/src/utils/ConfigUtils.h b/core/src/utils/ConfigUtils.h new file mode 100644 index 0000000000..27d3ce6842 --- /dev/null +++ b/core/src/utils/ConfigUtils.h @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "utils/Status.h" + +namespace milvus { +namespace server { + +extern int64_t +parse_bytes(const std::string& str, std::string& err); + +extern bool +GetSystemMemInfo(int64_t& total_mem, int64_t& free_mem); + +extern bool +GetSystemAvailableThreads(int64_t& thread_count); + +extern Status +ValidateGpuIndex(int32_t gpu_index); + +#ifdef MILVUS_GPU_VERSION +extern Status +GetGpuMemory(int32_t gpu_index, int64_t& memory); +#endif + +extern Status +ValidateIpAddress(const std::string& ip_address); + +extern Status +ValidateStringIsNumber(const std::string& str); + +extern Status +ValidateStringIsBool(const std::string& str); + +extern Status +ValidateStringIsFloat(const std::string& str); + +extern Status +ValidateDbURI(const std::string& uri); + +extern Status +ValidateStoragePath(const std::string& path); + +extern Status +ValidateLogLevel(const std::string& level); + +extern bool +IsNumber(const std::string& s); +} // namespace server +} // namespace milvus diff --git a/core/src/utils/Exception.h b/core/src/utils/Exception.h new file mode 100644 index 0000000000..abf1d1d01e --- /dev/null +++ b/core/src/utils/Exception.h @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "utils/Error.h" + +#include +#include + +namespace milvus { + +#define THROW_ERROR(err_code, err_msg) \ + LOG_ENGINE_ERROR_ << err_msg; \ + throw Exception(err_code, err_msg); + +class Exception : public std::exception { + public: + Exception(ErrorCode code, const std::string& message) : code_(code), message_(message) { + } + + ErrorCode + code() const noexcept { + return code_; + } + + const char* + what() const noexcept override { + if (message_.empty()) { + return "Default Exception."; + } else { + return message_.c_str(); + } + } + + ~Exception() noexcept override = default; + + protected: + ErrorCode code_; + std::string message_; +}; + +class InvalidArgumentException : public Exception { + public: + InvalidArgumentException() : Exception(SERVER_INVALID_ARGUMENT, "Invalid Argument") { + } + + explicit InvalidArgumentException(const std::string& message) : Exception(SERVER_INVALID_ARGUMENT, message) { + } +}; + +} // namespace milvus diff --git a/core/src/utils/Json.h b/core/src/utils/Json.h new file mode 100644 index 0000000000..03ee2127fb --- /dev/null +++ b/core/src/utils/Json.h @@ -0,0 +1,34 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "nlohmann/json.hpp" + +namespace milvus { + +using json = nlohmann::json; + +#define JSON_NULL_CHECK(json) \ + do { \ + if (json.empty()) { \ + return Status{SERVER_INVALID_ARGUMENT, "Json is null"}; \ + } \ + } while (false) + +#define JSON_OBJECT_CHECK(json) \ + do { \ + if (!json.is_object()) { \ + return Status{SERVER_INVALID_ARGUMENT, "Json is not a json object"}; \ + } \ + } while (false) + +} // namespace milvus diff --git a/core/src/utils/Log.h b/core/src/utils/Log.h new file mode 100644 index 0000000000..f2a25d7de1 --- /dev/null +++ b/core/src/utils/Log.h @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "log/Log.h" diff --git a/core/src/utils/SignalHandler.cpp b/core/src/utils/SignalHandler.cpp new file mode 100644 index 0000000000..69ef16fedc --- /dev/null +++ b/core/src/utils/SignalHandler.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "utils/SignalHandler.h" +#include "utils/Log.h" + +#include +#include +#include + +namespace milvus { + +signal_func_ptr signal_routine_func = nullptr; + +void +HandleSignal(int signum) { + int32_t exit_code = 1; /* 0: normal exit; 1: exception */ + switch (signum) { + case SIGINT: + case SIGUSR2: + exit_code = 0; + /* no break */ + default: { + if (exit_code == 0) { + LOG_SERVER_INFO_ << "Server received signal: " << signum; + } else { + LOG_SERVER_INFO_ << "Server received critical signal: " << signum; + PrintStacktrace(); + } + if (signal_routine_func != nullptr) { + (*signal_routine_func)(exit_code); + } + } + } +} + +void +PrintStacktrace() { + const int bt_depth = 128; + void* array[bt_depth]; + int stack_num = backtrace(array, bt_depth); + char** stacktrace = backtrace_symbols(array, stack_num); + + LOG_SERVER_INFO_ << "Call stack:"; + for (int i = 0; i < stack_num; ++i) { + std::string info = stacktrace[i]; + std::cout << "No." << i << ": " << info << std::endl; + LOG_SERVER_INFO_ << info; + } + free(stacktrace); +} + +} // namespace milvus diff --git a/core/src/utils/SignalHandler.h b/core/src/utils/SignalHandler.h new file mode 100644 index 0000000000..5c58886e45 --- /dev/null +++ b/core/src/utils/SignalHandler.h @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +namespace milvus { + +typedef void (*signal_func_ptr)(int32_t); + +extern signal_func_ptr signal_routine_func; + +extern void +HandleSignal(int signum); + +extern void +PrintStacktrace(); + +} // namespace milvus diff --git a/core/src/utils/Status.cpp b/core/src/utils/Status.cpp index 5cf627ef32..a4f987bdf2 100644 --- a/core/src/utils/Status.cpp +++ b/core/src/utils/Status.cpp @@ -21,7 +21,7 @@ Status::Status(StatusCode code, const std::string& msg) { // 4 bytes store code // 4 bytes store message length // the left bytes store message string - auto length = static_cast(msg.size()); + const uint32_t length = (uint32_t)msg.size(); auto result = new char[length + sizeof(length) + CODE_WIDTH]; std::memcpy(result, &code, CODE_WIDTH); std::memcpy(result + CODE_WIDTH, &length, sizeof(length)); @@ -30,26 +30,29 @@ Status::Status(StatusCode code, const std::string& msg) { state_ = result; } +Status::Status() : state_(nullptr) { +} + Status::~Status() { delete state_; } -Status::Status(const Status& s) { +Status::Status(const Status& s) : state_(nullptr) { CopyFrom(s); } -Status::Status(Status&& s) noexcept { - MoveFrom(s); -} - Status& Status::operator=(const Status& s) { CopyFrom(s); return *this; } +Status::Status(Status&& s) : state_(nullptr) { + MoveFrom(s); +} + Status& -Status::operator=(Status&& s) noexcept { +Status::operator=(Status&& s) { MoveFrom(s); return *this; } diff --git a/core/src/utils/Status.h b/core/src/utils/Status.h index 1a45e7bb73..e67d6ed048 100644 --- a/core/src/utils/Status.h +++ b/core/src/utils/Status.h @@ -31,18 +31,18 @@ using StatusCode = ErrorCode; class Status { public: Status(StatusCode code, const std::string& msg); - Status() = default; - virtual ~Status(); + Status(); + ~Status(); Status(const Status& s); - Status(Status&& s) noexcept; - Status& operator=(const Status& s); + Status(Status&& s); + Status& - operator=(Status&& s) noexcept; + operator=(Status&& s); static Status OK() { diff --git a/core/src/utils/StringHelpFunctions.cpp b/core/src/utils/StringHelpFunctions.cpp new file mode 100644 index 0000000000..49ab32f050 --- /dev/null +++ b/core/src/utils/StringHelpFunctions.cpp @@ -0,0 +1,160 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "utils/StringHelpFunctions.h" + +#include +#include +#include + +namespace milvus { + +void +StringHelpFunctions::TrimStringBlank(std::string& string) { + if (!string.empty()) { + static std::string s_format(" \n\r\t"); + string.erase(0, string.find_first_not_of(s_format)); + string.erase(string.find_last_not_of(s_format) + 1); + } +} + +void +StringHelpFunctions::TrimStringQuote(std::string& string, const std::string& qoute) { + if (!string.empty()) { + string.erase(0, string.find_first_not_of(qoute)); + string.erase(string.find_last_not_of(qoute) + 1); + } +} + +void +StringHelpFunctions::SplitStringByDelimeter(const std::string& str, const std::string& delimeter, + std::vector& result) { + if (str.empty()) { + return; + } + + size_t prev = 0; + while (true) { + size_t pos = str.find_first_of(delimeter, prev); + if (pos == std::string::npos) { + result.emplace_back(str.substr(prev)); + break; + } else { + result.emplace_back(str.substr(prev, pos - prev)); + prev = pos + 1; + } + } +} + +void +StringHelpFunctions::MergeStringWithDelimeter(const std::vector& strs, const std::string& delimeter, + std::string& result) { + if (strs.empty()) { + result = ""; + return; + } + + result = strs[0]; + for (size_t i = 1; i < strs.size(); i++) { + result = result + delimeter + strs[i]; + } +} + +Status +StringHelpFunctions::SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, + std::vector& result) { + if (quote.empty()) { + SplitStringByDelimeter(str, delimeter, result); + return Status::OK(); + } + + size_t last = 0; + size_t index = str.find_first_of(quote, last); + if (index == std::string::npos) { + SplitStringByDelimeter(str, delimeter, result); + return Status::OK(); + } + + std::string process_str = str; + while (index != std::string::npos) { + std::string prefix = process_str.substr(last, index - last); + std::string append_prefix; + if (!prefix.empty()) { + std::vector prefix_split; + SplitStringByDelimeter(prefix, delimeter, prefix_split); + for (size_t i = 0; i < prefix_split.size() - 1; i++) { + result.push_back(prefix_split[i]); + } + append_prefix = prefix_split[prefix_split.size() - 1]; + } + last = index + 1; + std::string postfix = process_str.substr(last); + index = postfix.find_first_of(quote, 0); + + + if (index == std::string::npos) { + return Status(SERVER_UNEXPECTED_ERROR, ""); + } + std::string quoted_text = postfix.substr(0, index); + append_prefix += quoted_text; + + last = index + 1; + index = postfix.find_first_of(delimeter, last); + + + if (index != std::string::npos) { + if (index > last) { + append_prefix += postfix.substr(last, index - last); + } + } else { + append_prefix += postfix.substr(last); + } + result.emplace_back(append_prefix); + + if (last == postfix.length()) { + return Status::OK(); + } + + process_str = postfix.substr(index + 1); + last = 0; + index = process_str.find_first_of(quote, last); + } + + if (!process_str.empty()) { + SplitStringByDelimeter(process_str, delimeter, result); + } + + return Status::OK(); +} + +bool +StringHelpFunctions::IsRegexMatch(const std::string& target_str, const std::string& pattern_str) { + // if target_str equals pattern_str, return true + if (target_str == pattern_str) { + return true; + } + + // regex match + std::regex pattern(pattern_str); + std::smatch results; + return std::regex_match(target_str, results, pattern); +} + +Status +StringHelpFunctions::ConvertToBoolean(const std::string& str, bool& value) { + std::string s = str; + std::transform(s.begin(), s.end(), s.begin(), ::tolower); + value = s == "true" || s == "on" || s == "yes" || s == "1"; + + return Status::OK(); +} + +} // namespace milvus diff --git a/core/src/utils/StringHelpFunctions.h b/core/src/utils/StringHelpFunctions.h new file mode 100644 index 0000000000..2b779084c4 --- /dev/null +++ b/core/src/utils/StringHelpFunctions.h @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "utils/Status.h" + +#include +#include + +namespace milvus { + +class StringHelpFunctions { + private: + StringHelpFunctions() = default; + + public: + // trim blanks from begin and end + // " a b c " => "a b c" + static void + TrimStringBlank(std::string& string); + + // trim quotes from begin and end + // "'abc'" => "abc" + static void + TrimStringQuote(std::string& string, const std::string& qoute); + + // split string by delimeter ',' + // a,b,c a | b | c + // a,b, a | b | + // ,b,c | b | c + // ,b, | b | + // ,, | | + // a a + static void + SplitStringByDelimeter(const std::string& str, const std::string& delimeter, std::vector& result); + + // merge strings with delimeter + // "a", "b", "c" => "a,b,c" + static void + MergeStringWithDelimeter(const std::vector& strs, const std::string& delimeter, std::string& result); + + // assume the collection has two columns, quote='\"', delimeter=',' + // a,b a | b + // "aa,gg,yy",b aa,gg,yy | b + // aa"dd,rr"kk,pp aadd,rrkk | pp + // "aa,bb" aa,bb + // 55,1122\"aa,bb\",yyy,\"kkk\" 55 | 1122aa,bb | yyy | kkk + // "55,1122"aa,bb",yyy,"kkk" illegal + static Status + SplitStringByQuote(const std::string& str, const std::string& delimeter, const std::string& quote, + std::vector& result); + + // std regex match function + // regex grammar reference: http://www.cplusplus.com/reference/regex/ECMAScript/ + static bool + IsRegexMatch(const std::string& target_str, const std::string& pattern); + + // conversion rules refer to ValidationUtil::ValidateStringIsBool() + // "true", "on", "yes", "1" ==> true + // "false", "off", "no", "0", "" ==> false + static Status + ConvertToBoolean(const std::string& str, bool& value); +}; + +} // namespace milvus diff --git a/core/src/utils/ThreadPool.h b/core/src/utils/ThreadPool.h new file mode 100644 index 0000000000..ab42d11e1d --- /dev/null +++ b/core/src/utils/ThreadPool.h @@ -0,0 +1,112 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_THREADS_NUM 32 + +namespace milvus { + +class ThreadPool { + public: + explicit ThreadPool(size_t threads, size_t queue_size = 1000); + + template + auto + enqueue(F&& f, Args&&... args) -> std::future::type>; + + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers_; + + // the task queue + std::queue > tasks_; + + size_t max_queue_size_; + + // synchronization + std::mutex queue_mutex_; + + std::condition_variable condition_; + + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads, size_t queue_size) : max_queue_size_(queue_size), stop(false) { + for (size_t i = 0; i < threads; ++i) + workers_.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex_); + this->condition_.wait(lock, [this] { return this->stop || !this->tasks_.empty(); }); + if (this->stop && this->tasks_.empty()) + return; + task = std::move(this->tasks_.front()); + this->tasks_.pop(); + } + this->condition_.notify_all(); + + task(); + } + }); +} + +// add new work item to the pool +template +auto +ThreadPool::enqueue(F&& f, Args&&... args) -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex_); + this->condition_.wait(lock, [this] { return this->tasks_.size() < max_queue_size_; }); + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks_.emplace([task]() { (*task)(); }); + } + condition_.notify_all(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex_); + stop = true; + } + condition_.notify_all(); + for (std::thread& worker : workers_) { + worker.join(); + } +} + +} // namespace milvus diff --git a/core/src/utils/TimeRecorder.cpp b/core/src/utils/TimeRecorder.cpp new file mode 100644 index 0000000000..cb8b674753 --- /dev/null +++ b/core/src/utils/TimeRecorder.cpp @@ -0,0 +1,99 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "utils/TimeRecorder.h" +#include "utils/Log.h" + +namespace milvus { + +TimeRecorder::TimeRecorder(const std::string& header, int64_t log_level) : header_(header), log_level_(log_level) { + start_ = last_ = stdclock::now(); +} + +TimeRecorder::~TimeRecorder() = default; + +std::string +TimeRecorder::GetTimeSpanStr(double span) { + std::string str_sec = std::to_string(span * 0.000001) + ((span > 1000000) ? " seconds" : " second"); + std::string str_ms = std::to_string(span * 0.001) + " ms"; + + return str_sec + " [" + str_ms + "]"; +} + +void +TimeRecorder::PrintTimeRecord(const std::string& msg, double span) { + std::string str_log; + if (!header_.empty()) + str_log += header_ + ": "; + str_log += msg; + str_log += " ("; + str_log += TimeRecorder::GetTimeSpanStr(span); + str_log += ")"; + + switch (log_level_) { + case 0: { + LOG_SERVER_TRACE_ << str_log; + break; + } + case 1: { + LOG_SERVER_DEBUG_ << str_log; + break; + } + case 2: { + LOG_SERVER_INFO_ << str_log; + break; + } + case 3: { + LOG_SERVER_WARNING_ << str_log; + break; + } + case 4: { + LOG_SERVER_ERROR_ << str_log; + break; + } + case 5: { + LOG_SERVER_FATAL_ << str_log; + break; + } + default: { + LOG_SERVER_INFO_ << str_log; + break; + } + } +} + +double +TimeRecorder::RecordSection(const std::string& msg) { + stdclock::time_point curr = stdclock::now(); + double span = (std::chrono::duration(curr - last_)).count(); + last_ = curr; + + PrintTimeRecord(msg, span); + return span; +} + +double +TimeRecorder::ElapseFromBegin(const std::string& msg) { + stdclock::time_point curr = stdclock::now(); + double span = (std::chrono::duration(curr - start_)).count(); + + PrintTimeRecord(msg, span); + return span; +} + +TimeRecorderAuto::TimeRecorderAuto(const std::string& header, int64_t log_level) : TimeRecorder(header, log_level) { +} + +TimeRecorderAuto::~TimeRecorderAuto() { + ElapseFromBegin("totally cost"); +} + +} // namespace milvus diff --git a/core/src/utils/TimeRecorder.h b/core/src/utils/TimeRecorder.h new file mode 100644 index 0000000000..5103513cfa --- /dev/null +++ b/core/src/utils/TimeRecorder.h @@ -0,0 +1,66 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include "utils/Log.h" + +namespace milvus { + +inline void +print_timestamp(const std::string& message) { + std::chrono::time_point now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + auto micros = std::chrono::duration_cast(duration).count(); + micros %= 1000000; + double millisecond = (double)micros / 1000.0; + + LOG_SERVER_DEBUG_ << std::fixed << " " << millisecond << "(ms) [timestamp]" << message; +} + +class TimeRecorder { + using stdclock = std::chrono::high_resolution_clock; + + public: + explicit TimeRecorder(const std::string& header, int64_t log_level = 1); + + virtual ~TimeRecorder(); // trace = 0, debug = 1, info = 2, warn = 3, error = 4, critical = 5 + + double + RecordSection(const std::string& msg); + + double + ElapseFromBegin(const std::string& msg); + + static std::string + GetTimeSpanStr(double span); + + private: + void + PrintTimeRecord(const std::string& msg, double span); + + private: + std::string header_; + stdclock::time_point start_; + stdclock::time_point last_; + int64_t log_level_; +}; + +class TimeRecorderAuto : public TimeRecorder { + public: + explicit TimeRecorderAuto(const std::string& header, int64_t log_level = 1); + + ~TimeRecorderAuto() override; +}; + +} // namespace milvus diff --git a/core/thirdparty/CMakeLists.txt b/core/thirdparty/CMakeLists.txt new file mode 100644 index 0000000000..233140a68a --- /dev/null +++ b/core/thirdparty/CMakeLists.txt @@ -0,0 +1,48 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- +# Using default c and cxx compiler in our build tree +# Thirdpart cxx and c flags +add_compile_options( -O3 -fPIC -Wno-error -fopenmp ) + +if ( NOT KNOWHERE_VERBOSE_THIRDPARTY_BUILD ) + set( EP_LOG_OPTIONS LOG_CONFIGURE 1 LOG_BUILD 1 LOG_INSTALL 1 LOG_DOWNLOAD 1 ) +else () + set( EP_LOG_OPTIONS ) +endif () + +set( MAKE_BUILD_ARGS "-j6" ) + +include( FetchContent ) +set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download ) +set( FETCHCONTENT_QUIET OFF ) + +# ---------------------------------------------------------------------- +# Find pthreads + +set( THREADS_PREFER_PTHREAD_FLAG ON ) +find_package( Threads REQUIRED ) + +# ****************************** Thirdparty googletest *************************************** +if ( MILVUS_BUILD_TESTS ) + # add_subdirectory( gtest ) +endif() + +# ****************************** Thirdparty yaml *************************************** +if ( MILVUS_WITH_YAMLCPP ) + add_subdirectory( yaml-cpp ) +endif() + +# ****************************** Thirdparty opentracing *************************************** +if ( MILVUS_WITH_OPENTRACING ) + add_subdirectory( opentracing ) +endif() \ No newline at end of file diff --git a/core/thirdparty/easyloggingpp/easylogging++.cc b/core/thirdparty/easyloggingpp/easylogging++.cc new file mode 100644 index 0000000000..4c6df12686 --- /dev/null +++ b/core/thirdparty/easyloggingpp/easylogging++.cc @@ -0,0 +1,3299 @@ +// +// Bismillah ar-Rahmaan ar-Raheem +// +// Easylogging++ v9.96.7 +// Cross-platform logging library for C++ applications +// +// Copyright (c) 2012-2018 Zuhd Web Services +// Copyright (c) 2012-2018 @abumusamq +// +// This library is released under the MIT Licence. +// https://github.com/zuhd-org/easyloggingpp/blob/master/LICENSE +// +// https://zuhd.org +// http://muflihun.com +// + +#include "easylogging++.h" + +#if defined(AUTO_INITIALIZE_EASYLOGGINGPP) +INITIALIZE_EASYLOGGINGPP +#endif + +namespace el { + +// el::base +namespace base { +// el::base::consts +namespace consts { + +// Level log values - These are values that are replaced in place of %level format specifier +// Extra spaces after format specifiers are only for readability purposes in log files +static const base::type::char_t* kInfoLevelLogValue = ELPP_LITERAL("INFO"); +static const base::type::char_t* kDebugLevelLogValue = ELPP_LITERAL("DEBUG"); +static const base::type::char_t* kWarningLevelLogValue = ELPP_LITERAL("WARNING"); +static const base::type::char_t* kErrorLevelLogValue = ELPP_LITERAL("ERROR"); +static const base::type::char_t* kFatalLevelLogValue = ELPP_LITERAL("FATAL"); +static const base::type::char_t* kVerboseLevelLogValue = + ELPP_LITERAL("VERBOSE"); // will become VERBOSE-x where x = verbose level +static const base::type::char_t* kTraceLevelLogValue = ELPP_LITERAL("TRACE"); +static const base::type::char_t* kInfoLevelShortLogValue = ELPP_LITERAL("I"); +static const base::type::char_t* kDebugLevelShortLogValue = ELPP_LITERAL("D"); +static const base::type::char_t* kWarningLevelShortLogValue = ELPP_LITERAL("W"); +static const base::type::char_t* kErrorLevelShortLogValue = ELPP_LITERAL("E"); +static const base::type::char_t* kFatalLevelShortLogValue = ELPP_LITERAL("F"); +static const base::type::char_t* kVerboseLevelShortLogValue = ELPP_LITERAL("V"); +static const base::type::char_t* kTraceLevelShortLogValue = ELPP_LITERAL("T"); +// Format specifiers - These are used to define log format +static const base::type::char_t* kAppNameFormatSpecifier = ELPP_LITERAL("%app"); +static const base::type::char_t* kLoggerIdFormatSpecifier = ELPP_LITERAL("%logger"); +static const base::type::char_t* kThreadIdFormatSpecifier = ELPP_LITERAL("%thread"); +static const base::type::char_t* kSeverityLevelFormatSpecifier = ELPP_LITERAL("%level"); +static const base::type::char_t* kSeverityLevelShortFormatSpecifier = ELPP_LITERAL("%levshort"); +static const base::type::char_t* kDateTimeFormatSpecifier = ELPP_LITERAL("%datetime"); +static const base::type::char_t* kLogFileFormatSpecifier = ELPP_LITERAL("%file"); +static const base::type::char_t* kLogFileBaseFormatSpecifier = ELPP_LITERAL("%fbase"); +static const base::type::char_t* kLogLineFormatSpecifier = ELPP_LITERAL("%line"); +static const base::type::char_t* kLogLocationFormatSpecifier = ELPP_LITERAL("%loc"); +static const base::type::char_t* kLogFunctionFormatSpecifier = ELPP_LITERAL("%func"); +static const base::type::char_t* kCurrentUserFormatSpecifier = ELPP_LITERAL("%user"); +static const base::type::char_t* kCurrentHostFormatSpecifier = ELPP_LITERAL("%host"); +static const base::type::char_t* kMessageFormatSpecifier = ELPP_LITERAL("%msg"); +static const base::type::char_t* kVerboseLevelFormatSpecifier = ELPP_LITERAL("%vlevel"); +static const char* kDateTimeFormatSpecifierForFilename = "%datetime"; +// Date/time +static const char* kDays[7] = {"Sundayaaa", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}; +static const char* kDaysAbbrev[7] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}; +static const char* kMonths[12] = {"January", "February", "March", "Apri", "May", "June", + "July", "August", "September", "October", "November", "December"}; +static const char* kMonthsAbbrev[12] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; +static const char* kDefaultDateTimeFormat = "%Y-%M-%d %H:%m:%s,%g"; +static const char* kDefaultDateTimeFormatInFilename = "%Y-%M-%d_%H-%m"; +static const int kYearBase = 1900; +static const char* kAm = "AM"; +static const char* kPm = "PM"; +// Miscellaneous constants + +static const char* kNullPointer = "nullptr"; +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED +static const base::type::VerboseLevel kMaxVerboseLevel = 9; +static const char* kUnknownUser = "user"; +static const char* kUnknownHost = "unknown-host"; + +//---------------- DEFAULT LOG FILE ----------------------- + +#if defined(ELPP_NO_DEFAULT_LOG_FILE) +#if ELPP_OS_UNIX +static const char* kDefaultLogFile = "/dev/null"; +#elif ELPP_OS_WINDOWS +static const char* kDefaultLogFile = "nul"; +#endif // ELPP_OS_UNIX +#elif defined(ELPP_DEFAULT_LOG_FILE) +static const char* kDefaultLogFile = ELPP_DEFAULT_LOG_FILE; +#else +static const char* kDefaultLogFile = "myeasylog.log"; +#endif // defined(ELPP_NO_DEFAULT_LOG_FILE) + +#if !defined(ELPP_DISABLE_LOG_FILE_FROM_ARG) +static const char* kDefaultLogFileParam = "--default-log-file"; +#endif // !defined(ELPP_DISABLE_LOG_FILE_FROM_ARG) +#if defined(ELPP_LOGGING_FLAGS_FROM_ARG) +static const char* kLoggingFlagsParam = "--logging-flags"; +#endif // defined(ELPP_LOGGING_FLAGS_FROM_ARG) +static const char* kValidLoggerIdSymbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._"; +static const char* kConfigurationComment = "##"; +static const char* kConfigurationLevel = "*"; +static const char* kConfigurationLoggerId = "--"; +} // namespace consts +// el::base::utils +namespace utils { + +/// @brief Aborts application due with user-defined status +static void +abort(int status, const std::string& reason) { + // Both status and reason params are there for debugging with tools like gdb etc + ELPP_UNUSED(status); + ELPP_UNUSED(reason); +#if defined(ELPP_COMPILER_MSVC) && defined(_M_IX86) && defined(_DEBUG) + // Ignore msvc critical error dialog - break instead (on debug mode) + _asm int 3 +#else + ::abort(); +#endif // defined(ELPP_COMPILER_MSVC) && defined(_M_IX86) && defined(_DEBUG) +} + +} // namespace utils +} // namespace base + +// el + +// LevelHelper + +const char* +LevelHelper::convertToString(Level level) { + // Do not use switch over strongly typed enums because Intel C++ compilers dont support them yet. + if (level == Level::Global) + return "GLOBAL"; + if (level == Level::Debug) + return "DEBUG"; + if (level == Level::Info) + return "INFO"; + if (level == Level::Warning) + return "WARNING"; + if (level == Level::Error) + return "ERROR"; + if (level == Level::Fatal) + return "FATAL"; + if (level == Level::Verbose) + return "VERBOSE"; + if (level == Level::Trace) + return "TRACE"; + return "UNKNOWN"; +} + +struct StringToLevelItem { + const char* levelString; + Level level; +}; + +static struct StringToLevelItem stringToLevelMap[] = { + {"global", Level::Global}, {"debug", Level::Debug}, {"info", Level::Info}, {"warning", Level::Warning}, + {"error", Level::Error}, {"fatal", Level::Fatal}, {"verbose", Level::Verbose}, {"trace", Level::Trace}}; + +Level +LevelHelper::convertFromString(const char* levelStr) { + for (auto& item : stringToLevelMap) { + if (base::utils::Str::cStringCaseEq(levelStr, item.levelString)) { + return item.level; + } + } + return Level::Unknown; +} + +void +LevelHelper::forEachLevel(base::type::EnumType* startIndex, const std::function& fn) { + base::type::EnumType lIndexMax = LevelHelper::kMaxValid; + do { + if (fn()) { + break; + } + *startIndex = static_cast(*startIndex << 1); + } while (*startIndex <= lIndexMax); +} + +// ConfigurationTypeHelper + +const char* +ConfigurationTypeHelper::convertToString(ConfigurationType configurationType) { + // Do not use switch over strongly typed enums because Intel C++ compilers dont support them yet. + if (configurationType == ConfigurationType::Enabled) + return "ENABLED"; + if (configurationType == ConfigurationType::Filename) + return "FILENAME"; + if (configurationType == ConfigurationType::Format) + return "FORMAT"; + if (configurationType == ConfigurationType::ToFile) + return "TO_FILE"; + if (configurationType == ConfigurationType::ToStandardOutput) + return "TO_STANDARD_OUTPUT"; + if (configurationType == ConfigurationType::SubsecondPrecision) + return "SUBSECOND_PRECISION"; + if (configurationType == ConfigurationType::PerformanceTracking) + return "PERFORMANCE_TRACKING"; + if (configurationType == ConfigurationType::MaxLogFileSize) + return "MAX_LOG_FILE_SIZE"; + if (configurationType == ConfigurationType::LogFlushThreshold) + return "LOG_FLUSH_THRESHOLD"; + return "UNKNOWN"; +} + +struct ConfigurationStringToTypeItem { + const char* configString; + ConfigurationType configType; +}; + +static struct ConfigurationStringToTypeItem configStringToTypeMap[] = { + {"enabled", ConfigurationType::Enabled}, + {"to_file", ConfigurationType::ToFile}, + {"to_standard_output", ConfigurationType::ToStandardOutput}, + {"format", ConfigurationType::Format}, + {"filename", ConfigurationType::Filename}, + {"subsecond_precision", ConfigurationType::SubsecondPrecision}, + {"milliseconds_width", ConfigurationType::MillisecondsWidth}, + {"performance_tracking", ConfigurationType::PerformanceTracking}, + {"max_log_file_size", ConfigurationType::MaxLogFileSize}, + {"log_flush_threshold", ConfigurationType::LogFlushThreshold}, +}; + +ConfigurationType +ConfigurationTypeHelper::convertFromString(const char* configStr) { + for (auto& item : configStringToTypeMap) { + if (base::utils::Str::cStringCaseEq(configStr, item.configString)) { + return item.configType; + } + } + return ConfigurationType::Unknown; +} + +void +ConfigurationTypeHelper::forEachConfigType(base::type::EnumType* startIndex, const std::function& fn) { + base::type::EnumType cIndexMax = ConfigurationTypeHelper::kMaxValid; + do { + if (fn()) { + break; + } + *startIndex = static_cast(*startIndex << 1); + } while (*startIndex <= cIndexMax); +} + +// Configuration + +Configuration::Configuration(const Configuration& c) + : m_level(c.m_level), m_configurationType(c.m_configurationType), m_value(c.m_value) { +} + +Configuration& +Configuration::operator=(const Configuration& c) { + if (&c != this) { + m_level = c.m_level; + m_configurationType = c.m_configurationType; + m_value = c.m_value; + } + return *this; +} + +/// @brief Full constructor used to sets value of configuration +Configuration::Configuration(Level level, ConfigurationType configurationType, const std::string& value) + : m_level(level), m_configurationType(configurationType), m_value(value) { +} + +void +Configuration::log(el::base::type::ostream_t& os) const { + os << LevelHelper::convertToString(m_level) << ELPP_LITERAL(" ") + << ConfigurationTypeHelper::convertToString(m_configurationType) << ELPP_LITERAL(" = ") << m_value.c_str(); +} + +/// @brief Used to find configuration from configuration (pointers) repository. Avoid using it. +Configuration::Predicate::Predicate(Level level, ConfigurationType configurationType) + : m_level(level), m_configurationType(configurationType) { +} + +bool +Configuration::Predicate::operator()(const Configuration* conf) const { + return ((conf != nullptr) && (conf->level() == m_level) && (conf->configurationType() == m_configurationType)); +} + +// Configurations + +Configurations::Configurations(void) : m_configurationFile(std::string()), m_isFromFile(false) { +} + +Configurations::Configurations(const std::string& configurationFile, bool useDefaultsForRemaining, Configurations* base) + : m_configurationFile(configurationFile), m_isFromFile(false) { + parseFromFile(configurationFile, base); + if (useDefaultsForRemaining) { + setRemainingToDefault(); + } +} + +bool +Configurations::parseFromFile(const std::string& configurationFile, Configurations* base) { + // We initial assertion with true because if we have assertion diabled, we want to pass this + // check and if assertion is enabled we will have values re-assigned any way. + bool assertionPassed = true; + ELPP_ASSERT((assertionPassed = base::utils::File::pathExists(configurationFile.c_str(), true)) == true, + "Configuration file [" << configurationFile << "] does not exist!"); + if (!assertionPassed) { + return false; + } + bool success = Parser::parseFromFile(configurationFile, this, base); + m_isFromFile = success; + return success; +} + +bool +Configurations::parseFromText(const std::string& configurationsString, Configurations* base) { + bool success = Parser::parseFromText(configurationsString, this, base); + if (success) { + m_isFromFile = false; + } + return success; +} + +void +Configurations::setFromBase(Configurations* base) { + if (base == nullptr || base == this) { + return; + } + base::threading::ScopedLock scopedLock(base->lock()); + for (Configuration*& conf : base->list()) { + set(conf); + } +} + +bool +Configurations::hasConfiguration(ConfigurationType configurationType) { + base::type::EnumType lIndex = LevelHelper::kMinValid; + bool result = false; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + if (hasConfiguration(LevelHelper::castFromInt(lIndex), configurationType)) { + result = true; + } + return result; + }); + return result; +} + +bool +Configurations::hasConfiguration(Level level, ConfigurationType configurationType) { + base::threading::ScopedLock scopedLock(lock()); +#if ELPP_COMPILER_INTEL + // We cant specify template types here, Intel C++ throws compilation error + // "error: type name is not allowed" + return RegistryWithPred::get(level, configurationType) != nullptr; +#else + return RegistryWithPred::get(level, configurationType) != nullptr; +#endif // ELPP_COMPILER_INTEL +} + +void +Configurations::set(Level level, ConfigurationType configurationType, const std::string& value) { + base::threading::ScopedLock scopedLock(lock()); + unsafeSet(level, configurationType, value); // This is not unsafe anymore as we have locked mutex + if (level == Level::Global) { + unsafeSetGlobally(configurationType, value, false); // Again this is not unsafe either + } +} + +void +Configurations::set(Configuration* conf) { + if (conf == nullptr) { + return; + } + set(conf->level(), conf->configurationType(), conf->value()); +} + +void +Configurations::setToDefault(void) { + setGlobally(ConfigurationType::Enabled, std::string("true"), true); + setGlobally(ConfigurationType::Filename, std::string(base::consts::kDefaultLogFile), true); +#if defined(ELPP_NO_LOG_TO_FILE) + setGlobally(ConfigurationType::ToFile, std::string("false"), true); +#else + setGlobally(ConfigurationType::ToFile, std::string("true"), true); +#endif // defined(ELPP_NO_LOG_TO_FILE) + setGlobally(ConfigurationType::ToStandardOutput, std::string("true"), true); + setGlobally(ConfigurationType::SubsecondPrecision, std::string("3"), true); + setGlobally(ConfigurationType::PerformanceTracking, std::string("true"), true); + setGlobally(ConfigurationType::MaxLogFileSize, std::string("0"), true); + setGlobally(ConfigurationType::LogFlushThreshold, std::string("0"), true); + + setGlobally(ConfigurationType::Format, std::string("%datetime %level [%logger] %msg"), true); + set(Level::Debug, ConfigurationType::Format, + std::string("%datetime %level [%logger] [%user@%host] [%func] [%loc] %msg")); + // INFO and WARNING are set to default by Level::Global + set(Level::Error, ConfigurationType::Format, std::string("%datetime %level [%logger] %msg")); + set(Level::Fatal, ConfigurationType::Format, std::string("%datetime %level [%logger] %msg")); + set(Level::Verbose, ConfigurationType::Format, std::string("%datetime %level-%vlevel [%logger] %msg")); + set(Level::Trace, ConfigurationType::Format, std::string("%datetime %level [%logger] [%func] [%loc] %msg")); +} + +void +Configurations::setRemainingToDefault(void) { + base::threading::ScopedLock scopedLock(lock()); +#if defined(ELPP_NO_LOG_TO_FILE) + unsafeSetIfNotExist(Level::Global, ConfigurationType::Enabled, std::string("false")); +#else + unsafeSetIfNotExist(Level::Global, ConfigurationType::Enabled, std::string("true")); +#endif // defined(ELPP_NO_LOG_TO_FILE) + unsafeSetIfNotExist(Level::Global, ConfigurationType::Filename, std::string(base::consts::kDefaultLogFile)); + unsafeSetIfNotExist(Level::Global, ConfigurationType::ToStandardOutput, std::string("true")); + unsafeSetIfNotExist(Level::Global, ConfigurationType::SubsecondPrecision, std::string("3")); + unsafeSetIfNotExist(Level::Global, ConfigurationType::PerformanceTracking, std::string("true")); + unsafeSetIfNotExist(Level::Global, ConfigurationType::MaxLogFileSize, std::string("0")); + unsafeSetIfNotExist(Level::Global, ConfigurationType::Format, std::string("%datetime %level [%logger] %msg")); + unsafeSetIfNotExist(Level::Debug, ConfigurationType::Format, + std::string("%datetime %level [%logger] [%user@%host] [%func] [%loc] %msg")); + // INFO and WARNING are set to default by Level::Global + unsafeSetIfNotExist(Level::Error, ConfigurationType::Format, std::string("%datetime %level [%logger] %msg")); + unsafeSetIfNotExist(Level::Fatal, ConfigurationType::Format, std::string("%datetime %level [%logger] %msg")); + unsafeSetIfNotExist(Level::Verbose, ConfigurationType::Format, + std::string("%datetime %level-%vlevel [%logger] %msg")); + unsafeSetIfNotExist(Level::Trace, ConfigurationType::Format, + std::string("%datetime %level [%logger] [%func] [%loc] %msg")); +} + +bool +Configurations::Parser::parseFromFile(const std::string& configurationFile, Configurations* sender, + Configurations* base) { + sender->setFromBase(base); + std::ifstream fileStream_(configurationFile.c_str(), std::ifstream::in); + ELPP_ASSERT(fileStream_.is_open(), "Unable to open configuration file [" << configurationFile << "] for parsing."); + bool parsedSuccessfully = false; + std::string line = std::string(); + Level currLevel = Level::Unknown; + std::string currConfigStr = std::string(); + std::string currLevelStr = std::string(); + while (fileStream_.good()) { + std::getline(fileStream_, line); + parsedSuccessfully = parseLine(&line, &currConfigStr, &currLevelStr, &currLevel, sender); + ELPP_ASSERT(parsedSuccessfully, "Unable to parse configuration line: " << line); + } + return parsedSuccessfully; +} + +bool +Configurations::Parser::parseFromText(const std::string& configurationsString, Configurations* sender, + Configurations* base) { + sender->setFromBase(base); + bool parsedSuccessfully = false; + std::stringstream ss(configurationsString); + std::string line = std::string(); + Level currLevel = Level::Unknown; + std::string currConfigStr = std::string(); + std::string currLevelStr = std::string(); + while (std::getline(ss, line)) { + parsedSuccessfully = parseLine(&line, &currConfigStr, &currLevelStr, &currLevel, sender); + ELPP_ASSERT(parsedSuccessfully, "Unable to parse configuration line: " << line); + } + return parsedSuccessfully; +} + +void +Configurations::Parser::ignoreComments(std::string* line) { + std::size_t foundAt = 0; + std::size_t quotesStart = line->find("\""); + std::size_t quotesEnd = std::string::npos; + if (quotesStart != std::string::npos) { + quotesEnd = line->find("\"", quotesStart + 1); + while (quotesEnd != std::string::npos && line->at(quotesEnd - 1) == '\\') { + // Do not erase slash yet - we will erase it in parseLine(..) while loop + quotesEnd = line->find("\"", quotesEnd + 2); + } + } + if ((foundAt = line->find(base::consts::kConfigurationComment)) != std::string::npos) { + if (foundAt < quotesEnd) { + foundAt = line->find(base::consts::kConfigurationComment, quotesEnd + 1); + } + *line = line->substr(0, foundAt); + } +} + +bool +Configurations::Parser::isLevel(const std::string& line) { + return base::utils::Str::startsWith(line, std::string(base::consts::kConfigurationLevel)); +} + +bool +Configurations::Parser::isComment(const std::string& line) { + return base::utils::Str::startsWith(line, std::string(base::consts::kConfigurationComment)); +} + +bool +Configurations::Parser::isConfig(const std::string& line) { + std::size_t assignment = line.find('='); + return line != "" && ((line[0] >= 'A' && line[0] <= 'Z') || (line[0] >= 'a' && line[0] <= 'z')) && + (assignment != std::string::npos) && (line.size() > assignment); +} + +bool +Configurations::Parser::parseLine(std::string* line, std::string* currConfigStr, std::string* currLevelStr, + Level* currLevel, Configurations* conf) { + ConfigurationType currConfig = ConfigurationType::Unknown; + std::string currValue = std::string(); + *line = base::utils::Str::trim(*line); + if (isComment(*line)) + return true; + ignoreComments(line); + *line = base::utils::Str::trim(*line); + if (line->empty()) { + // Comment ignored + return true; + } + if (isLevel(*line)) { + if (line->size() <= 2) { + return true; + } + *currLevelStr = line->substr(1, line->size() - 2); + *currLevelStr = base::utils::Str::toUpper(*currLevelStr); + *currLevelStr = base::utils::Str::trim(*currLevelStr); + *currLevel = LevelHelper::convertFromString(currLevelStr->c_str()); + return true; + } + if (isConfig(*line)) { + std::size_t assignment = line->find('='); + *currConfigStr = line->substr(0, assignment); + *currConfigStr = base::utils::Str::toUpper(*currConfigStr); + *currConfigStr = base::utils::Str::trim(*currConfigStr); + currConfig = ConfigurationTypeHelper::convertFromString(currConfigStr->c_str()); + currValue = line->substr(assignment + 1); + currValue = base::utils::Str::trim(currValue); + std::size_t quotesStart = currValue.find("\"", 0); + std::size_t quotesEnd = std::string::npos; + if (quotesStart != std::string::npos) { + quotesEnd = currValue.find("\"", quotesStart + 1); + while (quotesEnd != std::string::npos && currValue.at(quotesEnd - 1) == '\\') { + currValue = currValue.erase(quotesEnd - 1, 1); + quotesEnd = currValue.find("\"", quotesEnd + 2); + } + } + if (quotesStart != std::string::npos && quotesEnd != std::string::npos) { + // Quote provided - check and strip if valid + ELPP_ASSERT((quotesStart < quotesEnd), + "Configuration error - No ending quote found in [" << currConfigStr << "]"); + ELPP_ASSERT((quotesStart + 1 != quotesEnd), "Empty configuration value for [" << currConfigStr << "]"); + if ((quotesStart != quotesEnd) && (quotesStart + 1 != quotesEnd)) { + // Explicit check in case if assertion is disabled + currValue = currValue.substr(quotesStart + 1, quotesEnd - 1); + } + } + } + ELPP_ASSERT(*currLevel != Level::Unknown, "Unrecognized severity level [" << *currLevelStr << "]"); + ELPP_ASSERT(currConfig != ConfigurationType::Unknown, "Unrecognized configuration [" << *currConfigStr << "]"); + if (*currLevel == Level::Unknown || currConfig == ConfigurationType::Unknown) { + return false; // unrecognizable level or config + } + conf->set(*currLevel, currConfig, currValue); + return true; +} + +void +Configurations::unsafeSetIfNotExist(Level level, ConfigurationType configurationType, const std::string& value) { + Configuration* conf = RegistryWithPred::get(level, configurationType); + if (conf == nullptr) { + unsafeSet(level, configurationType, value); + } +} + +void +Configurations::unsafeSet(Level level, ConfigurationType configurationType, const std::string& value) { + Configuration* conf = RegistryWithPred::get(level, configurationType); + if (conf == nullptr) { + registerNew(new Configuration(level, configurationType, value)); + } else { + conf->setValue(value); + } + if (level == Level::Global) { + unsafeSetGlobally(configurationType, value, false); + } +} + +void +Configurations::setGlobally(ConfigurationType configurationType, const std::string& value, bool includeGlobalLevel) { + if (includeGlobalLevel) { + set(Level::Global, configurationType, value); + } + base::type::EnumType lIndex = LevelHelper::kMinValid; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + set(LevelHelper::castFromInt(lIndex), configurationType, value); + return false; // Do not break lambda function yet as we need to set all levels regardless + }); +} + +void +Configurations::unsafeSetGlobally(ConfigurationType configurationType, const std::string& value, + bool includeGlobalLevel) { + if (includeGlobalLevel) { + unsafeSet(Level::Global, configurationType, value); + } + base::type::EnumType lIndex = LevelHelper::kMinValid; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + unsafeSet(LevelHelper::castFromInt(lIndex), configurationType, value); + return false; // Do not break lambda function yet as we need to set all levels regardless + }); +} + +// LogBuilder + +void +LogBuilder::convertToColoredOutput(base::type::string_t* logLine, Level level) { + if (!m_termSupportsColor) + return; + const base::type::char_t* resetColor = ELPP_LITERAL("\x1b[0m"); + if (level == Level::Error || level == Level::Fatal) + *logLine = ELPP_LITERAL("\x1b[31m") + *logLine + resetColor; + else if (level == Level::Warning) + *logLine = ELPP_LITERAL("\x1b[33m") + *logLine + resetColor; + else if (level == Level::Debug) + *logLine = ELPP_LITERAL("\x1b[32m") + *logLine + resetColor; + else if (level == Level::Info) + *logLine = ELPP_LITERAL("\x1b[36m") + *logLine + resetColor; + else if (level == Level::Trace) + *logLine = ELPP_LITERAL("\x1b[35m") + *logLine + resetColor; +} + +// Logger + +Logger::Logger(const std::string& id, base::LogStreamsReferenceMap* logStreamsReference) + : m_id(id), + m_typedConfigurations(nullptr), + m_parentApplicationName(std::string()), + m_isConfigured(false), + m_logStreamsReference(logStreamsReference) { + initUnflushedCount(); +} + +Logger::Logger(const std::string& id, const Configurations& configurations, + base::LogStreamsReferenceMap* logStreamsReference) + : m_id(id), + m_typedConfigurations(nullptr), + m_parentApplicationName(std::string()), + m_isConfigured(false), + m_logStreamsReference(logStreamsReference) { + initUnflushedCount(); + configure(configurations); +} + +Logger::Logger(const Logger& logger) { + base::utils::safeDelete(m_typedConfigurations); + m_id = logger.m_id; + m_typedConfigurations = logger.m_typedConfigurations; + m_parentApplicationName = logger.m_parentApplicationName; + m_isConfigured = logger.m_isConfigured; + m_configurations = logger.m_configurations; + m_unflushedCount = logger.m_unflushedCount; + m_logStreamsReference = logger.m_logStreamsReference; +} + +Logger& +Logger::operator=(const Logger& logger) { + if (&logger != this) { + base::utils::safeDelete(m_typedConfigurations); + m_id = logger.m_id; + m_typedConfigurations = logger.m_typedConfigurations; + m_parentApplicationName = logger.m_parentApplicationName; + m_isConfigured = logger.m_isConfigured; + m_configurations = logger.m_configurations; + m_unflushedCount = logger.m_unflushedCount; + m_logStreamsReference = logger.m_logStreamsReference; + } + return *this; +} + +void +Logger::configure(const Configurations& configurations) { + m_isConfigured = false; // we set it to false in case if we fail + initUnflushedCount(); + if (m_typedConfigurations != nullptr) { + Configurations* c = const_cast(m_typedConfigurations->configurations()); + if (c->hasConfiguration(Level::Global, ConfigurationType::Filename)) { + flush(); + } + } + base::threading::ScopedLock scopedLock(lock()); + if (m_configurations != configurations) { + m_configurations.setFromBase(const_cast(&configurations)); + } + base::utils::safeDelete(m_typedConfigurations); + m_typedConfigurations = new base::TypedConfigurations(&m_configurations, m_logStreamsReference); + resolveLoggerFormatSpec(); + m_isConfigured = true; +} + +void +Logger::reconfigure(void) { + ELPP_INTERNAL_INFO(1, "Reconfiguring logger [" << m_id << "]"); + configure(m_configurations); +} + +bool +Logger::isValidId(const std::string& id) { + for (std::string::const_iterator it = id.begin(); it != id.end(); ++it) { + if (!base::utils::Str::contains(base::consts::kValidLoggerIdSymbols, *it)) { + return false; + } + } + return true; +} + +void +Logger::flush(void) { + ELPP_INTERNAL_INFO(3, "Flushing logger [" << m_id << "] all levels"); + base::threading::ScopedLock scopedLock(lock()); + base::type::EnumType lIndex = LevelHelper::kMinValid; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + flush(LevelHelper::castFromInt(lIndex), nullptr); + return false; + }); +} + +void +Logger::flush(Level level, base::type::fstream_t* fs) { + if (fs == nullptr && m_typedConfigurations->toFile(level)) { + fs = m_typedConfigurations->fileStream(level); + } + if (fs != nullptr) { + fs->flush(); + std::unordered_map::iterator iter = m_unflushedCount.find(level); + if (iter != m_unflushedCount.end()) { + iter->second = 0; + } + Helpers::validateFileRolling(this, level); + } +} + +void +Logger::initUnflushedCount(void) { + m_unflushedCount.clear(); + base::type::EnumType lIndex = LevelHelper::kMinValid; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + m_unflushedCount.insert(std::make_pair(LevelHelper::castFromInt(lIndex), 0)); + return false; + }); +} + +void +Logger::resolveLoggerFormatSpec(void) const { + base::type::EnumType lIndex = LevelHelper::kMinValid; + LevelHelper::forEachLevel(&lIndex, [&](void) -> bool { + base::LogFormat* logFormat = + const_cast(&m_typedConfigurations->logFormat(LevelHelper::castFromInt(lIndex))); + base::utils::Str::replaceFirstWithEscape(logFormat->m_format, base::consts::kLoggerIdFormatSpecifier, m_id); + return false; + }); +} + +// el::base +namespace base { + +// el::base::utils +namespace utils { + +// File + +base::type::fstream_t* +File::newFileStream(const std::string& filename) { + base::type::fstream_t* fs = new base::type::fstream_t(filename.c_str(), base::type::fstream_t::out +#if !defined(ELPP_FRESH_LOG_FILE) + | base::type::fstream_t::app +#endif + ); +#if defined(ELPP_UNICODE) + std::locale elppUnicodeLocale(""); +#if ELPP_OS_WINDOWS + std::locale elppUnicodeLocaleWindows(elppUnicodeLocale, new std::codecvt_utf8_utf16); + elppUnicodeLocale = elppUnicodeLocaleWindows; +#endif // ELPP_OS_WINDOWS + fs->imbue(elppUnicodeLocale); +#endif // defined(ELPP_UNICODE) + if (fs->is_open()) { + fs->flush(); + } else { + base::utils::safeDelete(fs); + ELPP_INTERNAL_ERROR("Bad file [" << filename << "]", true); + } + return fs; +} + +std::size_t +File::getSizeOfFile(base::type::fstream_t* fs) { + if (fs == nullptr) { + return 0; + } + // Since the file stream is appended to or truncated, the current + // offset is the file size. + std::size_t size = static_cast(fs->tellg()); + return size; +} + +bool +File::pathExists(const char* path, bool considerFile) { + if (path == nullptr) { + return false; + } +#if ELPP_OS_UNIX + ELPP_UNUSED(considerFile); + struct stat st; + return (stat(path, &st) == 0); +#elif ELPP_OS_WINDOWS + DWORD fileType = GetFileAttributesA(path); + if (fileType == INVALID_FILE_ATTRIBUTES) { + return false; + } + return considerFile ? true : ((fileType & FILE_ATTRIBUTE_DIRECTORY) == 0 ? false : true); +#endif // ELPP_OS_UNIX +} + +bool +File::createPath(const std::string& path) { + if (path.empty()) { + return false; + } + if (base::utils::File::pathExists(path.c_str())) { + return true; + } + int status = -1; + + char* currPath = const_cast(path.c_str()); + std::string builtPath = std::string(); +#if ELPP_OS_UNIX + if (path[0] == '/') { + builtPath = "/"; + } + currPath = STRTOK(currPath, base::consts::kFilePathSeperator, 0); +#elif ELPP_OS_WINDOWS + // Use secure functions API + char* nextTok_ = nullptr; + currPath = STRTOK(currPath, base::consts::kFilePathSeperator, &nextTok_); + ELPP_UNUSED(nextTok_); +#endif // ELPP_OS_UNIX + while (currPath != nullptr) { + builtPath.append(currPath); + builtPath.append(base::consts::kFilePathSeperator); +#if ELPP_OS_UNIX + status = mkdir(builtPath.c_str(), ELPP_LOG_PERMS); + currPath = STRTOK(nullptr, base::consts::kFilePathSeperator, 0); +#elif ELPP_OS_WINDOWS + status = _mkdir(builtPath.c_str()); + currPath = STRTOK(nullptr, base::consts::kFilePathSeperator, &nextTok_); +#endif // ELPP_OS_UNIX + } + if (status == -1) { + ELPP_INTERNAL_ERROR("Error while creating path [" << path << "]", true); + return false; + } + return true; +} + +std::string +File::extractPathFromFilename(const std::string& fullPath, const char* separator) { + if ((fullPath == "") || (fullPath.find(separator) == std::string::npos)) { + return fullPath; + } + std::size_t lastSlashAt = fullPath.find_last_of(separator); + if (lastSlashAt == 0) { + return std::string(separator); + } + return fullPath.substr(0, lastSlashAt + 1); +} + +void +File::buildStrippedFilename(const char* filename, char buff[], std::size_t limit) { + std::size_t sizeOfFilename = strlen(filename); + if (sizeOfFilename >= limit) { + filename += (sizeOfFilename - limit); + if (filename[0] != '.' && filename[1] != '.') { // prepend if not already + filename += 3; // 3 = '..' + STRCAT(buff, "..", limit); + } + } + STRCAT(buff, filename, limit); +} + +void +File::buildBaseFilename(const std::string& fullPath, char buff[], std::size_t limit, const char* separator) { + const char* filename = fullPath.c_str(); + std::size_t lastSlashAt = fullPath.find_last_of(separator); + filename += lastSlashAt ? lastSlashAt + 1 : 0; + std::size_t sizeOfFilename = strlen(filename); + if (sizeOfFilename >= limit) { + filename += (sizeOfFilename - limit); + if (filename[0] != '.' && filename[1] != '.') { // prepend if not already + filename += 3; // 3 = '..' + STRCAT(buff, "..", limit); + } + } + STRCAT(buff, filename, limit); +} + +// Str + +bool +Str::wildCardMatch(const char* str, const char* pattern) { + while (*pattern) { + switch (*pattern) { + case '?': + if (!*str) + return false; + ++str; + ++pattern; + break; + case '*': + if (wildCardMatch(str, pattern + 1)) + return true; + if (*str && wildCardMatch(str + 1, pattern)) + return true; + return false; + default: + if (*str++ != *pattern++) + return false; + break; + } + } + return !*str && !*pattern; +} + +std::string& +Str::ltrim(std::string& str) { + str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char c) { return !std::isspace(c); })); + return str; +} + +std::string& +Str::rtrim(std::string& str) { + str.erase(std::find_if(str.rbegin(), str.rend(), [](char c) { return !std::isspace(c); }).base(), str.end()); + return str; +} + +std::string& +Str::trim(std::string& str) { + return ltrim(rtrim(str)); +} + +bool +Str::startsWith(const std::string& str, const std::string& start) { + return (str.length() >= start.length()) && (str.compare(0, start.length(), start) == 0); +} + +bool +Str::endsWith(const std::string& str, const std::string& end) { + return (str.length() >= end.length()) && (str.compare(str.length() - end.length(), end.length(), end) == 0); +} + +std::string& +Str::replaceAll(std::string& str, char replaceWhat, char replaceWith) { + std::replace(str.begin(), str.end(), replaceWhat, replaceWith); + return str; +} + +std::string& +Str::replaceAll(std::string& str, const std::string& replaceWhat, const std::string& replaceWith) { + if (replaceWhat == replaceWith) + return str; + std::size_t foundAt = std::string::npos; + while ((foundAt = str.find(replaceWhat, foundAt + 1)) != std::string::npos) { + str.replace(foundAt, replaceWhat.length(), replaceWith); + } + return str; +} + +void +Str::replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const base::type::string_t& replaceWith) { + std::size_t foundAt = base::type::string_t::npos; + while ((foundAt = str.find(replaceWhat, foundAt + 1)) != base::type::string_t::npos) { + if (foundAt > 0 && str[foundAt - 1] == base::consts::kFormatSpecifierChar) { + str.erase(foundAt - 1, 1); + ++foundAt; + } else { + str.replace(foundAt, replaceWhat.length(), replaceWith); + return; + } + } +} +#if defined(ELPP_UNICODE) +void +Str::replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const std::string& replaceWith) { + replaceFirstWithEscape(str, replaceWhat, base::type::string_t(replaceWith.begin(), replaceWith.end())); +} +#endif // defined(ELPP_UNICODE) + +std::string& +Str::toUpper(std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), [](char c) { return static_cast(::toupper(c)); }); + return str; +} + +bool +Str::cStringEq(const char* s1, const char* s2) { + if (s1 == nullptr && s2 == nullptr) + return true; + if (s1 == nullptr || s2 == nullptr) + return false; + return strcmp(s1, s2) == 0; +} + +bool +Str::cStringCaseEq(const char* s1, const char* s2) { + if (s1 == nullptr && s2 == nullptr) + return true; + if (s1 == nullptr || s2 == nullptr) + return false; + + // With thanks to cygwin for this code + int d = 0; + + while (true) { + const int c1 = toupper(*s1++); + const int c2 = toupper(*s2++); + + if (((d = c1 - c2) != 0) || (c2 == '\0')) { + break; + } + } + + return d == 0; +} + +bool +Str::contains(const char* str, char c) { + for (; *str; ++str) { + if (*str == c) + return true; + } + return false; +} + +char* +Str::convertAndAddToBuff(std::size_t n, int len, char* buf, const char* bufLim, bool zeroPadded) { + char localBuff[10] = ""; + char* p = localBuff + sizeof(localBuff) - 2; + if (n > 0) { + for (; n > 0 && p > localBuff && len > 0; n /= 10, --len) *--p = static_cast(n % 10 + '0'); + } else { + *--p = '0'; + --len; + } + if (zeroPadded) + while (p > localBuff && len-- > 0) *--p = static_cast('0'); + return addToBuff(p, buf, bufLim); +} + +char* +Str::addToBuff(const char* str, char* buf, const char* bufLim) { + while ((buf < bufLim) && ((*buf = *str++) != '\0')) ++buf; + return buf; +} + +char* +Str::clearBuff(char buff[], std::size_t lim) { + STRCPY(buff, "", lim); + ELPP_UNUSED(lim); // For *nix we dont have anything using lim in above STRCPY macro + return buff; +} + +/// @brief Converst wchar* to char* +/// NOTE: Need to free return value after use! +char* +Str::wcharPtrToCharPtr(const wchar_t* line) { + std::size_t len_ = wcslen(line) + 1; + char* buff_ = static_cast(malloc(len_ + 1)); +#if ELPP_OS_UNIX || (ELPP_OS_WINDOWS && !ELPP_CRT_DBG_WARNINGS) + std::wcstombs(buff_, line, len_); +#elif ELPP_OS_WINDOWS + std::size_t convCount_ = 0; + mbstate_t mbState_; + ::memset(static_cast(&mbState_), 0, sizeof(mbState_)); + wcsrtombs_s(&convCount_, buff_, len_, &line, len_, &mbState_); +#endif // ELPP_OS_UNIX || (ELPP_OS_WINDOWS && !ELPP_CRT_DBG_WARNINGS) + return buff_; +} + +// OS + +#if ELPP_OS_WINDOWS +/// @brief Gets environment variables for Windows based OS. +/// We are not using getenv(const char*) because of CRT deprecation +/// @param varname Variable name to get environment variable value for +/// @return If variable exist the value of it otherwise nullptr +const char* +OS::getWindowsEnvironmentVariable(const char* varname) { + const DWORD bufferLen = 50; + static char buffer[bufferLen]; + if (GetEnvironmentVariableA(varname, buffer, bufferLen)) { + return buffer; + } + return nullptr; +} +#endif // ELPP_OS_WINDOWS +#if ELPP_OS_ANDROID +std::string +OS::getProperty(const char* prop) { + char propVal[PROP_VALUE_MAX + 1]; + int ret = __system_property_get(prop, propVal); + return ret == 0 ? std::string() : std::string(propVal); +} + +std::string +OS::getDeviceName(void) { + std::stringstream ss; + std::string manufacturer = getProperty("ro.product.manufacturer"); + std::string model = getProperty("ro.product.model"); + if (manufacturer.empty() || model.empty()) { + return std::string(); + } + ss << manufacturer << "-" << model; + return ss.str(); +} +#endif // ELPP_OS_ANDROID + +const std::string +OS::getBashOutput(const char* command) { +#if (ELPP_OS_UNIX && !ELPP_OS_ANDROID && !ELPP_CYGWIN) + if (command == nullptr) { + return std::string(); + } + FILE* proc = nullptr; + if ((proc = popen(command, "r")) == nullptr) { + ELPP_INTERNAL_ERROR("\nUnable to run command [" << command << "]", true); + return std::string(); + } + char hBuff[4096]; + if (fgets(hBuff, sizeof(hBuff), proc) != nullptr) { + pclose(proc); + const std::size_t buffLen = strlen(hBuff); + if (buffLen > 0 && hBuff[buffLen - 1] == '\n') { + hBuff[buffLen - 1] = '\0'; + } + return std::string(hBuff); + } else { + pclose(proc); + } + return std::string(); +#else + ELPP_UNUSED(command); + return std::string(); +#endif // (ELPP_OS_UNIX && !ELPP_OS_ANDROID && !ELPP_CYGWIN) +} + +std::string +OS::getEnvironmentVariable(const char* variableName, const char* defaultVal, const char* alternativeBashCommand) { +#if ELPP_OS_UNIX + const char* val = getenv(variableName); +#elif ELPP_OS_WINDOWS + const char* val = getWindowsEnvironmentVariable(variableName); +#endif // ELPP_OS_UNIX + if ((val == nullptr) || ((strcmp(val, "") == 0))) { +#if ELPP_OS_UNIX && defined(ELPP_FORCE_ENV_VAR_FROM_BASH) + // Try harder on unix-based systems + std::string valBash = base::utils::OS::getBashOutput(alternativeBashCommand); + if (valBash.empty()) { + return std::string(defaultVal); + } else { + return valBash; + } +#elif ELPP_OS_WINDOWS || ELPP_OS_UNIX + ELPP_UNUSED(alternativeBashCommand); + return std::string(defaultVal); +#endif // ELPP_OS_UNIX && defined(ELPP_FORCE_ENV_VAR_FROM_BASH) + } + return std::string(val); +} + +std::string +OS::currentUser(void) { +#if ELPP_OS_UNIX && !ELPP_OS_ANDROID + return getEnvironmentVariable("USER", base::consts::kUnknownUser, "whoami"); +#elif ELPP_OS_WINDOWS + return getEnvironmentVariable("USERNAME", base::consts::kUnknownUser); +#elif ELPP_OS_ANDROID + ELPP_UNUSED(base::consts::kUnknownUser); + return std::string("android"); +#else + return std::string(); +#endif // ELPP_OS_UNIX && !ELPP_OS_ANDROID +} + +std::string +OS::currentHost(void) { +#if ELPP_OS_UNIX && !ELPP_OS_ANDROID + return getEnvironmentVariable("HOSTNAME", base::consts::kUnknownHost, "hostname"); +#elif ELPP_OS_WINDOWS + return getEnvironmentVariable("COMPUTERNAME", base::consts::kUnknownHost); +#elif ELPP_OS_ANDROID + ELPP_UNUSED(base::consts::kUnknownHost); + return getDeviceName(); +#else + return std::string(); +#endif // ELPP_OS_UNIX && !ELPP_OS_ANDROID +} + +bool +OS::termSupportsColor(void) { + std::string term = getEnvironmentVariable("TERM", ""); + return term == "xterm" || term == "xterm-color" || term == "xterm-256color" || term == "screen" || + term == "linux" || term == "cygwin" || term == "screen-256color"; +} + +// DateTime + +void +DateTime::gettimeofday(struct timeval* tv) { +#if ELPP_OS_WINDOWS + if (tv != nullptr) { +#if ELPP_COMPILER_MSVC || defined(_MSC_EXTENSIONS) + const unsigned __int64 delta_ = 11644473600000000Ui64; +#else + const unsigned __int64 delta_ = 11644473600000000ULL; +#endif // ELPP_COMPILER_MSVC || defined(_MSC_EXTENSIONS) + const double secOffSet = 0.000001; + const unsigned long usecOffSet = 1000000; + FILETIME fileTime; + GetSystemTimeAsFileTime(&fileTime); + unsigned __int64 present = 0; + present |= fileTime.dwHighDateTime; + present = present << 32; + present |= fileTime.dwLowDateTime; + present /= 10; // mic-sec + // Subtract the difference + present -= delta_; + tv->tv_sec = static_cast(present * secOffSet); + tv->tv_usec = static_cast(present % usecOffSet); + } +#else + ::gettimeofday(tv, nullptr); +#endif // ELPP_OS_WINDOWS +} + +std::string +DateTime::getDateTime(const char* format, const base::SubsecondPrecision* ssPrec) { + struct timeval currTime; + gettimeofday(&currTime); + return timevalToString(currTime, format, ssPrec); +} + +std::string +DateTime::timevalToString(struct timeval tval, const char* format, const el::base::SubsecondPrecision* ssPrec) { + struct ::tm timeInfo; + buildTimeInfo(&tval, &timeInfo); + const int kBuffSize = 30; + char buff_[kBuffSize] = ""; + parseFormat(buff_, kBuffSize, format, &timeInfo, static_cast(tval.tv_usec / ssPrec->m_offset), ssPrec); + return std::string(buff_); +} + +base::type::string_t +DateTime::formatTime(unsigned long long time, base::TimestampUnit timestampUnit) { + base::type::EnumType start = static_cast(timestampUnit); + const base::type::char_t* unit = base::consts::kTimeFormats[start].unit; + for (base::type::EnumType i = start; i < base::consts::kTimeFormatsCount - 1; ++i) { + if (time <= base::consts::kTimeFormats[i].value) { + break; + } + if (base::consts::kTimeFormats[i].value == 1000.0f && time / 1000.0f < 1.9f) { + break; + } + time /= static_cast(base::consts::kTimeFormats[i].value); + unit = base::consts::kTimeFormats[i + 1].unit; + } + base::type::stringstream_t ss; + ss << time << " " << unit; + return ss.str(); +} + +unsigned long long +DateTime::getTimeDifference(const struct timeval& endTime, const struct timeval& startTime, + base::TimestampUnit timestampUnit) { + if (timestampUnit == base::TimestampUnit::Microsecond) { + return static_cast( + static_cast(1000000 * endTime.tv_sec + endTime.tv_usec) - + static_cast(1000000 * startTime.tv_sec + startTime.tv_usec)); + } + // milliseconds + auto conv = [](const struct timeval& tim) { + return static_cast((tim.tv_sec * 1000) + (tim.tv_usec / 1000)); + }; + return static_cast(conv(endTime) - conv(startTime)); +} + +struct ::tm* +DateTime::buildTimeInfo(struct timeval* currTime, struct ::tm* timeInfo) { +#if ELPP_OS_UNIX + time_t rawTime = currTime->tv_sec; + ::elpptime_r(&rawTime, timeInfo); + return timeInfo; +#else +#if ELPP_COMPILER_MSVC + ELPP_UNUSED(currTime); + time_t t; +#if defined(_USE_32BIT_TIME_T) + _time32(&t); +#else + _time64(&t); +#endif + elpptime_s(timeInfo, &t); + return timeInfo; +#else + // For any other compilers that don't have CRT warnings issue e.g, MinGW or TDM GCC- we use different method + time_t rawTime = currTime->tv_sec; + struct tm* tmInf = elpptime(&rawTime); + *timeInfo = *tmInf; + return timeInfo; +#endif // ELPP_COMPILER_MSVC +#endif // ELPP_OS_UNIX +} + +char* +DateTime::parseFormat(char* buf, std::size_t bufSz, const char* format, const struct tm* tInfo, std::size_t msec, + const base::SubsecondPrecision* ssPrec) { + const char* bufLim = buf + bufSz; + for (; *format; ++format) { + if (*format == base::consts::kFormatSpecifierChar) { + switch (*++format) { + case base::consts::kFormatSpecifierChar: // Escape + break; + case '\0': // End + --format; + break; + case 'd': // Day + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_mday, 2, buf, bufLim); + continue; + case 'a': // Day of week (short) + buf = base::utils::Str::addToBuff(base::consts::kDaysAbbrev[tInfo->tm_wday], buf, bufLim); + continue; + case 'A': // Day of week (long) + buf = base::utils::Str::addToBuff(base::consts::kDays[tInfo->tm_wday], buf, bufLim); + continue; + case 'M': // month + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_mon + 1, 2, buf, bufLim); + continue; + case 'b': // month (short) + buf = base::utils::Str::addToBuff(base::consts::kMonthsAbbrev[tInfo->tm_mon], buf, bufLim); + continue; + case 'B': // month (long) + buf = base::utils::Str::addToBuff(base::consts::kMonths[tInfo->tm_mon], buf, bufLim); + continue; + case 'y': // year (two digits) + buf = + base::utils::Str::convertAndAddToBuff(tInfo->tm_year + base::consts::kYearBase, 2, buf, bufLim); + continue; + case 'Y': // year (four digits) + buf = + base::utils::Str::convertAndAddToBuff(tInfo->tm_year + base::consts::kYearBase, 4, buf, bufLim); + continue; + case 'h': // hour (12-hour) + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_hour % 12, 2, buf, bufLim); + continue; + case 'H': // hour (24-hour) + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_hour, 2, buf, bufLim); + continue; + case 'm': // minute + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_min, 2, buf, bufLim); + continue; + case 's': // second + buf = base::utils::Str::convertAndAddToBuff(tInfo->tm_sec, 2, buf, bufLim); + continue; + case 'z': // subsecond part + case 'g': + buf = base::utils::Str::convertAndAddToBuff(msec, ssPrec->m_width, buf, bufLim); + continue; + case 'F': // AM/PM + buf = base::utils::Str::addToBuff((tInfo->tm_hour >= 12) ? base::consts::kPm : base::consts::kAm, + buf, bufLim); + continue; + default: + continue; + } + } + if (buf == bufLim) + break; + *buf++ = *format; + } + return buf; +} + +// CommandLineArgs + +void +CommandLineArgs::setArgs(int argc, char** argv) { + m_params.clear(); + m_paramsWithValue.clear(); + if (argc == 0 || argv == nullptr) { + return; + } + m_argc = argc; + m_argv = argv; + for (int i = 1; i < m_argc; ++i) { + const char* v = (strstr(m_argv[i], "=")); + if (v != nullptr && strlen(v) > 0) { + std::string key = std::string(m_argv[i]); + key = key.substr(0, key.find_first_of('=')); + if (hasParamWithValue(key.c_str())) { + ELPP_INTERNAL_INFO(1, "Skipping [" << key << "] arg since it already has value [" + << getParamValue(key.c_str()) << "]"); + } else { + m_paramsWithValue.insert(std::make_pair(key, std::string(v + 1))); + } + } + if (v == nullptr) { + if (hasParam(m_argv[i])) { + ELPP_INTERNAL_INFO(1, "Skipping [" << m_argv[i] << "] arg since it already exists"); + } else { + m_params.push_back(std::string(m_argv[i])); + } + } + } +} + +bool +CommandLineArgs::hasParamWithValue(const char* paramKey) const { + return m_paramsWithValue.find(std::string(paramKey)) != m_paramsWithValue.end(); +} + +const char* +CommandLineArgs::getParamValue(const char* paramKey) const { + std::unordered_map::const_iterator iter = m_paramsWithValue.find(std::string(paramKey)); + return iter != m_paramsWithValue.end() ? iter->second.c_str() : ""; +} + +bool +CommandLineArgs::hasParam(const char* paramKey) const { + return std::find(m_params.begin(), m_params.end(), std::string(paramKey)) != m_params.end(); +} + +bool +CommandLineArgs::empty(void) const { + return m_params.empty() && m_paramsWithValue.empty(); +} + +std::size_t +CommandLineArgs::size(void) const { + return m_params.size() + m_paramsWithValue.size(); +} + +base::type::ostream_t& +operator<<(base::type::ostream_t& os, const CommandLineArgs& c) { + for (int i = 1; i < c.m_argc; ++i) { + os << ELPP_LITERAL("[") << c.m_argv[i] << ELPP_LITERAL("]"); + if (i < c.m_argc - 1) { + os << ELPP_LITERAL(" "); + } + } + return os; +} + +} // namespace utils + +// el::base::threading +namespace threading { + +#if ELPP_THREADING_ENABLED +#if ELPP_USE_STD_THREADING +#if ELPP_ASYNC_LOGGING +static void +msleep(int ms) { + // Only when async logging enabled - this is because async is strict on compiler +#if defined(ELPP_NO_SLEEP_FOR) + usleep(ms * 1000); +#else + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); +#endif // defined(ELPP_NO_SLEEP_FOR) +} +#endif // ELPP_ASYNC_LOGGING +#endif // !ELPP_USE_STD_THREADING +#endif // ELPP_THREADING_ENABLED + +} // namespace threading + +// el::base + +// SubsecondPrecision + +void +SubsecondPrecision::init(int width) { + if (width < 1 || width > 6) { + width = base::consts::kDefaultSubsecondPrecision; + } + m_width = width; + switch (m_width) { + case 3: + m_offset = 1000; + break; + case 4: + m_offset = 100; + break; + case 5: + m_offset = 10; + break; + case 6: + m_offset = 1; + break; + default: + m_offset = 1000; + break; + } +} + +// LogFormat + +LogFormat::LogFormat(void) + : m_level(Level::Unknown), + m_userFormat(base::type::string_t()), + m_format(base::type::string_t()), + m_dateTimeFormat(std::string()), + m_flags(0x0), + m_currentUser(base::utils::OS::currentUser()), + m_currentHost(base::utils::OS::currentHost()) { +} + +LogFormat::LogFormat(Level level, const base::type::string_t& format) + : m_level(level), + m_userFormat(format), + m_currentUser(base::utils::OS::currentUser()), + m_currentHost(base::utils::OS::currentHost()) { + parseFromFormat(m_userFormat); +} + +LogFormat::LogFormat(const LogFormat& logFormat) + : m_level(logFormat.m_level), + m_userFormat(logFormat.m_userFormat), + m_format(logFormat.m_format), + m_dateTimeFormat(logFormat.m_dateTimeFormat), + m_flags(logFormat.m_flags), + m_currentUser(logFormat.m_currentUser), + m_currentHost(logFormat.m_currentHost) { +} + +LogFormat::LogFormat(LogFormat&& logFormat) { + m_level = std::move(logFormat.m_level); + m_userFormat = std::move(logFormat.m_userFormat); + m_format = std::move(logFormat.m_format); + m_dateTimeFormat = std::move(logFormat.m_dateTimeFormat); + m_flags = std::move(logFormat.m_flags); + m_currentUser = std::move(logFormat.m_currentUser); + m_currentHost = std::move(logFormat.m_currentHost); +} + +LogFormat& +LogFormat::operator=(const LogFormat& logFormat) { + if (&logFormat != this) { + m_level = logFormat.m_level; + m_userFormat = logFormat.m_userFormat; + m_dateTimeFormat = logFormat.m_dateTimeFormat; + m_flags = logFormat.m_flags; + m_currentUser = logFormat.m_currentUser; + m_currentHost = logFormat.m_currentHost; + } + return *this; +} + +bool +LogFormat::operator==(const LogFormat& other) { + return m_level == other.m_level && m_userFormat == other.m_userFormat && m_format == other.m_format && + m_dateTimeFormat == other.m_dateTimeFormat && m_flags == other.m_flags; +} + +/// @brief Updates format to be used while logging. +/// @param userFormat User provided format +void +LogFormat::parseFromFormat(const base::type::string_t& userFormat) { + // We make copy because we will be changing the format + // i.e, removing user provided date format from original format + // and then storing it. + base::type::string_t formatCopy = userFormat; + m_flags = 0x0; + auto conditionalAddFlag = [&](const base::type::char_t* specifier, base::FormatFlags flag) { + std::size_t foundAt = base::type::string_t::npos; + while ((foundAt = formatCopy.find(specifier, foundAt + 1)) != base::type::string_t::npos) { + if (foundAt > 0 && formatCopy[foundAt - 1] == base::consts::kFormatSpecifierChar) { + if (hasFlag(flag)) { + // If we already have flag we remove the escape chars so that '%%' is turned to '%' + // even after specifier resolution - this is because we only replaceFirst specifier + formatCopy.erase(foundAt - 1, 1); + ++foundAt; + } + } else { + if (!hasFlag(flag)) + addFlag(flag); + } + } + }; + conditionalAddFlag(base::consts::kAppNameFormatSpecifier, base::FormatFlags::AppName); + conditionalAddFlag(base::consts::kSeverityLevelFormatSpecifier, base::FormatFlags::Level); + conditionalAddFlag(base::consts::kSeverityLevelShortFormatSpecifier, base::FormatFlags::LevelShort); + conditionalAddFlag(base::consts::kLoggerIdFormatSpecifier, base::FormatFlags::LoggerId); + conditionalAddFlag(base::consts::kThreadIdFormatSpecifier, base::FormatFlags::ThreadId); + conditionalAddFlag(base::consts::kLogFileFormatSpecifier, base::FormatFlags::File); + conditionalAddFlag(base::consts::kLogFileBaseFormatSpecifier, base::FormatFlags::FileBase); + conditionalAddFlag(base::consts::kLogLineFormatSpecifier, base::FormatFlags::Line); + conditionalAddFlag(base::consts::kLogLocationFormatSpecifier, base::FormatFlags::Location); + conditionalAddFlag(base::consts::kLogFunctionFormatSpecifier, base::FormatFlags::Function); + conditionalAddFlag(base::consts::kCurrentUserFormatSpecifier, base::FormatFlags::User); + conditionalAddFlag(base::consts::kCurrentHostFormatSpecifier, base::FormatFlags::Host); + conditionalAddFlag(base::consts::kMessageFormatSpecifier, base::FormatFlags::LogMessage); + conditionalAddFlag(base::consts::kVerboseLevelFormatSpecifier, base::FormatFlags::VerboseLevel); + // For date/time we need to extract user's date format first + std::size_t dateIndex = std::string::npos; + if ((dateIndex = formatCopy.find(base::consts::kDateTimeFormatSpecifier)) != std::string::npos) { + while (dateIndex > 0 && formatCopy[dateIndex - 1] == base::consts::kFormatSpecifierChar) { + dateIndex = formatCopy.find(base::consts::kDateTimeFormatSpecifier, dateIndex + 1); + } + if (dateIndex != std::string::npos) { + addFlag(base::FormatFlags::DateTime); + updateDateFormat(dateIndex, formatCopy); + } + } + m_format = formatCopy; + updateFormatSpec(); +} + +void +LogFormat::updateDateFormat(std::size_t index, base::type::string_t& currFormat) { + if (hasFlag(base::FormatFlags::DateTime)) { + index += ELPP_STRLEN(base::consts::kDateTimeFormatSpecifier); + } + const base::type::char_t* ptr = currFormat.c_str() + index; + if ((currFormat.size() > index) && (ptr[0] == '{')) { + // User has provided format for date/time + ++ptr; + int count = 1; // Start by 1 in order to remove starting brace + std::stringstream ss; + for (; *ptr; ++ptr, ++count) { + if (*ptr == '}') { + ++count; // In order to remove ending brace + break; + } + ss << static_cast(*ptr); + } + currFormat.erase(index, count); + m_dateTimeFormat = ss.str(); + } else { + // No format provided, use default + if (hasFlag(base::FormatFlags::DateTime)) { + m_dateTimeFormat = std::string(base::consts::kDefaultDateTimeFormat); + } + } +} + +void +LogFormat::updateFormatSpec(void) { + // Do not use switch over strongly typed enums because Intel C++ compilers dont support them yet. + if (m_level == Level::Debug) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kDebugLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kDebugLevelShortLogValue); + } else if (m_level == Level::Info) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kInfoLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kInfoLevelShortLogValue); + } else if (m_level == Level::Warning) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kWarningLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kWarningLevelShortLogValue); + } else if (m_level == Level::Error) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kErrorLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kErrorLevelShortLogValue); + } else if (m_level == Level::Fatal) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kFatalLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kFatalLevelShortLogValue); + } else if (m_level == Level::Verbose) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kVerboseLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kVerboseLevelShortLogValue); + } else if (m_level == Level::Trace) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelFormatSpecifier, + base::consts::kTraceLevelLogValue); + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kSeverityLevelShortFormatSpecifier, + base::consts::kTraceLevelShortLogValue); + } + if (hasFlag(base::FormatFlags::User)) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kCurrentUserFormatSpecifier, m_currentUser); + } + if (hasFlag(base::FormatFlags::Host)) { + base::utils::Str::replaceFirstWithEscape(m_format, base::consts::kCurrentHostFormatSpecifier, m_currentHost); + } + // Ignore Level::Global and Level::Unknown +} + +// TypedConfigurations + +TypedConfigurations::TypedConfigurations(Configurations* configurations, + base::LogStreamsReferenceMap* logStreamsReference) { + m_configurations = configurations; + m_logStreamsReference = logStreamsReference; + build(m_configurations); +} + +TypedConfigurations::TypedConfigurations(const TypedConfigurations& other) { + this->m_configurations = other.m_configurations; + this->m_logStreamsReference = other.m_logStreamsReference; + build(m_configurations); +} + +bool +TypedConfigurations::enabled(Level level) { + return getConfigByVal(level, &m_enabledMap, "enabled"); +} + +bool +TypedConfigurations::toFile(Level level) { + return getConfigByVal(level, &m_toFileMap, "toFile"); +} + +const std::string& +TypedConfigurations::filename(Level level) { + return getConfigByRef(level, &m_filenameMap, "filename"); +} + +bool +TypedConfigurations::toStandardOutput(Level level) { + return getConfigByVal(level, &m_toStandardOutputMap, "toStandardOutput"); +} + +const base::LogFormat& +TypedConfigurations::logFormat(Level level) { + return getConfigByRef(level, &m_logFormatMap, "logFormat"); +} + +const base::SubsecondPrecision& +TypedConfigurations::subsecondPrecision(Level level) { + return getConfigByRef(level, &m_subsecondPrecisionMap, "subsecondPrecision"); +} + +const base::MillisecondsWidth& +TypedConfigurations::millisecondsWidth(Level level) { + return getConfigByRef(level, &m_subsecondPrecisionMap, "millisecondsWidth"); +} + +bool +TypedConfigurations::performanceTracking(Level level) { + return getConfigByVal(level, &m_performanceTrackingMap, "performanceTracking"); +} + +base::type::fstream_t* +TypedConfigurations::fileStream(Level level) { + return getConfigByRef(level, &m_fileStreamMap, "fileStream").get(); +} + +std::size_t +TypedConfigurations::maxLogFileSize(Level level) { + return getConfigByVal(level, &m_maxLogFileSizeMap, "maxLogFileSize"); +} + +std::size_t +TypedConfigurations::logFlushThreshold(Level level) { + return getConfigByVal(level, &m_logFlushThresholdMap, "logFlushThreshold"); +} + +void +TypedConfigurations::build(Configurations* configurations) { + base::threading::ScopedLock scopedLock(lock()); + auto getBool = [](std::string boolStr) -> bool { // Pass by value for trimming + base::utils::Str::trim(boolStr); + return (boolStr == "TRUE" || boolStr == "true" || boolStr == "1"); + }; + std::vector withFileSizeLimit; + for (Configurations::const_iterator it = configurations->begin(); it != configurations->end(); ++it) { + Configuration* conf = *it; + // We cannot use switch on strong enums because Intel C++ dont support them yet + if (conf->configurationType() == ConfigurationType::Enabled) { + setValue(conf->level(), getBool(conf->value()), &m_enabledMap); + } else if (conf->configurationType() == ConfigurationType::ToFile) { + setValue(conf->level(), getBool(conf->value()), &m_toFileMap); + } else if (conf->configurationType() == ConfigurationType::ToStandardOutput) { + setValue(conf->level(), getBool(conf->value()), &m_toStandardOutputMap); + } else if (conf->configurationType() == ConfigurationType::Filename) { + // We do not yet configure filename but we will configure in another + // loop. This is because if file cannot be created, we will force ToFile + // to be false. Because configuring logger is not necessarily performance + // sensative operation, we can live with another loop; (by the way this loop + // is not very heavy either) + } else if (conf->configurationType() == ConfigurationType::Format) { + setValue(conf->level(), + base::LogFormat(conf->level(), base::type::string_t(conf->value().begin(), conf->value().end())), + &m_logFormatMap); + } else if (conf->configurationType() == ConfigurationType::SubsecondPrecision) { + setValue(Level::Global, base::SubsecondPrecision(static_cast(getULong(conf->value()))), + &m_subsecondPrecisionMap); + } else if (conf->configurationType() == ConfigurationType::PerformanceTracking) { + setValue(Level::Global, getBool(conf->value()), &m_performanceTrackingMap); + } else if (conf->configurationType() == ConfigurationType::MaxLogFileSize) { + auto v = getULong(conf->value()); + setValue(conf->level(), static_cast(v), &m_maxLogFileSizeMap); + if (v != 0) { + withFileSizeLimit.push_back(conf); + } + } else if (conf->configurationType() == ConfigurationType::LogFlushThreshold) { + setValue(conf->level(), static_cast(getULong(conf->value())), &m_logFlushThresholdMap); + } + } + // As mentioned earlier, we will now set filename configuration in separate loop to deal with non-existent files + for (Configurations::const_iterator it = configurations->begin(); it != configurations->end(); ++it) { + Configuration* conf = *it; + if (conf->configurationType() == ConfigurationType::Filename) { + insertFile(conf->level(), conf->value()); + } + } + for (std::vector::iterator conf = withFileSizeLimit.begin(); conf != withFileSizeLimit.end(); + ++conf) { + // This is not unsafe as mutex is locked in currect scope + unsafeValidateFileRolling((*conf)->level(), base::defaultPreRollOutCallback); + } +} + +unsigned long +TypedConfigurations::getULong(std::string confVal) { + bool valid = true; + base::utils::Str::trim(confVal); + valid = !confVal.empty() && std::find_if(confVal.begin(), confVal.end(), + [](char c) { return !base::utils::Str::isDigit(c); }) == confVal.end(); + if (!valid) { + valid = false; + ELPP_ASSERT(valid, "Configuration value not a valid integer [" << confVal << "]"); + return 0; + } + return atol(confVal.c_str()); +} + +std::string +TypedConfigurations::resolveFilename(const std::string& filename) { + std::string resultingFilename = filename; + std::size_t dateIndex = std::string::npos; + std::string dateTimeFormatSpecifierStr = std::string(base::consts::kDateTimeFormatSpecifierForFilename); + if ((dateIndex = resultingFilename.find(dateTimeFormatSpecifierStr.c_str())) != std::string::npos) { + while (dateIndex > 0 && resultingFilename[dateIndex - 1] == base::consts::kFormatSpecifierChar) { + dateIndex = resultingFilename.find(dateTimeFormatSpecifierStr.c_str(), dateIndex + 1); + } + if (dateIndex != std::string::npos) { + const char* ptr = resultingFilename.c_str() + dateIndex; + // Goto end of specifier + ptr += dateTimeFormatSpecifierStr.size(); + std::string fmt; + if ((resultingFilename.size() > dateIndex) && (ptr[0] == '{')) { + // User has provided format for date/time + ++ptr; + int count = 1; // Start by 1 in order to remove starting brace + std::stringstream ss; + for (; *ptr; ++ptr, ++count) { + if (*ptr == '}') { + ++count; // In order to remove ending brace + break; + } + ss << *ptr; + } + resultingFilename.erase(dateIndex + dateTimeFormatSpecifierStr.size(), count); + fmt = ss.str(); + } else { + fmt = std::string(base::consts::kDefaultDateTimeFormatInFilename); + } + base::SubsecondPrecision ssPrec(3); + std::string now = base::utils::DateTime::getDateTime(fmt.c_str(), &ssPrec); + base::utils::Str::replaceAll(now, '/', '-'); // Replace path element since we are dealing with filename + base::utils::Str::replaceAll(resultingFilename, dateTimeFormatSpecifierStr, now); + } + } + return resultingFilename; +} + +void +TypedConfigurations::insertFile(Level level, const std::string& fullFilename) { + std::string resolvedFilename = resolveFilename(fullFilename); + if (resolvedFilename.empty()) { + std::cerr << "Could not load empty file for logging, please re-check your configurations for level [" + << LevelHelper::convertToString(level) << "]"; + } + std::string filePath = + base::utils::File::extractPathFromFilename(resolvedFilename, base::consts::kFilePathSeperator); + if (filePath.size() < resolvedFilename.size()) { + base::utils::File::createPath(filePath); + } + auto create = [&](Level level) { + base::LogStreamsReferenceMap::iterator filestreamIter = m_logStreamsReference->find(resolvedFilename); + base::type::fstream_t* fs = nullptr; + if (filestreamIter == m_logStreamsReference->end()) { + // We need a completely new stream, nothing to share with + fs = base::utils::File::newFileStream(resolvedFilename); + m_filenameMap.insert(std::make_pair(level, resolvedFilename)); + m_fileStreamMap.insert(std::make_pair(level, base::FileStreamPtr(fs))); + m_logStreamsReference->insert( + std::make_pair(resolvedFilename, base::FileStreamPtr(m_fileStreamMap.at(level)))); + } else { + // Woops! we have an existing one, share it! + m_filenameMap.insert(std::make_pair(level, filestreamIter->first)); + m_fileStreamMap.insert(std::make_pair(level, base::FileStreamPtr(filestreamIter->second))); + fs = filestreamIter->second.get(); + } + if (fs == nullptr) { + // We display bad file error from newFileStream() + ELPP_INTERNAL_ERROR("Setting [TO_FILE] of [" << LevelHelper::convertToString(level) << "] to FALSE", false); + setValue(level, false, &m_toFileMap); + } + }; + // If we dont have file conf for any level, create it for Level::Global first + // otherwise create for specified level + create(m_filenameMap.empty() && m_fileStreamMap.empty() ? Level::Global : level); +} + +bool +TypedConfigurations::unsafeValidateFileRolling(Level level, const PreRollOutCallback& preRollOutCallback) { + base::type::fstream_t* fs = unsafeGetConfigByRef(level, &m_fileStreamMap, "fileStream").get(); + if (fs == nullptr) { + return true; + } + std::size_t maxLogFileSize = unsafeGetConfigByVal(level, &m_maxLogFileSizeMap, "maxLogFileSize"); + std::size_t currFileSize = base::utils::File::getSizeOfFile(fs); + if (maxLogFileSize != 0 && currFileSize >= maxLogFileSize) { + std::string fname = unsafeGetConfigByRef(level, &m_filenameMap, "filename"); + ELPP_INTERNAL_INFO(1, "Truncating log file [" << fname << "] as a result of configurations for level [" + << LevelHelper::convertToString(level) << "]"); + fs->close(); + preRollOutCallback(fname.c_str(), currFileSize, level); + fs->open(fname, std::fstream::out | std::fstream::trunc); + return true; + } + return false; +} + +// RegisteredHitCounters + +bool +RegisteredHitCounters::validateEveryN(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + base::threading::ScopedLock scopedLock(lock()); + base::HitCounter* counter = get(filename, lineNumber); + if (counter == nullptr) { + registerNew(counter = new base::HitCounter(filename, lineNumber)); + } + counter->validateHitCounts(n); + bool result = (n >= 1 && counter->hitCounts() != 0 && counter->hitCounts() % n == 0); + return result; +} + +/// @brief Validates counter for hits >= N, i.e, registers new if does not exist otherwise updates original one +/// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned +bool +RegisteredHitCounters::validateAfterN(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + base::threading::ScopedLock scopedLock(lock()); + base::HitCounter* counter = get(filename, lineNumber); + if (counter == nullptr) { + registerNew(counter = new base::HitCounter(filename, lineNumber)); + } + // Do not use validateHitCounts here since we do not want to reset counter here + // Note the >= instead of > because we are incrementing + // after this check + if (counter->hitCounts() >= n) + return true; + counter->increment(); + return false; +} + +/// @brief Validates counter for hits are <= n, i.e, registers new if does not exist otherwise updates original one +/// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned +bool +RegisteredHitCounters::validateNTimes(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + base::threading::ScopedLock scopedLock(lock()); + base::HitCounter* counter = get(filename, lineNumber); + if (counter == nullptr) { + registerNew(counter = new base::HitCounter(filename, lineNumber)); + } + counter->increment(); + // Do not use validateHitCounts here since we do not want to reset counter here + if (counter->hitCounts() <= n) + return true; + return false; +} + +// RegisteredLoggers + +RegisteredLoggers::RegisteredLoggers(const LogBuilderPtr& defaultLogBuilder) : m_defaultLogBuilder(defaultLogBuilder) { + m_defaultConfigurations.setToDefault(); +} + +Logger* +RegisteredLoggers::get(const std::string& id, bool forceCreation) { + base::threading::ScopedLock scopedLock(lock()); + Logger* logger_ = base::utils::Registry::get(id); + if (logger_ == nullptr && forceCreation) { + bool validId = Logger::isValidId(id); + if (!validId) { + ELPP_ASSERT(validId, "Invalid logger ID [" << id << "]. Not registering this logger."); + return nullptr; + } + logger_ = new Logger(id, m_defaultConfigurations, &m_logStreamsReference); + logger_->m_logBuilder = m_defaultLogBuilder; + registerNew(id, logger_); + LoggerRegistrationCallback* callback = nullptr; + for (const std::pair& h : + m_loggerRegistrationCallbacks) { + callback = h.second.get(); + if (callback != nullptr && callback->enabled()) { + callback->handle(logger_); + } + } + } + return logger_; +} + +bool +RegisteredLoggers::remove(const std::string& id) { + if (id == base::consts::kDefaultLoggerId) { + return false; + } + // get has internal lock + Logger* logger = base::utils::Registry::get(id); + if (logger != nullptr) { + // unregister has internal lock + unregister(logger); + } + return true; +} + +void +RegisteredLoggers::unsafeFlushAll(void) { + ELPP_INTERNAL_INFO(1, "Flushing all log files"); + for (base::LogStreamsReferenceMap::iterator it = m_logStreamsReference.begin(); it != m_logStreamsReference.end(); + ++it) { + if (it->second.get() == nullptr) + continue; + it->second->flush(); + } +} + +// VRegistry + +VRegistry::VRegistry(base::type::VerboseLevel level, base::type::EnumType* pFlags) : m_level(level), m_pFlags(pFlags) { +} + +/// @brief Sets verbose level. Accepted range is 0-9 +void +VRegistry::setLevel(base::type::VerboseLevel level) { + base::threading::ScopedLock scopedLock(lock()); + if (level > 9) + m_level = base::consts::kMaxVerboseLevel; + else + m_level = level; +} + +void +VRegistry::setModules(const char* modules) { + base::threading::ScopedLock scopedLock(lock()); + auto addSuffix = [](std::stringstream& ss, const char* sfx, const char* prev) { + if (prev != nullptr && base::utils::Str::endsWith(ss.str(), std::string(prev))) { + std::string chr(ss.str().substr(0, ss.str().size() - strlen(prev))); + ss.str(std::string("")); + ss << chr; + } + if (base::utils::Str::endsWith(ss.str(), std::string(sfx))) { + std::string chr(ss.str().substr(0, ss.str().size() - strlen(sfx))); + ss.str(std::string("")); + ss << chr; + } + ss << sfx; + }; + auto insert = [&](std::stringstream& ss, base::type::VerboseLevel level) { + if (!base::utils::hasFlag(LoggingFlag::DisableVModulesExtensions, *m_pFlags)) { + addSuffix(ss, ".h", nullptr); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".c", ".h"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".cpp", ".c"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".cc", ".cpp"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".cxx", ".cc"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".-inl.h", ".cxx"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".hxx", ".-inl.h"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".hpp", ".hxx"); + m_modules.insert(std::make_pair(ss.str(), level)); + addSuffix(ss, ".hh", ".hpp"); + } + m_modules.insert(std::make_pair(ss.str(), level)); + }; + bool isMod = true; + bool isLevel = false; + std::stringstream ss; + int level = -1; + for (; *modules; ++modules) { + switch (*modules) { + case '=': + isLevel = true; + isMod = false; + break; + case ',': + isLevel = false; + isMod = true; + if (!ss.str().empty() && level != -1) { + insert(ss, static_cast(level)); + ss.str(std::string("")); + level = -1; + } + break; + default: + if (isMod) { + ss << *modules; + } else if (isLevel) { + if (isdigit(*modules)) { + level = static_cast(*modules) - 48; + } + } + break; + } + } + if (!ss.str().empty() && level != -1) { + insert(ss, static_cast(level)); + } +} + +bool +VRegistry::allowed(base::type::VerboseLevel vlevel, const char* file) { + base::threading::ScopedLock scopedLock(lock()); + if (m_modules.empty() || file == nullptr) { + return vlevel <= m_level; + } else { + char baseFilename[base::consts::kSourceFilenameMaxLength] = ""; + base::utils::File::buildBaseFilename(file, baseFilename); + std::unordered_map::iterator it = m_modules.begin(); + for (; it != m_modules.end(); ++it) { + if (base::utils::Str::wildCardMatch(baseFilename, it->first.c_str())) { + return vlevel <= it->second; + } + } + if (base::utils::hasFlag(LoggingFlag::AllowVerboseIfModuleNotSpecified, *m_pFlags)) { + return true; + } + return false; + } +} + +void +VRegistry::setFromArgs(const base::utils::CommandLineArgs* commandLineArgs) { + if (commandLineArgs->hasParam("-v") || commandLineArgs->hasParam("--verbose") || commandLineArgs->hasParam("-V") || + commandLineArgs->hasParam("--VERBOSE")) { + setLevel(base::consts::kMaxVerboseLevel); + } else if (commandLineArgs->hasParamWithValue("--v")) { + setLevel(static_cast(atoi(commandLineArgs->getParamValue("--v")))); + } else if (commandLineArgs->hasParamWithValue("--V")) { + setLevel(static_cast(atoi(commandLineArgs->getParamValue("--V")))); + } else if ((commandLineArgs->hasParamWithValue("-vmodule")) && vModulesEnabled()) { + setModules(commandLineArgs->getParamValue("-vmodule")); + } else if (commandLineArgs->hasParamWithValue("-VMODULE") && vModulesEnabled()) { + setModules(commandLineArgs->getParamValue("-VMODULE")); + } +} + +#if !defined(ELPP_DEFAULT_LOGGING_FLAGS) +#define ELPP_DEFAULT_LOGGING_FLAGS 0x0 +#endif // !defined(ELPP_DEFAULT_LOGGING_FLAGS) +// Storage +#if ELPP_ASYNC_LOGGING +Storage::Storage(const LogBuilderPtr& defaultLogBuilder, base::IWorker* asyncDispatchWorker) + : +#else +Storage::Storage(const LogBuilderPtr& defaultLogBuilder) + : +#endif // ELPP_ASYNC_LOGGING + m_registeredHitCounters(new base::RegisteredHitCounters()), + m_registeredLoggers(new base::RegisteredLoggers(defaultLogBuilder)), + m_flags(ELPP_DEFAULT_LOGGING_FLAGS), + m_vRegistry(new base::VRegistry(0, &m_flags)), + +#if ELPP_ASYNC_LOGGING + m_asyncLogQueue(new base::AsyncLogQueue()), + m_asyncDispatchWorker(asyncDispatchWorker), +#endif // ELPP_ASYNC_LOGGING + + m_preRollOutCallback(base::defaultPreRollOutCallback) { + // Register default logger + m_registeredLoggers->get(std::string(base::consts::kDefaultLoggerId)); + // We register default logger anyway (worse case it's not going to register) just in case + m_registeredLoggers->get("default"); + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + // Register performance logger and reconfigure format + Logger* performanceLogger = m_registeredLoggers->get(std::string(base::consts::kPerformanceLoggerId)); + m_registeredLoggers->get("performance"); + performanceLogger->configurations()->setGlobally(ConfigurationType::Format, std::string("%datetime %level %msg")); + performanceLogger->reconfigure(); +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + +#if defined(ELPP_SYSLOG) + // Register syslog logger and reconfigure format + Logger* sysLogLogger = m_registeredLoggers->get(std::string(base::consts::kSysLogLoggerId)); + sysLogLogger->configurations()->setGlobally(ConfigurationType::Format, std::string("%level: %msg")); + sysLogLogger->reconfigure(); +#endif // defined(ELPP_SYSLOG) + addFlag(LoggingFlag::AllowVerboseIfModuleNotSpecified); +#if ELPP_ASYNC_LOGGING + installLogDispatchCallback(std::string("AsyncLogDispatchCallback")); +#else + installLogDispatchCallback(std::string("DefaultLogDispatchCallback")); +#endif // ELPP_ASYNC_LOGGING +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + installPerformanceTrackingCallback( + std::string("DefaultPerformanceTrackingCallback")); +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + ELPP_INTERNAL_INFO(1, "Easylogging++ has been initialized"); +#if ELPP_ASYNC_LOGGING + m_asyncDispatchWorker->start(); +#endif // ELPP_ASYNC_LOGGING +} + +Storage::~Storage(void) { + ELPP_INTERNAL_INFO(4, "Destroying storage"); +#if ELPP_ASYNC_LOGGING + ELPP_INTERNAL_INFO(5, "Replacing log dispatch callback to synchronous"); + uninstallLogDispatchCallback(std::string("AsyncLogDispatchCallback")); + installLogDispatchCallback(std::string("DefaultLogDispatchCallback")); + ELPP_INTERNAL_INFO(5, "Destroying asyncDispatchWorker"); + base::utils::safeDelete(m_asyncDispatchWorker); + ELPP_INTERNAL_INFO(5, "Destroying asyncLogQueue"); + base::utils::safeDelete(m_asyncLogQueue); +#endif // ELPP_ASYNC_LOGGING + ELPP_INTERNAL_INFO(5, "Destroying registeredHitCounters"); + base::utils::safeDelete(m_registeredHitCounters); + ELPP_INTERNAL_INFO(5, "Destroying registeredLoggers"); + base::utils::safeDelete(m_registeredLoggers); + ELPP_INTERNAL_INFO(5, "Destroying vRegistry"); + base::utils::safeDelete(m_vRegistry); +} + +bool +Storage::hasCustomFormatSpecifier(const char* formatSpecifier) { + base::threading::ScopedLock scopedLock(customFormatSpecifiersLock()); + return std::find(m_customFormatSpecifiers.begin(), m_customFormatSpecifiers.end(), formatSpecifier) != + m_customFormatSpecifiers.end(); +} + +void +Storage::installCustomFormatSpecifier(const CustomFormatSpecifier& customFormatSpecifier) { + if (hasCustomFormatSpecifier(customFormatSpecifier.formatSpecifier())) { + return; + } + base::threading::ScopedLock scopedLock(customFormatSpecifiersLock()); + m_customFormatSpecifiers.push_back(customFormatSpecifier); +} + +bool +Storage::uninstallCustomFormatSpecifier(const char* formatSpecifier) { + base::threading::ScopedLock scopedLock(customFormatSpecifiersLock()); + std::vector::iterator it = + std::find(m_customFormatSpecifiers.begin(), m_customFormatSpecifiers.end(), formatSpecifier); + if (it != m_customFormatSpecifiers.end() && strcmp(formatSpecifier, it->formatSpecifier()) == 0) { + m_customFormatSpecifiers.erase(it); + return true; + } + return false; +} + +void +Storage::setApplicationArguments(int argc, char** argv) { + m_commandLineArgs.setArgs(argc, argv); + m_vRegistry->setFromArgs(commandLineArgs()); + // default log file +#if !defined(ELPP_DISABLE_LOG_FILE_FROM_ARG) + if (m_commandLineArgs.hasParamWithValue(base::consts::kDefaultLogFileParam)) { + Configurations c; + c.setGlobally(ConfigurationType::Filename, + std::string(m_commandLineArgs.getParamValue(base::consts::kDefaultLogFileParam))); + registeredLoggers()->setDefaultConfigurations(c); + for (base::RegisteredLoggers::iterator it = registeredLoggers()->begin(); it != registeredLoggers()->end(); + ++it) { + it->second->configure(c); + } + } +#endif // !defined(ELPP_DISABLE_LOG_FILE_FROM_ARG) +#if defined(ELPP_LOGGING_FLAGS_FROM_ARG) + if (m_commandLineArgs.hasParamWithValue(base::consts::kLoggingFlagsParam)) { + int userInput = atoi(m_commandLineArgs.getParamValue(base::consts::kLoggingFlagsParam)); + if (ELPP_DEFAULT_LOGGING_FLAGS == 0x0) { + m_flags = userInput; + } else { + base::utils::addFlag(userInput, &m_flags); + } + } +#endif // defined(ELPP_LOGGING_FLAGS_FROM_ARG) +} + +} // namespace base + +// LogDispatchCallback +void +LogDispatchCallback::handle(const LogDispatchData* data) { +#if defined(ELPP_THREAD_SAFE) + base::threading::ScopedLock scopedLock(m_fileLocksMapLock); + std::string filename = data->logMessage()->logger()->typedConfigurations()->filename(data->logMessage()->level()); + auto lock = m_fileLocks.find(filename); + if (lock == m_fileLocks.end()) { + m_fileLocks.emplace( + std::make_pair(filename, std::unique_ptr(new base::threading::Mutex))); + } +#endif +} + +base::threading::Mutex& +LogDispatchCallback::fileHandle(const LogDispatchData* data) { + auto it = + m_fileLocks.find(data->logMessage()->logger()->typedConfigurations()->filename(data->logMessage()->level())); + return *(it->second.get()); +} + +namespace base { +// DefaultLogDispatchCallback + +void +DefaultLogDispatchCallback::handle(const LogDispatchData* data) { +#if defined(ELPP_THREAD_SAFE) + LogDispatchCallback::handle(data); + base::threading::ScopedLock scopedLock(fileHandle(data)); +#endif + m_data = data; + dispatch(m_data->logMessage()->logger()->logBuilder()->build( + m_data->logMessage(), m_data->dispatchAction() == base::DispatchAction::NormalLog)); +} + +void +DefaultLogDispatchCallback::dispatch(base::type::string_t&& logLine) { + if (m_data->dispatchAction() == base::DispatchAction::NormalLog) { + if (m_data->logMessage()->logger()->m_typedConfigurations->toFile(m_data->logMessage()->level())) { + base::type::fstream_t* fs = + m_data->logMessage()->logger()->m_typedConfigurations->fileStream(m_data->logMessage()->level()); + if (fs != nullptr) { + fs->write(logLine.c_str(), logLine.size()); + if (fs->fail()) { + ELPP_INTERNAL_ERROR("Unable to write log to file [" + << m_data->logMessage()->logger()->m_typedConfigurations->filename( + m_data->logMessage()->level()) + << "].\n" + << "Few possible reasons (could be something else):\n" + << " * Permission denied\n" + << " * Disk full\n" + << " * Disk is not writable", + true); + } else { + if (ELPP->hasFlag(LoggingFlag::ImmediateFlush) || + (m_data->logMessage()->logger()->isFlushNeeded(m_data->logMessage()->level()))) { + m_data->logMessage()->logger()->flush(m_data->logMessage()->level(), fs); + } + } + } else { + ELPP_INTERNAL_ERROR("Log file for [" + << LevelHelper::convertToString(m_data->logMessage()->level()) << "] " + << "has not been configured but [TO_FILE] is configured to TRUE. [Logger ID: " + << m_data->logMessage()->logger()->id() << "]", + false); + } + } + if (m_data->logMessage()->logger()->m_typedConfigurations->toStandardOutput(m_data->logMessage()->level())) { + if (ELPP->hasFlag(LoggingFlag::ColoredTerminalOutput)) + m_data->logMessage()->logger()->logBuilder()->convertToColoredOutput(&logLine, + m_data->logMessage()->level()); + ELPP_COUT << ELPP_COUT_LINE(logLine); + } + } +#if defined(ELPP_SYSLOG) + else if (m_data->dispatchAction() == base::DispatchAction::SysLog) { + // Determine syslog priority + int sysLogPriority = 0; + if (m_data->logMessage()->level() == Level::Fatal) + sysLogPriority = LOG_EMERG; + else if (m_data->logMessage()->level() == Level::Error) + sysLogPriority = LOG_ERR; + else if (m_data->logMessage()->level() == Level::Warning) + sysLogPriority = LOG_WARNING; + else if (m_data->logMessage()->level() == Level::Info) + sysLogPriority = LOG_INFO; + else if (m_data->logMessage()->level() == Level::Debug) + sysLogPriority = LOG_DEBUG; + else + sysLogPriority = LOG_NOTICE; +#if defined(ELPP_UNICODE) + char* line = base::utils::Str::wcharPtrToCharPtr(logLine.c_str()); + syslog(sysLogPriority, "%s", line); + free(line); +#else + syslog(sysLogPriority, "%s", logLine.c_str()); +#endif + } +#endif // defined(ELPP_SYSLOG) +} + +#if ELPP_ASYNC_LOGGING + +// AsyncLogDispatchCallback + +void +AsyncLogDispatchCallback::handle(const LogDispatchData* data) { + base::type::string_t logLine = data->logMessage()->logger()->logBuilder()->build( + data->logMessage(), data->dispatchAction() == base::DispatchAction::NormalLog); + if (data->dispatchAction() == base::DispatchAction::NormalLog && + data->logMessage()->logger()->typedConfigurations()->toStandardOutput(data->logMessage()->level())) { + if (ELPP->hasFlag(LoggingFlag::ColoredTerminalOutput)) + data->logMessage()->logger()->logBuilder()->convertToColoredOutput(&logLine, data->logMessage()->level()); + ELPP_COUT << ELPP_COUT_LINE(logLine); + } + // Save resources and only queue if we want to write to file otherwise just ignore handler + if (data->logMessage()->logger()->typedConfigurations()->toFile(data->logMessage()->level())) { + ELPP->asyncLogQueue()->push(AsyncLogItem(*(data->logMessage()), *data, logLine)); + } +} + +// AsyncDispatchWorker +AsyncDispatchWorker::AsyncDispatchWorker() { + setContinueRunning(false); +} + +AsyncDispatchWorker::~AsyncDispatchWorker() { + setContinueRunning(false); + ELPP_INTERNAL_INFO(6, "Stopping dispatch worker - Cleaning log queue"); + clean(); + ELPP_INTERNAL_INFO(6, "Log queue cleaned"); +} + +bool +AsyncDispatchWorker::clean(void) { + std::mutex m; + std::unique_lock lk(m); + cv.wait(lk, [] { return !ELPP->asyncLogQueue()->empty(); }); + emptyQueue(); + lk.unlock(); + cv.notify_one(); + return ELPP->asyncLogQueue()->empty(); +} + +void +AsyncDispatchWorker::emptyQueue(void) { + while (!ELPP->asyncLogQueue()->empty()) { + AsyncLogItem data = ELPP->asyncLogQueue()->next(); + handle(&data); + base::threading::msleep(100); + } +} + +void +AsyncDispatchWorker::start(void) { + base::threading::msleep(5000); // 5s (why?) + setContinueRunning(true); + std::thread t1(&AsyncDispatchWorker::run, this); + t1.join(); +} + +void +AsyncDispatchWorker::handle(AsyncLogItem* logItem) { + LogDispatchData* data = logItem->data(); + LogMessage* logMessage = logItem->logMessage(); + Logger* logger = logMessage->logger(); + base::TypedConfigurations* conf = logger->typedConfigurations(); + base::type::string_t logLine = logItem->logLine(); + if (data->dispatchAction() == base::DispatchAction::NormalLog) { + if (conf->toFile(logMessage->level())) { + base::type::fstream_t* fs = conf->fileStream(logMessage->level()); + if (fs != nullptr) { + fs->write(logLine.c_str(), logLine.size()); + if (fs->fail()) { + ELPP_INTERNAL_ERROR("Unable to write log to file [" + << conf->filename(logMessage->level()) << "].\n" + << "Few possible reasons (could be something else):\n" + << " * Permission denied\n" + << " * Disk full\n" + << " * Disk is not writable", + true); + } else { + if (ELPP->hasFlag(LoggingFlag::ImmediateFlush) || (logger->isFlushNeeded(logMessage->level()))) { + logger->flush(logMessage->level(), fs); + } + } + } else { + ELPP_INTERNAL_ERROR("Log file for [" + << LevelHelper::convertToString(logMessage->level()) << "] " + << "has not been configured but [TO_FILE] is configured to TRUE. [Logger ID: " + << logger->id() << "]", + false); + } + } + } +#if defined(ELPP_SYSLOG) + else if (data->dispatchAction() == base::DispatchAction::SysLog) { + // Determine syslog priority + int sysLogPriority = 0; + if (logMessage->level() == Level::Fatal) + sysLogPriority = LOG_EMERG; + else if (logMessage->level() == Level::Error) + sysLogPriority = LOG_ERR; + else if (logMessage->level() == Level::Warning) + sysLogPriority = LOG_WARNING; + else if (logMessage->level() == Level::Info) + sysLogPriority = LOG_INFO; + else if (logMessage->level() == Level::Debug) + sysLogPriority = LOG_DEBUG; + else + sysLogPriority = LOG_NOTICE; +#if defined(ELPP_UNICODE) + char* line = base::utils::Str::wcharPtrToCharPtr(logLine.c_str()); + syslog(sysLogPriority, "%s", line); + free(line); +#else + syslog(sysLogPriority, "%s", logLine.c_str()); +#endif + } +#endif // defined(ELPP_SYSLOG) +} + +void +AsyncDispatchWorker::run(void) { + while (continueRunning()) { + emptyQueue(); + base::threading::msleep(10); // 10ms + } +} +#endif // ELPP_ASYNC_LOGGING + +// DefaultLogBuilder + +base::type::string_t +DefaultLogBuilder::build(const LogMessage* logMessage, bool appendNewLine) const { + base::TypedConfigurations* tc = logMessage->logger()->typedConfigurations(); + const base::LogFormat* logFormat = &tc->logFormat(logMessage->level()); + base::type::string_t logLine = logFormat->format(); + char buff[base::consts::kSourceFilenameMaxLength + base::consts::kSourceLineMaxLength] = ""; + const char* bufLim = buff + sizeof(buff); + if (logFormat->hasFlag(base::FormatFlags::AppName)) { + // App name + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kAppNameFormatSpecifier, + logMessage->logger()->parentApplicationName()); + } + if (logFormat->hasFlag(base::FormatFlags::ThreadId)) { + // Thread ID + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kThreadIdFormatSpecifier, + ELPP->getThreadName(base::threading::getCurrentThreadId())); + } + if (logFormat->hasFlag(base::FormatFlags::DateTime)) { + // DateTime + base::utils::Str::replaceFirstWithEscape( + logLine, base::consts::kDateTimeFormatSpecifier, + base::utils::DateTime::getDateTime(logFormat->dateTimeFormat().c_str(), + &tc->subsecondPrecision(logMessage->level()))); + } + if (logFormat->hasFlag(base::FormatFlags::Function)) { + // Function + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kLogFunctionFormatSpecifier, + logMessage->func()); + } + if (logFormat->hasFlag(base::FormatFlags::File)) { + // File + base::utils::Str::clearBuff(buff, base::consts::kSourceFilenameMaxLength); + base::utils::File::buildStrippedFilename(logMessage->file().c_str(), buff); + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kLogFileFormatSpecifier, std::string(buff)); + } + if (logFormat->hasFlag(base::FormatFlags::FileBase)) { + // FileBase + base::utils::Str::clearBuff(buff, base::consts::kSourceFilenameMaxLength); + base::utils::File::buildBaseFilename(logMessage->file(), buff); + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kLogFileBaseFormatSpecifier, std::string(buff)); + } + if (logFormat->hasFlag(base::FormatFlags::Line)) { + // Line + char* buf = base::utils::Str::clearBuff(buff, base::consts::kSourceLineMaxLength); + buf = base::utils::Str::convertAndAddToBuff(logMessage->line(), base::consts::kSourceLineMaxLength, buf, bufLim, + false); + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kLogLineFormatSpecifier, std::string(buff)); + } + if (logFormat->hasFlag(base::FormatFlags::Location)) { + // Location + char* buf = base::utils::Str::clearBuff( + buff, base::consts::kSourceFilenameMaxLength + base::consts::kSourceLineMaxLength); + base::utils::File::buildStrippedFilename(logMessage->file().c_str(), buff); + buf = base::utils::Str::addToBuff(buff, buf, bufLim); + buf = base::utils::Str::addToBuff(":", buf, bufLim); + buf = base::utils::Str::convertAndAddToBuff(logMessage->line(), base::consts::kSourceLineMaxLength, buf, bufLim, + false); + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kLogLocationFormatSpecifier, std::string(buff)); + } + if (logMessage->level() == Level::Verbose && logFormat->hasFlag(base::FormatFlags::VerboseLevel)) { + // Verbose level + char* buf = base::utils::Str::clearBuff(buff, 1); + buf = base::utils::Str::convertAndAddToBuff(logMessage->verboseLevel(), 1, buf, bufLim, false); + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kVerboseLevelFormatSpecifier, + std::string(buff)); + } + if (logFormat->hasFlag(base::FormatFlags::LogMessage)) { + // Log message + base::utils::Str::replaceFirstWithEscape(logLine, base::consts::kMessageFormatSpecifier, logMessage->message()); + } +#if !defined(ELPP_DISABLE_CUSTOM_FORMAT_SPECIFIERS) + el::base::threading::ScopedLock lock_(ELPP->customFormatSpecifiersLock()); + ELPP_UNUSED(lock_); + for (std::vector::const_iterator it = ELPP->customFormatSpecifiers()->begin(); + it != ELPP->customFormatSpecifiers()->end(); ++it) { + std::string fs(it->formatSpecifier()); + base::type::string_t wcsFormatSpecifier(fs.begin(), fs.end()); + base::utils::Str::replaceFirstWithEscape(logLine, wcsFormatSpecifier, it->resolver()(logMessage)); + } +#endif // !defined(ELPP_DISABLE_CUSTOM_FORMAT_SPECIFIERS) + if (appendNewLine) + logLine += ELPP_LITERAL("\n"); + return logLine; +} + +// LogDispatcher + +void +LogDispatcher::dispatch(void) { + if (m_proceed && m_dispatchAction == base::DispatchAction::None) { + m_proceed = false; + } + if (!m_proceed) { + return; + } +#ifndef ELPP_NO_GLOBAL_LOCK + // see https://github.com/muflihun/easyloggingpp/issues/580 + // global lock is turned off by default unless + // ELPP_NO_GLOBAL_LOCK is defined + base::threading::ScopedLock scopedLock(ELPP->lock()); +#endif + base::TypedConfigurations* tc = m_logMessage->logger()->m_typedConfigurations; + if (ELPP->hasFlag(LoggingFlag::StrictLogFileSizeCheck)) { + tc->validateFileRolling(m_logMessage->level(), ELPP->preRollOutCallback()); + } + LogDispatchCallback* callback = nullptr; + LogDispatchData data; + for (const std::pair& h : ELPP->m_logDispatchCallbacks) { + callback = h.second.get(); + if (callback != nullptr && callback->enabled()) { + data.setLogMessage(m_logMessage); + data.setDispatchAction(m_dispatchAction); + callback->handle(&data); + } + } +} + +// MessageBuilder + +void +MessageBuilder::initialize(Logger* logger) { + m_logger = logger; + m_containerLogSeperator = + ELPP->hasFlag(LoggingFlag::NewLineForContainer) ? ELPP_LITERAL("\n ") : ELPP_LITERAL(", "); +} + +MessageBuilder& +MessageBuilder::operator<<(const wchar_t* msg) { + if (msg == nullptr) { + m_logger->stream() << base::consts::kNullPointer; + return *this; + } +#if defined(ELPP_UNICODE) + m_logger->stream() << msg; +#else + char* buff_ = base::utils::Str::wcharPtrToCharPtr(msg); + m_logger->stream() << buff_; + free(buff_); +#endif + if (ELPP->hasFlag(LoggingFlag::AutoSpacing)) { + m_logger->stream() << " "; + } + return *this; +} + +// Writer + +Writer& +Writer::construct(Logger* logger, bool needLock) { + m_logger = logger; + initializeLogger(logger->id(), false, needLock); + m_messageBuilder.initialize(m_logger); + return *this; +} + +Writer& +Writer::construct(int count, const char* loggerIds, ...) { + if (ELPP->hasFlag(LoggingFlag::MultiLoggerSupport)) { + va_list loggersList; + va_start(loggersList, loggerIds); + const char* id = loggerIds; + m_loggerIds.reserve(count); + for (int i = 0; i < count; ++i) { + m_loggerIds.push_back(std::string(id)); + id = va_arg(loggersList, const char*); + } + va_end(loggersList); + initializeLogger(m_loggerIds.at(0)); + } else { + initializeLogger(std::string(loggerIds)); + } + m_messageBuilder.initialize(m_logger); + return *this; +} + +void +Writer::initializeLogger(const std::string& loggerId, bool lookup, bool needLock) { + if (lookup) { + m_logger = ELPP->registeredLoggers()->get(loggerId, ELPP->hasFlag(LoggingFlag::CreateLoggerAutomatically)); + } + if (m_logger == nullptr) { + { + if (!ELPP->registeredLoggers()->has(std::string(base::consts::kDefaultLoggerId))) { + // Somehow default logger has been unregistered. Not good! Register again + ELPP->registeredLoggers()->get(std::string(base::consts::kDefaultLoggerId)); + } + } + Writer(Level::Debug, m_file, m_line, m_func).construct(1, base::consts::kDefaultLoggerId) + << "Logger [" << loggerId << "] is not registered yet!"; + m_proceed = false; + } else { + if (needLock) { + m_logger->acquireLock(); // This should not be unlocked by checking m_proceed because + // m_proceed can be changed by lines below + } + if (ELPP->hasFlag(LoggingFlag::HierarchicalLogging)) { + m_proceed = m_level == Level::Verbose + ? m_logger->enabled(m_level) + : LevelHelper::castToInt(m_level) >= LevelHelper::castToInt(ELPP->m_loggingLevel); + } else { + m_proceed = m_logger->enabled(m_level); + } + } +} + +void +Writer::processDispatch() { +#if ELPP_LOGGING_ENABLED + if (ELPP->hasFlag(LoggingFlag::MultiLoggerSupport)) { + bool firstDispatched = false; + base::type::string_t logMessage; + std::size_t i = 0; + do { + if (m_proceed) { + if (firstDispatched) { + m_logger->stream() << logMessage; + } else { + firstDispatched = true; + if (m_loggerIds.size() > 1) { + logMessage = m_logger->stream().str(); + } + } + triggerDispatch(); + } else if (m_logger != nullptr) { + m_logger->stream().str(ELPP_LITERAL("")); + m_logger->releaseLock(); + } + if (i + 1 < m_loggerIds.size()) { + initializeLogger(m_loggerIds.at(i + 1)); + } + } while (++i < m_loggerIds.size()); + } else { + if (m_proceed) { + triggerDispatch(); + } else if (m_logger != nullptr) { + m_logger->stream().str(ELPP_LITERAL("")); + m_logger->releaseLock(); + } + } +#else + if (m_logger != nullptr) { + m_logger->stream().str(ELPP_LITERAL("")); + m_logger->releaseLock(); + } +#endif // ELPP_LOGGING_ENABLED +} + +void +Writer::triggerDispatch(void) { + if (m_proceed) { + if (m_msg == nullptr) { + LogMessage msg(m_level, m_file, m_line, m_func, m_verboseLevel, m_logger); + base::LogDispatcher(m_proceed, &msg, m_dispatchAction).dispatch(); + } else { + base::LogDispatcher(m_proceed, m_msg, m_dispatchAction).dispatch(); + } + } + if (m_logger != nullptr) { + m_logger->stream().str(ELPP_LITERAL("")); + m_logger->releaseLock(); + } + if (m_proceed && m_level == Level::Fatal && !ELPP->hasFlag(LoggingFlag::DisableApplicationAbortOnFatalLog)) { + base::Writer(Level::Warning, m_file, m_line, m_func).construct(1, base::consts::kDefaultLoggerId) + << "Aborting application. Reason: Fatal log at [" << m_file << ":" << m_line << "]"; + std::stringstream reasonStream; + reasonStream << "Fatal log at [" << m_file << ":" << m_line << "]" + << " If you wish to disable 'abort on fatal log' please use " + << "el::Loggers::addFlag(el::LoggingFlag::DisableApplicationAbortOnFatalLog)"; + base::utils::abort(1, reasonStream.str()); + } + m_proceed = false; +} + +// PErrorWriter + +PErrorWriter::~PErrorWriter(void) { + if (m_proceed) { +#if ELPP_COMPILER_MSVC + char buff[256]; + strerror_s(buff, 256, errno); + m_logger->stream() << ": " << buff << " [" << errno << "]"; +#else + m_logger->stream() << ": " << strerror(errno) << " [" << errno << "]"; +#endif + } +} + +// PerformanceTracker + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + +PerformanceTracker::PerformanceTracker(const std::string& blockName, base::TimestampUnit timestampUnit, + const std::string& loggerId, bool scopedLog, Level level) + : m_blockName(blockName), + m_timestampUnit(timestampUnit), + m_loggerId(loggerId), + m_scopedLog(scopedLog), + m_level(level), + m_hasChecked(false), + m_lastCheckpointId(std::string()), + m_enabled(false) { +#if !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) && ELPP_LOGGING_ENABLED + // We store it locally so that if user happen to change configuration by the end of scope + // or before calling checkpoint, we still depend on state of configuraton at time of construction + el::Logger* loggerPtr = ELPP->registeredLoggers()->get(loggerId, false); + m_enabled = loggerPtr != nullptr && loggerPtr->m_typedConfigurations->performanceTracking(m_level); + if (m_enabled) { + base::utils::DateTime::gettimeofday(&m_startTime); + } +#endif // !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) && ELPP_LOGGING_ENABLED +} + +PerformanceTracker::~PerformanceTracker(void) { +#if !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) && ELPP_LOGGING_ENABLED + if (m_enabled) { + base::threading::ScopedLock scopedLock(lock()); + if (m_scopedLog) { + base::utils::DateTime::gettimeofday(&m_endTime); + base::type::string_t formattedTime = getFormattedTimeTaken(); + PerformanceTrackingData data(PerformanceTrackingData::DataType::Complete); + data.init(this); + data.m_formattedTimeTaken = formattedTime; + PerformanceTrackingCallback* callback = nullptr; + for (const std::pair& h : + ELPP->m_performanceTrackingCallbacks) { + callback = h.second.get(); + if (callback != nullptr && callback->enabled()) { + callback->handle(&data); + } + } + } + } +#endif // !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) +} + +void +PerformanceTracker::checkpoint(const std::string& id, const char* file, base::type::LineNumber line, const char* func) { +#if !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) && ELPP_LOGGING_ENABLED + if (m_enabled) { + base::threading::ScopedLock scopedLock(lock()); + base::utils::DateTime::gettimeofday(&m_endTime); + base::type::string_t formattedTime = + m_hasChecked ? getFormattedTimeTaken(m_lastCheckpointTime) : ELPP_LITERAL(""); + PerformanceTrackingData data(PerformanceTrackingData::DataType::Checkpoint); + data.init(this); + data.m_checkpointId = id; + data.m_file = file; + data.m_line = line; + data.m_func = func; + data.m_formattedTimeTaken = formattedTime; + PerformanceTrackingCallback* callback = nullptr; + for (const std::pair& h : + ELPP->m_performanceTrackingCallbacks) { + callback = h.second.get(); + if (callback != nullptr && callback->enabled()) { + callback->handle(&data); + } + } + base::utils::DateTime::gettimeofday(&m_lastCheckpointTime); + m_hasChecked = true; + m_lastCheckpointId = id; + } +#endif // !defined(ELPP_DISABLE_PERFORMANCE_TRACKING) && ELPP_LOGGING_ENABLED + ELPP_UNUSED(id); + ELPP_UNUSED(file); + ELPP_UNUSED(line); + ELPP_UNUSED(func); +} + +const base::type::string_t +PerformanceTracker::getFormattedTimeTaken(struct timeval startTime) const { + if (ELPP->hasFlag(LoggingFlag::FixedTimeFormat)) { + base::type::stringstream_t ss; + ss << base::utils::DateTime::getTimeDifference(m_endTime, startTime, m_timestampUnit) << " " + << base::consts::kTimeFormats[static_cast(m_timestampUnit)].unit; + return ss.str(); + } + return base::utils::DateTime::formatTime( + base::utils::DateTime::getTimeDifference(m_endTime, startTime, m_timestampUnit), m_timestampUnit); +} + +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + +namespace debug { +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + +// StackTrace + +StackTrace::StackTraceEntry::StackTraceEntry(std::size_t index, const std::string& loc, const std::string& demang, + const std::string& hex, const std::string& addr) + : m_index(index), m_location(loc), m_demangled(demang), m_hex(hex), m_addr(addr) { +} + +std::ostream& +operator<<(std::ostream& ss, const StackTrace::StackTraceEntry& si) { + ss << "[" << si.m_index << "] " << si.m_location << (si.m_hex.empty() ? "" : "+") << si.m_hex << " " << si.m_addr + << (si.m_demangled.empty() ? "" : ":") << si.m_demangled; + return ss; +} + +std::ostream& +operator<<(std::ostream& os, const StackTrace& st) { + std::vector::const_iterator it = st.m_stack.begin(); + while (it != st.m_stack.end()) { + os << " " << *it++ << "\n"; + } + return os; +} + +void +StackTrace::generateNew(void) { +#if ELPP_STACKTRACE + m_stack.clear(); + void* stack[kMaxStack]; + unsigned int size = backtrace(stack, kMaxStack); + char** strings = backtrace_symbols(stack, size); + if (size > kStackStart) { // Skip StackTrace c'tor and generateNew + for (std::size_t i = kStackStart; i < size; ++i) { + std::string mangName; + std::string location; + std::string hex; + std::string addr; + + // entry: 2 crash.cpp.bin 0x0000000101552be5 _ZN2el4base5debug10StackTraceC1Ev + 21 + const std::string line(strings[i]); + auto p = line.find("_"); + if (p != std::string::npos) { + mangName = line.substr(p); + mangName = mangName.substr(0, mangName.find(" +")); + } + p = line.find("0x"); + if (p != std::string::npos) { + addr = line.substr(p); + addr = addr.substr(0, addr.find("_")); + } + // Perform demangling if parsed properly + if (!mangName.empty()) { + int status = 0; + char* demangName = abi::__cxa_demangle(mangName.data(), 0, 0, &status); + // if demangling is successful, output the demangled function name + if (status == 0) { + // Success (see http://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html) + StackTraceEntry entry(i - 1, location, demangName, hex, addr); + m_stack.push_back(entry); + } else { + // Not successful - we will use mangled name + StackTraceEntry entry(i - 1, location, mangName, hex, addr); + m_stack.push_back(entry); + } + free(demangName); + } else { + StackTraceEntry entry(i - 1, line); + m_stack.push_back(entry); + } + } + } + free(strings); +#else + ELPP_INTERNAL_INFO(1, "Stacktrace generation not supported for selected compiler"); +#endif // ELPP_STACKTRACE +} + +// Static helper functions + +static std::string +crashReason(int sig) { + std::stringstream ss; + bool foundReason = false; + for (int i = 0; i < base::consts::kCrashSignalsCount; ++i) { + if (base::consts::kCrashSignals[i].numb == sig) { + ss << "Application has crashed due to [" << base::consts::kCrashSignals[i].name << "] signal"; + if (ELPP->hasFlag(el::LoggingFlag::LogDetailedCrashReason)) { + ss << std::endl + << " " << base::consts::kCrashSignals[i].brief << std::endl + << " " << base::consts::kCrashSignals[i].detail; + } + foundReason = true; + } + } + if (!foundReason) { + ss << "Application has crashed due to unknown signal [" << sig << "]"; + } + return ss.str(); +} +/// @brief Logs reason of crash from sig +static void +logCrashReason(int sig, bool stackTraceIfAvailable, Level level, const char* logger) { + if (sig == SIGINT && ELPP->hasFlag(el::LoggingFlag::IgnoreSigInt)) { + return; + } + std::stringstream ss; + ss << "CRASH HANDLED; "; + ss << crashReason(sig); +#if ELPP_STACKTRACE + if (stackTraceIfAvailable) { + ss << std::endl << " ======= Backtrace: =========" << std::endl << base::debug::StackTrace(); + } +#else + ELPP_UNUSED(stackTraceIfAvailable); +#endif // ELPP_STACKTRACE + ELPP_WRITE_LOG(el::base::Writer, level, base::DispatchAction::NormalLog, logger) << ss.str(); +} + +static inline void +crashAbort(int sig) { + base::utils::abort(sig, std::string()); +} + +/// @brief Default application crash handler +/// +/// @detail This function writes log using 'default' logger, prints stack trace for GCC based compilers and aborts +/// program. +static inline void +defaultCrashHandler(int sig) { + base::debug::logCrashReason(sig, true, Level::Fatal, base::consts::kDefaultLoggerId); + base::debug::crashAbort(sig); +} + +// CrashHandler + +CrashHandler::CrashHandler(bool useDefault) { + if (useDefault) { + setHandler(defaultCrashHandler); + } +} + +void +CrashHandler::setHandler(const Handler& cHandler) { + m_handler = cHandler; +#if defined(ELPP_HANDLE_SIGABRT) + int i = 0; // SIGABRT is at base::consts::kCrashSignals[0] +#else + int i = 1; +#endif // defined(ELPP_HANDLE_SIGABRT) + for (; i < base::consts::kCrashSignalsCount; ++i) { + m_handler = signal(base::consts::kCrashSignals[i].numb, cHandler); + } +} + +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) +} // namespace debug +} // namespace base + +// el + +// Helpers + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + +void +Helpers::crashAbort(int sig, const char* sourceFile, unsigned int long line) { + std::stringstream ss; + ss << base::debug::crashReason(sig).c_str(); + ss << " - [Called el::Helpers::crashAbort(" << sig << ")]"; + if (sourceFile != nullptr && strlen(sourceFile) > 0) { + ss << " - Source: " << sourceFile; + if (line > 0) + ss << ":" << line; + else + ss << " (line number not specified)"; + } + base::utils::abort(sig, ss.str()); +} + +void +Helpers::logCrashReason(int sig, bool stackTraceIfAvailable, Level level, const char* logger) { + el::base::debug::logCrashReason(sig, stackTraceIfAvailable, level, logger); +} + +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + +// Loggers + +Logger* +Loggers::getLogger(const std::string& identity, bool registerIfNotAvailable) { + return ELPP->registeredLoggers()->get(identity, registerIfNotAvailable); +} + +void +Loggers::setDefaultLogBuilder(el::LogBuilderPtr& logBuilderPtr) { + ELPP->registeredLoggers()->setDefaultLogBuilder(logBuilderPtr); +} + +bool +Loggers::unregisterLogger(const std::string& identity) { + return ELPP->registeredLoggers()->remove(identity); +} + +bool +Loggers::hasLogger(const std::string& identity) { + return ELPP->registeredLoggers()->has(identity); +} + +Logger* +Loggers::reconfigureLogger(Logger* logger, const Configurations& configurations) { + if (!logger) + return nullptr; + logger->configure(configurations); + return logger; +} + +Logger* +Loggers::reconfigureLogger(const std::string& identity, const Configurations& configurations) { + return Loggers::reconfigureLogger(Loggers::getLogger(identity), configurations); +} + +Logger* +Loggers::reconfigureLogger(const std::string& identity, ConfigurationType configurationType, const std::string& value) { + Logger* logger = Loggers::getLogger(identity); + if (logger == nullptr) { + return nullptr; + } + logger->configurations()->set(Level::Global, configurationType, value); + logger->reconfigure(); + return logger; +} + +void +Loggers::reconfigureAllLoggers(const Configurations& configurations) { + for (base::RegisteredLoggers::iterator it = ELPP->registeredLoggers()->begin(); + it != ELPP->registeredLoggers()->end(); ++it) { + Loggers::reconfigureLogger(it->second, configurations); + } +} + +void +Loggers::reconfigureAllLoggers(Level level, ConfigurationType configurationType, const std::string& value) { + for (base::RegisteredLoggers::iterator it = ELPP->registeredLoggers()->begin(); + it != ELPP->registeredLoggers()->end(); ++it) { + Logger* logger = it->second; + logger->configurations()->set(level, configurationType, value); + logger->reconfigure(); + } +} + +void +Loggers::setDefaultConfigurations(const Configurations& configurations, bool reconfigureExistingLoggers) { + ELPP->registeredLoggers()->setDefaultConfigurations(configurations); + if (reconfigureExistingLoggers) { + Loggers::reconfigureAllLoggers(configurations); + } +} + +const Configurations* +Loggers::defaultConfigurations(void) { + return ELPP->registeredLoggers()->defaultConfigurations(); +} + +const base::LogStreamsReferenceMap* +Loggers::logStreamsReference(void) { + return ELPP->registeredLoggers()->logStreamsReference(); +} + +base::TypedConfigurations +Loggers::defaultTypedConfigurations(void) { + return base::TypedConfigurations(ELPP->registeredLoggers()->defaultConfigurations(), + ELPP->registeredLoggers()->logStreamsReference()); +} + +std::vector* +Loggers::populateAllLoggerIds(std::vector* targetList) { + targetList->clear(); + for (base::RegisteredLoggers::iterator it = ELPP->registeredLoggers()->list().begin(); + it != ELPP->registeredLoggers()->list().end(); ++it) { + targetList->push_back(it->first); + } + return targetList; +} + +void +Loggers::configureFromGlobal(const char* globalConfigurationFilePath) { + std::ifstream gcfStream(globalConfigurationFilePath, std::ifstream::in); + ELPP_ASSERT(gcfStream.is_open(), + "Unable to open global configuration file [" << globalConfigurationFilePath << "] for parsing."); + std::string line = std::string(); + std::stringstream ss; + Logger* logger = nullptr; + auto configure = [&](void) { + ELPP_INTERNAL_INFO(8, "Configuring logger: '" << logger->id() << "' with configurations \n" + << ss.str() << "\n--------------"); + Configurations c; + c.parseFromText(ss.str()); + logger->configure(c); + }; + while (gcfStream.good()) { + std::getline(gcfStream, line); + ELPP_INTERNAL_INFO(1, "Parsing line: " << line); + base::utils::Str::trim(line); + if (Configurations::Parser::isComment(line)) + continue; + Configurations::Parser::ignoreComments(&line); + base::utils::Str::trim(line); + if (line.size() > 2 && base::utils::Str::startsWith(line, std::string(base::consts::kConfigurationLoggerId))) { + if (!ss.str().empty() && logger != nullptr) { + configure(); + } + ss.str(std::string("")); + line = line.substr(2); + base::utils::Str::trim(line); + if (line.size() > 1) { + ELPP_INTERNAL_INFO(1, "Getting logger: '" << line << "'"); + logger = getLogger(line); + } + } else { + ss << line << "\n"; + } + } + if (!ss.str().empty() && logger != nullptr) { + configure(); + } +} + +bool +Loggers::configureFromArg(const char* argKey) { +#if defined(ELPP_DISABLE_CONFIGURATION_FROM_PROGRAM_ARGS) + ELPP_UNUSED(argKey); +#else + if (!Helpers::commandLineArgs()->hasParamWithValue(argKey)) { + return false; + } + configureFromGlobal(Helpers::commandLineArgs()->getParamValue(argKey)); +#endif // defined(ELPP_DISABLE_CONFIGURATION_FROM_PROGRAM_ARGS) + return true; +} + +void +Loggers::flushAll(void) { + ELPP->registeredLoggers()->flushAll(); +} + +void +Loggers::setVerboseLevel(base::type::VerboseLevel level) { + ELPP->vRegistry()->setLevel(level); +} + +base::type::VerboseLevel +Loggers::verboseLevel(void) { + return ELPP->vRegistry()->level(); +} + +void +Loggers::setVModules(const char* modules) { + if (ELPP->vRegistry()->vModulesEnabled()) { + ELPP->vRegistry()->setModules(modules); + } +} + +void +Loggers::clearVModules(void) { + ELPP->vRegistry()->clearModules(); +} + +// VersionInfo + +const std::string +VersionInfo::version(void) { + return std::string("9.96.7"); +} +/// @brief Release date of current version +const std::string +VersionInfo::releaseDate(void) { + return std::string("24-11-2018 0728hrs"); +} + +} // namespace el diff --git a/core/thirdparty/easyloggingpp/easylogging++.h b/core/thirdparty/easyloggingpp/easylogging++.h new file mode 100644 index 0000000000..ce6c7ece48 --- /dev/null +++ b/core/thirdparty/easyloggingpp/easylogging++.h @@ -0,0 +1,5193 @@ +// +// Bismillah ar-Rahmaan ar-Raheem +// +// Easylogging++ v9.96.7 +// Single-header only, cross-platform logging library for C++ applications +// +// Copyright (c) 2012-2018 Zuhd Web Services +// Copyright (c) 2012-2018 @abumusamq +// +// This library is released under the MIT Licence. +// https://github.com/zuhd-org/easyloggingpp/blob/master/LICENSE +// +// https://zuhd.org +// http://muflihun.com +// + +#ifndef EASYLOGGINGPP_H +#define EASYLOGGINGPP_H +// Compilers and C++0x/C++11 Evaluation +#if __cplusplus >= 201103L +#define ELPP_CXX11 1 +#endif // __cplusplus >= 201103L +#if (defined(__GNUC__)) +#define ELPP_COMPILER_GCC 1 +#else +#define ELPP_COMPILER_GCC 0 +#endif +#if ELPP_COMPILER_GCC +#define ELPP_GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) +#if defined(__GXX_EXPERIMENTAL_CXX0X__) +#define ELPP_CXX0X 1 +#endif +#endif +// Visual C++ +#if defined(_MSC_VER) +#define ELPP_COMPILER_MSVC 1 +#else +#define ELPP_COMPILER_MSVC 0 +#endif +#define ELPP_CRT_DBG_WARNINGS ELPP_COMPILER_MSVC +#if ELPP_COMPILER_MSVC +#if (_MSC_VER == 1600) +#define ELPP_CXX0X 1 +#elif (_MSC_VER >= 1700) +#define ELPP_CXX11 1 +#endif +#endif +// Clang++ +#if (defined(__clang__) && (__clang__ == 1)) +#define ELPP_COMPILER_CLANG 1 +#else +#define ELPP_COMPILER_CLANG 0 +#endif +#if ELPP_COMPILER_CLANG +#if __has_include() +#include // Make __GLIBCXX__ defined when using libstdc++ +#if !defined(__GLIBCXX__) || __GLIBCXX__ >= 20150426 +#define ELPP_CLANG_SUPPORTS_THREAD +#endif // !defined(__GLIBCXX__) || __GLIBCXX__ >= 20150426 +#endif // __has_include() +#endif +#if (defined(__MINGW32__) || defined(__MINGW64__)) +#define ELPP_MINGW 1 +#else +#define ELPP_MINGW 0 +#endif +#if (defined(__CYGWIN__) && (__CYGWIN__ == 1)) +#define ELPP_CYGWIN 1 +#else +#define ELPP_CYGWIN 0 +#endif +#if (defined(__INTEL_COMPILER)) +#define ELPP_COMPILER_INTEL 1 +#else +#define ELPP_COMPILER_INTEL 0 +#endif +// Operating System Evaluation +// Windows +#if (defined(_WIN32) || defined(_WIN64)) +#define ELPP_OS_WINDOWS 1 +#else +#define ELPP_OS_WINDOWS 0 +#endif +// Linux +#if (defined(__linux) || defined(__linux__)) +#define ELPP_OS_LINUX 1 +#else +#define ELPP_OS_LINUX 0 +#endif +#if (defined(__APPLE__)) +#define ELPP_OS_MAC 1 +#else +#define ELPP_OS_MAC 0 +#endif +#if (defined(__FreeBSD__) || defined(__FreeBSD_kernel__)) +#define ELPP_OS_FREEBSD 1 +#else +#define ELPP_OS_FREEBSD 0 +#endif +#if (defined(__sun)) +#define ELPP_OS_SOLARIS 1 +#else +#define ELPP_OS_SOLARIS 0 +#endif +#if (defined(_AIX)) +#define ELPP_OS_AIX 1 +#else +#define ELPP_OS_AIX 0 +#endif +#if (defined(__NetBSD__)) +#define ELPP_OS_NETBSD 1 +#else +#define ELPP_OS_NETBSD 0 +#endif +#if defined(__EMSCRIPTEN__) +#define ELPP_OS_EMSCRIPTEN 1 +#else +#define ELPP_OS_EMSCRIPTEN 0 +#endif +// Unix +#if ((ELPP_OS_LINUX || ELPP_OS_MAC || ELPP_OS_FREEBSD || ELPP_OS_NETBSD || ELPP_OS_SOLARIS || ELPP_OS_AIX || \ + ELPP_OS_EMSCRIPTEN) && \ + (!ELPP_OS_WINDOWS)) +#define ELPP_OS_UNIX 1 +#else +#define ELPP_OS_UNIX 0 +#endif +#if (defined(__ANDROID__)) +#define ELPP_OS_ANDROID 1 +#else +#define ELPP_OS_ANDROID 0 +#endif +// Evaluating Cygwin as *nix OS +#if !ELPP_OS_UNIX && !ELPP_OS_WINDOWS && ELPP_CYGWIN +#undef ELPP_OS_UNIX +#undef ELPP_OS_LINUX +#define ELPP_OS_UNIX 1 +#define ELPP_OS_LINUX 1 +#endif // !ELPP_OS_UNIX && !ELPP_OS_WINDOWS && ELPP_CYGWIN +#if !defined(ELPP_INTERNAL_DEBUGGING_OUT_INFO) +#define ELPP_INTERNAL_DEBUGGING_OUT_INFO std::cout +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_OUT_ERROR) +#define ELPP_INTERNAL_DEBUGGING_OUT_ERROR std::cerr +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_ENDL) +#define ELPP_INTERNAL_DEBUGGING_ENDL std::endl +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_MSG) +#define ELPP_INTERNAL_DEBUGGING_MSG(msg) msg +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +// Internal Assertions and errors +#if !defined(ELPP_DISABLE_ASSERT) +#if (defined(ELPP_DEBUG_ASSERT_FAILURE)) +#define ELPP_ASSERT(expr, msg) \ + if (!(expr)) { \ + std::stringstream internalInfoStream; \ + internalInfoStream << msg; \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR \ + << "EASYLOGGING++ ASSERTION FAILED (LINE: " << __LINE__ << ") [" #expr << "] WITH MESSAGE \"" \ + << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) << "\"" << ELPP_INTERNAL_DEBUGGING_ENDL; \ + base::utils::abort(1, "ELPP Assertion failure, please define ELPP_DEBUG_ASSERT_FAILURE"); \ + } +#else +#define ELPP_ASSERT(expr, msg) \ + if (!(expr)) { \ + std::stringstream internalInfoStream; \ + internalInfoStream << msg; \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR \ + << "ASSERTION FAILURE FROM EASYLOGGING++ (LINE: " << __LINE__ << ") [" #expr << "] WITH MESSAGE \"" \ + << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) << "\"" << ELPP_INTERNAL_DEBUGGING_ENDL; \ + } +#endif // (defined(ELPP_DEBUG_ASSERT_FAILURE)) +#else +#define ELPP_ASSERT(x, y) +#endif //(!defined(ELPP_DISABLE_ASSERT) +#if ELPP_COMPILER_MSVC +#define ELPP_INTERNAL_DEBUGGING_WRITE_PERROR \ + { \ + char buff[256]; \ + strerror_s(buff, 256, errno); \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR << ": " << buff << " [" << errno << "]"; \ + } \ + (void)0 +#else +#define ELPP_INTERNAL_DEBUGGING_WRITE_PERROR \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR << ": " << strerror(errno) << " [" << errno << "]"; \ + (void)0 +#endif // ELPP_COMPILER_MSVC +#if defined(ELPP_DEBUG_ERRORS) +#if !defined(ELPP_INTERNAL_ERROR) +#define ELPP_INTERNAL_ERROR(msg, pe) \ + { \ + std::stringstream internalInfoStream; \ + internalInfoStream << " " << msg; \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR << "ERROR FROM EASYLOGGING++ (LINE: " << __LINE__ << ") " \ + << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) \ + << ELPP_INTERNAL_DEBUGGING_ENDL; \ + if (pe) { \ + ELPP_INTERNAL_DEBUGGING_OUT_ERROR << " "; \ + ELPP_INTERNAL_DEBUGGING_WRITE_PERROR; \ + } \ + } \ + (void)0 +#endif +#else +#undef ELPP_INTERNAL_INFO +#define ELPP_INTERNAL_ERROR(msg, pe) +#endif // defined(ELPP_DEBUG_ERRORS) +#if (defined(ELPP_DEBUG_INFO)) +#if !(defined(ELPP_INTERNAL_INFO_LEVEL)) +#define ELPP_INTERNAL_INFO_LEVEL 9 +#endif // !(defined(ELPP_INTERNAL_INFO_LEVEL)) +#if !defined(ELPP_INTERNAL_INFO) +#define ELPP_INTERNAL_INFO(lvl, msg) \ + { \ + if (lvl <= ELPP_INTERNAL_INFO_LEVEL) { \ + std::stringstream internalInfoStream; \ + internalInfoStream << " " << msg; \ + ELPP_INTERNAL_DEBUGGING_OUT_INFO << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) \ + << ELPP_INTERNAL_DEBUGGING_ENDL; \ + } \ + } +#endif +#else +#undef ELPP_INTERNAL_INFO +#define ELPP_INTERNAL_INFO(lvl, msg) +#endif // (defined(ELPP_DEBUG_INFO)) +#if (defined(ELPP_FEATURE_ALL)) || (defined(ELPP_FEATURE_CRASH_LOG)) +#if (ELPP_COMPILER_GCC && !ELPP_MINGW && !ELPP_OS_ANDROID && !ELPP_OS_EMSCRIPTEN) +#define ELPP_STACKTRACE 1 +#else +#if ELPP_COMPILER_MSVC +#pragma message("Stack trace not available for this compiler") +#else +#warning "Stack trace not available for this compiler"; +#endif // ELPP_COMPILER_MSVC +#define ELPP_STACKTRACE 0 +#endif // ELPP_COMPILER_GCC +#else +#define ELPP_STACKTRACE 0 +#endif // (defined(ELPP_FEATURE_ALL)) || (defined(ELPP_FEATURE_CRASH_LOG)) +// Miscellaneous macros +#define ELPP_UNUSED(x) (void)x +#if ELPP_OS_UNIX +// Log file permissions for unix-based systems +#define ELPP_LOG_PERMS S_IRUSR | S_IWUSR | S_IXUSR | S_IWGRP | S_IRGRP | S_IXGRP | S_IWOTH | S_IXOTH +#endif // ELPP_OS_UNIX +#if defined(ELPP_AS_DLL) && ELPP_COMPILER_MSVC +#if defined(ELPP_EXPORT_SYMBOLS) +#define ELPP_EXPORT __declspec(dllexport) +#else +#define ELPP_EXPORT __declspec(dllimport) +#endif // defined(ELPP_EXPORT_SYMBOLS) +#else +#define ELPP_EXPORT +#endif // defined(ELPP_AS_DLL) && ELPP_COMPILER_MSVC +// Some special functions that are VC++ specific +#undef STRTOK +#undef STRERROR +#undef STRCAT +#undef STRCPY +#if ELPP_CRT_DBG_WARNINGS +#define STRTOK(a, b, c) strtok_s(a, b, c) +#define STRERROR(a, b, c) strerror_s(a, b, c) +#define STRCAT(a, b, len) strcat_s(a, len, b) +#define STRCPY(a, b, len) strcpy_s(a, len, b) +#else +#define STRTOK(a, b, c) strtok(a, b) +#define STRERROR(a, b, c) strerror(c) +#define STRCAT(a, b, len) strcat(a, b) +#define STRCPY(a, b, len) strcpy(a, b) +#endif +// Compiler specific support evaluations +#if (ELPP_MINGW && !defined(ELPP_FORCE_USE_STD_THREAD)) +#define ELPP_USE_STD_THREADING 0 +#else +#if ((ELPP_COMPILER_CLANG && defined(ELPP_CLANG_SUPPORTS_THREAD)) || (!ELPP_COMPILER_CLANG && defined(ELPP_CXX11)) || \ + defined(ELPP_FORCE_USE_STD_THREAD)) +#define ELPP_USE_STD_THREADING 1 +#else +#define ELPP_USE_STD_THREADING 0 +#endif +#endif +#undef ELPP_FINAL +#if ELPP_COMPILER_INTEL || (ELPP_GCC_VERSION < 40702) +#define ELPP_FINAL +#else +#define ELPP_FINAL final +#endif // ELPP_COMPILER_INTEL || (ELPP_GCC_VERSION < 40702) +#if defined(ELPP_EXPERIMENTAL_ASYNC) +#define ELPP_ASYNC_LOGGING 1 +#else +#define ELPP_ASYNC_LOGGING 0 +#endif // defined(ELPP_EXPERIMENTAL_ASYNC) +#if defined(ELPP_THREAD_SAFE) || ELPP_ASYNC_LOGGING +#define ELPP_THREADING_ENABLED 1 +#else +#define ELPP_THREADING_ENABLED 0 +#endif // defined(ELPP_THREAD_SAFE) || ELPP_ASYNC_LOGGING +// Function macro ELPP_FUNC +#undef ELPP_FUNC +#if ELPP_COMPILER_MSVC // Visual C++ +#define ELPP_FUNC __FUNCSIG__ +#elif ELPP_COMPILER_GCC // GCC +#define ELPP_FUNC __PRETTY_FUNCTION__ +#elif ELPP_COMPILER_INTEL // Intel C++ +#define ELPP_FUNC __PRETTY_FUNCTION__ +#elif ELPP_COMPILER_CLANG // Clang++ +#define ELPP_FUNC __PRETTY_FUNCTION__ +#else +#if defined(__func__) +#define ELPP_FUNC __func__ +#else +#define ELPP_FUNC "" +#endif // defined(__func__) +#endif // defined(_MSC_VER) +#undef ELPP_VARIADIC_TEMPLATES_SUPPORTED +// Keep following line commented until features are fixed +#define ELPP_VARIADIC_TEMPLATES_SUPPORTED \ + (ELPP_COMPILER_GCC || ELPP_COMPILER_CLANG || ELPP_COMPILER_INTEL || (ELPP_COMPILER_MSVC && _MSC_VER >= 1800)) +// Logging Enable/Disable macros +#if defined(ELPP_DISABLE_LOGS) +#define ELPP_LOGGING_ENABLED 0 +#else +#define ELPP_LOGGING_ENABLED 1 +#endif +#if (!defined(ELPP_DISABLE_DEBUG_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_DEBUG_LOG 1 +#else +#define ELPP_DEBUG_LOG 0 +#endif // (!defined(ELPP_DISABLE_DEBUG_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_INFO_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_INFO_LOG 1 +#else +#define ELPP_INFO_LOG 0 +#endif // (!defined(ELPP_DISABLE_INFO_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_WARNING_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_WARNING_LOG 1 +#else +#define ELPP_WARNING_LOG 0 +#endif // (!defined(ELPP_DISABLE_WARNING_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_ERROR_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_ERROR_LOG 1 +#else +#define ELPP_ERROR_LOG 0 +#endif // (!defined(ELPP_DISABLE_ERROR_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_FATAL_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_FATAL_LOG 1 +#else +#define ELPP_FATAL_LOG 0 +#endif // (!defined(ELPP_DISABLE_FATAL_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_TRACE_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_TRACE_LOG 1 +#else +#define ELPP_TRACE_LOG 0 +#endif // (!defined(ELPP_DISABLE_TRACE_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_VERBOSE_LOGS) && (ELPP_LOGGING_ENABLED)) +#define ELPP_VERBOSE_LOG 1 +#else +#define ELPP_VERBOSE_LOG 0 +#endif // (!defined(ELPP_DISABLE_VERBOSE_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!(ELPP_CXX0X || ELPP_CXX11)) +#error "C++0x (or higher) support not detected! (Is `-std=c++11' missing?)" +#endif // (!(ELPP_CXX0X || ELPP_CXX11)) +// Headers +#if defined(ELPP_SYSLOG) +#include +#endif // defined(ELPP_SYSLOG) +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(ELPP_UNICODE) +#include +#if ELPP_OS_WINDOWS +#include +#endif // ELPP_OS_WINDOWS +#endif // defined(ELPP_UNICODE) +#if ELPP_STACKTRACE +#include +#include +#endif // ELPP_STACKTRACE +#if ELPP_OS_ANDROID +#include +#endif // ELPP_OS_ANDROID +#if ELPP_OS_UNIX +#include +#include +#elif ELPP_OS_WINDOWS +#include +#include +#if defined(WIN32_LEAN_AND_MEAN) +#if defined(ELPP_WINSOCK2) +#include +#else +#include +#endif // defined(ELPP_WINSOCK2) +#endif // defined(WIN32_LEAN_AND_MEAN) +#endif // ELPP_OS_UNIX +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if ELPP_THREADING_ENABLED +#if ELPP_USE_STD_THREADING +#include +#include +#else +#if ELPP_OS_UNIX +#include +#endif // ELPP_OS_UNIX +#endif // ELPP_USE_STD_THREADING +#endif // ELPP_THREADING_ENABLED +#if ELPP_ASYNC_LOGGING +#if defined(ELPP_NO_SLEEP_FOR) +#include +#endif // defined(ELPP_NO_SLEEP_FOR) +#include +#include +#include +#endif // ELPP_ASYNC_LOGGING +#if defined(ELPP_STL_LOGGING) +// For logging STL based templates +#include +#include +#include +#include +#include +#include +#if defined(ELPP_LOG_STD_ARRAY) +#include +#endif // defined(ELPP_LOG_STD_ARRAY) +#if defined(ELPP_LOG_UNORDERED_SET) +#include +#endif // defined(ELPP_UNORDERED_SET) +#endif // defined(ELPP_STL_LOGGING) +#if defined(ELPP_QT_LOGGING) +// For logging Qt based classes & templates +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif // defined(ELPP_QT_LOGGING) +#if defined(ELPP_BOOST_LOGGING) +// For logging boost based classes & templates +#include +#include +#include +#include +#include +#include +#include +#include +#endif // defined(ELPP_BOOST_LOGGING) +#if defined(ELPP_WXWIDGETS_LOGGING) +// For logging wxWidgets based classes & templates +#include +#endif // defined(ELPP_WXWIDGETS_LOGGING) +#if defined(ELPP_UTC_DATETIME) +#define elpptime_r gmtime_r +#define elpptime_s gmtime_s +#define elpptime gmtime +#else +#define elpptime_r localtime_r +#define elpptime_s localtime_s +#define elpptime localtime +#endif // defined(ELPP_UTC_DATETIME) +// Forward declarations +namespace el { +class Logger; +class LogMessage; +class PerformanceTrackingData; +class Loggers; +class Helpers; +template +class Callback; +class LogDispatchCallback; +class PerformanceTrackingCallback; +class LoggerRegistrationCallback; +class LogDispatchData; +namespace base { +class Storage; +class RegisteredLoggers; +class PerformanceTracker; +class MessageBuilder; +class Writer; +class PErrorWriter; +class LogDispatcher; +class DefaultLogBuilder; +class DefaultLogDispatchCallback; +#if ELPP_ASYNC_LOGGING +class AsyncLogDispatchCallback; +class AsyncDispatchWorker; +#endif // ELPP_ASYNC_LOGGING +class DefaultPerformanceTrackingCallback; +} // namespace base +} // namespace el +/// @brief Easylogging++ entry namespace +namespace el { +/// @brief Namespace containing base/internal functionality used by Easylogging++ +namespace base { +/// @brief Data types used by Easylogging++ +namespace type { +#undef ELPP_LITERAL +#undef ELPP_STRLEN +#undef ELPP_COUT +#if defined(ELPP_UNICODE) +#define ELPP_LITERAL(txt) L##txt +#define ELPP_STRLEN wcslen +#if defined ELPP_CUSTOM_COUT +#define ELPP_COUT ELPP_CUSTOM_COUT +#else +#define ELPP_COUT std::wcout +#endif // defined ELPP_CUSTOM_COUT +typedef wchar_t char_t; +typedef std::wstring string_t; +typedef std::wstringstream stringstream_t; +typedef std::wfstream fstream_t; +typedef std::wostream ostream_t; +#else +#define ELPP_LITERAL(txt) txt +#define ELPP_STRLEN strlen +#if defined ELPP_CUSTOM_COUT +#define ELPP_COUT ELPP_CUSTOM_COUT +#else +#define ELPP_COUT std::cout +#endif // defined ELPP_CUSTOM_COUT +typedef char char_t; +typedef std::string string_t; +typedef std::stringstream stringstream_t; +typedef std::fstream fstream_t; +typedef std::ostream ostream_t; +#endif // defined(ELPP_UNICODE) +#if defined(ELPP_CUSTOM_COUT_LINE) +#define ELPP_COUT_LINE(logLine) ELPP_CUSTOM_COUT_LINE(logLine) +#else +#define ELPP_COUT_LINE(logLine) logLine << std::flush +#endif // defined(ELPP_CUSTOM_COUT_LINE) +typedef unsigned int EnumType; +typedef unsigned short VerboseLevel; +typedef unsigned long int LineNumber; +typedef std::shared_ptr StoragePointer; +typedef std::shared_ptr LogDispatchCallbackPtr; +typedef std::shared_ptr PerformanceTrackingCallbackPtr; +typedef std::shared_ptr LoggerRegistrationCallbackPtr; +typedef std::unique_ptr PerformanceTrackerPtr; +} // namespace type +/// @brief Internal helper class that prevent copy constructor for class +/// +/// @detail When using this class simply inherit it privately +class NoCopy { + protected: + NoCopy(void) { + } + + private: + NoCopy(const NoCopy&); + NoCopy& + operator=(const NoCopy&); +}; +/// @brief Internal helper class that makes all default constructors private. +/// +/// @detail This prevents initializing class making it static unless an explicit constructor is declared. +/// When using this class simply inherit it privately +class StaticClass { + private: + StaticClass(void); + StaticClass(const StaticClass&); + StaticClass& + operator=(const StaticClass&); +}; +} // namespace base +/// @brief Represents enumeration for severity level used to determine level of logging +/// +/// @detail With Easylogging++, developers may disable or enable any level regardless of +/// what the severity is. Or they can choose to log using hierarchical logging flag +enum class Level : base::type::EnumType { + /// @brief Generic level that represents all the levels. Useful when setting global configuration for all levels + Global = 1, + /// @brief Information that can be useful to back-trace certain events - mostly useful than debug logs. + Trace = 2, + /// @brief Informational events most useful for developers to debug application + Debug = 4, + /// @brief Severe error information that will presumably abort application + Fatal = 8, + /// @brief Information representing errors in application but application will keep running + Error = 16, + /// @brief Useful when application has potentially harmful situtaions + Warning = 32, + /// @brief Information that can be highly useful and vary with verbose logging level. + Verbose = 64, + /// @brief Mainly useful to represent current progress of application + Info = 128, + /// @brief Represents unknown level + Unknown = 1010 +}; +} // namespace el +namespace std { +template <> +struct hash { + public: + std::size_t + operator()(const el::Level& l) const { + return hash{}(static_cast(l)); + } +}; +} // namespace std +namespace el { +/// @brief Static class that contains helper functions for el::Level +class LevelHelper : base::StaticClass { + public: + /// @brief Represents minimum valid level. Useful when iterating through enum. + static const base::type::EnumType kMinValid = static_cast(Level::Trace); + /// @brief Represents maximum valid level. This is used internally and you should not need it. + static const base::type::EnumType kMaxValid = static_cast(Level::Info); + /// @brief Casts level to int, useful for iterating through enum. + static base::type::EnumType + castToInt(Level level) { + return static_cast(level); + } + /// @brief Casts int(ushort) to level, useful for iterating through enum. + static Level + castFromInt(base::type::EnumType l) { + return static_cast(l); + } + /// @brief Converts level to associated const char* + /// @return Upper case string based level. + static const char* + convertToString(Level level); + /// @brief Converts from levelStr to Level + /// @param levelStr Upper case string based level. + /// Lower case is also valid but providing upper case is recommended. + static Level + convertFromString(const char* levelStr); + /// @brief Applies specified function to each level starting from startIndex + /// @param startIndex initial value to start the iteration from. This is passed as pointer and + /// is left-shifted so this can be used inside function (fn) to represent current level. + /// @param fn function to apply with each level. This bool represent whether or not to stop iterating through + /// levels. + static void + forEachLevel(base::type::EnumType* startIndex, const std::function& fn); +}; +/// @brief Represents enumeration of ConfigurationType used to configure or access certain aspect +/// of logging +enum class ConfigurationType : base::type::EnumType { + /// @brief Determines whether or not corresponding level and logger of logging is enabled + /// You may disable all logs by using el::Level::Global + Enabled = 1, + /// @brief Whether or not to write corresponding log to log file + ToFile = 2, + /// @brief Whether or not to write corresponding level and logger log to standard output. + /// By standard output meaning termnal, command prompt etc + ToStandardOutput = 4, + /// @brief Determines format of logging corresponding level and logger. + Format = 8, + /// @brief Determines log file (full path) to write logs to for correponding level and logger + Filename = 16, + /// @brief Specifies precision of the subsecond part. It should be within range (1-6). + SubsecondPrecision = 32, + /// @brief Alias of SubsecondPrecision (for backward compatibility) + MillisecondsWidth = SubsecondPrecision, + /// @brief Determines whether or not performance tracking is enabled. + /// + /// @detail This does not depend on logger or level. Performance tracking always uses 'performance' logger + PerformanceTracking = 64, + /// @brief Specifies log file max size. + /// + /// @detail If file size of corresponding log file (for corresponding level) is >= specified size, log file will + /// be truncated and re-initiated. + MaxLogFileSize = 128, + /// @brief Specifies number of log entries to hold until we flush pending log data + LogFlushThreshold = 256, + /// @brief Represents unknown configuration + Unknown = 1010 +}; +/// @brief Static class that contains helper functions for el::ConfigurationType +class ConfigurationTypeHelper : base::StaticClass { + public: + /// @brief Represents minimum valid configuration type. Useful when iterating through enum. + static const base::type::EnumType kMinValid = static_cast(ConfigurationType::Enabled); + /// @brief Represents maximum valid configuration type. This is used internally and you should not need it. + static const base::type::EnumType kMaxValid = static_cast(ConfigurationType::MaxLogFileSize); + /// @brief Casts configuration type to int, useful for iterating through enum. + static base::type::EnumType + castToInt(ConfigurationType configurationType) { + return static_cast(configurationType); + } + /// @brief Casts int(ushort) to configurationt type, useful for iterating through enum. + static ConfigurationType + castFromInt(base::type::EnumType c) { + return static_cast(c); + } + /// @brief Converts configuration type to associated const char* + /// @returns Upper case string based configuration type. + static const char* + convertToString(ConfigurationType configurationType); + /// @brief Converts from configStr to ConfigurationType + /// @param configStr Upper case string based configuration type. + /// Lower case is also valid but providing upper case is recommended. + static ConfigurationType + convertFromString(const char* configStr); + /// @brief Applies specified function to each configuration type starting from startIndex + /// @param startIndex initial value to start the iteration from. This is passed by pointer and is left-shifted + /// so this can be used inside function (fn) to represent current configuration type. + /// @param fn function to apply with each configuration type. + /// This bool represent whether or not to stop iterating through configurations. + static inline void + forEachConfigType(base::type::EnumType* startIndex, const std::function& fn); +}; +/// @brief Flags used while writing logs. This flags are set by user +enum class LoggingFlag : base::type::EnumType { + /// @brief Makes sure we have new line for each container log entry + NewLineForContainer = 1, + /// @brief Makes sure if -vmodule is used and does not specifies a module, then verbose + /// logging is allowed via that module. + AllowVerboseIfModuleNotSpecified = 2, + /// @brief When handling crashes by default, detailed crash reason will be logged as well + LogDetailedCrashReason = 4, + /// @brief Allows to disable application abortion when logged using FATAL level + DisableApplicationAbortOnFatalLog = 8, + /// @brief Flushes log with every log-entry (performance sensative) - Disabled by default + ImmediateFlush = 16, + /// @brief Enables strict file rolling + StrictLogFileSizeCheck = 32, + /// @brief Make terminal output colorful for supported terminals + ColoredTerminalOutput = 64, + /// @brief Supports use of multiple logging in same macro, e.g, CLOG(INFO, "default", "network") + MultiLoggerSupport = 128, + /// @brief Disables comparing performance tracker's checkpoints + DisablePerformanceTrackingCheckpointComparison = 256, + /// @brief Disable VModules + DisableVModules = 512, + /// @brief Disable VModules extensions + DisableVModulesExtensions = 1024, + /// @brief Enables hierarchical logging + HierarchicalLogging = 2048, + /// @brief Creates logger automatically when not available + CreateLoggerAutomatically = 4096, + /// @brief Adds spaces b/w logs that separated by left-shift operator + AutoSpacing = 8192, + /// @brief Preserves time format and does not convert it to sec, hour etc (performance tracking only) + FixedTimeFormat = 16384, + // @brief Ignore SIGINT or crash + IgnoreSigInt = 32768, +}; +namespace base { +/// @brief Namespace containing constants used internally. +namespace consts { +static const char kFormatSpecifierCharValue = 'v'; +static const char kFormatSpecifierChar = '%'; +static const unsigned int kMaxLogPerCounter = 100000; +static const unsigned int kMaxLogPerContainer = 100; +static const unsigned int kDefaultSubsecondPrecision = 3; + +#ifdef ELPP_DEFAULT_LOGGER +static const char* kDefaultLoggerId = ELPP_DEFAULT_LOGGER; +#else +static const char* kDefaultLoggerId = "default"; +#endif + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +#ifdef ELPP_DEFAULT_PERFORMANCE_LOGGER +static const char* kPerformanceLoggerId = ELPP_DEFAULT_PERFORMANCE_LOGGER; +#else +static const char* kPerformanceLoggerId = "performance"; +#endif // ELPP_DEFAULT_PERFORMANCE_LOGGER +#endif + +#if defined(ELPP_SYSLOG) +static const char* kSysLogLoggerId = "syslog"; +#endif // defined(ELPP_SYSLOG) + +#if ELPP_OS_WINDOWS +static const char* kFilePathSeperator = "\\"; +#else +static const char* kFilePathSeperator = "/"; +#endif // ELPP_OS_WINDOWS + +static const std::size_t kSourceFilenameMaxLength = 100; +static const std::size_t kSourceLineMaxLength = 10; +static const Level kPerformanceTrackerDefaultLevel = Level::Info; +const struct { + double value; + const base::type::char_t* unit; +} kTimeFormats[] = {{1000.0f, ELPP_LITERAL("us")}, {1000.0f, ELPP_LITERAL("ms")}, {60.0f, ELPP_LITERAL("seconds")}, + {60.0f, ELPP_LITERAL("minutes")}, {24.0f, ELPP_LITERAL("hours")}, {7.0f, ELPP_LITERAL("days")}}; +static const int kTimeFormatsCount = sizeof(kTimeFormats) / sizeof(kTimeFormats[0]); +const struct { + int numb; + const char* name; + const char* brief; + const char* detail; +} kCrashSignals[] = { + // NOTE: Do not re-order, if you do please check CrashHandler(bool) constructor and CrashHandler::setHandler(..) + {SIGABRT, "SIGABRT", "Abnormal termination", "Program was abnormally terminated."}, + {SIGFPE, "SIGFPE", "Erroneous arithmetic operation", + "Arithemetic operation issue such as division by zero or operation resulting in overflow."}, + {SIGILL, "SIGILL", "Illegal instruction", + "Generally due to a corruption in the code or to an attempt to execute data."}, + {SIGSEGV, "SIGSEGV", "Invalid access to memory", + "Program is trying to read an invalid (unallocated, deleted or corrupted) or inaccessible memory."}, + {SIGINT, "SIGINT", "Interactive attention signal", + "Interruption generated (generally) by user or operating system."}, +}; +static const int kCrashSignalsCount = sizeof(kCrashSignals) / sizeof(kCrashSignals[0]); +} // namespace consts +} // namespace base +typedef std::function PreRollOutCallback; +namespace base { +static inline void +defaultPreRollOutCallback(const char*, std::size_t, Level level) { +} +/// @brief Enum to represent timestamp unit +enum class TimestampUnit : base::type::EnumType { + Microsecond = 0, + Millisecond = 1, + Second = 2, + Minute = 3, + Hour = 4, + Day = 5 +}; +/// @brief Format flags used to determine specifiers that are active for performance improvements. +enum class FormatFlags : base::type::EnumType { + DateTime = 1 << 1, + LoggerId = 1 << 2, + File = 1 << 3, + Line = 1 << 4, + Location = 1 << 5, + Function = 1 << 6, + User = 1 << 7, + Host = 1 << 8, + LogMessage = 1 << 9, + VerboseLevel = 1 << 10, + AppName = 1 << 11, + ThreadId = 1 << 12, + Level = 1 << 13, + FileBase = 1 << 14, + LevelShort = 1 << 15 +}; +/// @brief A subsecond precision class containing actual width and offset of the subsecond part +class SubsecondPrecision { + public: + SubsecondPrecision(void) { + init(base::consts::kDefaultSubsecondPrecision); + } + explicit SubsecondPrecision(int width) { + init(width); + } + bool + operator==(const SubsecondPrecision& ssPrec) { + return m_width == ssPrec.m_width && m_offset == ssPrec.m_offset; + } + int m_width; + unsigned int m_offset; + + private: + void + init(int width); +}; +/// @brief Type alias of SubsecondPrecision +typedef SubsecondPrecision MillisecondsWidth; +/// @brief Namespace containing utility functions/static classes used internally +namespace utils { +/// @brief Deletes memory safely and points to null +template +static typename std::enable_if::value, void>::type +safeDelete(T*& pointer) { + if (pointer == nullptr) + return; + delete pointer; + pointer = nullptr; +} +/// @brief Bitwise operations for C++11 strong enum class. This casts e into Flag_T and returns value after bitwise +/// operation Use these function as
flag = bitwise::Or(MyEnum::val1, flag);
+namespace bitwise { +template +static inline base::type::EnumType +And(Enum e, base::type::EnumType flag) { + return static_cast(flag) & static_cast(e); +} +template +static inline base::type::EnumType +Not(Enum e, base::type::EnumType flag) { + return static_cast(flag) & ~(static_cast(e)); +} +template +static inline base::type::EnumType +Or(Enum e, base::type::EnumType flag) { + return static_cast(flag) | static_cast(e); +} +} // namespace bitwise +template +static inline void +addFlag(Enum e, base::type::EnumType* flag) { + *flag = base::utils::bitwise::Or(e, *flag); +} +template +static inline void +removeFlag(Enum e, base::type::EnumType* flag) { + *flag = base::utils::bitwise::Not(e, *flag); +} +template +static inline bool +hasFlag(Enum e, base::type::EnumType flag) { + return base::utils::bitwise::And(e, flag) > 0x0; +} +} // namespace utils +namespace threading { +#if ELPP_THREADING_ENABLED +#if !ELPP_USE_STD_THREADING +namespace internal { +/// @brief A mutex wrapper for compiler that dont yet support std::recursive_mutex +class Mutex : base::NoCopy { + public: + Mutex(void) { +#if ELPP_OS_UNIX + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(&m_underlyingMutex, &attr); + pthread_mutexattr_destroy(&attr); +#elif ELPP_OS_WINDOWS + InitializeCriticalSection(&m_underlyingMutex); +#endif // ELPP_OS_UNIX + } + + virtual ~Mutex(void) { +#if ELPP_OS_UNIX + pthread_mutex_destroy(&m_underlyingMutex); +#elif ELPP_OS_WINDOWS + DeleteCriticalSection(&m_underlyingMutex); +#endif // ELPP_OS_UNIX + } + + inline void + lock(void) { +#if ELPP_OS_UNIX + pthread_mutex_lock(&m_underlyingMutex); +#elif ELPP_OS_WINDOWS + EnterCriticalSection(&m_underlyingMutex); +#endif // ELPP_OS_UNIX + } + + inline bool + try_lock(void) { +#if ELPP_OS_UNIX + return (pthread_mutex_trylock(&m_underlyingMutex) == 0); +#elif ELPP_OS_WINDOWS + return TryEnterCriticalSection(&m_underlyingMutex); +#endif // ELPP_OS_UNIX + } + + inline void + unlock(void) { +#if ELPP_OS_UNIX + pthread_mutex_unlock(&m_underlyingMutex); +#elif ELPP_OS_WINDOWS + LeaveCriticalSection(&m_underlyingMutex); +#endif // ELPP_OS_UNIX + } + + private: +#if ELPP_OS_UNIX + pthread_mutex_t m_underlyingMutex; +#elif ELPP_OS_WINDOWS + CRITICAL_SECTION m_underlyingMutex; +#endif // ELPP_OS_UNIX +}; +/// @brief Scoped lock for compiler that dont yet support std::lock_guard +template +class ScopedLock : base::NoCopy { + public: + explicit ScopedLock(M& mutex) { + m_mutex = &mutex; + m_mutex->lock(); + } + + virtual ~ScopedLock(void) { + m_mutex->unlock(); + } + + private: + M* m_mutex; + ScopedLock(void); +}; +} // namespace internal +typedef base::threading::internal::Mutex Mutex; +typedef base::threading::internal::ScopedLock ScopedLock; +#else +typedef std::recursive_mutex Mutex; +typedef std::lock_guard ScopedLock; +#endif // !ELPP_USE_STD_THREADING +#else +namespace internal { +/// @brief Mutex wrapper used when multi-threading is disabled. +class NoMutex : base::NoCopy { + public: + NoMutex(void) { + } + inline void + lock(void) { + } + inline bool + try_lock(void) { + return true; + } + inline void + unlock(void) { + } +}; +/// @brief Lock guard wrapper used when multi-threading is disabled. +template +class NoScopedLock : base::NoCopy { + public: + explicit NoScopedLock(Mutex&) { + } + virtual ~NoScopedLock(void) { + } + + private: + NoScopedLock(void); +}; +} // namespace internal +typedef base::threading::internal::NoMutex Mutex; +typedef base::threading::internal::NoScopedLock ScopedLock; +#endif // ELPP_THREADING_ENABLED +/// @brief Base of thread safe class, this class is inheritable-only +class ThreadSafe { + public: + virtual inline void + acquireLock(void) ELPP_FINAL { + m_mutex.lock(); + } + virtual inline void + releaseLock(void) ELPP_FINAL { + m_mutex.unlock(); + } + virtual inline base::threading::Mutex& + lock(void) ELPP_FINAL { + return m_mutex; + } + + protected: + ThreadSafe(void) { + } + virtual ~ThreadSafe(void) { + } + + private: + base::threading::Mutex m_mutex; +}; + +#if ELPP_THREADING_ENABLED +#if !ELPP_USE_STD_THREADING +/// @brief Gets ID of currently running threading in windows systems. On unix, nothing is returned. +static std::string +getCurrentThreadId(void) { + std::stringstream ss; +#if (ELPP_OS_WINDOWS) + ss << GetCurrentThreadId(); +#endif // (ELPP_OS_WINDOWS) + return ss.str(); +} +#else +/// @brief Gets ID of currently running threading using std::this_thread::get_id() +static std::string +getCurrentThreadId(void) { + std::stringstream ss; + ss << std::this_thread::get_id(); + return ss.str(); +} +#endif // !ELPP_USE_STD_THREADING +#else +static inline std::string +getCurrentThreadId(void) { + return std::string(); +} +#endif // ELPP_THREADING_ENABLED +} // namespace threading +namespace utils { +class File : base::StaticClass { + public: + /// @brief Creates new out file stream for specified filename. + /// @return Pointer to newly created fstream or nullptr + static base::type::fstream_t* + newFileStream(const std::string& filename); + + /// @brief Gets size of file provided in stream + static std::size_t + getSizeOfFile(base::type::fstream_t* fs); + + /// @brief Determines whether or not provided path exist in current file system + static bool + pathExists(const char* path, bool considerFile = false); + + /// @brief Creates specified path on file system + /// @param path Path to create. + static bool + createPath(const std::string& path); + /// @brief Extracts path of filename with leading slash + static std::string + extractPathFromFilename(const std::string& fullPath, const char* seperator = base::consts::kFilePathSeperator); + /// @brief builds stripped filename and puts it in buff + static void + buildStrippedFilename(const char* filename, char buff[], + std::size_t limit = base::consts::kSourceFilenameMaxLength); + /// @brief builds base filename and puts it in buff + static void + buildBaseFilename(const std::string& fullPath, char buff[], + std::size_t limit = base::consts::kSourceFilenameMaxLength, + const char* seperator = base::consts::kFilePathSeperator); +}; +/// @brief String utilities helper class used internally. You should not use it. +class Str : base::StaticClass { + public: + /// @brief Checks if character is digit. Dont use libc implementation of it to prevent locale issues. + static inline bool + isDigit(char c) { + return c >= '0' && c <= '9'; + } + + /// @brief Matches wildcards, '*' and '?' only supported. + static bool + wildCardMatch(const char* str, const char* pattern); + + static std::string& + ltrim(std::string& str); + static std::string& + rtrim(std::string& str); + static std::string& + trim(std::string& str); + + /// @brief Determines whether or not str starts with specified string + /// @param str String to check + /// @param start String to check against + /// @return Returns true if starts with specified string, false otherwise + static bool + startsWith(const std::string& str, const std::string& start); + + /// @brief Determines whether or not str ends with specified string + /// @param str String to check + /// @param end String to check against + /// @return Returns true if ends with specified string, false otherwise + static bool + endsWith(const std::string& str, const std::string& end); + + /// @brief Replaces all instances of replaceWhat with 'replaceWith'. Original variable is changed for performance. + /// @param [in,out] str String to replace from + /// @param replaceWhat Character to replace + /// @param replaceWith Character to replace with + /// @return Modified version of str + static std::string& + replaceAll(std::string& str, char replaceWhat, char replaceWith); + + /// @brief Replaces all instances of 'replaceWhat' with 'replaceWith'. (String version) Replaces in place + /// @param str String to replace from + /// @param replaceWhat Character to replace + /// @param replaceWith Character to replace with + /// @return Modified (original) str + static std::string& + replaceAll(std::string& str, const std::string& replaceWhat, const std::string& replaceWith); + + static void + replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const base::type::string_t& replaceWith); +#if defined(ELPP_UNICODE) + static void + replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const std::string& replaceWith); +#endif // defined(ELPP_UNICODE) + /// @brief Converts string to uppercase + /// @param str String to convert + /// @return Uppercase string + static std::string& + toUpper(std::string& str); + + /// @brief Compares cstring equality - uses strcmp + static bool + cStringEq(const char* s1, const char* s2); + + /// @brief Compares cstring equality (case-insensitive) - uses toupper(char) + /// Dont use strcasecmp because of CRT (VC++) + static bool + cStringCaseEq(const char* s1, const char* s2); + + /// @brief Returns true if c exist in str + static bool + contains(const char* str, char c); + + static char* + convertAndAddToBuff(std::size_t n, int len, char* buf, const char* bufLim, bool zeroPadded = true); + static char* + addToBuff(const char* str, char* buf, const char* bufLim); + static char* + clearBuff(char buff[], std::size_t lim); + + /// @brief Converst wchar* to char* + /// NOTE: Need to free return value after use! + static char* + wcharPtrToCharPtr(const wchar_t* line); +}; +/// @brief Operating System helper static class used internally. You should not use it. +class OS : base::StaticClass { + public: +#if ELPP_OS_WINDOWS + /// @brief Gets environment variables for Windows based OS. + /// We are not using getenv(const char*) because of CRT deprecation + /// @param varname Variable name to get environment variable value for + /// @return If variable exist the value of it otherwise nullptr + static const char* + getWindowsEnvironmentVariable(const char* varname); +#endif // ELPP_OS_WINDOWS +#if ELPP_OS_ANDROID + /// @brief Reads android property value + static std::string + getProperty(const char* prop); + + /// @brief Reads android device name + static std::string + getDeviceName(void); +#endif // ELPP_OS_ANDROID + + /// @brief Runs command on terminal and returns the output. + /// + /// @detail This is applicable only on unix based systems, for all other OS, an empty string is returned. + /// @param command Bash command + /// @return Result of bash output or empty string if no result found. + static const std::string + getBashOutput(const char* command); + + /// @brief Gets environment variable. This is cross-platform and CRT safe (for VC++) + /// @param variableName Environment variable name + /// @param defaultVal If no environment variable or value found the value to return by default + /// @param alternativeBashCommand If environment variable not found what would be alternative bash command + /// in order to look for value user is looking for. E.g, for 'user' alternative command will 'whoami' + static std::string + getEnvironmentVariable(const char* variableName, const char* defaultVal, + const char* alternativeBashCommand = nullptr); + /// @brief Gets current username. + static std::string + currentUser(void); + + /// @brief Gets current host name or computer name. + /// + /// @detail For android systems this is device name with its manufacturer and model seperated by hyphen + static std::string + currentHost(void); + /// @brief Whether or not terminal supports colors + static bool + termSupportsColor(void); +}; +/// @brief Contains utilities for cross-platform date/time. This class make use of el::base::utils::Str +class DateTime : base::StaticClass { + public: + /// @brief Cross platform gettimeofday for Windows and unix platform. This can be used to determine current + /// microsecond. + /// + /// @detail For unix system it uses gettimeofday(timeval*, timezone*) and for Windows, a seperate implementation is + /// provided + /// @param [in,out] tv Pointer that gets updated + static void + gettimeofday(struct timeval* tv); + + /// @brief Gets current date and time with a subsecond part. + /// @param format User provided date/time format + /// @param ssPrec A pointer to base::SubsecondPrecision from configuration (non-null) + /// @returns string based date time in specified format. + static std::string + getDateTime(const char* format, const base::SubsecondPrecision* ssPrec); + + /// @brief Converts timeval (struct from ctime) to string using specified format and subsecond precision + static std::string + timevalToString(struct timeval tval, const char* format, const el::base::SubsecondPrecision* ssPrec); + + /// @brief Formats time to get unit accordingly, units like second if > 1000 or minutes if > 60000 etc + static base::type::string_t + formatTime(unsigned long long time, base::TimestampUnit timestampUnit); + + /// @brief Gets time difference in milli/micro second depending on timestampUnit + static unsigned long long + getTimeDifference(const struct timeval& endTime, const struct timeval& startTime, + base::TimestampUnit timestampUnit); + + static struct ::tm* + buildTimeInfo(struct timeval* currTime, struct ::tm* timeInfo); + + private: + static char* + parseFormat(char* buf, std::size_t bufSz, const char* format, const struct tm* tInfo, std::size_t msec, + const base::SubsecondPrecision* ssPrec); +}; +/// @brief Command line arguments for application if specified using el::Helpers::setArgs(..) or START_EASYLOGGINGPP(..) +class CommandLineArgs { + public: + CommandLineArgs(void) { + setArgs(0, static_cast(nullptr)); + } + CommandLineArgs(int argc, const char** argv) { + setArgs(argc, argv); + } + CommandLineArgs(int argc, char** argv) { + setArgs(argc, argv); + } + virtual ~CommandLineArgs(void) { + } + /// @brief Sets arguments and parses them + inline void + setArgs(int argc, const char** argv) { + setArgs(argc, const_cast(argv)); + } + /// @brief Sets arguments and parses them + void + setArgs(int argc, char** argv); + /// @brief Returns true if arguments contain paramKey with a value (seperated by '=') + bool + hasParamWithValue(const char* paramKey) const; + /// @brief Returns value of arguments + /// @see hasParamWithValue(const char*) + const char* + getParamValue(const char* paramKey) const; + /// @brief Return true if arguments has a param (not having a value) i,e without '=' + bool + hasParam(const char* paramKey) const; + /// @brief Returns true if no params available. This exclude argv[0] + bool + empty(void) const; + /// @brief Returns total number of arguments. This exclude argv[0] + std::size_t + size(void) const; + friend base::type::ostream_t& + operator<<(base::type::ostream_t& os, const CommandLineArgs& c); + + private: + int m_argc; + char** m_argv; + std::unordered_map m_paramsWithValue; + std::vector m_params; +}; +/// @brief Abstract registry (aka repository) that provides basic interface for pointer repository specified by T_Ptr +/// type. +/// +/// @detail Most of the functions are virtual final methods but anything implementing this abstract class should +/// implement unregisterAll() and deepCopy(const AbstractRegistry&) and write registerNew() method +/// according to container and few more methods; get() to find element, unregister() to unregister single entry. Please +/// note that this is thread-unsafe and should also implement thread-safety mechanisms in implementation. +template +class AbstractRegistry : public base::threading::ThreadSafe { + public: + typedef typename Container::iterator iterator; + typedef typename Container::const_iterator const_iterator; + + /// @brief Default constructor + AbstractRegistry(void) { + } + + /// @brief Move constructor that is useful for base classes + AbstractRegistry(AbstractRegistry&& sr) { + if (this == &sr) { + return; + } + unregisterAll(); + m_list = std::move(sr.m_list); + } + + bool + operator==(const AbstractRegistry& other) { + if (size() != other.size()) { + return false; + } + for (std::size_t i = 0; i < m_list.size(); ++i) { + if (m_list.at(i) != other.m_list.at(i)) { + return false; + } + } + return true; + } + + bool + operator!=(const AbstractRegistry& other) { + if (size() != other.size()) { + return true; + } + for (std::size_t i = 0; i < m_list.size(); ++i) { + if (m_list.at(i) != other.m_list.at(i)) { + return true; + } + } + return false; + } + + /// @brief Assignment move operator + AbstractRegistry& + operator=(AbstractRegistry&& sr) { + if (this == &sr) { + return *this; + } + unregisterAll(); + m_list = std::move(sr.m_list); + return *this; + } + + virtual ~AbstractRegistry(void) { + } + + /// @return Iterator pointer from start of repository + virtual inline iterator + begin(void) ELPP_FINAL { + return m_list.begin(); + } + + /// @return Iterator pointer from end of repository + virtual inline iterator + end(void) ELPP_FINAL { + return m_list.end(); + } + + /// @return Constant iterator pointer from start of repository + virtual inline const_iterator + cbegin(void) const ELPP_FINAL { + return m_list.cbegin(); + } + + /// @return End of repository + virtual inline const_iterator + cend(void) const ELPP_FINAL { + return m_list.cend(); + } + + /// @return Whether or not repository is empty + virtual inline bool + empty(void) const ELPP_FINAL { + return m_list.empty(); + } + + /// @return Size of repository + virtual inline std::size_t + size(void) const ELPP_FINAL { + return m_list.size(); + } + + /// @brief Returns underlying container by reference + virtual inline Container& + list(void) ELPP_FINAL { + return m_list; + } + + /// @brief Returns underlying container by constant reference. + virtual inline const Container& + list(void) const ELPP_FINAL { + return m_list; + } + + /// @brief Unregisters all the pointers from current repository. + virtual void + unregisterAll(void) = 0; + + protected: + virtual void + deepCopy(const AbstractRegistry&) = 0; + void + reinitDeepCopy(const AbstractRegistry& sr) { + unregisterAll(); + deepCopy(sr); + } + + private: + Container m_list; +}; + +/// @brief A pointer registry mechanism to manage memory and provide search functionalities. (non-predicate version) +/// +/// @detail NOTE: This is thread-unsafe implementation (although it contains lock function, it does not use these +/// functions) +/// of AbstractRegistry. Any implementation of this class should be +/// explicitly (by using lock functions) +template +class Registry : public AbstractRegistry> { + public: + typedef typename Registry::iterator iterator; + typedef typename Registry::const_iterator const_iterator; + + Registry(void) { + } + + /// @brief Copy constructor that is useful for base classes. Try to avoid this constructor, use move constructor. + Registry(const Registry& sr) : AbstractRegistry>() { + if (this == &sr) { + return; + } + this->reinitDeepCopy(sr); + } + + /// @brief Assignment operator that unregisters all the existing registeries and deeply copies each of repo element + /// @see unregisterAll() + /// @see deepCopy(const AbstractRegistry&) + Registry& + operator=(const Registry& sr) { + if (this == &sr) { + return *this; + } + this->reinitDeepCopy(sr); + return *this; + } + + virtual ~Registry(void) { + unregisterAll(); + } + + protected: + virtual void + unregisterAll(void) ELPP_FINAL { + if (!this->empty()) { + for (auto&& curr : this->list()) { + base::utils::safeDelete(curr.second); + } + this->list().clear(); + } + } + + /// @brief Registers new registry to repository. + virtual void + registerNew(const T_Key& uniqKey, T_Ptr* ptr) ELPP_FINAL { + unregister(uniqKey); + this->list().insert(std::make_pair(uniqKey, ptr)); + } + + /// @brief Unregisters single entry mapped to specified unique key + void + unregister(const T_Key& uniqKey) { + T_Ptr* existing = get(uniqKey); + if (existing != nullptr) { + this->list().erase(uniqKey); + base::utils::safeDelete(existing); + } + } + + /// @brief Gets pointer from repository. If none found, nullptr is returned. + T_Ptr* + get(const T_Key& uniqKey) { + iterator it = this->list().find(uniqKey); + return it == this->list().end() ? nullptr : it->second; + } + + private: + virtual void + deepCopy(const AbstractRegistry>& sr) ELPP_FINAL { + for (const_iterator it = sr.cbegin(); it != sr.cend(); ++it) { + registerNew(it->first, new T_Ptr(*it->second)); + } + } +}; + +/// @brief A pointer registry mechanism to manage memory and provide search functionalities. (predicate version) +/// +/// @detail NOTE: This is thread-unsafe implementation of AbstractRegistry. Any implementation of this +/// class should be made thread-safe explicitly +template +class RegistryWithPred : public AbstractRegistry> { + public: + typedef typename RegistryWithPred::iterator iterator; + typedef typename RegistryWithPred::const_iterator const_iterator; + + RegistryWithPred(void) { + } + + virtual ~RegistryWithPred(void) { + unregisterAll(); + } + + /// @brief Copy constructor that is useful for base classes. Try to avoid this constructor, use move constructor. + RegistryWithPred(const RegistryWithPred& sr) : AbstractRegistry>() { + if (this == &sr) { + return; + } + this->reinitDeepCopy(sr); + } + + /// @brief Assignment operator that unregisters all the existing registeries and deeply copies each of repo element + /// @see unregisterAll() + /// @see deepCopy(const AbstractRegistry&) + RegistryWithPred& + operator=(const RegistryWithPred& sr) { + if (this == &sr) { + return *this; + } + this->reinitDeepCopy(sr); + return *this; + } + + friend base::type::ostream_t& + operator<<(base::type::ostream_t& os, const RegistryWithPred& sr) { + for (const_iterator it = sr.list().begin(); it != sr.list().end(); ++it) { + os << ELPP_LITERAL(" ") << **it << ELPP_LITERAL("\n"); + } + return os; + } + + protected: + virtual void + unregisterAll(void) ELPP_FINAL { + if (!this->empty()) { + for (auto&& curr : this->list()) { + base::utils::safeDelete(curr); + } + this->list().clear(); + } + } + + virtual void + unregister(T_Ptr*& ptr) ELPP_FINAL { + if (ptr) { + iterator iter = this->begin(); + for (; iter != this->end(); ++iter) { + if (ptr == *iter) { + break; + } + } + if (iter != this->end() && *iter != nullptr) { + this->list().erase(iter); + base::utils::safeDelete(*iter); + } + } + } + + virtual inline void + registerNew(T_Ptr* ptr) ELPP_FINAL { + this->list().push_back(ptr); + } + + /// @brief Gets pointer from repository with speicifed arguments. Arguments are passed to predicate + /// in order to validate pointer. + template + T_Ptr* + get(const T& arg1, const T2 arg2) { + iterator iter = std::find_if(this->list().begin(), this->list().end(), Pred(arg1, arg2)); + if (iter != this->list().end() && *iter != nullptr) { + return *iter; + } + return nullptr; + } + + private: + virtual void + deepCopy(const AbstractRegistry>& sr) { + for (const_iterator it = sr.list().begin(); it != sr.list().end(); ++it) { + registerNew(new T_Ptr(**it)); + } + } +}; +class Utils { + public: + template + static bool + installCallback(const std::string& id, std::unordered_map* mapT) { + if (mapT->find(id) == mapT->end()) { + mapT->insert(std::make_pair(id, TPtr(new T()))); + return true; + } + return false; + } + + template + static void + uninstallCallback(const std::string& id, std::unordered_map* mapT) { + if (mapT->find(id) != mapT->end()) { + mapT->erase(id); + } + } + + template + static T* + callback(const std::string& id, std::unordered_map* mapT) { + typename std::unordered_map::iterator iter = mapT->find(id); + if (iter != mapT->end()) { + return static_cast(iter->second.get()); + } + return nullptr; + } +}; +} // namespace utils +} // namespace base +/// @brief Base of Easylogging++ friendly class +/// +/// @detail After inheriting this class publicly, implement pure-virtual function `void log(std::ostream&) const` +class Loggable { + public: + virtual ~Loggable(void) { + } + virtual void + log(el::base::type::ostream_t&) const = 0; + + private: + friend inline el::base::type::ostream_t& + operator<<(el::base::type::ostream_t& os, const Loggable& loggable) { + loggable.log(os); + return os; + } +}; +namespace base { +/// @brief Represents log format containing flags and date format. This is used internally to start initial log +class LogFormat : public Loggable { + public: + LogFormat(void); + LogFormat(Level level, const base::type::string_t& format); + LogFormat(const LogFormat& logFormat); + LogFormat(LogFormat&& logFormat); + LogFormat& + operator=(const LogFormat& logFormat); + virtual ~LogFormat(void) { + } + bool + operator==(const LogFormat& other); + + /// @brief Updates format to be used while logging. + /// @param userFormat User provided format + void + parseFromFormat(const base::type::string_t& userFormat); + + inline Level + level(void) const { + return m_level; + } + + inline const base::type::string_t& + userFormat(void) const { + return m_userFormat; + } + + inline const base::type::string_t& + format(void) const { + return m_format; + } + + inline const std::string& + dateTimeFormat(void) const { + return m_dateTimeFormat; + } + + inline base::type::EnumType + flags(void) const { + return m_flags; + } + + inline bool + hasFlag(base::FormatFlags flag) const { + return base::utils::hasFlag(flag, m_flags); + } + + virtual void + log(el::base::type::ostream_t& os) const { + os << m_format; + } + + protected: + /// @brief Updates date time format if available in currFormat. + /// @param index Index where %datetime, %date or %time was found + /// @param [in,out] currFormat current format that is being used to format + virtual void + updateDateFormat(std::size_t index, base::type::string_t& currFormat) ELPP_FINAL; + + /// @brief Updates %level from format. This is so that we dont have to do it at log-writing-time. It uses m_format + /// and m_level + virtual void + updateFormatSpec(void) ELPP_FINAL; + + inline void + addFlag(base::FormatFlags flag) { + base::utils::addFlag(flag, &m_flags); + } + + private: + Level m_level; + base::type::string_t m_userFormat; + base::type::string_t m_format; + std::string m_dateTimeFormat; + base::type::EnumType m_flags; + std::string m_currentUser; + std::string m_currentHost; + friend class el::Logger; // To resolve loggerId format specifier easily +}; +} // namespace base +/// @brief Resolving function for format specifier +typedef std::function FormatSpecifierValueResolver; +/// @brief User-provided custom format specifier +/// @see el::Helpers::installCustomFormatSpecifier +/// @see FormatSpecifierValueResolver +class CustomFormatSpecifier { + public: + CustomFormatSpecifier(const char* formatSpecifier, const FormatSpecifierValueResolver& resolver) + : m_formatSpecifier(formatSpecifier), m_resolver(resolver) { + } + inline const char* + formatSpecifier(void) const { + return m_formatSpecifier; + } + inline const FormatSpecifierValueResolver& + resolver(void) const { + return m_resolver; + } + inline bool + operator==(const char* formatSpecifier) { + return strcmp(m_formatSpecifier, formatSpecifier) == 0; + } + + private: + const char* m_formatSpecifier; + FormatSpecifierValueResolver m_resolver; +}; +/// @brief Represents single configuration that has representing level, configuration type and a string based value. +/// +/// @detail String based value means any value either its boolean, integer or string itself, it will be embedded inside +/// quotes and will be parsed later. +/// +/// Consider some examples below: +/// * el::Configuration confEnabledInfo(el::Level::Info, el::ConfigurationType::Enabled, "true"); +/// * el::Configuration confMaxLogFileSizeInfo(el::Level::Info, el::ConfigurationType::MaxLogFileSize, "2048"); +/// * el::Configuration confFilenameInfo(el::Level::Info, el::ConfigurationType::Filename, "/var/log/my.log"); +class Configuration : public Loggable { + public: + Configuration(const Configuration& c); + Configuration& + operator=(const Configuration& c); + + virtual ~Configuration(void) { + } + + /// @brief Full constructor used to sets value of configuration + Configuration(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Gets level of current configuration + inline Level + level(void) const { + return m_level; + } + + /// @brief Gets configuration type of current configuration + inline ConfigurationType + configurationType(void) const { + return m_configurationType; + } + + /// @brief Gets string based configuration value + inline const std::string& + value(void) const { + return m_value; + } + + /// @brief Set string based configuration value + /// @param value Value to set. Values have to be std::string; For boolean values use "true", "false", for any + /// integral values + /// use them in quotes. They will be parsed when configuring + inline void + setValue(const std::string& value) { + m_value = value; + } + + virtual void + log(el::base::type::ostream_t& os) const; + + /// @brief Used to find configuration from configuration (pointers) repository. Avoid using it. + class Predicate { + public: + Predicate(Level level, ConfigurationType configurationType); + + bool + operator()(const Configuration* conf) const; + + private: + Level m_level; + ConfigurationType m_configurationType; + }; + + private: + Level m_level; + ConfigurationType m_configurationType; + std::string m_value; +}; + +/// @brief Thread-safe Configuration repository +/// +/// @detail This repository represents configurations for all the levels and configuration type mapped to a value. +class Configurations : public base::utils::RegistryWithPred { + public: + /// @brief Default constructor with empty repository + Configurations(void); + + /// @brief Constructor used to set configurations using configuration file. + /// @param configurationFile Full path to configuration file + /// @param useDefaultsForRemaining Lets you set the remaining configurations to default. + /// @param base If provided, this configuration will be based off existing repository that this argument is pointing + /// to. + /// @see parseFromFile(const std::string&, Configurations* base) + /// @see setRemainingToDefault() + Configurations(const std::string& configurationFile, bool useDefaultsForRemaining = true, + Configurations* base = nullptr); + + virtual ~Configurations(void) { + } + + /// @brief Parses configuration from file. + /// @param configurationFile Full path to configuration file + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration file. + /// @return True if successfully parsed, false otherwise. You may define 'ELPP_DEBUG_ASSERT_FAILURE' to make sure + /// you + /// do not proceed without successful parse. + bool + parseFromFile(const std::string& configurationFile, Configurations* base = nullptr); + + /// @brief Parse configurations from configuration string. + /// + /// @detail This configuration string has same syntax as configuration file contents. Make sure all the necessary + /// new line characters are provided. + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration text. + /// @return True if successfully parsed, false otherwise. You may define 'ELPP_DEBUG_ASSERT_FAILURE' to make sure + /// you + /// do not proceed without successful parse. + bool + parseFromText(const std::string& configurationsString, Configurations* base = nullptr); + + /// @brief Sets configuration based-off an existing configurations. + /// @param base Pointer to existing configurations. + void + setFromBase(Configurations* base); + + /// @brief Determines whether or not specified configuration type exists in the repository. + /// + /// @detail Returns as soon as first level is found. + /// @param configurationType Type of configuration to check existence for. + bool + hasConfiguration(ConfigurationType configurationType); + + /// @brief Determines whether or not specified configuration type exists for specified level + /// @param level Level to check + /// @param configurationType Type of configuration to check existence for. + bool + hasConfiguration(Level level, ConfigurationType configurationType); + + /// @brief Sets value of configuration for specified level. + /// + /// @detail Any existing configuration for specified level will be replaced. Also note that configuration types + /// ConfigurationType::SubsecondPrecision and ConfigurationType::PerformanceTracking will be ignored if not set for + /// Level::Global because these configurations are not dependant on level. + /// @param level Level to set configuration for (el::Level). + /// @param configurationType Type of configuration (el::ConfigurationType) + /// @param value A string based value. Regardless of what the data type of configuration is, it will always be + /// string from users' point of view. This is then parsed later to be used internally. + /// @see Configuration::setValue(const std::string& value) + /// @see el::Level + /// @see el::ConfigurationType + void + set(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Sets single configuration based on other single configuration. + /// @see set(Level level, ConfigurationType configurationType, const std::string& value) + void + set(Configuration* conf); + + inline Configuration* + get(Level level, ConfigurationType configurationType) { + base::threading::ScopedLock scopedLock(lock()); + return RegistryWithPred::get(level, configurationType); + } + + /// @brief Sets configuration for all levels. + /// @param configurationType Type of configuration + /// @param value String based value + /// @see Configurations::set(Level level, ConfigurationType configurationType, const std::string& value) + inline void + setGlobally(ConfigurationType configurationType, const std::string& value) { + setGlobally(configurationType, value, false); + } + + /// @brief Clears repository so that all the configurations are unset + inline void + clear(void) { + base::threading::ScopedLock scopedLock(lock()); + unregisterAll(); + } + + /// @brief Gets configuration file used in parsing this configurations. + /// + /// @detail If this repository was set manually or by text this returns empty string. + inline const std::string& + configurationFile(void) const { + return m_configurationFile; + } + + /// @brief Sets configurations to "factory based" configurations. + void + setToDefault(void); + + /// @brief Lets you set the remaining configurations to default. + /// + /// @detail By remaining, it means that the level/type a configuration does not exist for. + /// This function is useful when you want to minimize chances of failures, e.g, if you have a configuration file + /// that sets configuration for all the configurations except for Enabled or not, we use this so that ENABLED is set + /// to default i.e, true. If you dont do this explicitly (either by calling this function or by using second param + /// in Constructor and try to access a value, an error is thrown + void + setRemainingToDefault(void); + + /// @brief Parser used internally to parse configurations from file or text. + /// + /// @detail This class makes use of base::utils::Str. + /// You should not need this unless you are working on some tool for Easylogging++ + class Parser : base::StaticClass { + public: + /// @brief Parses configuration from file. + /// @param configurationFile Full path to configuration file + /// @param sender Sender configurations pointer. Usually 'this' is used from calling class + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration + /// file. + /// @return True if successfully parsed, false otherwise. You may define '_STOP_ON_FIRSTELPP_ASSERTION' to make + /// sure you + /// do not proceed without successful parse. + static bool + parseFromFile(const std::string& configurationFile, Configurations* sender, Configurations* base = nullptr); + + /// @brief Parse configurations from configuration string. + /// + /// @detail This configuration string has same syntax as configuration file contents. Make sure all the + /// necessary new line characters are provided. You may define '_STOP_ON_FIRSTELPP_ASSERTION' to make sure you + /// do not proceed without successful parse (This is recommended) + /// @param configurationsString the configuration in plain text format + /// @param sender Sender configurations pointer. Usually 'this' is used from calling class + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration + /// text. + /// @return True if successfully parsed, false otherwise. + static bool + parseFromText(const std::string& configurationsString, Configurations* sender, Configurations* base = nullptr); + + private: + friend class el::Loggers; + static void + ignoreComments(std::string* line); + static bool + isLevel(const std::string& line); + static bool + isComment(const std::string& line); + static inline bool + isConfig(const std::string& line); + static bool + parseLine(std::string* line, std::string* currConfigStr, std::string* currLevelStr, Level* currLevel, + Configurations* conf); + }; + + private: + std::string m_configurationFile; + bool m_isFromFile; + friend class el::Loggers; + + /// @brief Unsafely sets configuration if does not already exist + void + unsafeSetIfNotExist(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Thread unsafe set + void + unsafeSet(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Sets configurations for all levels including Level::Global if includeGlobalLevel is true + /// @see Configurations::setGlobally(ConfigurationType configurationType, const std::string& value) + void + setGlobally(ConfigurationType configurationType, const std::string& value, bool includeGlobalLevel); + + /// @brief Sets configurations (Unsafely) for all levels including Level::Global if includeGlobalLevel is true + /// @see Configurations::setGlobally(ConfigurationType configurationType, const std::string& value) + void + unsafeSetGlobally(ConfigurationType configurationType, const std::string& value, bool includeGlobalLevel); +}; + +namespace base { +typedef std::shared_ptr FileStreamPtr; +typedef std::unordered_map LogStreamsReferenceMap; +/// @brief Configurations with data types. +/// +/// @detail el::Configurations have string based values. This is whats used internally in order to read correct +/// configurations. This is to perform faster while writing logs using correct configurations. +/// +/// This is thread safe and final class containing non-virtual destructor (means nothing should inherit this class) +class TypedConfigurations : public base::threading::ThreadSafe { + public: + /// @brief Constructor to initialize (construct) the object off el::Configurations + /// @param configurations Configurations pointer/reference to base this typed configurations off. + /// @param logStreamsReference Use ELPP->registeredLoggers()->logStreamsReference() + TypedConfigurations(Configurations* configurations, base::LogStreamsReferenceMap* logStreamsReference); + + TypedConfigurations(const TypedConfigurations& other); + + virtual ~TypedConfigurations(void) { + } + + const Configurations* + configurations(void) const { + return m_configurations; + } + + bool + enabled(Level level); + bool + toFile(Level level); + const std::string& + filename(Level level); + bool + toStandardOutput(Level level); + const base::LogFormat& + logFormat(Level level); + const base::SubsecondPrecision& + subsecondPrecision(Level level = Level::Global); + const base::MillisecondsWidth& + millisecondsWidth(Level level = Level::Global); + bool + performanceTracking(Level level = Level::Global); + base::type::fstream_t* + fileStream(Level level); + std::size_t + maxLogFileSize(Level level); + std::size_t + logFlushThreshold(Level level); + + private: + Configurations* m_configurations; + std::unordered_map m_enabledMap; + std::unordered_map m_toFileMap; + std::unordered_map m_filenameMap; + std::unordered_map m_toStandardOutputMap; + std::unordered_map m_logFormatMap; + std::unordered_map m_subsecondPrecisionMap; + std::unordered_map m_performanceTrackingMap; + std::unordered_map m_fileStreamMap; + std::unordered_map m_maxLogFileSizeMap; + std::unordered_map m_logFlushThresholdMap; + base::LogStreamsReferenceMap* m_logStreamsReference; + + friend class el::Helpers; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::DefaultLogDispatchCallback; + friend class el::base::LogDispatcher; + + template + inline Conf_T + getConfigByVal(Level level, const std::unordered_map* confMap, const char* confName) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeGetConfigByVal(level, confMap, confName); // This is not unsafe anymore - mutex locked in scope + } + + template + inline Conf_T& + getConfigByRef(Level level, std::unordered_map* confMap, const char* confName) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeGetConfigByRef(level, confMap, confName); // This is not unsafe anymore - mutex locked in scope + } + + template + Conf_T + unsafeGetConfigByVal(Level level, const std::unordered_map* confMap, const char* confName) { + ELPP_UNUSED(confName); + typename std::unordered_map::const_iterator it = confMap->find(level); + if (it == confMap->end()) { + try { + return confMap->at(Level::Global); + } catch (...) { + ELPP_INTERNAL_ERROR("Unable to get configuration [" + << confName << "] for level [" << LevelHelper::convertToString(level) << "]" + << std::endl + << "Please ensure you have properly configured logger.", + false); + return Conf_T(); + } + } + return it->second; + } + + template + Conf_T& + unsafeGetConfigByRef(Level level, std::unordered_map* confMap, const char* confName) { + ELPP_UNUSED(confName); + typename std::unordered_map::iterator it = confMap->find(level); + if (it == confMap->end()) { + try { + return confMap->at(Level::Global); + } catch (...) { + ELPP_INTERNAL_ERROR("Unable to get configuration [" + << confName << "] for level [" << LevelHelper::convertToString(level) << "]" + << std::endl + << "Please ensure you have properly configured logger.", + false); + } + } + return it->second; + } + + template + void + setValue(Level level, const Conf_T& value, std::unordered_map* confMap, + bool includeGlobalLevel = true) { + // If map is empty and we are allowed to add into generic level (Level::Global), do it! + if (confMap->empty() && includeGlobalLevel) { + confMap->insert(std::make_pair(Level::Global, value)); + return; + } + // If same value exist in generic level already, dont add it to explicit level + typename std::unordered_map::iterator it = confMap->find(Level::Global); + if (it != confMap->end() && it->second == value) { + return; + } + // Now make sure we dont double up values if we really need to add it to explicit level + it = confMap->find(level); + if (it == confMap->end()) { + // Value not found for level, add new + confMap->insert(std::make_pair(level, value)); + } else { + // Value found, just update value + confMap->at(level) = value; + } + } + + void + build(Configurations* configurations); + unsigned long + getULong(std::string confVal); + std::string + resolveFilename(const std::string& filename); + void + insertFile(Level level, const std::string& fullFilename); + bool + unsafeValidateFileRolling(Level level, const PreRollOutCallback& preRollOutCallback); + + inline bool + validateFileRolling(Level level, const PreRollOutCallback& preRollOutCallback) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeValidateFileRolling(level, preRollOutCallback); + } +}; +/// @brief Class that keeps record of current line hit for occasional logging +class HitCounter { + public: + HitCounter(void) : m_filename(""), m_lineNumber(0), m_hitCounts(0) { + } + + HitCounter(const char* filename, base::type::LineNumber lineNumber) + : m_filename(filename), m_lineNumber(lineNumber), m_hitCounts(0) { + } + + HitCounter(const HitCounter& hitCounter) + : m_filename(hitCounter.m_filename), + m_lineNumber(hitCounter.m_lineNumber), + m_hitCounts(hitCounter.m_hitCounts) { + } + + HitCounter& + operator=(const HitCounter& hitCounter) { + if (&hitCounter != this) { + m_filename = hitCounter.m_filename; + m_lineNumber = hitCounter.m_lineNumber; + m_hitCounts = hitCounter.m_hitCounts; + } + return *this; + } + + virtual ~HitCounter(void) { + } + + /// @brief Resets location of current hit counter + inline void + resetLocation(const char* filename, base::type::LineNumber lineNumber) { + m_filename = filename; + m_lineNumber = lineNumber; + } + + /// @brief Validates hit counts and resets it if necessary + inline void + validateHitCounts(std::size_t n) { + if (m_hitCounts >= base::consts::kMaxLogPerCounter) { + m_hitCounts = (n >= 1 ? base::consts::kMaxLogPerCounter % n : 0); + } + ++m_hitCounts; + } + + inline const char* + filename(void) const { + return m_filename; + } + + inline base::type::LineNumber + lineNumber(void) const { + return m_lineNumber; + } + + inline std::size_t + hitCounts(void) const { + return m_hitCounts; + } + + inline void + increment(void) { + ++m_hitCounts; + } + + class Predicate { + public: + Predicate(const char* filename, base::type::LineNumber lineNumber) + : m_filename(filename), m_lineNumber(lineNumber) { + } + inline bool + operator()(const HitCounter* counter) { + return ((counter != nullptr) && (strcmp(counter->m_filename, m_filename) == 0) && + (counter->m_lineNumber == m_lineNumber)); + } + + private: + const char* m_filename; + base::type::LineNumber m_lineNumber; + }; + + private: + const char* m_filename; + base::type::LineNumber m_lineNumber; + std::size_t m_hitCounts; +}; +/// @brief Repository for hit counters used across the application +class RegisteredHitCounters : public base::utils::RegistryWithPred { + public: + /// @brief Validates counter for every N, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool + validateEveryN(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Validates counter for hits >= N, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool + validateAfterN(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Validates counter for hits are <= n, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool + validateNTimes(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Gets hit counter registered at specified position + inline const base::HitCounter* + getCounter(const char* filename, base::type::LineNumber lineNumber) { + base::threading::ScopedLock scopedLock(lock()); + return get(filename, lineNumber); + } +}; +/// @brief Action to be taken for dispatching +enum class DispatchAction : base::type::EnumType { None = 1, NormalLog = 2, SysLog = 4 }; +} // namespace base +template +class Callback : protected base::threading::ThreadSafe { + public: + Callback(void) : m_enabled(true) { + } + inline bool + enabled(void) const { + return m_enabled; + } + inline void + setEnabled(bool enabled) { + base::threading::ScopedLock scopedLock(lock()); + m_enabled = enabled; + } + + protected: + virtual void + handle(const T* handlePtr) = 0; + + private: + bool m_enabled; +}; +class LogDispatchData { + public: + LogDispatchData() : m_logMessage(nullptr), m_dispatchAction(base::DispatchAction::None) { + } + inline const LogMessage* + logMessage(void) const { + return m_logMessage; + } + inline base::DispatchAction + dispatchAction(void) const { + return m_dispatchAction; + } + inline void + setLogMessage(LogMessage* logMessage) { + m_logMessage = logMessage; + } + inline void + setDispatchAction(base::DispatchAction dispatchAction) { + m_dispatchAction = dispatchAction; + } + + private: + LogMessage* m_logMessage; + base::DispatchAction m_dispatchAction; + friend class base::LogDispatcher; +}; +class LogDispatchCallback : public Callback { + protected: + virtual void + handle(const LogDispatchData* data); + base::threading::Mutex& + fileHandle(const LogDispatchData* data); + + private: + friend class base::LogDispatcher; + std::unordered_map> m_fileLocks; + base::threading::Mutex m_fileLocksMapLock; +}; +class PerformanceTrackingCallback : public Callback { + private: + friend class base::PerformanceTracker; +}; +class LoggerRegistrationCallback : public Callback { + private: + friend class base::RegisteredLoggers; +}; +class LogBuilder : base::NoCopy { + public: + LogBuilder() : m_termSupportsColor(base::utils::OS::termSupportsColor()) { + } + virtual ~LogBuilder(void) { + ELPP_INTERNAL_INFO(3, "Destroying log builder...") + } + virtual base::type::string_t + build(const LogMessage* logMessage, bool appendNewLine) const = 0; + void + convertToColoredOutput(base::type::string_t* logLine, Level level); + + private: + bool m_termSupportsColor; + friend class el::base::DefaultLogDispatchCallback; +}; +typedef std::shared_ptr LogBuilderPtr; +/// @brief Represents a logger holding ID and configurations we need to write logs +/// +/// @detail This class does not write logs itself instead its used by writer to read configuations from. +class Logger : public base::threading::ThreadSafe, public Loggable { + public: + Logger(const std::string& id, base::LogStreamsReferenceMap* logStreamsReference); + Logger(const std::string& id, const Configurations& configurations, + base::LogStreamsReferenceMap* logStreamsReference); + Logger(const Logger& logger); + Logger& + operator=(const Logger& logger); + + virtual ~Logger(void) { + base::utils::safeDelete(m_typedConfigurations); + } + + virtual inline void + log(el::base::type::ostream_t& os) const { + os << m_id.c_str(); + } + + /// @brief Configures the logger using specified configurations. + void + configure(const Configurations& configurations); + + /// @brief Reconfigures logger using existing configurations + void + reconfigure(void); + + inline const std::string& + id(void) const { + return m_id; + } + + inline const std::string& + parentApplicationName(void) const { + return m_parentApplicationName; + } + + inline void + setParentApplicationName(const std::string& parentApplicationName) { + m_parentApplicationName = parentApplicationName; + } + + inline Configurations* + configurations(void) { + return &m_configurations; + } + + inline base::TypedConfigurations* + typedConfigurations(void) { + return m_typedConfigurations; + } + + static bool + isValidId(const std::string& id); + + /// @brief Flushes logger to sync all log files for all levels + void + flush(void); + + void + flush(Level level, base::type::fstream_t* fs); + + inline bool + isFlushNeeded(Level level) { + return ++m_unflushedCount.find(level)->second >= m_typedConfigurations->logFlushThreshold(level); + } + + inline LogBuilder* + logBuilder(void) const { + return m_logBuilder.get(); + } + + inline void + setLogBuilder(const LogBuilderPtr& logBuilder) { + m_logBuilder = logBuilder; + } + + inline bool + enabled(Level level) const { + return m_typedConfigurations->enabled(level); + } + +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED +#define LOGGER_LEVEL_WRITERS_SIGNATURES(FUNCTION_NAME) \ + template \ + inline void FUNCTION_NAME(const char*, const T&, const Args&...); \ + template \ + inline void FUNCTION_NAME(const T&); + + template + inline void + verbose(int, const char*, const T&, const Args&...); + + template + inline void + verbose(int, const T&); + + LOGGER_LEVEL_WRITERS_SIGNATURES(info) + LOGGER_LEVEL_WRITERS_SIGNATURES(debug) + LOGGER_LEVEL_WRITERS_SIGNATURES(warn) + LOGGER_LEVEL_WRITERS_SIGNATURES(error) + LOGGER_LEVEL_WRITERS_SIGNATURES(fatal) + LOGGER_LEVEL_WRITERS_SIGNATURES(trace) +#undef LOGGER_LEVEL_WRITERS_SIGNATURES +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED + private: + std::string m_id; + base::TypedConfigurations* m_typedConfigurations; + base::type::stringstream_t m_stream; + std::string m_parentApplicationName; + bool m_isConfigured; + Configurations m_configurations; + std::unordered_map m_unflushedCount; + base::LogStreamsReferenceMap* m_logStreamsReference; + LogBuilderPtr m_logBuilder; + + friend class el::LogMessage; + friend class el::Loggers; + friend class el::Helpers; + friend class el::base::RegisteredLoggers; + friend class el::base::DefaultLogDispatchCallback; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::PErrorWriter; + friend class el::base::Storage; + friend class el::base::PerformanceTracker; + friend class el::base::LogDispatcher; + + Logger(void); + +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED + template + void + log_(Level, int, const char*, const T&, const Args&...); + + template + inline void + log_(Level, int, const T&); + + template + void + log(Level, const char*, const T&, const Args&...); + + template + inline void + log(Level, const T&); +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED + + void + initUnflushedCount(void); + + inline base::type::stringstream_t& + stream(void) { + return m_stream; + } + + void + resolveLoggerFormatSpec(void) const; +}; +namespace base { +/// @brief Loggers repository +class RegisteredLoggers : public base::utils::Registry { + public: + explicit RegisteredLoggers(const LogBuilderPtr& defaultLogBuilder); + + virtual ~RegisteredLoggers(void) { + unsafeFlushAll(); + } + + inline void + setDefaultConfigurations(const Configurations& configurations) { + base::threading::ScopedLock scopedLock(lock()); + m_defaultConfigurations.setFromBase(const_cast(&configurations)); + } + + inline Configurations* + defaultConfigurations(void) { + return &m_defaultConfigurations; + } + + Logger* + get(const std::string& id, bool forceCreation = true); + + template + inline bool + installLoggerRegistrationCallback(const std::string& id) { + return base::utils::Utils::installCallback( + id, &m_loggerRegistrationCallbacks); + } + + template + inline void + uninstallLoggerRegistrationCallback(const std::string& id) { + base::utils::Utils::uninstallCallback( + id, &m_loggerRegistrationCallbacks); + } + + template + inline T* + loggerRegistrationCallback(const std::string& id) { + return base::utils::Utils::callback( + id, &m_loggerRegistrationCallbacks); + } + + bool + remove(const std::string& id); + + inline bool + has(const std::string& id) { + return get(id, false) != nullptr; + } + + inline void + unregister(Logger*& logger) { + base::threading::ScopedLock scopedLock(lock()); + base::utils::Registry::unregister(logger->id()); + } + + inline base::LogStreamsReferenceMap* + logStreamsReference(void) { + return &m_logStreamsReference; + } + + inline void + flushAll(void) { + base::threading::ScopedLock scopedLock(lock()); + unsafeFlushAll(); + } + + inline void + setDefaultLogBuilder(LogBuilderPtr& logBuilderPtr) { + base::threading::ScopedLock scopedLock(lock()); + m_defaultLogBuilder = logBuilderPtr; + } + + private: + LogBuilderPtr m_defaultLogBuilder; + Configurations m_defaultConfigurations; + base::LogStreamsReferenceMap m_logStreamsReference; + std::unordered_map m_loggerRegistrationCallbacks; + friend class el::base::Storage; + + void + unsafeFlushAll(void); +}; +/// @brief Represents registries for verbose logging +class VRegistry : base::NoCopy, public base::threading::ThreadSafe { + public: + explicit VRegistry(base::type::VerboseLevel level, base::type::EnumType* pFlags); + + /// @brief Sets verbose level. Accepted range is 0-9 + void + setLevel(base::type::VerboseLevel level); + + inline base::type::VerboseLevel + level(void) const { + return m_level; + } + + inline void + clearModules(void) { + base::threading::ScopedLock scopedLock(lock()); + m_modules.clear(); + } + + void + setModules(const char* modules); + + bool + allowed(base::type::VerboseLevel vlevel, const char* file); + + inline const std::unordered_map& + modules(void) const { + return m_modules; + } + + void + setFromArgs(const base::utils::CommandLineArgs* commandLineArgs); + + /// @brief Whether or not vModules enabled + inline bool + vModulesEnabled(void) { + return !base::utils::hasFlag(LoggingFlag::DisableVModules, *m_pFlags); + } + + private: + base::type::VerboseLevel m_level; + base::type::EnumType* m_pFlags; + std::unordered_map m_modules; +}; +} // namespace base +class LogMessage { + public: + LogMessage(Level level, const std::string& file, base::type::LineNumber line, const std::string& func, + base::type::VerboseLevel verboseLevel, Logger* logger) + : m_level(level), + m_file(file), + m_line(line), + m_func(func), + m_verboseLevel(verboseLevel), + m_logger(logger), + m_message(logger->stream().str()) { + } + inline Level + level(void) const { + return m_level; + } + inline const std::string& + file(void) const { + return m_file; + } + inline base::type::LineNumber + line(void) const { + return m_line; + } + inline const std::string& + func(void) const { + return m_func; + } + inline base::type::VerboseLevel + verboseLevel(void) const { + return m_verboseLevel; + } + inline Logger* + logger(void) const { + return m_logger; + } + inline const base::type::string_t& + message(void) const { + return m_message; + } + + private: + Level m_level; + std::string m_file; + base::type::LineNumber m_line; + std::string m_func; + base::type::VerboseLevel m_verboseLevel; + Logger* m_logger; + base::type::string_t m_message; +}; +namespace base { +#if ELPP_ASYNC_LOGGING +class AsyncLogItem { + public: + explicit AsyncLogItem(const LogMessage& logMessage, const LogDispatchData& data, + const base::type::string_t& logLine) + : m_logMessage(logMessage), m_dispatchData(data), m_logLine(logLine) { + } + virtual ~AsyncLogItem() { + } + inline LogMessage* + logMessage(void) { + return &m_logMessage; + } + inline LogDispatchData* + data(void) { + return &m_dispatchData; + } + inline base::type::string_t + logLine(void) { + return m_logLine; + } + + private: + LogMessage m_logMessage; + LogDispatchData m_dispatchData; + base::type::string_t m_logLine; +}; +class AsyncLogQueue : public base::threading::ThreadSafe { + public: + virtual ~AsyncLogQueue() { + ELPP_INTERNAL_INFO(6, "~AsyncLogQueue"); + } + + inline AsyncLogItem + next(void) { + base::threading::ScopedLock scopedLock(lock()); + AsyncLogItem result = m_queue.front(); + m_queue.pop(); + return result; + } + + inline void + push(const AsyncLogItem& item) { + base::threading::ScopedLock scopedLock(lock()); + m_queue.push(item); + } + inline void + pop(void) { + base::threading::ScopedLock scopedLock(lock()); + m_queue.pop(); + } + inline AsyncLogItem + front(void) { + base::threading::ScopedLock scopedLock(lock()); + return m_queue.front(); + } + inline bool + empty(void) { + base::threading::ScopedLock scopedLock(lock()); + return m_queue.empty(); + } + + private: + std::queue m_queue; +}; +class IWorker { + public: + virtual ~IWorker() { + } + virtual void + start() = 0; +}; +#endif // ELPP_ASYNC_LOGGING +/// @brief Easylogging++ management storage +class Storage : base::NoCopy, public base::threading::ThreadSafe { + public: +#if ELPP_ASYNC_LOGGING + Storage(const LogBuilderPtr& defaultLogBuilder, base::IWorker* asyncDispatchWorker); +#else + explicit Storage(const LogBuilderPtr& defaultLogBuilder); +#endif // ELPP_ASYNC_LOGGING + + virtual ~Storage(void); + + inline bool + validateEveryNCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t occasion) { + return hitCounters()->validateEveryN(filename, lineNumber, occasion); + } + + inline bool + validateAfterNCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + return hitCounters()->validateAfterN(filename, lineNumber, n); + } + + inline bool + validateNTimesCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + return hitCounters()->validateNTimes(filename, lineNumber, n); + } + + inline base::RegisteredHitCounters* + hitCounters(void) const { + return m_registeredHitCounters; + } + + inline base::RegisteredLoggers* + registeredLoggers(void) const { + return m_registeredLoggers; + } + + inline base::VRegistry* + vRegistry(void) const { + return m_vRegistry; + } + +#if ELPP_ASYNC_LOGGING + inline base::AsyncLogQueue* + asyncLogQueue(void) const { + return m_asyncLogQueue; + } +#endif // ELPP_ASYNC_LOGGING + + inline const base::utils::CommandLineArgs* + commandLineArgs(void) const { + return &m_commandLineArgs; + } + + inline void + addFlag(LoggingFlag flag) { + base::utils::addFlag(flag, &m_flags); + } + + inline void + removeFlag(LoggingFlag flag) { + base::utils::removeFlag(flag, &m_flags); + } + + inline bool + hasFlag(LoggingFlag flag) const { + return base::utils::hasFlag(flag, m_flags); + } + + inline base::type::EnumType + flags(void) const { + return m_flags; + } + + inline void + setFlags(base::type::EnumType flags) { + m_flags = flags; + } + + inline void + setPreRollOutCallback(const PreRollOutCallback& callback) { + m_preRollOutCallback = callback; + } + + inline void + unsetPreRollOutCallback(void) { + m_preRollOutCallback = base::defaultPreRollOutCallback; + } + + inline PreRollOutCallback& + preRollOutCallback(void) { + return m_preRollOutCallback; + } + + bool + hasCustomFormatSpecifier(const char* formatSpecifier); + void + installCustomFormatSpecifier(const CustomFormatSpecifier& customFormatSpecifier); + bool + uninstallCustomFormatSpecifier(const char* formatSpecifier); + + const std::vector* + customFormatSpecifiers(void) const { + return &m_customFormatSpecifiers; + } + + base::threading::Mutex& + customFormatSpecifiersLock() { + return m_customFormatSpecifiersLock; + } + + inline void + setLoggingLevel(Level level) { + m_loggingLevel = level; + } + + template + inline bool + installLogDispatchCallback(const std::string& id) { + return base::utils::Utils::installCallback(id, &m_logDispatchCallbacks); + } + + template + inline void + uninstallLogDispatchCallback(const std::string& id) { + base::utils::Utils::uninstallCallback(id, &m_logDispatchCallbacks); + } + template + inline T* + logDispatchCallback(const std::string& id) { + return base::utils::Utils::callback(id, &m_logDispatchCallbacks); + } + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + template + inline bool + installPerformanceTrackingCallback(const std::string& id) { + return base::utils::Utils::installCallback( + id, &m_performanceTrackingCallbacks); + } + + template + inline void + uninstallPerformanceTrackingCallback(const std::string& id) { + base::utils::Utils::uninstallCallback( + id, &m_performanceTrackingCallbacks); + } + + template + inline T* + performanceTrackingCallback(const std::string& id) { + return base::utils::Utils::callback( + id, &m_performanceTrackingCallbacks); + } +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + + /// @brief Sets thread name for current thread. Requires std::thread + inline void + setThreadName(const std::string& name) { + if (name.empty()) + return; + base::threading::ScopedLock scopedLock(m_threadNamesLock); + m_threadNames[base::threading::getCurrentThreadId()] = name; + } + + inline std::string + getThreadName(const std::string& threadId) { + base::threading::ScopedLock scopedLock(m_threadNamesLock); + std::unordered_map::const_iterator it = m_threadNames.find(threadId); + if (it == m_threadNames.end()) { + return threadId; + } + return it->second; + } + + private: + base::RegisteredHitCounters* m_registeredHitCounters; + base::RegisteredLoggers* m_registeredLoggers; + base::type::EnumType m_flags; + base::VRegistry* m_vRegistry; +#if ELPP_ASYNC_LOGGING + base::AsyncLogQueue* m_asyncLogQueue; + base::IWorker* m_asyncDispatchWorker; +#endif // ELPP_ASYNC_LOGGING + base::utils::CommandLineArgs m_commandLineArgs; + PreRollOutCallback m_preRollOutCallback; + std::unordered_map m_logDispatchCallbacks; + std::unordered_map m_performanceTrackingCallbacks; + std::unordered_map m_threadNames; + std::vector m_customFormatSpecifiers; + base::threading::Mutex m_customFormatSpecifiersLock; + base::threading::Mutex m_threadNamesLock; + Level m_loggingLevel; + + friend class el::Helpers; + friend class el::base::DefaultLogDispatchCallback; + friend class el::LogBuilder; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::PerformanceTracker; + friend class el::base::LogDispatcher; + + void + setApplicationArguments(int argc, char** argv); + + inline void + setApplicationArguments(int argc, const char** argv) { + setApplicationArguments(argc, const_cast(argv)); + } +}; +extern ELPP_EXPORT base::type::StoragePointer elStorage; +#define ELPP el::base::elStorage +class DefaultLogDispatchCallback : public LogDispatchCallback { + protected: + void + handle(const LogDispatchData* data); + + private: + const LogDispatchData* m_data; + void + dispatch(base::type::string_t&& logLine); +}; +#if ELPP_ASYNC_LOGGING +class AsyncLogDispatchCallback : public LogDispatchCallback { + protected: + void + handle(const LogDispatchData* data); +}; +class AsyncDispatchWorker : public base::IWorker, public base::threading::ThreadSafe { + public: + AsyncDispatchWorker(); + virtual ~AsyncDispatchWorker(); + + bool + clean(void); + void + emptyQueue(void); + virtual void + start(void); + void + handle(AsyncLogItem* logItem); + void + run(void); + + void + setContinueRunning(bool value) { + base::threading::ScopedLock scopedLock(m_continueRunningLock); + m_continueRunning = value; + } + + bool + continueRunning(void) const { + return m_continueRunning; + } + + private: + std::condition_variable cv; + bool m_continueRunning; + base::threading::Mutex m_continueRunningLock; +}; +#endif // ELPP_ASYNC_LOGGING +} // namespace base +namespace base { +class DefaultLogBuilder : public LogBuilder { + public: + base::type::string_t + build(const LogMessage* logMessage, bool appendNewLine) const; +}; +/// @brief Dispatches log messages +class LogDispatcher : base::NoCopy { + public: + LogDispatcher(bool proceed, LogMessage* logMessage, base::DispatchAction dispatchAction) + : m_proceed(proceed), m_logMessage(logMessage), m_dispatchAction(std::move(dispatchAction)) { + } + + void + dispatch(void); + + private: + bool m_proceed; + LogMessage* m_logMessage; + base::DispatchAction m_dispatchAction; +}; +#if defined(ELPP_STL_LOGGING) +/// @brief Workarounds to write some STL logs +/// +/// @detail There is workaround needed to loop through some stl containers. In order to do that, we need iterable +/// containers of same type and provide iterator interface and pass it on to writeIterator(). Remember, this is passed +/// by value in constructor so that we dont change original containers. This operation is as expensive as +/// Big-O(std::min(class_.size(), base::consts::kMaxLogPerContainer)) +namespace workarounds { +/// @brief Abstract IterableContainer template that provides interface for iterable classes of type T +template +class IterableContainer { + public: + typedef typename Container::iterator iterator; + typedef typename Container::const_iterator const_iterator; + IterableContainer(void) { + } + virtual ~IterableContainer(void) { + } + iterator + begin(void) { + return getContainer().begin(); + } + iterator + end(void) { + return getContainer().end(); + } + + private: + virtual Container& + getContainer(void) = 0; +}; +/// @brief Implements IterableContainer and provides iterable std::priority_queue class +template , + typename Comparator = std::less> +class IterablePriorityQueue : public IterableContainer, + public std::priority_queue { + public: + IterablePriorityQueue(std::priority_queue queue_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !queue_.empty()) { + this->push(queue_.top()); + queue_.pop(); + } + } + + private: + inline Container& + getContainer(void) { + return this->c; + } +}; +/// @brief Implements IterableContainer and provides iterable std::queue class +template > +class IterableQueue : public IterableContainer, public std::queue { + public: + IterableQueue(std::queue queue_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !queue_.empty()) { + this->push(queue_.front()); + queue_.pop(); + } + } + + private: + inline Container& + getContainer(void) { + return this->c; + } +}; +/// @brief Implements IterableContainer and provides iterable std::stack class +template > +class IterableStack : public IterableContainer, public std::stack { + public: + IterableStack(std::stack stack_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !stack_.empty()) { + this->push(stack_.top()); + stack_.pop(); + } + } + + private: + inline Container& + getContainer(void) { + return this->c; + } +}; +} // namespace workarounds +#endif // defined(ELPP_STL_LOGGING) +// Log message builder +class MessageBuilder { + public: + MessageBuilder(void) : m_logger(nullptr), m_containerLogSeperator(ELPP_LITERAL("")) { + } + void + initialize(Logger* logger); + +#define ELPP_SIMPLE_LOG(LOG_TYPE) \ + MessageBuilder& operator<<(LOG_TYPE msg) { \ + m_logger->stream() << msg; \ + if (ELPP->hasFlag(LoggingFlag::AutoSpacing)) { \ + m_logger->stream() << " "; \ + } \ + return *this; \ + } + + inline MessageBuilder& + operator<<(const std::string& msg) { + return operator<<(msg.c_str()); + } + ELPP_SIMPLE_LOG(char) + ELPP_SIMPLE_LOG(bool) + ELPP_SIMPLE_LOG(signed short) + ELPP_SIMPLE_LOG(unsigned short) + ELPP_SIMPLE_LOG(signed int) + ELPP_SIMPLE_LOG(unsigned int) + ELPP_SIMPLE_LOG(signed long) + ELPP_SIMPLE_LOG(unsigned long) + ELPP_SIMPLE_LOG(float) + ELPP_SIMPLE_LOG(double) + ELPP_SIMPLE_LOG(char*) + ELPP_SIMPLE_LOG(const char*) + ELPP_SIMPLE_LOG(const void*) + ELPP_SIMPLE_LOG(long double) + inline MessageBuilder& + operator<<(const std::wstring& msg) { + return operator<<(msg.c_str()); + } + MessageBuilder& + operator<<(const wchar_t* msg); + // ostream manipulators + inline MessageBuilder& + operator<<(std::ostream& (*OStreamMani)(std::ostream&)) { + m_logger->stream() << OStreamMani; + return *this; + } +#define ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(temp) \ + template \ + inline MessageBuilder& operator<<(const temp& template_inst) { \ + return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ + } +#define ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(temp) \ + template \ + inline MessageBuilder& operator<<(const temp& template_inst) { \ + return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ + } +#define ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(temp) \ + template \ + inline MessageBuilder& operator<<(const temp& template_inst) { \ + return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ + } +#define ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(temp) \ + template \ + inline MessageBuilder& operator<<(const temp& template_inst) { \ + return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ + } +#define ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(temp) \ + template \ + inline MessageBuilder& operator<<(const temp& template_inst) { \ + return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ + } + +#if defined(ELPP_STL_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::list) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::deque) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(std::set) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(std::multiset) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::map) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::multimap) + template + inline MessageBuilder& + operator<<(const std::queue& queue_) { + base::workarounds::IterableQueue iterableQueue_ = + static_cast>(queue_); + return writeIterator(iterableQueue_.begin(), iterableQueue_.end(), iterableQueue_.size()); + } + template + inline MessageBuilder& + operator<<(const std::stack& stack_) { + base::workarounds::IterableStack iterableStack_ = + static_cast>(stack_); + return writeIterator(iterableStack_.begin(), iterableStack_.end(), iterableStack_.size()); + } + template + inline MessageBuilder& + operator<<(const std::priority_queue& priorityQueue_) { + base::workarounds::IterablePriorityQueue iterablePriorityQueue_ = + static_cast>(priorityQueue_); + return writeIterator(iterablePriorityQueue_.begin(), iterablePriorityQueue_.end(), + iterablePriorityQueue_.size()); + } + template + MessageBuilder& + operator<<(const std::pair& pair_) { + m_logger->stream() << ELPP_LITERAL("("); + operator<<(static_cast(pair_.first)); + m_logger->stream() << ELPP_LITERAL(", "); + operator<<(static_cast(pair_.second)); + m_logger->stream() << ELPP_LITERAL(")"); + return *this; + } + template + MessageBuilder& + operator<<(const std::bitset& bitset_) { + m_logger->stream() << ELPP_LITERAL("["); + operator<<(bitset_.to_string()); + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } +#if defined(ELPP_LOG_STD_ARRAY) + template + inline MessageBuilder& + operator<<(const std::array& array) { + return writeIterator(array.begin(), array.end(), array.size()); + } +#endif // defined(ELPP_LOG_STD_ARRAY) +#if defined(ELPP_LOG_UNORDERED_MAP) + ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(std::unordered_map) + ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(std::unordered_multimap) +#endif // defined(ELPP_LOG_UNORDERED_MAP) +#if defined(ELPP_LOG_UNORDERED_SET) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::unordered_set) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::unordered_multiset) +#endif // defined(ELPP_LOG_UNORDERED_SET) +#endif // defined(ELPP_STL_LOGGING) +#if defined(ELPP_QT_LOGGING) + inline MessageBuilder& + operator<<(const QString& msg) { +#if defined(ELPP_UNICODE) + m_logger->stream() << msg.toStdWString(); +#else + m_logger->stream() << msg.toStdString(); +#endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& + operator<<(const QByteArray& msg) { + return operator<<(QString(msg)); + } + inline MessageBuilder& + operator<<(const QStringRef& msg) { + return operator<<(msg.toString()); + } + inline MessageBuilder& + operator<<(qint64 msg) { +#if defined(ELPP_UNICODE) + m_logger->stream() << QString::number(msg).toStdWString(); +#else + m_logger->stream() << QString::number(msg).toStdString(); +#endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& + operator<<(quint64 msg) { +#if defined(ELPP_UNICODE) + m_logger->stream() << QString::number(msg).toStdWString(); +#else + m_logger->stream() << QString::number(msg).toStdString(); +#endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& + operator<<(QChar msg) { + m_logger->stream() << msg.toLatin1(); + return *this; + } + inline MessageBuilder& + operator<<(const QLatin1String& msg) { + m_logger->stream() << msg.latin1(); + return *this; + } + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QList) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QVector) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QQueue) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QSet) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QLinkedList) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QStack) + template + MessageBuilder& + operator<<(const QPair& pair_) { + m_logger->stream() << ELPP_LITERAL("("); + operator<<(static_cast(pair_.first)); + m_logger->stream() << ELPP_LITERAL(", "); + operator<<(static_cast(pair_.second)); + m_logger->stream() << ELPP_LITERAL(")"); + return *this; + } + template + MessageBuilder& + operator<<(const QMap& map_) { + m_logger->stream() << ELPP_LITERAL("["); + QList keys = map_.keys(); + typename QList::const_iterator begin = keys.begin(); + typename QList::const_iterator end = keys.end(); + int max_ = static_cast(base::consts::kMaxLogPerContainer); // to prevent warning + for (int index_ = 0; begin != end && index_ < max_; ++index_, ++begin) { + m_logger->stream() << ELPP_LITERAL("("); + operator<<(static_cast(*begin)); + m_logger->stream() << ELPP_LITERAL(", "); + operator<<(static_cast(map_.value(*begin))); + m_logger->stream() << ELPP_LITERAL(")"); + m_logger->stream() << ((index_ < keys.size() - 1) ? m_containerLogSeperator : ELPP_LITERAL("")); + } + if (begin != end) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } + template + inline MessageBuilder& + operator<<(const QMultiMap& map_) { + operator<<(static_cast>(map_)); + return *this; + } + template + MessageBuilder& + operator<<(const QHash& hash_) { + m_logger->stream() << ELPP_LITERAL("["); + QList keys = hash_.keys(); + typename QList::const_iterator begin = keys.begin(); + typename QList::const_iterator end = keys.end(); + int max_ = static_cast(base::consts::kMaxLogPerContainer); // prevent type warning + for (int index_ = 0; begin != end && index_ < max_; ++index_, ++begin) { + m_logger->stream() << ELPP_LITERAL("("); + operator<<(static_cast(*begin)); + m_logger->stream() << ELPP_LITERAL(", "); + operator<<(static_cast(hash_.value(*begin))); + m_logger->stream() << ELPP_LITERAL(")"); + m_logger->stream() << ((index_ < keys.size() - 1) ? m_containerLogSeperator : ELPP_LITERAL("")); + } + if (begin != end) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } + template + inline MessageBuilder& + operator<<(const QMultiHash& multiHash_) { + operator<<(static_cast>(multiHash_)); + return *this; + } +#endif // defined(ELPP_QT_LOGGING) +#if defined(ELPP_BOOST_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::stable_vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::list) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::deque) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(boost::container::map) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(boost::container::flat_map) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(boost::container::set) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(boost::container::flat_set) +#endif // defined(ELPP_BOOST_LOGGING) + + /// @brief Macro used internally that can be used externally to make containers easylogging++ friendly + /// + /// @detail This macro expands to write an ostream& operator<< for container. This container is expected to + /// have begin() and end() methods that return respective iterators + /// @param ContainerType Type of container e.g, MyList from WX_DECLARE_LIST(int, MyList); in wxwidgets + /// @param SizeMethod Method used to get size of container. + /// @param ElementInstance Instance of element to be fed out. Insance name is "elem". See WXELPP_ENABLED macro + /// for an example usage +#define MAKE_CONTAINERELPP_FRIENDLY(ContainerType, SizeMethod, ElementInstance) \ + el::base::type::ostream_t& operator<<(el::base::type::ostream_t& ss, const ContainerType& container) { \ + const el::base::type::char_t* sep = \ + ELPP->hasFlag(el::LoggingFlag::NewLineForContainer) ? ELPP_LITERAL("\n ") : ELPP_LITERAL(", "); \ + ContainerType::const_iterator elem = container.begin(); \ + ContainerType::const_iterator endElem = container.end(); \ + std::size_t size_ = container.SizeMethod; \ + ss << ELPP_LITERAL("["); \ + for (std::size_t i = 0; elem != endElem && i < el::base::consts::kMaxLogPerContainer; ++i, ++elem) { \ + ss << ElementInstance; \ + ss << ((i < size_ - 1) ? sep : ELPP_LITERAL("")); \ + } \ + if (elem != endElem) { \ + ss << ELPP_LITERAL("..."); \ + } \ + ss << ELPP_LITERAL("]"); \ + return ss; \ + } +#if defined(ELPP_WXWIDGETS_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(wxVector) +#define ELPP_WX_PTR_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), *(*elem)) +#define ELPP_WX_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), (*elem)) +#define ELPP_WX_HASH_MAP_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), \ +ELPP_LITERAL("(") << elem->first << ELPP_LITERAL(", ") << elem->second << ELPP_LITERAL(")") +#else +#define ELPP_WX_PTR_ENABLED(ContainerType) +#define ELPP_WX_ENABLED(ContainerType) +#define ELPP_WX_HASH_MAP_ENABLED(ContainerType) +#endif // defined(ELPP_WXWIDGETS_LOGGING) + // Other classes + template + ELPP_SIMPLE_LOG(const Class&) +#undef ELPP_SIMPLE_LOG +#undef ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG + private : Logger* m_logger; + const base::type::char_t* m_containerLogSeperator; + + template + MessageBuilder& + writeIterator(Iterator begin_, Iterator end_, std::size_t size_) { + m_logger->stream() << ELPP_LITERAL("["); + for (std::size_t i = 0; begin_ != end_ && i < base::consts::kMaxLogPerContainer; ++i, ++begin_) { + operator<<(*begin_); + m_logger->stream() << ((i < size_ - 1) ? m_containerLogSeperator : ELPP_LITERAL("")); + } + if (begin_ != end_) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + if (ELPP->hasFlag(LoggingFlag::AutoSpacing)) { + m_logger->stream() << " "; + } + return *this; + } +}; +/// @brief Writes nothing - Used when certain log is disabled +class NullWriter : base::NoCopy { + public: + NullWriter(void) { + } + + // Null manipulator + inline NullWriter& + operator<<(std::ostream& (*)(std::ostream&)) { + return *this; + } + + template + inline NullWriter& + operator<<(const T&) { + return *this; + } + + inline operator bool() { + return true; + } +}; +/// @brief Main entry point of each logging +class Writer : base::NoCopy { + public: + Writer(Level level, const char* file, base::type::LineNumber line, const char* func, + base::DispatchAction dispatchAction = base::DispatchAction::NormalLog, + base::type::VerboseLevel verboseLevel = 0) + : m_msg(nullptr), + m_level(level), + m_file(file), + m_line(line), + m_func(func), + m_verboseLevel(verboseLevel), + m_logger(nullptr), + m_proceed(false), + m_dispatchAction(dispatchAction) { + } + + Writer(LogMessage* msg, base::DispatchAction dispatchAction = base::DispatchAction::NormalLog) + : m_msg(msg), + m_level(msg != nullptr ? msg->level() : Level::Unknown), + m_line(0), + m_logger(nullptr), + m_proceed(false), + m_dispatchAction(dispatchAction) { + } + + virtual ~Writer(void) { + processDispatch(); + } + + template + inline Writer& + operator<<(const T& log) { +#if ELPP_LOGGING_ENABLED + if (m_proceed) { + m_messageBuilder << log; + } +#endif // ELPP_LOGGING_ENABLED + return *this; + } + + inline Writer& + operator<<(std::ostream& (*log)(std::ostream&)) { +#if ELPP_LOGGING_ENABLED + if (m_proceed) { + m_messageBuilder << log; + } +#endif // ELPP_LOGGING_ENABLED + return *this; + } + + inline operator bool() { + return true; + } + + Writer& + construct(Logger* logger, bool needLock = true); + Writer& + construct(int count, const char* loggerIds, ...); + + protected: + LogMessage* m_msg; + Level m_level; + const char* m_file; + const base::type::LineNumber m_line; + const char* m_func; + base::type::VerboseLevel m_verboseLevel; + Logger* m_logger; + bool m_proceed; + base::MessageBuilder m_messageBuilder; + base::DispatchAction m_dispatchAction; + std::vector m_loggerIds; + friend class el::Helpers; + + void + initializeLogger(const std::string& loggerId, bool lookup = true, bool needLock = true); + void + processDispatch(); + void + triggerDispatch(void); +}; +class PErrorWriter : public base::Writer { + public: + PErrorWriter(Level level, const char* file, base::type::LineNumber line, const char* func, + base::DispatchAction dispatchAction = base::DispatchAction::NormalLog, + base::type::VerboseLevel verboseLevel = 0) + : base::Writer(level, file, line, func, dispatchAction, verboseLevel) { + } + + virtual ~PErrorWriter(void); +}; +} // namespace base +// Logging from Logger class. Why this is here? Because we have Storage and Writer class available +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED +template +void +Logger::log_(Level level, int vlevel, const char* s, const T& value, const Args&... args) { + base::MessageBuilder b; + b.initialize(this); + while (*s) { + if (*s == base::consts::kFormatSpecifierChar) { + if (*(s + 1) == base::consts::kFormatSpecifierChar) { + ++s; + } else { + if (*(s + 1) == base::consts::kFormatSpecifierCharValue) { + ++s; + b << value; + log_(level, vlevel, ++s, args...); + return; + } + } + } + b << *s++; + } + ELPP_INTERNAL_ERROR("Too many arguments provided. Unable to handle. Please provide more format specifiers", false); +} +template +void +Logger::log_(Level level, int vlevel, const T& log) { + if (level == Level::Verbose) { + if (ELPP->vRegistry()->allowed(vlevel, __FILE__)) { + base::Writer(Level::Verbose, "FILE", 0, "FUNCTION", base::DispatchAction::NormalLog, vlevel) + .construct(this, false) + << log; + } else { + stream().str(ELPP_LITERAL("")); + releaseLock(); + } + } else { + base::Writer(level, "FILE", 0, "FUNCTION").construct(this, false) << log; + } +} +template +inline void +Logger::log(Level level, const char* s, const T& value, const Args&... args) { + acquireLock(); // released in Writer! + log_(level, 0, s, value, args...); +} +template +inline void +Logger::log(Level level, const T& log) { + acquireLock(); // released in Writer! + log_(level, 0, log); +} +#if ELPP_VERBOSE_LOG +template +inline void +Logger::verbose(int vlevel, const char* s, const T& value, const Args&... args) { + acquireLock(); // released in Writer! + log_(el::Level::Verbose, vlevel, s, value, args...); +} +template +inline void +Logger::verbose(int vlevel, const T& log) { + acquireLock(); // released in Writer! + log_(el::Level::Verbose, vlevel, log); +} +#else +template +inline void +Logger::verbose(int, const char*, const T&, const Args&...) { + return; +} +template +inline void +Logger::verbose(int, const T&) { + return; +} +#endif // ELPP_VERBOSE_LOG +#define LOGGER_LEVEL_WRITERS(FUNCTION_NAME, LOG_LEVEL) \ + template \ + inline void Logger::FUNCTION_NAME(const char* s, const T& value, const Args&... args) { \ + log(LOG_LEVEL, s, value, args...); \ + } \ + template \ + inline void Logger::FUNCTION_NAME(const T& value) { \ + log(LOG_LEVEL, value); \ + } +#define LOGGER_LEVEL_WRITERS_DISABLED(FUNCTION_NAME, LOG_LEVEL) \ + template \ + inline void Logger::FUNCTION_NAME(const char*, const T&, const Args&...) { \ + return; \ + } \ + template \ + inline void Logger::FUNCTION_NAME(const T&) { \ + return; \ + } + +#if ELPP_INFO_LOG +LOGGER_LEVEL_WRITERS(info, Level::Info) +#else +LOGGER_LEVEL_WRITERS_DISABLED(info, Level::Info) +#endif // ELPP_INFO_LOG +#if ELPP_DEBUG_LOG +LOGGER_LEVEL_WRITERS(debug, Level::Debug) +#else +LOGGER_LEVEL_WRITERS_DISABLED(debug, Level::Debug) +#endif // ELPP_DEBUG_LOG +#if ELPP_WARNING_LOG +LOGGER_LEVEL_WRITERS(warn, Level::Warning) +#else +LOGGER_LEVEL_WRITERS_DISABLED(warn, Level::Warning) +#endif // ELPP_WARNING_LOG +#if ELPP_ERROR_LOG +LOGGER_LEVEL_WRITERS(error, Level::Error) +#else +LOGGER_LEVEL_WRITERS_DISABLED(error, Level::Error) +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +LOGGER_LEVEL_WRITERS(fatal, Level::Fatal) +#else +LOGGER_LEVEL_WRITERS_DISABLED(fatal, Level::Fatal) +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +LOGGER_LEVEL_WRITERS(trace, Level::Trace) +#else +LOGGER_LEVEL_WRITERS_DISABLED(trace, Level::Trace) +#endif // ELPP_TRACE_LOG +#undef LOGGER_LEVEL_WRITERS +#undef LOGGER_LEVEL_WRITERS_DISABLED +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED +#if ELPP_COMPILER_MSVC +#define ELPP_VARIADIC_FUNC_MSVC(variadicFunction, variadicArgs) variadicFunction variadicArgs +#define ELPP_VARIADIC_FUNC_MSVC_RUN(variadicFunction, ...) ELPP_VARIADIC_FUNC_MSVC(variadicFunction, (__VA_ARGS__)) +#define el_getVALength(...) \ + ELPP_VARIADIC_FUNC_MSVC_RUN(el_resolveVALength, 0, ##__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +#else +#if ELPP_COMPILER_CLANG +#define el_getVALength(...) el_resolveVALength(0, __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +#else +#define el_getVALength(...) el_resolveVALength(0, ##__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +#endif // ELPP_COMPILER_CLANG +#endif // ELPP_COMPILER_MSVC +#define el_resolveVALength(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define ELPP_WRITE_LOG(writer, level, dispatchAction, ...) \ + writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_IF(writer, condition, level, dispatchAction, ...) \ + if (condition) \ + writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_EVERY_N(writer, occasion, level, dispatchAction, ...) \ + ELPP->validateEveryNCounter(__FILE__, __LINE__, occasion) && \ + writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction) \ + .construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_AFTER_N(writer, n, level, dispatchAction, ...) \ + ELPP->validateAfterNCounter(__FILE__, __LINE__, n) && writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction) \ + .construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_N_TIMES(writer, n, level, dispatchAction, ...) \ + ELPP->validateNTimesCounter(__FILE__, __LINE__, n) && writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction) \ + .construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +class PerformanceTrackingData { + public: + enum class DataType : base::type::EnumType { Checkpoint = 1, Complete = 2 }; + // Do not use constructor, will run into multiple definition error, use init(PerformanceTracker*) + explicit PerformanceTrackingData(DataType dataType) + : m_performanceTracker(nullptr), + m_dataType(dataType), + m_firstCheckpoint(false), + m_file(""), + m_line(0), + m_func("") { + } + inline const std::string* + blockName(void) const; + inline const struct timeval* + startTime(void) const; + inline const struct timeval* + endTime(void) const; + inline const struct timeval* + lastCheckpointTime(void) const; + inline const base::PerformanceTracker* + performanceTracker(void) const { + return m_performanceTracker; + } + inline PerformanceTrackingData::DataType + dataType(void) const { + return m_dataType; + } + inline bool + firstCheckpoint(void) const { + return m_firstCheckpoint; + } + inline std::string + checkpointId(void) const { + return m_checkpointId; + } + inline const char* + file(void) const { + return m_file; + } + inline base::type::LineNumber + line(void) const { + return m_line; + } + inline const char* + func(void) const { + return m_func; + } + inline const base::type::string_t* + formattedTimeTaken() const { + return &m_formattedTimeTaken; + } + inline const std::string& + loggerId(void) const; + + private: + base::PerformanceTracker* m_performanceTracker; + base::type::string_t m_formattedTimeTaken; + PerformanceTrackingData::DataType m_dataType; + bool m_firstCheckpoint; + std::string m_checkpointId; + const char* m_file; + base::type::LineNumber m_line; + const char* m_func; + inline void + init(base::PerformanceTracker* performanceTracker, bool firstCheckpoint = false) { + m_performanceTracker = performanceTracker; + m_firstCheckpoint = firstCheckpoint; + } + + friend class el::base::PerformanceTracker; +}; +namespace base { +/// @brief Represents performanceTracker block of code that conditionally adds performance status to log +/// either when goes outside the scope of when checkpoint() is called +class PerformanceTracker : public base::threading::ThreadSafe, public Loggable { + public: + PerformanceTracker(const std::string& blockName, + base::TimestampUnit timestampUnit = base::TimestampUnit::Millisecond, + const std::string& loggerId = std::string(el::base::consts::kPerformanceLoggerId), + bool scopedLog = true, Level level = base::consts::kPerformanceTrackerDefaultLevel); + /// @brief Copy constructor + PerformanceTracker(const PerformanceTracker& t) + : m_blockName(t.m_blockName), + m_timestampUnit(t.m_timestampUnit), + m_loggerId(t.m_loggerId), + m_scopedLog(t.m_scopedLog), + m_level(t.m_level), + m_hasChecked(t.m_hasChecked), + m_lastCheckpointId(t.m_lastCheckpointId), + m_enabled(t.m_enabled), + m_startTime(t.m_startTime), + m_endTime(t.m_endTime), + m_lastCheckpointTime(t.m_lastCheckpointTime) { + } + virtual ~PerformanceTracker(void); + /// @brief A checkpoint for current performanceTracker block. + void + checkpoint(const std::string& id = std::string(), const char* file = __FILE__, + base::type::LineNumber line = __LINE__, const char* func = ""); + inline Level + level(void) const { + return m_level; + } + + private: + std::string m_blockName; + base::TimestampUnit m_timestampUnit; + std::string m_loggerId; + bool m_scopedLog; + Level m_level; + bool m_hasChecked; + std::string m_lastCheckpointId; + bool m_enabled; + struct timeval m_startTime, m_endTime, m_lastCheckpointTime; + + PerformanceTracker(void); + + friend class el::PerformanceTrackingData; + friend class base::DefaultPerformanceTrackingCallback; + + const inline base::type::string_t + getFormattedTimeTaken() const { + return getFormattedTimeTaken(m_startTime); + } + + const base::type::string_t + getFormattedTimeTaken(struct timeval startTime) const; + + virtual inline void + log(el::base::type::ostream_t& os) const { + os << getFormattedTimeTaken(); + } +}; +class DefaultPerformanceTrackingCallback : public PerformanceTrackingCallback { + protected: + void + handle(const PerformanceTrackingData* data) { + m_data = data; + base::type::stringstream_t ss; + if (m_data->dataType() == PerformanceTrackingData::DataType::Complete) { + ss << ELPP_LITERAL("Executed [") << m_data->blockName()->c_str() << ELPP_LITERAL("] in [") + << *m_data->formattedTimeTaken() << ELPP_LITERAL("]"); + } else { + ss << ELPP_LITERAL("Performance checkpoint"); + if (!m_data->checkpointId().empty()) { + ss << ELPP_LITERAL(" [") << m_data->checkpointId().c_str() << ELPP_LITERAL("]"); + } + ss << ELPP_LITERAL(" for block [") << m_data->blockName()->c_str() << ELPP_LITERAL("] : [") + << *m_data->performanceTracker(); + if (!ELPP->hasFlag(LoggingFlag::DisablePerformanceTrackingCheckpointComparison) && + m_data->performanceTracker()->m_hasChecked) { + ss << ELPP_LITERAL(" ([") << *m_data->formattedTimeTaken() << ELPP_LITERAL("] from "); + if (m_data->performanceTracker()->m_lastCheckpointId.empty()) { + ss << ELPP_LITERAL("last checkpoint"); + } else { + ss << ELPP_LITERAL("checkpoint '") << m_data->performanceTracker()->m_lastCheckpointId.c_str() + << ELPP_LITERAL("'"); + } + ss << ELPP_LITERAL(")]"); + } else { + ss << ELPP_LITERAL("]"); + } + } + el::base::Writer(m_data->performanceTracker()->level(), m_data->file(), m_data->line(), m_data->func()) + .construct(1, m_data->loggerId().c_str()) + << ss.str(); + } + + private: + const PerformanceTrackingData* m_data; +}; +} // namespace base +inline const std::string* +PerformanceTrackingData::blockName() const { + return const_cast(&m_performanceTracker->m_blockName); +} +inline const struct timeval* +PerformanceTrackingData::startTime() const { + return const_cast(&m_performanceTracker->m_startTime); +} +inline const struct timeval* +PerformanceTrackingData::endTime() const { + return const_cast(&m_performanceTracker->m_endTime); +} +inline const struct timeval* +PerformanceTrackingData::lastCheckpointTime() const { + return const_cast(&m_performanceTracker->m_lastCheckpointTime); +} +inline const std::string& +PerformanceTrackingData::loggerId(void) const { + return m_performanceTracker->m_loggerId; +} +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +namespace base { +/// @brief Contains some internal debugging tools like crash handler and stack tracer +namespace debug { +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) +class StackTrace : base::NoCopy { + public: + static const unsigned int kMaxStack = 64; + static const unsigned int kStackStart = 2; // We want to skip c'tor and StackTrace::generateNew() + class StackTraceEntry { + public: + StackTraceEntry(std::size_t index, const std::string& loc, const std::string& demang, const std::string& hex, + const std::string& addr); + StackTraceEntry(std::size_t index, const std::string& loc) : m_index(index), m_location(loc) { + } + std::size_t m_index; + std::string m_location; + std::string m_demangled; + std::string m_hex; + std::string m_addr; + friend std::ostream& + operator<<(std::ostream& ss, const StackTraceEntry& si); + + private: + StackTraceEntry(void); + }; + + StackTrace(void) { + generateNew(); + } + + virtual ~StackTrace(void) { + } + + inline std::vector& + getLatestStack(void) { + return m_stack; + } + + friend std::ostream& + operator<<(std::ostream& os, const StackTrace& st); + + private: + std::vector m_stack; + + void + generateNew(void); +}; +/// @brief Handles unexpected crashes +class CrashHandler : base::NoCopy { + public: + typedef void (*Handler)(int); + + explicit CrashHandler(bool useDefault); + explicit CrashHandler(const Handler& cHandler) { + setHandler(cHandler); + } + void + setHandler(const Handler& cHandler); + + private: + Handler m_handler; +}; +#else +class CrashHandler { + public: + explicit CrashHandler(bool) { + } +}; +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) +} // namespace debug +} // namespace base +extern base::debug::CrashHandler elCrashHandler; +#define MAKE_LOGGABLE(ClassType, ClassInstance, OutputStreamInstance) \ + el::base::type::ostream_t& operator<<(el::base::type::ostream_t& OutputStreamInstance, \ + const ClassType& ClassInstance) +/// @brief Initializes syslog with process ID, options and facility. calls closelog() on d'tor +class SysLogInitializer { + public: + SysLogInitializer(const char* processIdent, int options = 0, int facility = 0) { +#if defined(ELPP_SYSLOG) + openlog(processIdent, options, facility); +#else + ELPP_UNUSED(processIdent); + ELPP_UNUSED(options); + ELPP_UNUSED(facility); +#endif // defined(ELPP_SYSLOG) + } + virtual ~SysLogInitializer(void) { +#if defined(ELPP_SYSLOG) + closelog(); +#endif // defined(ELPP_SYSLOG) + } +}; +#define ELPP_INITIALIZE_SYSLOG(id, opt, fac) el::SysLogInitializer elSyslogInit(id, opt, fac) +/// @brief Static helpers for developers +class Helpers : base::StaticClass { + public: + /// @brief Shares logging repository (base::Storage) + static inline void + setStorage(base::type::StoragePointer storage) { + ELPP = storage; + } + /// @return Main storage repository + static inline base::type::StoragePointer + storage() { + return ELPP; + } + /// @brief Sets application arguments and figures out whats active for logging and whats not. + static inline void + setArgs(int argc, char** argv) { + ELPP->setApplicationArguments(argc, argv); + } + /// @copydoc setArgs(int argc, char** argv) + static inline void + setArgs(int argc, const char** argv) { + ELPP->setApplicationArguments(argc, const_cast(argv)); + } + /// @brief Sets thread name for current thread. Requires std::thread + static inline void + setThreadName(const std::string& name) { + ELPP->setThreadName(name); + } + static inline std::string + getThreadName() { + return ELPP->getThreadName(base::threading::getCurrentThreadId()); + } +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + /// @brief Overrides default crash handler and installs custom handler. + /// @param crashHandler A functor with no return type that takes single int argument. + /// Handler is a typedef with specification: void (*Handler)(int) + static inline void + setCrashHandler(const el::base::debug::CrashHandler::Handler& crashHandler) { + el::elCrashHandler.setHandler(crashHandler); + } + /// @brief Abort due to crash with signal in parameter + /// @param sig Crash signal + static void + crashAbort(int sig, const char* sourceFile = "", unsigned int long line = 0); + /// @brief Logs reason of crash as per sig + /// @param sig Crash signal + /// @param stackTraceIfAvailable Includes stack trace if available + /// @param level Logging level + /// @param logger Logger to use for logging + static void + logCrashReason(int sig, bool stackTraceIfAvailable = false, Level level = Level::Fatal, + const char* logger = base::consts::kDefaultLoggerId); +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + /// @brief Installs pre rollout callback, this callback is triggered when log file is about to be rolled out + /// (can be useful for backing up) + static inline void + installPreRollOutCallback(const PreRollOutCallback& callback) { + ELPP->setPreRollOutCallback(callback); + } + /// @brief Uninstalls pre rollout callback + static inline void + uninstallPreRollOutCallback(void) { + ELPP->unsetPreRollOutCallback(); + } + /// @brief Installs post log dispatch callback, this callback is triggered when log is dispatched + template + static inline bool + installLogDispatchCallback(const std::string& id) { + return ELPP->installLogDispatchCallback(id); + } + /// @brief Uninstalls log dispatch callback + template + static inline void + uninstallLogDispatchCallback(const std::string& id) { + ELPP->uninstallLogDispatchCallback(id); + } + template + static inline T* + logDispatchCallback(const std::string& id) { + return ELPP->logDispatchCallback(id); + } +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + /// @brief Installs post performance tracking callback, this callback is triggered when performance tracking is + /// finished + template + static inline bool + installPerformanceTrackingCallback(const std::string& id) { + return ELPP->installPerformanceTrackingCallback(id); + } + /// @brief Uninstalls post performance tracking handler + template + static inline void + uninstallPerformanceTrackingCallback(const std::string& id) { + ELPP->uninstallPerformanceTrackingCallback(id); + } + template + static inline T* + performanceTrackingCallback(const std::string& id) { + return ELPP->performanceTrackingCallback(id); + } +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + /// @brief Converts template to std::string - useful for loggable classes to log containers within + /// log(std::ostream&) const + template + static std::string + convertTemplateToStdString(const T& templ) { + el::Logger* logger = ELPP->registeredLoggers()->get(el::base::consts::kDefaultLoggerId); + if (logger == nullptr) { + return std::string(); + } + base::MessageBuilder b; + b.initialize(logger); + logger->acquireLock(); + b << templ; +#if defined(ELPP_UNICODE) + std::string s = std::string(logger->stream().str().begin(), logger->stream().str().end()); +#else + std::string s = logger->stream().str(); +#endif // defined(ELPP_UNICODE) + logger->stream().str(ELPP_LITERAL("")); + logger->releaseLock(); + return s; + } + /// @brief Returns command line arguments (pointer) provided to easylogging++ + static inline const el::base::utils::CommandLineArgs* + commandLineArgs(void) { + return ELPP->commandLineArgs(); + } + /// @brief Reserve space for custom format specifiers for performance + /// @see std::vector::reserve + static inline void + reserveCustomFormatSpecifiers(std::size_t size) { + ELPP->m_customFormatSpecifiers.reserve(size); + } + /// @brief Installs user defined format specifier and handler + static inline void + installCustomFormatSpecifier(const CustomFormatSpecifier& customFormatSpecifier) { + ELPP->installCustomFormatSpecifier(customFormatSpecifier); + } + /// @brief Uninstalls user defined format specifier and handler + static inline bool + uninstallCustomFormatSpecifier(const char* formatSpecifier) { + return ELPP->uninstallCustomFormatSpecifier(formatSpecifier); + } + /// @brief Returns true if custom format specifier is installed + static inline bool + hasCustomFormatSpecifier(const char* formatSpecifier) { + return ELPP->hasCustomFormatSpecifier(formatSpecifier); + } + static inline void + validateFileRolling(Logger* logger, Level level) { + if (ELPP == nullptr || logger == nullptr) + return; + logger->m_typedConfigurations->validateFileRolling(level, ELPP->preRollOutCallback()); + } +}; +/// @brief Static helpers to deal with loggers and their configurations +class Loggers : base::StaticClass { + public: + /// @brief Gets existing or registers new logger + static Logger* + getLogger(const std::string& identity, bool registerIfNotAvailable = true); + /// @brief Changes default log builder for future loggers + static void + setDefaultLogBuilder(el::LogBuilderPtr& logBuilderPtr); + /// @brief Installs logger registration callback, this callback is triggered when new logger is registered + template + static inline bool + installLoggerRegistrationCallback(const std::string& id) { + return ELPP->registeredLoggers()->installLoggerRegistrationCallback(id); + } + /// @brief Uninstalls log dispatch callback + template + static inline void + uninstallLoggerRegistrationCallback(const std::string& id) { + ELPP->registeredLoggers()->uninstallLoggerRegistrationCallback(id); + } + template + static inline T* + loggerRegistrationCallback(const std::string& id) { + return ELPP->registeredLoggers()->loggerRegistrationCallback(id); + } + /// @brief Unregisters logger - use it only when you know what you are doing, you may unregister + /// loggers initialized / used by third-party libs. + static bool + unregisterLogger(const std::string& identity); + /// @brief Whether or not logger with id is registered + static bool + hasLogger(const std::string& identity); + /// @brief Reconfigures specified logger with new configurations + static Logger* + reconfigureLogger(Logger* logger, const Configurations& configurations); + /// @brief Reconfigures logger with new configurations after looking it up using identity + static Logger* + reconfigureLogger(const std::string& identity, const Configurations& configurations); + /// @brief Reconfigures logger's single configuration + static Logger* + reconfigureLogger(const std::string& identity, ConfigurationType configurationType, const std::string& value); + /// @brief Reconfigures all the existing loggers with new configurations + static void + reconfigureAllLoggers(const Configurations& configurations); + /// @brief Reconfigures single configuration for all the loggers + static inline void + reconfigureAllLoggers(ConfigurationType configurationType, const std::string& value) { + reconfigureAllLoggers(Level::Global, configurationType, value); + } + /// @brief Reconfigures single configuration for all the loggers for specified level + static void + reconfigureAllLoggers(Level level, ConfigurationType configurationType, const std::string& value); + /// @brief Sets default configurations. This configuration is used for future (and conditionally for existing) + /// loggers + static void + setDefaultConfigurations(const Configurations& configurations, bool reconfigureExistingLoggers = false); + /// @brief Returns current default + static const Configurations* + defaultConfigurations(void); + /// @brief Returns log stream reference pointer if needed by user + static const base::LogStreamsReferenceMap* + logStreamsReference(void); + /// @brief Default typed configuration based on existing defaultConf + static base::TypedConfigurations + defaultTypedConfigurations(void); + /// @brief Populates all logger IDs in current repository. + /// @param [out] targetList List of fill up. + static std::vector* + populateAllLoggerIds(std::vector* targetList); + /// @brief Sets configurations from global configuration file. + static void + configureFromGlobal(const char* globalConfigurationFilePath); + /// @brief Configures loggers using command line arg. Ensure you have already set command line args, + /// @return False if invalid argument or argument with no value provided, true if attempted to configure logger. + /// If true is returned that does not mean it has been configured successfully, it only means that it + /// has attempeted to configure logger using configuration file provided in argument + static bool + configureFromArg(const char* argKey); + /// @brief Flushes all loggers for all levels - Be careful if you dont know how many loggers are registered + static void + flushAll(void); + /// @brief Adds logging flag used internally. + static inline void + addFlag(LoggingFlag flag) { + ELPP->addFlag(flag); + } + /// @brief Removes logging flag used internally. + static inline void + removeFlag(LoggingFlag flag) { + ELPP->removeFlag(flag); + } + /// @brief Determines whether or not certain flag is active + static inline bool + hasFlag(LoggingFlag flag) { + return ELPP->hasFlag(flag); + } + /// @brief Adds flag and removes it when scope goes out + class ScopedAddFlag { + public: + ScopedAddFlag(LoggingFlag flag) : m_flag(flag) { + Loggers::addFlag(m_flag); + } + ~ScopedAddFlag(void) { + Loggers::removeFlag(m_flag); + } + + private: + LoggingFlag m_flag; + }; + /// @brief Removes flag and add it when scope goes out + class ScopedRemoveFlag { + public: + ScopedRemoveFlag(LoggingFlag flag) : m_flag(flag) { + Loggers::removeFlag(m_flag); + } + ~ScopedRemoveFlag(void) { + Loggers::addFlag(m_flag); + } + + private: + LoggingFlag m_flag; + }; + /// @brief Sets hierarchy for logging. Needs to enable logging flag (HierarchicalLogging) + static void + setLoggingLevel(Level level) { + ELPP->setLoggingLevel(level); + } + /// @brief Sets verbose level on the fly + static void + setVerboseLevel(base::type::VerboseLevel level); + /// @brief Gets current verbose level + static base::type::VerboseLevel + verboseLevel(void); + /// @brief Sets vmodules as specified (on the fly) + static void + setVModules(const char* modules); + /// @brief Clears vmodules + static void + clearVModules(void); +}; +class VersionInfo : base::StaticClass { + public: + /// @brief Current version number + static const std::string + version(void); + + /// @brief Release date of current version + static const std::string + releaseDate(void); +}; +} // namespace el +#undef VLOG_IS_ON +/// @brief Determines whether verbose logging is on for specified level current file. +#define VLOG_IS_ON(verboseLevel) (ELPP->vRegistry()->allowed(verboseLevel, __FILE__)) +#undef TIMED_BLOCK +#undef TIMED_SCOPE +#undef TIMED_SCOPE_IF +#undef TIMED_FUNC +#undef TIMED_FUNC_IF +#undef ELPP_MIN_UNIT +#if defined(ELPP_PERFORMANCE_MICROSECONDS) +#define ELPP_MIN_UNIT el::base::TimestampUnit::Microsecond +#else +#define ELPP_MIN_UNIT el::base::TimestampUnit::Millisecond +#endif // (defined(ELPP_PERFORMANCE_MICROSECONDS)) +/// @brief Performance tracked scope. Performance gets written when goes out of scope using +/// 'performance' logger. +/// +/// @detail Please note in order to check the performance at a certain time you can use obj->checkpoint(); +/// @see el::base::PerformanceTracker +/// @see el::base::PerformanceTracker::checkpoint +// Note: Do not surround this definition with null macro because of obj instance +#define TIMED_SCOPE_IF(obj, blockname, condition) \ + el::base::type::PerformanceTrackerPtr obj(condition ? new el::base::PerformanceTracker(blockname, ELPP_MIN_UNIT) \ + : nullptr) +#define TIMED_SCOPE(obj, blockname) TIMED_SCOPE_IF(obj, blockname, true) +#define TIMED_BLOCK(obj, blockName) \ + for (struct { \ + int i; \ + el::base::type::PerformanceTrackerPtr timer; \ + } obj = {0, \ + el::base::type::PerformanceTrackerPtr(new el::base::PerformanceTracker(blockName, ELPP_MIN_UNIT))}; \ + obj.i < 1; ++obj.i) +/// @brief Performance tracked function. Performance gets written when goes out of scope using +/// 'performance' logger. +/// +/// @detail Please note in order to check the performance at a certain time you can use obj->checkpoint(); +/// @see el::base::PerformanceTracker +/// @see el::base::PerformanceTracker::checkpoint +#define TIMED_FUNC_IF(obj, condition) TIMED_SCOPE_IF(obj, ELPP_FUNC, condition) +#define TIMED_FUNC(obj) TIMED_SCOPE(obj, ELPP_FUNC) +#undef PERFORMANCE_CHECKPOINT +#undef PERFORMANCE_CHECKPOINT_WITH_ID +#define PERFORMANCE_CHECKPOINT(obj) obj->checkpoint(std::string(), __FILE__, __LINE__, ELPP_FUNC) +#define PERFORMANCE_CHECKPOINT_WITH_ID(obj, id) obj->checkpoint(id, __FILE__, __LINE__, ELPP_FUNC) +#undef ELPP_COUNTER +#undef ELPP_COUNTER_POS +/// @brief Gets hit counter for file/line +#define ELPP_COUNTER (ELPP->hitCounters()->getCounter(__FILE__, __LINE__)) +/// @brief Gets hit counter position for file/line, -1 if not registered yet +#define ELPP_COUNTER_POS (ELPP_COUNTER == nullptr ? -1 : ELPP_COUNTER->hitCounts()) +// Undef levels to support LOG(LEVEL) +#undef INFO +#undef WARNING +#undef DEBUG +#undef ERROR +#undef FATAL +#undef TRACE +#undef VERBOSE +// Undef existing +#undef CINFO +#undef CWARNING +#undef CDEBUG +#undef CFATAL +#undef CERROR +#undef CTRACE +#undef CVERBOSE +#undef CINFO_IF +#undef CWARNING_IF +#undef CDEBUG_IF +#undef CERROR_IF +#undef CFATAL_IF +#undef CTRACE_IF +#undef CVERBOSE_IF +#undef CINFO_EVERY_N +#undef CWARNING_EVERY_N +#undef CDEBUG_EVERY_N +#undef CERROR_EVERY_N +#undef CFATAL_EVERY_N +#undef CTRACE_EVERY_N +#undef CVERBOSE_EVERY_N +#undef CINFO_AFTER_N +#undef CWARNING_AFTER_N +#undef CDEBUG_AFTER_N +#undef CERROR_AFTER_N +#undef CFATAL_AFTER_N +#undef CTRACE_AFTER_N +#undef CVERBOSE_AFTER_N +#undef CINFO_N_TIMES +#undef CWARNING_N_TIMES +#undef CDEBUG_N_TIMES +#undef CERROR_N_TIMES +#undef CFATAL_N_TIMES +#undef CTRACE_N_TIMES +#undef CVERBOSE_N_TIMES +// Normal logs +#if ELPP_INFO_LOG +#define CINFO(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +#define CINFO(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +#define CWARNING(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +#define CWARNING(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +#define CDEBUG(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +#define CDEBUG(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +#define CERROR(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +#define CERROR(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +#define CFATAL(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +#define CFATAL(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +#define CTRACE(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +#define CTRACE(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +#define CVERBOSE(writer, vlevel, dispatchAction, ...) \ + if (VLOG_IS_ON(vlevel)) \ + writer(el::Level::Verbose, __FILE__, __LINE__, ELPP_FUNC, dispatchAction, vlevel) \ + .construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#else +#define CVERBOSE(writer, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// Conditional logs +#if ELPP_INFO_LOG +#define CINFO_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Info, dispatchAction, __VA_ARGS__) +#else +#define CINFO_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +#define CWARNING_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +#define CWARNING_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +#define CDEBUG_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +#define CDEBUG_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +#define CERROR_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Error, dispatchAction, __VA_ARGS__) +#else +#define CERROR_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +#define CFATAL_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +#define CFATAL_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +#define CTRACE_IF(writer, condition_, dispatchAction, ...) \ + ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +#define CTRACE_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +#define CVERBOSE_IF(writer, condition_, vlevel, dispatchAction, ...) \ + if (VLOG_IS_ON(vlevel) && (condition_)) \ + writer(el::Level::Verbose, __FILE__, __LINE__, ELPP_FUNC, dispatchAction, vlevel) \ + .construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#else +#define CVERBOSE_IF(writer, condition_, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// Occasional logs +#if ELPP_INFO_LOG +#define CINFO_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +#define CINFO_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +#define CWARNING_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +#define CWARNING_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +#define CDEBUG_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +#define CDEBUG_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +#define CERROR_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +#define CERROR_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +#define CFATAL_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +#define CFATAL_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +#define CTRACE_EVERY_N(writer, occasion, dispatchAction, ...) \ + ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +#define CTRACE_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +#define CVERBOSE_EVERY_N(writer, occasion, vlevel, dispatchAction, ...) \ + CVERBOSE_IF(writer, ELPP->validateEveryNCounter(__FILE__, __LINE__, occasion), vlevel, dispatchAction, __VA_ARGS__) +#else +#define CVERBOSE_EVERY_N(writer, occasion, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// After N logs +#if ELPP_INFO_LOG +#define CINFO_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +#define CINFO_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +#define CWARNING_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +#define CWARNING_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +#define CDEBUG_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +#define CDEBUG_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +#define CERROR_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +#define CERROR_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +#define CFATAL_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +#define CFATAL_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +#define CTRACE_AFTER_N(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +#define CTRACE_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +#define CVERBOSE_AFTER_N(writer, n, vlevel, dispatchAction, ...) \ + CVERBOSE_IF(writer, ELPP->validateAfterNCounter(__FILE__, __LINE__, n), vlevel, dispatchAction, __VA_ARGS__) +#else +#define CVERBOSE_AFTER_N(writer, n, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// N Times logs +#if ELPP_INFO_LOG +#define CINFO_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +#define CINFO_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +#define CWARNING_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +#define CWARNING_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +#define CDEBUG_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +#define CDEBUG_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +#define CERROR_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +#define CERROR_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +#define CFATAL_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +#define CFATAL_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +#define CTRACE_N_TIMES(writer, n, dispatchAction, ...) \ + ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +#define CTRACE_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +#define CVERBOSE_N_TIMES(writer, n, vlevel, dispatchAction, ...) \ + CVERBOSE_IF(writer, ELPP->validateNTimesCounter(__FILE__, __LINE__, n), vlevel, dispatchAction, __VA_ARGS__) +#else +#define CVERBOSE_N_TIMES(writer, n, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// +// Custom Loggers - Requires (level, dispatchAction, loggerId/s) +// +// undef existing +#undef CLOG +#undef CLOG_VERBOSE +#undef CVLOG +#undef CLOG_IF +#undef CLOG_VERBOSE_IF +#undef CVLOG_IF +#undef CLOG_EVERY_N +#undef CVLOG_EVERY_N +#undef CLOG_AFTER_N +#undef CVLOG_AFTER_N +#undef CLOG_N_TIMES +#undef CVLOG_N_TIMES +// Normal logs +#define CLOG(LEVEL, ...) C##LEVEL(el::base::Writer, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG(vlevel, ...) CVERBOSE(el::base::Writer, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// Conditional logs +#define CLOG_IF(condition, LEVEL, ...) \ + C##LEVEL##_IF(el::base::Writer, condition, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_IF(condition, vlevel, ...) \ + CVERBOSE_IF(el::base::Writer, condition, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// Hit counts based logs +#define CLOG_EVERY_N(n, LEVEL, ...) \ + C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_EVERY_N(n, vlevel, ...) \ + CVERBOSE_EVERY_N(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CLOG_AFTER_N(n, LEVEL, ...) \ + C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_AFTER_N(n, vlevel, ...) \ + CVERBOSE_AFTER_N(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CLOG_N_TIMES(n, LEVEL, ...) \ + C##LEVEL##_N_TIMES(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_N_TIMES(n, vlevel, ...) \ + CVERBOSE_N_TIMES(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// +// Default Loggers macro using CLOG(), CLOG_VERBOSE() and CVLOG() macros +// +// undef existing +#undef LOG +#undef VLOG +#undef LOG_IF +#undef VLOG_IF +#undef LOG_EVERY_N +#undef VLOG_EVERY_N +#undef LOG_AFTER_N +#undef VLOG_AFTER_N +#undef LOG_N_TIMES +#undef VLOG_N_TIMES +#undef ELPP_CURR_FILE_LOGGER_ID +#if defined(ELPP_DEFAULT_LOGGER) +#define ELPP_CURR_FILE_LOGGER_ID ELPP_DEFAULT_LOGGER +#else +#define ELPP_CURR_FILE_LOGGER_ID el::base::consts::kDefaultLoggerId +#endif +#undef ELPP_TRACE +#define ELPP_TRACE CLOG(TRACE, ELPP_CURR_FILE_LOGGER_ID) +// Normal logs +#define LOG(LEVEL) CLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG(vlevel) CVLOG(vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Conditional logs +#define LOG_IF(condition, LEVEL) CLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_IF(condition, vlevel) CVLOG_IF(condition, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Hit counts based logs +#define LOG_EVERY_N(n, LEVEL) CLOG_EVERY_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_EVERY_N(n, vlevel) CVLOG_EVERY_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define LOG_AFTER_N(n, LEVEL) CLOG_AFTER_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_AFTER_N(n, vlevel) CVLOG_AFTER_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define LOG_N_TIMES(n, LEVEL) CLOG_N_TIMES(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_N_TIMES(n, vlevel) CVLOG_N_TIMES(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Generic PLOG() +#undef CPLOG +#undef CPLOG_IF +#undef PLOG +#undef PLOG_IF +#undef DCPLOG +#undef DCPLOG_IF +#undef DPLOG +#undef DPLOG_IF +#define CPLOG(LEVEL, ...) C##LEVEL(el::base::PErrorWriter, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CPLOG_IF(condition, LEVEL, ...) \ + C##LEVEL##_IF(el::base::PErrorWriter, condition, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define DCPLOG(LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + C##LEVEL(el::base::PErrorWriter, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define DCPLOG_IF(condition, LEVEL, ...) \ + C##LEVEL##_IF(el::base::PErrorWriter, (ELPP_DEBUG_LOG) && (condition), el::base::DispatchAction::NormalLog, \ + __VA_ARGS__) +#define PLOG(LEVEL) CPLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define PLOG_IF(condition, LEVEL) CPLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DPLOG(LEVEL) DCPLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DPLOG_IF(condition, LEVEL) DCPLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +// Generic SYSLOG() +#undef CSYSLOG +#undef CSYSLOG_IF +#undef CSYSLOG_EVERY_N +#undef CSYSLOG_AFTER_N +#undef CSYSLOG_N_TIMES +#undef SYSLOG +#undef SYSLOG_IF +#undef SYSLOG_EVERY_N +#undef SYSLOG_AFTER_N +#undef SYSLOG_N_TIMES +#undef DCSYSLOG +#undef DCSYSLOG_IF +#undef DCSYSLOG_EVERY_N +#undef DCSYSLOG_AFTER_N +#undef DCSYSLOG_N_TIMES +#undef DSYSLOG +#undef DSYSLOG_IF +#undef DSYSLOG_EVERY_N +#undef DSYSLOG_AFTER_N +#undef DSYSLOG_N_TIMES +#if defined(ELPP_SYSLOG) +#define CSYSLOG(LEVEL, ...) C##LEVEL(el::base::Writer, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define CSYSLOG_IF(condition, LEVEL, ...) \ + C##LEVEL##_IF(el::base::Writer, condition, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define CSYSLOG_EVERY_N(n, LEVEL, ...) \ + C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define CSYSLOG_AFTER_N(n, LEVEL, ...) \ + C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define CSYSLOG_N_TIMES(n, LEVEL, ...) \ + C##LEVEL##_N_TIMES(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define SYSLOG(LEVEL) CSYSLOG(LEVEL, el::base::consts::kSysLogLoggerId) +#define SYSLOG_IF(condition, LEVEL) CSYSLOG_IF(condition, LEVEL, el::base::consts::kSysLogLoggerId) +#define SYSLOG_EVERY_N(n, LEVEL) CSYSLOG_EVERY_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +#define SYSLOG_AFTER_N(n, LEVEL) CSYSLOG_AFTER_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +#define SYSLOG_N_TIMES(n, LEVEL) CSYSLOG_N_TIMES(n, LEVEL, el::base::consts::kSysLogLoggerId) +#define DCSYSLOG(LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + C##LEVEL(el::base::Writer, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define DCSYSLOG_IF(condition, LEVEL, ...) \ + C##LEVEL##_IF(el::base::Writer, (ELPP_DEBUG_LOG) && (condition), el::base::DispatchAction::SysLog, __VA_ARGS__) +#define DCSYSLOG_EVERY_N(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define DCSYSLOG_AFTER_N(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define DCSYSLOG_N_TIMES(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +#define DSYSLOG(LEVEL) DCSYSLOG(LEVEL, el::base::consts::kSysLogLoggerId) +#define DSYSLOG_IF(condition, LEVEL) DCSYSLOG_IF(condition, LEVEL, el::base::consts::kSysLogLoggerId) +#define DSYSLOG_EVERY_N(n, LEVEL) DCSYSLOG_EVERY_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +#define DSYSLOG_AFTER_N(n, LEVEL) DCSYSLOG_AFTER_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +#define DSYSLOG_N_TIMES(n, LEVEL) DCSYSLOG_N_TIMES(n, LEVEL, el::base::consts::kSysLogLoggerId) +#else +#define CSYSLOG(LEVEL, ...) el::base::NullWriter() +#define CSYSLOG_IF(condition, LEVEL, ...) el::base::NullWriter() +#define CSYSLOG_EVERY_N(n, LEVEL, ...) el::base::NullWriter() +#define CSYSLOG_AFTER_N(n, LEVEL, ...) el::base::NullWriter() +#define CSYSLOG_N_TIMES(n, LEVEL, ...) el::base::NullWriter() +#define SYSLOG(LEVEL) el::base::NullWriter() +#define SYSLOG_IF(condition, LEVEL) el::base::NullWriter() +#define SYSLOG_EVERY_N(n, LEVEL) el::base::NullWriter() +#define SYSLOG_AFTER_N(n, LEVEL) el::base::NullWriter() +#define SYSLOG_N_TIMES(n, LEVEL) el::base::NullWriter() +#define DCSYSLOG(LEVEL, ...) el::base::NullWriter() +#define DCSYSLOG_IF(condition, LEVEL, ...) el::base::NullWriter() +#define DCSYSLOG_EVERY_N(n, LEVEL, ...) el::base::NullWriter() +#define DCSYSLOG_AFTER_N(n, LEVEL, ...) el::base::NullWriter() +#define DCSYSLOG_N_TIMES(n, LEVEL, ...) el::base::NullWriter() +#define DSYSLOG(LEVEL) el::base::NullWriter() +#define DSYSLOG_IF(condition, LEVEL) el::base::NullWriter() +#define DSYSLOG_EVERY_N(n, LEVEL) el::base::NullWriter() +#define DSYSLOG_AFTER_N(n, LEVEL) el::base::NullWriter() +#define DSYSLOG_N_TIMES(n, LEVEL) el::base::NullWriter() +#endif // defined(ELPP_SYSLOG) +// +// Custom Debug Only Loggers - Requires (level, loggerId/s) +// +// undef existing +#undef DCLOG +#undef DCVLOG +#undef DCLOG_IF +#undef DCVLOG_IF +#undef DCLOG_EVERY_N +#undef DCVLOG_EVERY_N +#undef DCLOG_AFTER_N +#undef DCVLOG_AFTER_N +#undef DCLOG_N_TIMES +#undef DCVLOG_N_TIMES +// Normal logs +#define DCLOG(LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG(LEVEL, __VA_ARGS__) +#define DCLOG_VERBOSE(vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG_VERBOSE(vlevel, __VA_ARGS__) +#define DCVLOG(vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CVLOG(vlevel, __VA_ARGS__) +// Conditional logs +#define DCLOG_IF(condition, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG_IF(condition, LEVEL, __VA_ARGS__) +#define DCVLOG_IF(condition, vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CVLOG_IF(condition, vlevel, __VA_ARGS__) +// Hit counts based logs +#define DCLOG_EVERY_N(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG_EVERY_N(n, LEVEL, __VA_ARGS__) +#define DCVLOG_EVERY_N(n, vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CVLOG_EVERY_N(n, vlevel, __VA_ARGS__) +#define DCLOG_AFTER_N(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG_AFTER_N(n, LEVEL, __VA_ARGS__) +#define DCVLOG_AFTER_N(n, vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CVLOG_AFTER_N(n, vlevel, __VA_ARGS__) +#define DCLOG_N_TIMES(n, LEVEL, ...) \ + if (ELPP_DEBUG_LOG) \ + CLOG_N_TIMES(n, LEVEL, __VA_ARGS__) +#define DCVLOG_N_TIMES(n, vlevel, ...) \ + if (ELPP_DEBUG_LOG) \ + CVLOG_N_TIMES(n, vlevel, __VA_ARGS__) +// +// Default Debug Only Loggers macro using CLOG(), CLOG_VERBOSE() and CVLOG() macros +// +#if !defined(ELPP_NO_DEBUG_MACROS) +// undef existing +#undef DLOG +#undef DVLOG +#undef DLOG_IF +#undef DVLOG_IF +#undef DLOG_EVERY_N +#undef DVLOG_EVERY_N +#undef DLOG_AFTER_N +#undef DVLOG_AFTER_N +#undef DLOG_N_TIMES +#undef DVLOG_N_TIMES +// Normal logs +#define DLOG(LEVEL) DCLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG(vlevel) DCVLOG(vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Conditional logs +#define DLOG_IF(condition, LEVEL) DCLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_IF(condition, vlevel) DCVLOG_IF(condition, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Hit counts based logs +#define DLOG_EVERY_N(n, LEVEL) DCLOG_EVERY_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_EVERY_N(n, vlevel) DCVLOG_EVERY_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define DLOG_AFTER_N(n, LEVEL) DCLOG_AFTER_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_AFTER_N(n, vlevel) DCVLOG_AFTER_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define DLOG_N_TIMES(n, LEVEL) DCLOG_N_TIMES(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_N_TIMES(n, vlevel) DCVLOG_N_TIMES(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#endif // defined(ELPP_NO_DEBUG_MACROS) +#if !defined(ELPP_NO_CHECK_MACROS) +// Check macros +#undef CCHECK +#undef CPCHECK +#undef CCHECK_EQ +#undef CCHECK_NE +#undef CCHECK_LT +#undef CCHECK_GT +#undef CCHECK_LE +#undef CCHECK_GE +#undef CCHECK_BOUNDS +#undef CCHECK_NOTNULL +#undef CCHECK_STRCASEEQ +#undef CCHECK_STRCASENE +#undef CHECK +#undef PCHECK +#undef CHECK_EQ +#undef CHECK_NE +#undef CHECK_LT +#undef CHECK_GT +#undef CHECK_LE +#undef CHECK_GE +#undef CHECK_BOUNDS +#undef CHECK_NOTNULL +#undef CHECK_STRCASEEQ +#undef CHECK_STRCASENE +#define CCHECK(condition, ...) CLOG_IF(!(condition), FATAL, __VA_ARGS__) << "Check failed: [" << #condition << "] " +#define CPCHECK(condition, ...) CPLOG_IF(!(condition), FATAL, __VA_ARGS__) << "Check failed: [" << #condition << "] " +#define CHECK(condition) CCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define PCHECK(condition) CPCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define CCHECK_EQ(a, b, ...) CCHECK(a == b, __VA_ARGS__) +#define CCHECK_NE(a, b, ...) CCHECK(a != b, __VA_ARGS__) +#define CCHECK_LT(a, b, ...) CCHECK(a < b, __VA_ARGS__) +#define CCHECK_GT(a, b, ...) CCHECK(a > b, __VA_ARGS__) +#define CCHECK_LE(a, b, ...) CCHECK(a <= b, __VA_ARGS__) +#define CCHECK_GE(a, b, ...) CCHECK(a >= b, __VA_ARGS__) +#define CCHECK_BOUNDS(val, min, max, ...) CCHECK(val >= min && val <= max, __VA_ARGS__) +#define CHECK_EQ(a, b) CCHECK_EQ(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_NE(a, b) CCHECK_NE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_LT(a, b) CCHECK_LT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_GT(a, b) CCHECK_GT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_LE(a, b) CCHECK_LE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_GE(a, b) CCHECK_GE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_BOUNDS(val, min, max) CCHECK_BOUNDS(val, min, max, ELPP_CURR_FILE_LOGGER_ID) +#define CCHECK_NOTNULL(ptr, ...) CCHECK((ptr) != nullptr, __VA_ARGS__) +#define CCHECK_STREQ(str1, str2, ...) \ + CLOG_IF(!el::base::utils::Str::cStringEq(str1, str2), FATAL, __VA_ARGS__) \ + << "Check failed: [" << #str1 << " == " << #str2 << "] " +#define CCHECK_STRNE(str1, str2, ...) \ + CLOG_IF(el::base::utils::Str::cStringEq(str1, str2), FATAL, __VA_ARGS__) \ + << "Check failed: [" << #str1 << " != " << #str2 << "] " +#define CCHECK_STRCASEEQ(str1, str2, ...) \ + CLOG_IF(!el::base::utils::Str::cStringCaseEq(str1, str2), FATAL, __VA_ARGS__) \ + << "Check failed: [" << #str1 << " == " << #str2 << "] " +#define CCHECK_STRCASENE(str1, str2, ...) \ + CLOG_IF(el::base::utils::Str::cStringCaseEq(str1, str2), FATAL, __VA_ARGS__) \ + << "Check failed: [" << #str1 << " != " << #str2 << "] " +#define CHECK_NOTNULL(ptr) CCHECK_NOTNULL((ptr), ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STREQ(str1, str2) CCHECK_STREQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRNE(str1, str2) CCHECK_STRNE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRCASEEQ(str1, str2) CCHECK_STRCASEEQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRCASENE(str1, str2) CCHECK_STRCASENE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#undef DCCHECK +#undef DCCHECK_EQ +#undef DCCHECK_NE +#undef DCCHECK_LT +#undef DCCHECK_GT +#undef DCCHECK_LE +#undef DCCHECK_GE +#undef DCCHECK_BOUNDS +#undef DCCHECK_NOTNULL +#undef DCCHECK_STRCASEEQ +#undef DCCHECK_STRCASENE +#undef DCPCHECK +#undef DCHECK +#undef DCHECK_EQ +#undef DCHECK_NE +#undef DCHECK_LT +#undef DCHECK_GT +#undef DCHECK_LE +#undef DCHECK_GE +#undef DCHECK_BOUNDS_ +#undef DCHECK_NOTNULL +#undef DCHECK_STRCASEEQ +#undef DCHECK_STRCASENE +#undef DPCHECK +#define DCCHECK(condition, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK(condition, __VA_ARGS__) +#define DCCHECK_EQ(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_EQ(a, b, __VA_ARGS__) +#define DCCHECK_NE(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_NE(a, b, __VA_ARGS__) +#define DCCHECK_LT(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_LT(a, b, __VA_ARGS__) +#define DCCHECK_GT(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_GT(a, b, __VA_ARGS__) +#define DCCHECK_LE(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_LE(a, b, __VA_ARGS__) +#define DCCHECK_GE(a, b, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_GE(a, b, __VA_ARGS__) +#define DCCHECK_BOUNDS(val, min, max, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_BOUNDS(val, min, max, __VA_ARGS__) +#define DCCHECK_NOTNULL(ptr, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_NOTNULL((ptr), __VA_ARGS__) +#define DCCHECK_STREQ(str1, str2, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_STREQ(str1, str2, __VA_ARGS__) +#define DCCHECK_STRNE(str1, str2, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_STRNE(str1, str2, __VA_ARGS__) +#define DCCHECK_STRCASEEQ(str1, str2, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_STRCASEEQ(str1, str2, __VA_ARGS__) +#define DCCHECK_STRCASENE(str1, str2, ...) \ + if (ELPP_DEBUG_LOG) \ + CCHECK_STRCASENE(str1, str2, __VA_ARGS__) +#define DCPCHECK(condition, ...) \ + if (ELPP_DEBUG_LOG) \ + CPCHECK(condition, __VA_ARGS__) +#define DCHECK(condition) DCCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_EQ(a, b) DCCHECK_EQ(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_NE(a, b) DCCHECK_NE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_LT(a, b) DCCHECK_LT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_GT(a, b) DCCHECK_GT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_LE(a, b) DCCHECK_LE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_GE(a, b) DCCHECK_GE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_BOUNDS(val, min, max) DCCHECK_BOUNDS(val, min, max, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_NOTNULL(ptr) DCCHECK_NOTNULL((ptr), ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STREQ(str1, str2) DCCHECK_STREQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRNE(str1, str2) DCCHECK_STRNE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRCASEEQ(str1, str2) DCCHECK_STRCASEEQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRCASENE(str1, str2) DCCHECK_STRCASENE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DPCHECK(condition) DCPCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#endif // defined(ELPP_NO_CHECK_MACROS) +#if defined(ELPP_DISABLE_DEFAULT_CRASH_HANDLING) +#define ELPP_USE_DEF_CRASH_HANDLER false +#else +#define ELPP_USE_DEF_CRASH_HANDLER true +#endif // defined(ELPP_DISABLE_DEFAULT_CRASH_HANDLING) +#define ELPP_CRASH_HANDLER_INIT +#define ELPP_INIT_EASYLOGGINGPP(val) \ + namespace el { \ + namespace base { \ + el::base::type::StoragePointer elStorage(val); \ + } \ + el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER); \ + } + +#if ELPP_ASYNC_LOGGING +#define INITIALIZE_EASYLOGGINGPP \ + ELPP_INIT_EASYLOGGINGPP(new el::base::Storage(el::LogBuilderPtr(new el::base::DefaultLogBuilder()), \ + new el::base::AsyncDispatchWorker())) +#else +#define INITIALIZE_EASYLOGGINGPP \ + ELPP_INIT_EASYLOGGINGPP(new el::base::Storage(el::LogBuilderPtr(new el::base::DefaultLogBuilder()))) +#endif // ELPP_ASYNC_LOGGING +#define INITIALIZE_NULL_EASYLOGGINGPP \ + namespace el { \ + namespace base { \ + el::base::type::StoragePointer elStorage; \ + } \ + el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER); \ + } +#define SHARE_EASYLOGGINGPP(initializedStorage) \ + namespace el { \ + namespace base { \ + el::base::type::StoragePointer elStorage(initializedStorage); \ + } \ + el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER); \ + } + +#if defined(ELPP_UNICODE) +#define START_EASYLOGGINGPP(argc, argv) \ + el::Helpers::setArgs(argc, argv); \ + std::locale::global(std::locale("")) +#else +#define START_EASYLOGGINGPP(argc, argv) el::Helpers::setArgs(argc, argv) +#endif // defined(ELPP_UNICODE) +#endif // EASYLOGGINGPP_H diff --git a/core/thirdparty/gtest/CMakeLists.txt b/core/thirdparty/gtest/CMakeLists.txt new file mode 100644 index 0000000000..139fbfc5ac --- /dev/null +++ b/core/thirdparty/gtest/CMakeLists.txt @@ -0,0 +1,65 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +if ( DEFINED ENV{MILVUS_GTEST_URL} ) + set( GTEST_SOURCE_URL "$ENV{MILVUS_GTEST_URL}" ) +else() + set( GTEST_SOURCE_URL + "https://gitee.com/quicksilver/googletest/repository/archive/release-${GTEST_VERSION}.zip" ) +endif() + +message( STATUS "Building gtest-${GTEST_VERSION} from source" ) +include( FetchContent ) +set( CMAKE_POLICY_DEFAULT_CMP0022 NEW ) # for googletest only + +FetchContent_Declare( + googletest + URL ${GTEST_SOURCE_URL} + URL_MD5 "f9137c5bc18b7d74027936f0f1bfa5c8" + DOWNLOAD_DIR ${MILVUS_BINARY_DIR}/3rdparty_download/download + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-src + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + +) + +if ( NOT googletest_POPULATED ) + FetchContent_Populate( googletest ) + + # Adding the following targets: + # gtest, gtest_main, gmock, gmock_main + add_subdirectory( ${googletest_SOURCE_DIR} + ${googletest_BINARY_DIR} + EXCLUDE_FROM_ALL ) +endif() + +# **************************************************************** +# Create ALIAS Target +# **************************************************************** +# if (NOT TARGET GTest:gtest) +# add_library( GTest::gtest ALIAS gtest ) +# endif() +# if (NOT TARGET GTest:main) +# add_library( GTest::main ALIAS gtest_main ) +# endif() +# if (NOT TARGET GMock:gmock) +# target_link_libraries( gmock INTERFACE GTest::gtest ) +# add_library( GMock::gmock ALIAS gmock ) +# endif() +# if (NOT TARGET GMock:main) +# target_link_libraries( gmock_main INTERFACE GTest::gtest ) +# add_library( GMock::main ALIAS gmock_main ) +# endif() + + +get_property( var DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" PROPERTY COMPILE_OPTIONS ) +message( STATUS "gtest compile options: ${var}" ) diff --git a/core/thirdparty/nlohmann/json.hpp b/core/thirdparty/nlohmann/json.hpp new file mode 100644 index 0000000000..a70aaf8cbc --- /dev/null +++ b/core/thirdparty/nlohmann/json.hpp @@ -0,0 +1,25447 @@ +/* + __ _____ _____ _____ + __| | __| | | | JSON for Modern C++ +| | |__ | | | | | | version 3.9.1 +|_____|_____|_____|_|___| https://github.com/nlohmann/json + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2013-2019 Niels Lohmann . + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef INCLUDE_NLOHMANN_JSON_HPP_ +#define INCLUDE_NLOHMANN_JSON_HPP_ + +#define NLOHMANN_JSON_VERSION_MAJOR 3 +#define NLOHMANN_JSON_VERSION_MINOR 9 +#define NLOHMANN_JSON_VERSION_PATCH 1 + +#include // all_of, find, for_each +#include // nullptr_t, ptrdiff_t, size_t +#include // hash, less +#include // initializer_list +#include // istream, ostream +#include // random_access_iterator_tag +#include // unique_ptr +#include // accumulate +#include // string, stoi, to_string +#include // declval, forward, move, pair, swap +#include // vector + +// #include + + +#include + +// #include + + +#include // transform +#include // array +#include // forward_list +#include // inserter, front_inserter, end +#include // map +#include // string +#include // tuple, make_tuple +#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible +#include // unordered_map +#include // pair, declval +#include // valarray + +// #include + + +#include // exception +#include // runtime_error +#include // to_string + +// #include + + +#include // size_t + +namespace nlohmann +{ +namespace detail +{ +/// struct to capture the start position of the current token +struct position_t +{ + /// the total number of characters read + std::size_t chars_read_total = 0; + /// the number of characters read in the current line + std::size_t chars_read_current_line = 0; + /// the number of lines read + std::size_t lines_read = 0; + + /// conversion to size_t to preserve SAX interface + constexpr operator size_t() const + { + return chars_read_total; + } +}; + +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // pair +// #include +/* Hedley - https://nemequ.github.io/hedley + * Created by Evan Nemerson + * + * To the extent possible under law, the author(s) have dedicated all + * copyright and related and neighboring rights to this software to + * the public domain worldwide. This software is distributed without + * any warranty. + * + * For details, see . + * SPDX-License-Identifier: CC0-1.0 + */ + +#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 13) +#if defined(JSON_HEDLEY_VERSION) + #undef JSON_HEDLEY_VERSION +#endif +#define JSON_HEDLEY_VERSION 13 + +#if defined(JSON_HEDLEY_STRINGIFY_EX) + #undef JSON_HEDLEY_STRINGIFY_EX +#endif +#define JSON_HEDLEY_STRINGIFY_EX(x) #x + +#if defined(JSON_HEDLEY_STRINGIFY) + #undef JSON_HEDLEY_STRINGIFY +#endif +#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) + +#if defined(JSON_HEDLEY_CONCAT_EX) + #undef JSON_HEDLEY_CONCAT_EX +#endif +#define JSON_HEDLEY_CONCAT_EX(a,b) a##b + +#if defined(JSON_HEDLEY_CONCAT) + #undef JSON_HEDLEY_CONCAT +#endif +#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) + +#if defined(JSON_HEDLEY_CONCAT3_EX) + #undef JSON_HEDLEY_CONCAT3_EX +#endif +#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c + +#if defined(JSON_HEDLEY_CONCAT3) + #undef JSON_HEDLEY_CONCAT3 +#endif +#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) + +#if defined(JSON_HEDLEY_VERSION_ENCODE) + #undef JSON_HEDLEY_VERSION_ENCODE +#endif +#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) + #undef JSON_HEDLEY_VERSION_DECODE_MAJOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) + #undef JSON_HEDLEY_VERSION_DECODE_MINOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) + #undef JSON_HEDLEY_VERSION_DECODE_REVISION +#endif +#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) + +#if defined(JSON_HEDLEY_GNUC_VERSION) + #undef JSON_HEDLEY_GNUC_VERSION +#endif +#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#elif defined(__GNUC__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) +#endif + +#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) + #undef JSON_HEDLEY_GNUC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GNUC_VERSION) + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION) + #undef JSON_HEDLEY_MSVC_VERSION +#endif +#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) +#elif defined(_MSC_FULL_VER) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) +#elif defined(_MSC_VER) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) + #undef JSON_HEDLEY_MSVC_VERSION_CHECK +#endif +#if !defined(_MSC_VER) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) +#elif defined(_MSC_VER) && (_MSC_VER >= 1400) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) +#elif defined(_MSC_VER) && (_MSC_VER >= 1200) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) +#else + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION) + #undef JSON_HEDLEY_INTEL_VERSION +#endif +#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) +#elif defined(__INTEL_COMPILER) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) + #undef JSON_HEDLEY_INTEL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_INTEL_VERSION) + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION) + #undef JSON_HEDLEY_PGI_VERSION +#endif +#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) + #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION_CHECK) + #undef JSON_HEDLEY_PGI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PGI_VERSION) + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #undef JSON_HEDLEY_SUNPRO_VERSION +#endif +#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) +#elif defined(__SUNPRO_C) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) +#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) +#elif defined(__SUNPRO_CC) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) + #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION +#endif +#if defined(__EMSCRIPTEN__) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION) + #undef JSON_HEDLEY_ARM_VERSION +#endif +#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) +#elif defined(__CC_ARM) && defined(__ARMCC_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION_CHECK) + #undef JSON_HEDLEY_ARM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_ARM_VERSION) + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION) + #undef JSON_HEDLEY_IBM_VERSION +#endif +#if defined(__ibmxl__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) +#elif defined(__xlC__) && defined(__xlC_ver__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) +#elif defined(__xlC__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION_CHECK) + #undef JSON_HEDLEY_IBM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IBM_VERSION) + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_VERSION) + #undef JSON_HEDLEY_TI_VERSION +#endif +#if \ + defined(__TI_COMPILER_VERSION__) && \ + ( \ + defined(__TMS470__) || defined(__TI_ARM__) || \ + defined(__MSP430__) || \ + defined(__TMS320C2000__) \ + ) +#if (__TI_COMPILER_VERSION__ >= 16000000) + #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif +#endif + +#if defined(JSON_HEDLEY_TI_VERSION_CHECK) + #undef JSON_HEDLEY_TI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_VERSION) + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #undef JSON_HEDLEY_TI_CL2000_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) + #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #undef JSON_HEDLEY_TI_CL430_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) + #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #undef JSON_HEDLEY_TI_ARMCL_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) + #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) + #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #undef JSON_HEDLEY_TI_CL6X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) + #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #undef JSON_HEDLEY_TI_CL7X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) + #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #undef JSON_HEDLEY_TI_CLPRU_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) + #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION) + #undef JSON_HEDLEY_CRAY_VERSION +#endif +#if defined(_CRAYC) + #if defined(_RELEASE_PATCHLEVEL) + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) + #else + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) + #undef JSON_HEDLEY_CRAY_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_CRAY_VERSION) + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION) + #undef JSON_HEDLEY_IAR_VERSION +#endif +#if defined(__IAR_SYSTEMS_ICC__) + #if __VER__ > 1000 + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) + #else + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(VER / 100, __VER__ % 100, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION_CHECK) + #undef JSON_HEDLEY_IAR_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IAR_VERSION) + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION) + #undef JSON_HEDLEY_TINYC_VERSION +#endif +#if defined(__TINYC__) + #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) + #undef JSON_HEDLEY_TINYC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION) + #undef JSON_HEDLEY_DMC_VERSION +#endif +#if defined(__DMC__) + #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION_CHECK) + #undef JSON_HEDLEY_DMC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_DMC_VERSION) + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #undef JSON_HEDLEY_COMPCERT_VERSION +#endif +#if defined(__COMPCERT_VERSION__) + #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) + #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION) + #undef JSON_HEDLEY_PELLES_VERSION +#endif +#if defined(__POCC__) + #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) + #undef JSON_HEDLEY_PELLES_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PELLES_VERSION) + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION) + #undef JSON_HEDLEY_GCC_VERSION +#endif +#if \ + defined(JSON_HEDLEY_GNUC_VERSION) && \ + !defined(__clang__) && \ + !defined(JSON_HEDLEY_INTEL_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_ARM_VERSION) && \ + !defined(JSON_HEDLEY_TI_VERSION) && \ + !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ + !defined(__COMPCERT__) + #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GCC_VERSION) + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) __has_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) __has_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE +#endif +#if \ + defined(__has_cpp_attribute) && \ + defined(__cplusplus) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS +#endif +#if !defined(__cplusplus) || !defined(__has_cpp_attribute) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#elif \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ + (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_BUILTIN) + #undef JSON_HEDLEY_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else + #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) + #undef JSON_HEDLEY_GNUC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) + #undef JSON_HEDLEY_GCC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_FEATURE) + #undef JSON_HEDLEY_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) +#else + #define JSON_HEDLEY_HAS_FEATURE(feature) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) + #undef JSON_HEDLEY_GNUC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_FEATURE) + #undef JSON_HEDLEY_GCC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_EXTENSION) + #undef JSON_HEDLEY_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) +#else + #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) + #undef JSON_HEDLEY_GNUC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) + #undef JSON_HEDLEY_GCC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_WARNING) + #undef JSON_HEDLEY_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) +#else + #define JSON_HEDLEY_HAS_WARNING(warning) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_WARNING) + #undef JSON_HEDLEY_GNUC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_WARNING) + #undef JSON_HEDLEY_GCC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") +# if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# endif +# endif +#endif +#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x +#endif + +#if defined(JSON_HEDLEY_CONST_CAST) + #undef JSON_HEDLEY_CONST_CAST +#endif +#if defined(__cplusplus) +# define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) +#elif \ + JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_REINTERPRET_CAST) + #undef JSON_HEDLEY_REINTERPRET_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) +#else + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_STATIC_CAST) + #undef JSON_HEDLEY_STATIC_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) +#else + #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_CPP_CAST) + #undef JSON_HEDLEY_CPP_CAST +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ + ((T) (expr)) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("diag_suppress=Pe137") \ + JSON_HEDLEY_DIAGNOSTIC_POP \ +# else +# define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) +# endif +#else +# define JSON_HEDLEY_CPP_CAST(T, expr) (expr) +#endif + +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + defined(__clang__) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ + (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) + #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_PRAGMA(value) __pragma(value) +#else + #define JSON_HEDLEY_PRAGMA(value) +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) + #undef JSON_HEDLEY_DIAGNOSTIC_PUSH +#endif +#if defined(JSON_HEDLEY_DIAGNOSTIC_POP) + #undef JSON_HEDLEY_DIAGNOSTIC_POP +#endif +#if defined(__clang__) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) + #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) +#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif + +#if defined(JSON_HEDLEY_DEPRECATED) + #undef JSON_HEDLEY_DEPRECATED +#endif +#if defined(JSON_HEDLEY_DEPRECATED_FOR) + #undef JSON_HEDLEY_DEPRECATED_FOR +#endif +#if JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) +#elif defined(__cplusplus) && (__cplusplus >= 201402L) + #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) +#elif \ + JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") +#else + #define JSON_HEDLEY_DEPRECATED(since) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) +#endif + +#if defined(JSON_HEDLEY_UNAVAILABLE) + #undef JSON_HEDLEY_UNAVAILABLE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) +#else + #define JSON_HEDLEY_UNAVAILABLE(available_since) +#endif + +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT +#endif +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG +#endif +#if (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) +#elif defined(_Check_return_) /* SAL */ + #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ +#else + #define JSON_HEDLEY_WARN_UNUSED_RESULT + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) +#endif + +#if defined(JSON_HEDLEY_SENTINEL) + #undef JSON_HEDLEY_SENTINEL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) + #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) +#else + #define JSON_HEDLEY_SENTINEL(position) +#endif + +#if defined(JSON_HEDLEY_NO_RETURN) + #undef JSON_HEDLEY_NO_RETURN +#endif +#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NO_RETURN __noreturn +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L + #define JSON_HEDLEY_NO_RETURN _Noreturn +#elif defined(__cplusplus) && (__cplusplus >= 201103L) + #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#else + #define JSON_HEDLEY_NO_RETURN +#endif + +#if defined(JSON_HEDLEY_NO_ESCAPE) + #undef JSON_HEDLEY_NO_ESCAPE +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) + #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) +#else + #define JSON_HEDLEY_NO_ESCAPE +#endif + +#if defined(JSON_HEDLEY_UNREACHABLE) + #undef JSON_HEDLEY_UNREACHABLE +#endif +#if defined(JSON_HEDLEY_UNREACHABLE_RETURN) + #undef JSON_HEDLEY_UNREACHABLE_RETURN +#endif +#if defined(JSON_HEDLEY_ASSUME) + #undef JSON_HEDLEY_ASSUME +#endif +#if \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_ASSUME(expr) __assume(expr) +#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) + #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) +#elif \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #if defined(__cplusplus) + #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) + #else + #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) + #endif +#endif +#if \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) + #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() +#elif defined(JSON_HEDLEY_ASSUME) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif +#if !defined(JSON_HEDLEY_ASSUME) + #if defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) + #else + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) + #endif +#endif +#if defined(JSON_HEDLEY_UNREACHABLE) + #if \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) + #else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() + #endif +#else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) +#endif +#if !defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif + +JSON_HEDLEY_DIAGNOSTIC_PUSH +#if JSON_HEDLEY_HAS_WARNING("-Wpedantic") + #pragma clang diagnostic ignored "-Wpedantic" +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) + #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" +#endif +#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) + #if defined(__clang__) + #pragma clang diagnostic ignored "-Wvariadic-macros" + #elif defined(JSON_HEDLEY_GCC_VERSION) + #pragma GCC diagnostic ignored "-Wvariadic-macros" + #endif +#endif +#if defined(JSON_HEDLEY_NON_NULL) + #undef JSON_HEDLEY_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) +#else + #define JSON_HEDLEY_NON_NULL(...) +#endif +JSON_HEDLEY_DIAGNOSTIC_POP + +#if defined(JSON_HEDLEY_PRINTF_FORMAT) + #undef JSON_HEDLEY_PRINTF_FORMAT +#endif +#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) +#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) +#else + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) +#endif + +#if defined(JSON_HEDLEY_CONSTEXPR) + #undef JSON_HEDLEY_CONSTEXPR +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) + #endif +#endif +#if !defined(JSON_HEDLEY_CONSTEXPR) + #define JSON_HEDLEY_CONSTEXPR +#endif + +#if defined(JSON_HEDLEY_PREDICT) + #undef JSON_HEDLEY_PREDICT +#endif +#if defined(JSON_HEDLEY_LIKELY) + #undef JSON_HEDLEY_LIKELY +#endif +#if defined(JSON_HEDLEY_UNLIKELY) + #undef JSON_HEDLEY_UNLIKELY +#endif +#if defined(JSON_HEDLEY_UNPREDICTABLE) + #undef JSON_HEDLEY_UNPREDICTABLE +#endif +#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) + #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) +#endif +#if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) +# define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) +#elif \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) +# define JSON_HEDLEY_PREDICT(expr, expected, probability) \ + (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ + })) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ + })) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#else +# define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_LIKELY(expr) (!!(expr)) +# define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) +#endif +#if !defined(JSON_HEDLEY_UNPREDICTABLE) + #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) +#endif + +#if defined(JSON_HEDLEY_MALLOC) + #undef JSON_HEDLEY_MALLOC +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(14, 0, 0) + #define JSON_HEDLEY_MALLOC __declspec(restrict) +#else + #define JSON_HEDLEY_MALLOC +#endif + +#if defined(JSON_HEDLEY_PURE) + #undef JSON_HEDLEY_PURE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) +# define JSON_HEDLEY_PURE __attribute__((__pure__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) +# define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ + ) +# define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") +#else +# define JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_CONST) + #undef JSON_HEDLEY_CONST +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_CONST __attribute__((__const__)) +#elif \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_CONST _Pragma("no_side_effect") +#else + #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_RESTRICT) + #undef JSON_HEDLEY_RESTRICT +#endif +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT restrict +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + defined(__clang__) + #define JSON_HEDLEY_RESTRICT __restrict +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT _Restrict +#else + #define JSON_HEDLEY_RESTRICT +#endif + +#if defined(JSON_HEDLEY_INLINE) + #undef JSON_HEDLEY_INLINE +#endif +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + (defined(__cplusplus) && (__cplusplus >= 199711L)) + #define JSON_HEDLEY_INLINE inline +#elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) + #define JSON_HEDLEY_INLINE __inline__ +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_INLINE __inline +#else + #define JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_ALWAYS_INLINE) + #undef JSON_HEDLEY_ALWAYS_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) +# define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) +# define JSON_HEDLEY_ALWAYS_INLINE __forceinline +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ + ) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") +#else +# define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_NEVER_INLINE) + #undef JSON_HEDLEY_NEVER_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#else + #define JSON_HEDLEY_NEVER_INLINE +#endif + +#if defined(JSON_HEDLEY_PRIVATE) + #undef JSON_HEDLEY_PRIVATE +#endif +#if defined(JSON_HEDLEY_PUBLIC) + #undef JSON_HEDLEY_PUBLIC +#endif +#if defined(JSON_HEDLEY_IMPORT) + #undef JSON_HEDLEY_IMPORT +#endif +#if defined(_WIN32) || defined(__CYGWIN__) +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC __declspec(dllexport) +# define JSON_HEDLEY_IMPORT __declspec(dllimport) +#else +# if \ + JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + ( \ + defined(__TI_EABI__) && \ + ( \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ + ) \ + ) +# define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) +# define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) +# else +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC +# endif +# define JSON_HEDLEY_IMPORT extern +#endif + +#if defined(JSON_HEDLEY_NO_THROW) + #undef JSON_HEDLEY_NO_THROW +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NO_THROW __declspec(nothrow) +#else + #define JSON_HEDLEY_NO_THROW +#endif + +#if defined(JSON_HEDLEY_FALL_THROUGH) + #undef JSON_HEDLEY_FALL_THROUGH +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) + #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) +#elif defined(__fallthrough) /* SAL */ + #define JSON_HEDLEY_FALL_THROUGH __fallthrough +#else + #define JSON_HEDLEY_FALL_THROUGH +#endif + +#if defined(JSON_HEDLEY_RETURNS_NON_NULL) + #undef JSON_HEDLEY_RETURNS_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) + #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) +#elif defined(_Ret_notnull_) /* SAL */ + #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ +#else + #define JSON_HEDLEY_RETURNS_NON_NULL +#endif + +#if defined(JSON_HEDLEY_ARRAY_PARAM) + #undef JSON_HEDLEY_ARRAY_PARAM +#endif +#if \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ + !defined(__STDC_NO_VLA__) && \ + !defined(__cplusplus) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_ARRAY_PARAM(name) (name) +#else + #define JSON_HEDLEY_ARRAY_PARAM(name) +#endif + +#if defined(JSON_HEDLEY_IS_CONSTANT) + #undef JSON_HEDLEY_IS_CONSTANT +#endif +#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) + #undef JSON_HEDLEY_REQUIRE_CONSTEXPR +#endif +/* JSON_HEDLEY_IS_CONSTEXPR_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #undef JSON_HEDLEY_IS_CONSTEXPR_ +#endif +#if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) + #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) +#endif +#if !defined(__cplusplus) +# if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) +#endif +# elif \ + ( \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ + !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION)) || \ + JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) +#endif +# elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + defined(JSON_HEDLEY_INTEL_VERSION) || \ + defined(JSON_HEDLEY_TINYC_VERSION) || \ + defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ + defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ + defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ + defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ + defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ + defined(__clang__) +# define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ + sizeof(void) != \ + sizeof(*( \ + 1 ? \ + ((void*) ((expr) * 0L) ) : \ +((struct { char v[sizeof(void) * 2]; } *) 1) \ + ) \ + ) \ + ) +# endif +#endif +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) +#else + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) (0) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) +#endif + +#if defined(JSON_HEDLEY_BEGIN_C_DECLS) + #undef JSON_HEDLEY_BEGIN_C_DECLS +#endif +#if defined(JSON_HEDLEY_END_C_DECLS) + #undef JSON_HEDLEY_END_C_DECLS +#endif +#if defined(JSON_HEDLEY_C_DECL) + #undef JSON_HEDLEY_C_DECL +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { + #define JSON_HEDLEY_END_C_DECLS } + #define JSON_HEDLEY_C_DECL extern "C" +#else + #define JSON_HEDLEY_BEGIN_C_DECLS + #define JSON_HEDLEY_END_C_DECLS + #define JSON_HEDLEY_C_DECL +#endif + +#if defined(JSON_HEDLEY_STATIC_ASSERT) + #undef JSON_HEDLEY_STATIC_ASSERT +#endif +#if \ + !defined(__cplusplus) && ( \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ + JSON_HEDLEY_HAS_FEATURE(c_static_assert) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + defined(_Static_assert) \ + ) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) +#elif \ + (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) +#else +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) +#endif + +#if defined(JSON_HEDLEY_NULL) + #undef JSON_HEDLEY_NULL +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) + #elif defined(NULL) + #define JSON_HEDLEY_NULL NULL + #else + #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) + #endif +#elif defined(NULL) + #define JSON_HEDLEY_NULL NULL +#else + #define JSON_HEDLEY_NULL ((void*) 0) +#endif + +#if defined(JSON_HEDLEY_MESSAGE) + #undef JSON_HEDLEY_MESSAGE +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_MESSAGE(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(message msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) +#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_WARNING) + #undef JSON_HEDLEY_WARNING +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_WARNING(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(clang warning msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_REQUIRE) + #undef JSON_HEDLEY_REQUIRE +#endif +#if defined(JSON_HEDLEY_REQUIRE_MSG) + #undef JSON_HEDLEY_REQUIRE_MSG +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) +# if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") +# define JSON_HEDLEY_REQUIRE(expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), #expr, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), msg, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) +# endif +#else +# define JSON_HEDLEY_REQUIRE(expr) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) +#endif + +#if defined(JSON_HEDLEY_FLAGS) + #undef JSON_HEDLEY_FLAGS +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) + #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) +#endif + +#if defined(JSON_HEDLEY_FLAGS_CAST) + #undef JSON_HEDLEY_FLAGS_CAST +#endif +#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) +# define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("warning(disable:188)") \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) +#endif + +#if defined(JSON_HEDLEY_EMPTY_BASES) + #undef JSON_HEDLEY_EMPTY_BASES +#endif +#if JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0) + #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) +#else + #define JSON_HEDLEY_EMPTY_BASES +#endif + +/* Remaining macros are deprecated. */ + +#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK +#endif +#if defined(__clang__) + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) +#else + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) + #undef JSON_HEDLEY_CLANG_HAS_BUILTIN +#endif +#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) + +#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) + #undef JSON_HEDLEY_CLANG_HAS_FEATURE +#endif +#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) + +#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) + #undef JSON_HEDLEY_CLANG_HAS_EXTENSION +#endif +#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) + +#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_WARNING) + #undef JSON_HEDLEY_CLANG_HAS_WARNING +#endif +#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) + +#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ + + +// This file contains all internal macro definitions +// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them + +// exclude unsupported compilers +#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) + #if defined(__clang__) + #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 + #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) + #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 + #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #endif +#endif + +// C++ language standard detection +#if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) + #define JSON_HAS_CPP_20 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 +#elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 +#elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) + #define JSON_HAS_CPP_14 +#endif + +// disable float-equal warnings on GCC/clang +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + +// disable documentation warnings on clang +#if defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wdocumentation" +#endif + +// allow to disable exceptions +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) + #define JSON_THROW(exception) throw exception + #define JSON_TRY try + #define JSON_CATCH(exception) catch(exception) + #define JSON_INTERNAL_CATCH(exception) catch(exception) +#else + #include + #define JSON_THROW(exception) std::abort() + #define JSON_TRY if(true) + #define JSON_CATCH(exception) if(false) + #define JSON_INTERNAL_CATCH(exception) if(false) +#endif + +// override exception macros +#if defined(JSON_THROW_USER) + #undef JSON_THROW + #define JSON_THROW JSON_THROW_USER +#endif +#if defined(JSON_TRY_USER) + #undef JSON_TRY + #define JSON_TRY JSON_TRY_USER +#endif +#if defined(JSON_CATCH_USER) + #undef JSON_CATCH + #define JSON_CATCH JSON_CATCH_USER + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_CATCH_USER +#endif +#if defined(JSON_INTERNAL_CATCH_USER) + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER +#endif + +// allow to override assert +#if !defined(JSON_ASSERT) + #include // assert + #define JSON_ASSERT(x) assert(x) +#endif + +/*! +@brief macro to briefly define a mapping between an enum and JSON +@def NLOHMANN_JSON_SERIALIZE_ENUM +@since version 3.4.0 +*/ +#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ + template \ + inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [e](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.first == e; \ + }); \ + j = ((it != std::end(m)) ? it : std::begin(m))->second; \ + } \ + template \ + inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ + { \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [&j](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.second == j; \ + }); \ + e = ((it != std::end(m)) ? it : std::begin(m))->first; \ + } + +// Ugly macros to avoid uglier copy-paste when specializing basic_json. They +// may be removed in the future once the class is split. + +#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ + template class ObjectType, \ + template class ArrayType, \ + class StringType, class BooleanType, class NumberIntegerType, \ + class NumberUnsignedType, class NumberFloatType, \ + template class AllocatorType, \ + template class JSONSerializer, \ + class BinaryType> + +#define NLOHMANN_BASIC_JSON_TPL \ + basic_json + +// Macros to simplify conversion from/to types + +#define NLOHMANN_JSON_EXPAND( x ) x +#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME +#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ + NLOHMANN_JSON_PASTE64, \ + NLOHMANN_JSON_PASTE63, \ + NLOHMANN_JSON_PASTE62, \ + NLOHMANN_JSON_PASTE61, \ + NLOHMANN_JSON_PASTE60, \ + NLOHMANN_JSON_PASTE59, \ + NLOHMANN_JSON_PASTE58, \ + NLOHMANN_JSON_PASTE57, \ + NLOHMANN_JSON_PASTE56, \ + NLOHMANN_JSON_PASTE55, \ + NLOHMANN_JSON_PASTE54, \ + NLOHMANN_JSON_PASTE53, \ + NLOHMANN_JSON_PASTE52, \ + NLOHMANN_JSON_PASTE51, \ + NLOHMANN_JSON_PASTE50, \ + NLOHMANN_JSON_PASTE49, \ + NLOHMANN_JSON_PASTE48, \ + NLOHMANN_JSON_PASTE47, \ + NLOHMANN_JSON_PASTE46, \ + NLOHMANN_JSON_PASTE45, \ + NLOHMANN_JSON_PASTE44, \ + NLOHMANN_JSON_PASTE43, \ + NLOHMANN_JSON_PASTE42, \ + NLOHMANN_JSON_PASTE41, \ + NLOHMANN_JSON_PASTE40, \ + NLOHMANN_JSON_PASTE39, \ + NLOHMANN_JSON_PASTE38, \ + NLOHMANN_JSON_PASTE37, \ + NLOHMANN_JSON_PASTE36, \ + NLOHMANN_JSON_PASTE35, \ + NLOHMANN_JSON_PASTE34, \ + NLOHMANN_JSON_PASTE33, \ + NLOHMANN_JSON_PASTE32, \ + NLOHMANN_JSON_PASTE31, \ + NLOHMANN_JSON_PASTE30, \ + NLOHMANN_JSON_PASTE29, \ + NLOHMANN_JSON_PASTE28, \ + NLOHMANN_JSON_PASTE27, \ + NLOHMANN_JSON_PASTE26, \ + NLOHMANN_JSON_PASTE25, \ + NLOHMANN_JSON_PASTE24, \ + NLOHMANN_JSON_PASTE23, \ + NLOHMANN_JSON_PASTE22, \ + NLOHMANN_JSON_PASTE21, \ + NLOHMANN_JSON_PASTE20, \ + NLOHMANN_JSON_PASTE19, \ + NLOHMANN_JSON_PASTE18, \ + NLOHMANN_JSON_PASTE17, \ + NLOHMANN_JSON_PASTE16, \ + NLOHMANN_JSON_PASTE15, \ + NLOHMANN_JSON_PASTE14, \ + NLOHMANN_JSON_PASTE13, \ + NLOHMANN_JSON_PASTE12, \ + NLOHMANN_JSON_PASTE11, \ + NLOHMANN_JSON_PASTE10, \ + NLOHMANN_JSON_PASTE9, \ + NLOHMANN_JSON_PASTE8, \ + NLOHMANN_JSON_PASTE7, \ + NLOHMANN_JSON_PASTE6, \ + NLOHMANN_JSON_PASTE5, \ + NLOHMANN_JSON_PASTE4, \ + NLOHMANN_JSON_PASTE3, \ + NLOHMANN_JSON_PASTE2, \ + NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) +#define NLOHMANN_JSON_PASTE2(func, v1) func(v1) +#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) +#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) +#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) +#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) +#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) +#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) +#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) +#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) +#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) +#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) +#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) +#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) +#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) +#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) +#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) +#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) +#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) +#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) +#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) +#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) +#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) +#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) +#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) +#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) +#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) +#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) +#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) +#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) +#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) +#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) +#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) +#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) +#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) +#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) +#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) +#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) +#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) +#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) +#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) +#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) +#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) +#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) +#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) +#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) +#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) +#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) +#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) +#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) +#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) +#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) +#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) +#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) +#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) +#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) +#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) +#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) +#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) +#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) +#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) +#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) +#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) + +#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; +#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +#ifndef JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_USE_IMPLICIT_CONVERSIONS 1 +#endif + +#if JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_EXPLICIT +#else + #define JSON_EXPLICIT explicit +#endif + + +namespace nlohmann +{ +namespace detail +{ +//////////////// +// exceptions // +//////////////// + +/*! +@brief general exception of the @ref basic_json class + +This class is an extension of `std::exception` objects with a member @a id for +exception ids. It is used as the base class for all exceptions thrown by the +@ref basic_json class. This class can hence be used as "wildcard" to catch +exceptions. + +Subclasses: +- @ref parse_error for exceptions indicating a parse error +- @ref invalid_iterator for exceptions indicating errors with iterators +- @ref type_error for exceptions indicating executing a member function with + a wrong type +- @ref out_of_range for exceptions indicating access out of the defined range +- @ref other_error for exceptions indicating other library errors + +@internal +@note To have nothrow-copy-constructible exceptions, we internally use + `std::runtime_error` which can cope with arbitrary-length error messages. + Intermediate strings are built with static functions and then passed to + the actual constructor. +@endinternal + +@liveexample{The following code shows how arbitrary library exceptions can be +caught.,exception} + +@since version 3.0.0 +*/ +class exception : public std::exception +{ + public: + /// returns the explanatory string + JSON_HEDLEY_RETURNS_NON_NULL + const char* what() const noexcept override + { + return m.what(); + } + + /// the id of the exception + const int id; + + protected: + JSON_HEDLEY_NON_NULL(3) + exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} + + static std::string name(const std::string& ename, int id_) + { + return "[json.exception." + ename + "." + std::to_string(id_) + "] "; + } + + private: + /// an exception object as storage for error messages + std::runtime_error m; +}; + +/*! +@brief exception indicating a parse error + +This exception is thrown by the library when a parse error occurs. Parse errors +can occur during the deserialization of JSON text, CBOR, MessagePack, as well +as when using JSON Patch. + +Member @a byte holds the byte index of the last read character in the input +file. + +Exceptions have ids 1xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.parse_error.101 | parse error at 2: unexpected end of input; expected string literal | This error indicates a syntax error while deserializing a JSON text. The error message describes that an unexpected token (character) was encountered, and the member @a byte indicates the error position. +json.exception.parse_error.102 | parse error at 14: missing or wrong low surrogate | JSON uses the `\uxxxx` format to describe Unicode characters. Code points above above 0xFFFF are split into two `\uxxxx` entries ("surrogate pairs"). This error indicates that the surrogate pair is incomplete or contains an invalid code point. +json.exception.parse_error.103 | parse error: code points above 0x10FFFF are invalid | Unicode supports code points up to 0x10FFFF. Code points above 0x10FFFF are invalid. +json.exception.parse_error.104 | parse error: JSON patch must be an array of objects | [RFC 6902](https://tools.ietf.org/html/rfc6902) requires a JSON Patch document to be a JSON document that represents an array of objects. +json.exception.parse_error.105 | parse error: operation must have string member 'op' | An operation of a JSON Patch document must contain exactly one "op" member, whose value indicates the operation to perform. Its value must be one of "add", "remove", "replace", "move", "copy", or "test"; other values are errors. +json.exception.parse_error.106 | parse error: array index '01' must not begin with '0' | An array index in a JSON Pointer ([RFC 6901](https://tools.ietf.org/html/rfc6901)) may be `0` or any number without a leading `0`. +json.exception.parse_error.107 | parse error: JSON pointer must be empty or begin with '/' - was: 'foo' | A JSON Pointer must be a Unicode string containing a sequence of zero or more reference tokens, each prefixed by a `/` character. +json.exception.parse_error.108 | parse error: escape character '~' must be followed with '0' or '1' | In a JSON Pointer, only `~0` and `~1` are valid escape sequences. +json.exception.parse_error.109 | parse error: array index 'one' is not a number | A JSON Pointer array index must be a number. +json.exception.parse_error.110 | parse error at 1: cannot read 2 bytes from vector | When parsing CBOR or MessagePack, the byte vector ends before the complete value has been read. +json.exception.parse_error.112 | parse error at 1: error reading CBOR; last byte: 0xF8 | Not all types of CBOR or MessagePack are supported. This exception occurs if an unsupported byte was read. +json.exception.parse_error.113 | parse error at 2: expected a CBOR string; last byte: 0x98 | While parsing a map key, a value that is not a string has been read. +json.exception.parse_error.114 | parse error: Unsupported BSON record type 0x0F | The parsing of the corresponding BSON record type is not implemented (yet). +json.exception.parse_error.115 | parse error at byte 5: syntax error while parsing UBJSON high-precision number: invalid number text: 1A | A UBJSON high-precision number could not be parsed. + +@note For an input with n bytes, 1 is the index of the first character and n+1 + is the index of the terminating null byte or the end of file. This also + holds true when reading a byte vector (CBOR or MessagePack). + +@liveexample{The following code shows how a `parse_error` exception can be +caught.,parse_error} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class parse_error : public exception +{ + public: + /*! + @brief create a parse error exception + @param[in] id_ the id of the exception + @param[in] pos the position where the error occurred (or with + chars_read_total=0 if the position cannot be + determined) + @param[in] what_arg the explanatory string + @return parse_error object + */ + static parse_error create(int id_, const position_t& pos, const std::string& what_arg) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + position_string(pos) + ": " + what_arg; + return parse_error(id_, pos.chars_read_total, w.c_str()); + } + + static parse_error create(int id_, std::size_t byte_, const std::string& what_arg) + { + std::string w = exception::name("parse_error", id_) + "parse error" + + (byte_ != 0 ? (" at byte " + std::to_string(byte_)) : "") + + ": " + what_arg; + return parse_error(id_, byte_, w.c_str()); + } + + /*! + @brief byte index of the parse error + + The byte index of the last read character in the input file. + + @note For an input with n bytes, 1 is the index of the first character and + n+1 is the index of the terminating null byte or the end of file. + This also holds true when reading a byte vector (CBOR or MessagePack). + */ + const std::size_t byte; + + private: + parse_error(int id_, std::size_t byte_, const char* what_arg) + : exception(id_, what_arg), byte(byte_) {} + + static std::string position_string(const position_t& pos) + { + return " at line " + std::to_string(pos.lines_read + 1) + + ", column " + std::to_string(pos.chars_read_current_line); + } +}; + +/*! +@brief exception indicating errors with iterators + +This exception is thrown if iterators passed to a library function do not match +the expected semantics. + +Exceptions have ids 2xx. + +name / id | example message | description +----------------------------------- | --------------- | ------------------------- +json.exception.invalid_iterator.201 | iterators are not compatible | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.202 | iterator does not fit current value | In an erase or insert function, the passed iterator @a pos does not belong to the JSON value for which the function was called. It hence does not define a valid position for the deletion/insertion. +json.exception.invalid_iterator.203 | iterators do not fit current value | Either iterator passed to function @ref erase(IteratorType first, IteratorType last) does not belong to the JSON value from which values shall be erased. It hence does not define a valid range to delete values from. +json.exception.invalid_iterator.204 | iterators out of range | When an iterator range for a primitive type (number, boolean, or string) is passed to a constructor or an erase function, this range has to be exactly (@ref begin(), @ref end()), because this is the only way the single stored value is expressed. All other ranges are invalid. +json.exception.invalid_iterator.205 | iterator out of range | When an iterator for a primitive type (number, boolean, or string) is passed to an erase function, the iterator has to be the @ref begin() iterator, because it is the only way to address the stored value. All other iterators are invalid. +json.exception.invalid_iterator.206 | cannot construct with iterators from null | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) belong to a JSON null value and hence to not define a valid range. +json.exception.invalid_iterator.207 | cannot use key() for non-object iterators | The key() member function can only be used on iterators belonging to a JSON object, because other types do not have a concept of a key. +json.exception.invalid_iterator.208 | cannot use operator[] for object iterators | The operator[] to specify a concrete offset cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.209 | cannot use offsets with object iterators | The offset operators (+, -, +=, -=) cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. +json.exception.invalid_iterator.210 | iterators do not fit | The iterator range passed to the insert function are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. +json.exception.invalid_iterator.211 | passed iterators may not belong to container | The iterator range passed to the insert function must not be a subrange of the container to insert to. +json.exception.invalid_iterator.212 | cannot compare iterators of different containers | When two iterators are compared, they must belong to the same container. +json.exception.invalid_iterator.213 | cannot compare order of object iterators | The order of object iterators cannot be compared, because JSON objects are unordered. +json.exception.invalid_iterator.214 | cannot get value | Cannot get value for iterator: Either the iterator belongs to a null value or it is an iterator to a primitive type (number, boolean, or string), but the iterator is different to @ref begin(). + +@liveexample{The following code shows how an `invalid_iterator` exception can be +caught.,invalid_iterator} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class invalid_iterator : public exception +{ + public: + static invalid_iterator create(int id_, const std::string& what_arg) + { + std::string w = exception::name("invalid_iterator", id_) + what_arg; + return invalid_iterator(id_, w.c_str()); + } + + private: + JSON_HEDLEY_NON_NULL(3) + invalid_iterator(int id_, const char* what_arg) + : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating executing a member function with a wrong type + +This exception is thrown in case of a type error; that is, a library function is +executed on a JSON value whose type does not match the expected semantics. + +Exceptions have ids 3xx. + +name / id | example message | description +----------------------------- | --------------- | ------------------------- +json.exception.type_error.301 | cannot create object from initializer list | To create an object from an initializer list, the initializer list must consist only of a list of pairs whose first element is a string. When this constraint is violated, an array is created instead. +json.exception.type_error.302 | type must be object, but is array | During implicit or explicit value conversion, the JSON type must be compatible to the target type. For instance, a JSON string can only be converted into string types, but not into numbers or boolean types. +json.exception.type_error.303 | incompatible ReferenceType for get_ref, actual type is object | To retrieve a reference to a value stored in a @ref basic_json object with @ref get_ref, the type of the reference must match the value type. For instance, for a JSON array, the @a ReferenceType must be @ref array_t &. +json.exception.type_error.304 | cannot use at() with string | The @ref at() member functions can only be executed for certain JSON types. +json.exception.type_error.305 | cannot use operator[] with string | The @ref operator[] member functions can only be executed for certain JSON types. +json.exception.type_error.306 | cannot use value() with string | The @ref value() member functions can only be executed for certain JSON types. +json.exception.type_error.307 | cannot use erase() with string | The @ref erase() member functions can only be executed for certain JSON types. +json.exception.type_error.308 | cannot use push_back() with string | The @ref push_back() and @ref operator+= member functions can only be executed for certain JSON types. +json.exception.type_error.309 | cannot use insert() with | The @ref insert() member functions can only be executed for certain JSON types. +json.exception.type_error.310 | cannot use swap() with number | The @ref swap() member functions can only be executed for certain JSON types. +json.exception.type_error.311 | cannot use emplace_back() with string | The @ref emplace_back() member function can only be executed for certain JSON types. +json.exception.type_error.312 | cannot use update() with string | The @ref update() member functions can only be executed for certain JSON types. +json.exception.type_error.313 | invalid value to unflatten | The @ref unflatten function converts an object whose keys are JSON Pointers back into an arbitrary nested JSON value. The JSON Pointers must not overlap, because then the resulting value would not be well defined. +json.exception.type_error.314 | only objects can be unflattened | The @ref unflatten function only works for an object whose keys are JSON Pointers. +json.exception.type_error.315 | values in object must be primitive | The @ref unflatten function only works for an object whose keys are JSON Pointers and whose values are primitive. +json.exception.type_error.316 | invalid UTF-8 byte at index 10: 0x7E | The @ref dump function only works with UTF-8 encoded strings; that is, if you assign a `std::string` to a JSON value, make sure it is UTF-8 encoded. | +json.exception.type_error.317 | JSON value cannot be serialized to requested format | The dynamic type of the object cannot be represented in the requested serialization format (e.g. a raw `true` or `null` JSON object cannot be serialized to BSON) | + +@liveexample{The following code shows how a `type_error` exception can be +caught.,type_error} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref out_of_range for exceptions indicating access out of the defined range +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class type_error : public exception +{ + public: + static type_error create(int id_, const std::string& what_arg) + { + std::string w = exception::name("type_error", id_) + what_arg; + return type_error(id_, w.c_str()); + } + + private: + JSON_HEDLEY_NON_NULL(3) + type_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating access out of the defined range + +This exception is thrown in case a library function is called on an input +parameter that exceeds the expected range, for instance in case of array +indices or nonexisting object keys. + +Exceptions have ids 4xx. + +name / id | example message | description +------------------------------- | --------------- | ------------------------- +json.exception.out_of_range.401 | array index 3 is out of range | The provided array index @a i is larger than @a size-1. +json.exception.out_of_range.402 | array index '-' (3) is out of range | The special array index `-` in a JSON Pointer never describes a valid element of the array, but the index past the end. That is, it can only be used to add elements at this position, but not to read it. +json.exception.out_of_range.403 | key 'foo' not found | The provided key was not found in the JSON object. +json.exception.out_of_range.404 | unresolved reference token 'foo' | A reference token in a JSON Pointer could not be resolved. +json.exception.out_of_range.405 | JSON pointer has no parent | The JSON Patch operations 'remove' and 'add' can not be applied to the root element of the JSON value. +json.exception.out_of_range.406 | number overflow parsing '10E1000' | A parsed number could not be stored as without changing it to NaN or INF. +json.exception.out_of_range.407 | number overflow serializing '9223372036854775808' | UBJSON and BSON only support integer numbers up to 9223372036854775807. (until version 3.8.0) | +json.exception.out_of_range.408 | excessive array size: 8658170730974374167 | The size (following `#`) of an UBJSON array or object exceeds the maximal capacity. | +json.exception.out_of_range.409 | BSON key cannot contain code point U+0000 (at byte 2) | Key identifiers to be serialized to BSON cannot contain code point U+0000, since the key is stored as zero-terminated c-string | + +@liveexample{The following code shows how an `out_of_range` exception can be +caught.,out_of_range} + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref other_error for exceptions indicating other library errors + +@since version 3.0.0 +*/ +class out_of_range : public exception +{ + public: + static out_of_range create(int id_, const std::string& what_arg) + { + std::string w = exception::name("out_of_range", id_) + what_arg; + return out_of_range(id_, w.c_str()); + } + + private: + JSON_HEDLEY_NON_NULL(3) + out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; + +/*! +@brief exception indicating other library errors + +This exception is thrown in case of errors that cannot be classified with the +other exception types. + +Exceptions have ids 5xx. + +name / id | example message | description +------------------------------ | --------------- | ------------------------- +json.exception.other_error.501 | unsuccessful: {"op":"test","path":"/baz", "value":"bar"} | A JSON Patch operation 'test' failed. The unsuccessful operation is also printed. + +@sa - @ref exception for the base class of the library exceptions +@sa - @ref parse_error for exceptions indicating a parse error +@sa - @ref invalid_iterator for exceptions indicating errors with iterators +@sa - @ref type_error for exceptions indicating executing a member function with + a wrong type +@sa - @ref out_of_range for exceptions indicating access out of the defined range + +@liveexample{The following code shows how an `other_error` exception can be +caught.,other_error} + +@since version 3.0.0 +*/ +class other_error : public exception +{ + public: + static other_error create(int id_, const std::string& what_arg) + { + std::string w = exception::name("other_error", id_) + what_arg; + return other_error(id_, w.c_str()); + } + + private: + JSON_HEDLEY_NON_NULL(3) + other_error(int id_, const char* what_arg) : exception(id_, what_arg) {} +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // size_t +#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type + +namespace nlohmann +{ +namespace detail +{ +// alias templates to reduce boilerplate +template +using enable_if_t = typename std::enable_if::type; + +template +using uncvref_t = typename std::remove_cv::type>::type; + +// implementation of C++14 index_sequence and affiliates +// source: https://stackoverflow.com/a/32223343 +template +struct index_sequence +{ + using type = index_sequence; + using value_type = std::size_t; + static constexpr std::size_t size() noexcept + { + return sizeof...(Ints); + } +}; + +template +struct merge_and_renumber; + +template +struct merge_and_renumber, index_sequence> + : index_sequence < I1..., (sizeof...(I1) + I2)... > {}; + +template +struct make_index_sequence + : merge_and_renumber < typename make_index_sequence < N / 2 >::type, + typename make_index_sequence < N - N / 2 >::type > {}; + +template<> struct make_index_sequence<0> : index_sequence<> {}; +template<> struct make_index_sequence<1> : index_sequence<0> {}; + +template +using index_sequence_for = make_index_sequence; + +// dispatch utility (taken from ranges-v3) +template struct priority_tag : priority_tag < N - 1 > {}; +template<> struct priority_tag<0> {}; + +// taken from ranges-v3 +template +struct static_const +{ + static constexpr T value{}; +}; + +template +constexpr T static_const::value; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // numeric_limits +#include // false_type, is_constructible, is_integral, is_same, true_type +#include // declval + +// #include + + +#include // random_access_iterator_tag + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template struct make_void +{ + using type = void; +}; +template using void_t = typename make_void::type; +} // namespace detail +} // namespace nlohmann + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +struct iterator_types {}; + +template +struct iterator_types < + It, + void_t> +{ + using difference_type = typename It::difference_type; + using value_type = typename It::value_type; + using pointer = typename It::pointer; + using reference = typename It::reference; + using iterator_category = typename It::iterator_category; +}; + +// This is required as some compilers implement std::iterator_traits in a way that +// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. +template +struct iterator_traits +{ +}; + +template +struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> + : iterator_types +{ +}; + +template +struct iterator_traits::value>> +{ + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T*; + using reference = T&; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include + +// #include + + +// https://en.cppreference.com/w/cpp/experimental/is_detected +namespace nlohmann +{ +namespace detail +{ +struct nonesuch +{ + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + nonesuch(nonesuch const&&) = delete; + void operator=(nonesuch const&) = delete; + void operator=(nonesuch&&) = delete; +}; + +template class Op, + class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; + +template class Op, class... Args> +using is_detected = typename detector::value_t; + +template class Op, class... Args> +using detected_t = typename detector::type; + +template class Op, class... Args> +using detected_or = detector; + +template class Op, class... Args> +using detected_or_t = typename detected_or::type; + +template class Op, class... Args> +using is_detected_exact = std::is_same>; + +template class Op, class... Args> +using is_detected_convertible = + std::is_convertible, To>; +} // namespace detail +} // namespace nlohmann + +// #include +#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ +#define INCLUDE_NLOHMANN_JSON_FWD_HPP_ + +#include // int64_t, uint64_t +#include // map +#include // allocator +#include // string +#include // vector + +/*! +@brief namespace for Niels Lohmann +@see https://github.com/nlohmann +@since version 1.0.0 +*/ +namespace nlohmann +{ +/*! +@brief default JSONSerializer template argument + +This serializer ignores the template arguments and uses ADL +([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) +for serialization. +*/ +template +struct adl_serializer; + +template class ObjectType = + std::map, + template class ArrayType = std::vector, + class StringType = std::string, class BooleanType = bool, + class NumberIntegerType = std::int64_t, + class NumberUnsignedType = std::uint64_t, + class NumberFloatType = double, + template class AllocatorType = std::allocator, + template class JSONSerializer = + adl_serializer, + class BinaryType = std::vector> +class basic_json; + +/*! +@brief JSON Pointer + +A JSON pointer defines a string syntax for identifying a specific value +within a JSON document. It can be used with functions `at` and +`operator[]`. Furthermore, JSON pointers are the base for JSON patches. + +@sa [RFC 6901](https://tools.ietf.org/html/rfc6901) + +@since version 2.0.0 +*/ +template +class json_pointer; + +/*! +@brief default JSON class + +This type is the default specialization of the @ref basic_json class which +uses the standard template types. + +@since version 1.0.0 +*/ +using json = basic_json<>; + +template +struct ordered_map; + +/*! +@brief ordered JSON class + +This type preserves the insertion order of object keys. + +@since version 3.9.0 +*/ +using ordered_json = basic_json; + +} // namespace nlohmann + +#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ + + +namespace nlohmann +{ +/*! +@brief detail namespace with internal helper functions + +This namespace collects functions that should not be exposed, +implementations of some @ref basic_json methods, and meta-programming helpers. + +@since version 2.1.0 +*/ +namespace detail +{ +///////////// +// helpers // +///////////// + +// Note to maintainers: +// +// Every trait in this file expects a non CV-qualified type. +// The only exceptions are in the 'aliases for detected' section +// (i.e. those of the form: decltype(T::member_function(std::declval()))) +// +// In this case, T has to be properly CV-qualified to constraint the function arguments +// (e.g. to_json(BasicJsonType&, const T&)) + +template struct is_basic_json : std::false_type {}; + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +struct is_basic_json : std::true_type {}; + +////////////////////// +// json_ref helpers // +////////////////////// + +template +class json_ref; + +template +struct is_json_ref : std::false_type {}; + +template +struct is_json_ref> : std::true_type {}; + +////////////////////////// +// aliases for detected // +////////////////////////// + +template +using mapped_type_t = typename T::mapped_type; + +template +using key_type_t = typename T::key_type; + +template +using value_type_t = typename T::value_type; + +template +using difference_type_t = typename T::difference_type; + +template +using pointer_t = typename T::pointer; + +template +using reference_t = typename T::reference; + +template +using iterator_category_t = typename T::iterator_category; + +template +using iterator_t = typename T::iterator; + +template +using to_json_function = decltype(T::to_json(std::declval()...)); + +template +using from_json_function = decltype(T::from_json(std::declval()...)); + +template +using get_template_function = decltype(std::declval().template get()); + +// trait checking if JSONSerializer::from_json(json const&, udt&) exists +template +struct has_from_json : std::false_type {}; + +// trait checking if j.get is valid +// use this trait instead of std::is_constructible or std::is_convertible, +// both rely on, or make use of implicit conversions, and thus fail when T +// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) +template +struct is_getable +{ + static constexpr bool value = is_detected::value; +}; + +template +struct has_from_json < BasicJsonType, T, + enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if JSONSerializer::from_json(json const&) exists +// this overload is used for non-default-constructible user-defined-types +template +struct has_non_default_from_json : std::false_type {}; + +template +struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if BasicJsonType::json_serializer::to_json exists +// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. +template +struct has_to_json : std::false_type {}; + +template +struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + + +/////////////////// +// is_ functions // +/////////////////// + +template +struct is_iterator_traits : std::false_type {}; + +template +struct is_iterator_traits> +{ + private: + using traits = iterator_traits; + + public: + static constexpr auto value = + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value; +}; + +// source: https://stackoverflow.com/a/37193089/4116453 + +template +struct is_complete_type : std::false_type {}; + +template +struct is_complete_type : std::true_type {}; + +template +struct is_compatible_object_type_impl : std::false_type {}; + +template +struct is_compatible_object_type_impl < + BasicJsonType, CompatibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + + using object_t = typename BasicJsonType::object_t; + + // macOS's is_constructible does not play well with nonesuch... + static constexpr bool value = + std::is_constructible::value && + std::is_constructible::value; +}; + +template +struct is_compatible_object_type + : is_compatible_object_type_impl {}; + +template +struct is_constructible_object_type_impl : std::false_type {}; + +template +struct is_constructible_object_type_impl < + BasicJsonType, ConstructibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + using object_t = typename BasicJsonType::object_t; + + static constexpr bool value = + (std::is_default_constructible::value && + (std::is_move_assignable::value || + std::is_copy_assignable::value) && + (std::is_constructible::value && + std::is_same < + typename object_t::mapped_type, + typename ConstructibleObjectType::mapped_type >::value)) || + (has_from_json::value || + has_non_default_from_json < + BasicJsonType, + typename ConstructibleObjectType::mapped_type >::value); +}; + +template +struct is_constructible_object_type + : is_constructible_object_type_impl {}; + +template +struct is_compatible_string_type_impl : std::false_type {}; + +template +struct is_compatible_string_type_impl < + BasicJsonType, CompatibleStringType, + enable_if_t::value >> +{ + static constexpr auto value = + std::is_constructible::value; +}; + +template +struct is_compatible_string_type + : is_compatible_string_type_impl {}; + +template +struct is_constructible_string_type_impl : std::false_type {}; + +template +struct is_constructible_string_type_impl < + BasicJsonType, ConstructibleStringType, + enable_if_t::value >> +{ + static constexpr auto value = + std::is_constructible::value; +}; + +template +struct is_constructible_string_type + : is_constructible_string_type_impl {}; + +template +struct is_compatible_array_type_impl : std::false_type {}; + +template +struct is_compatible_array_type_impl < + BasicJsonType, CompatibleArrayType, + enable_if_t < is_detected::value&& + is_detected::value&& +// This is needed because json_reverse_iterator has a ::iterator type... +// Therefore it is detected as a CompatibleArrayType. +// The real fix would be to have an Iterable concept. + !is_iterator_traits < + iterator_traits>::value >> +{ + static constexpr bool value = + std::is_constructible::value; +}; + +template +struct is_compatible_array_type + : is_compatible_array_type_impl {}; + +template +struct is_constructible_array_type_impl : std::false_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t::value >> + : std::true_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t < !std::is_same::value&& + std::is_default_constructible::value&& +(std::is_move_assignable::value || + std::is_copy_assignable::value)&& +is_detected::value&& +is_detected::value&& +is_complete_type < +detected_t>::value >> +{ + static constexpr bool value = + // This is needed because json_reverse_iterator has a ::iterator type, + // furthermore, std::back_insert_iterator (and other iterators) have a + // base class `iterator`... Therefore it is detected as a + // ConstructibleArrayType. The real fix would be to have an Iterable + // concept. + !is_iterator_traits>::value && + + (std::is_same::value || + has_from_json::value || + has_non_default_from_json < + BasicJsonType, typename ConstructibleArrayType::value_type >::value); +}; + +template +struct is_constructible_array_type + : is_constructible_array_type_impl {}; + +template +struct is_compatible_integer_type_impl : std::false_type {}; + +template +struct is_compatible_integer_type_impl < + RealIntegerType, CompatibleNumberIntegerType, + enable_if_t < std::is_integral::value&& + std::is_integral::value&& + !std::is_same::value >> +{ + // is there an assert somewhere on overflows? + using RealLimits = std::numeric_limits; + using CompatibleLimits = std::numeric_limits; + + static constexpr auto value = + std::is_constructible::value && + CompatibleLimits::is_integer && + RealLimits::is_signed == CompatibleLimits::is_signed; +}; + +template +struct is_compatible_integer_type + : is_compatible_integer_type_impl {}; + +template +struct is_compatible_type_impl: std::false_type {}; + +template +struct is_compatible_type_impl < + BasicJsonType, CompatibleType, + enable_if_t::value >> +{ + static constexpr bool value = + has_to_json::value; +}; + +template +struct is_compatible_type + : is_compatible_type_impl {}; + +// https://en.cppreference.com/w/cpp/types/conjunction +template struct conjunction : std::true_type { }; +template struct conjunction : B1 { }; +template +struct conjunction +: std::conditional, B1>::type {}; + +template +struct is_constructible_tuple : std::false_type {}; + +template +struct is_constructible_tuple> : conjunction...> {}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // array +#include // size_t +#include // uint8_t +#include // string + +namespace nlohmann +{ +namespace detail +{ +/////////////////////////// +// JSON type enumeration // +/////////////////////////// + +/*! +@brief the JSON type enumeration + +This enumeration collects the different JSON types. It is internally used to +distinguish the stored values, and the functions @ref basic_json::is_null(), +@ref basic_json::is_object(), @ref basic_json::is_array(), +@ref basic_json::is_string(), @ref basic_json::is_boolean(), +@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), +@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), +@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and +@ref basic_json::is_structured() rely on it. + +@note There are three enumeration entries (number_integer, number_unsigned, and +number_float), because the library distinguishes these three types for numbers: +@ref basic_json::number_unsigned_t is used for unsigned integers, +@ref basic_json::number_integer_t is used for signed integers, and +@ref basic_json::number_float_t is used for floating-point numbers or to +approximate integers which do not fit in the limits of their respective type. + +@sa @ref basic_json::basic_json(const value_t value_type) -- create a JSON +value with the default value for a given type + +@since version 1.0.0 +*/ +enum class value_t : std::uint8_t +{ + null, ///< null value + object, ///< object (unordered set of name/value pairs) + array, ///< array (ordered collection of values) + string, ///< string value + boolean, ///< boolean value + number_integer, ///< number value (signed integer) + number_unsigned, ///< number value (unsigned integer) + number_float, ///< number value (floating-point) + binary, ///< binary array (ordered collection of bytes) + discarded ///< discarded by the parser callback function +}; + +/*! +@brief comparison operator for JSON types + +Returns an ordering that is similar to Python: +- order: null < boolean < number < object < array < string < binary +- furthermore, each type is not smaller than itself +- discarded values are not comparable +- binary is represented as a b"" string in python and directly comparable to a + string; however, making a binary array directly comparable with a string would + be surprising behavior in a JSON file. + +@since version 1.0.0 +*/ +inline bool operator<(const value_t lhs, const value_t rhs) noexcept +{ + static constexpr std::array order = {{ + 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, + 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, + 6 /* binary */ + } + }; + + const auto l_index = static_cast(lhs); + const auto r_index = static_cast(rhs); + return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; +} +} // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ +namespace detail +{ +template +void from_json(const BasicJsonType& j, typename std::nullptr_t& n) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_null())) + { + JSON_THROW(type_error::create(302, "type must be null, but is " + std::string(j.type_name()))); + } + n = nullptr; +} + +// overloads for basic_json template parameters +template < typename BasicJsonType, typename ArithmeticType, + enable_if_t < std::is_arithmetic::value&& + !std::is_same::value, + int > = 0 > +void get_arithmetic_value(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); + } +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::boolean_t& b) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_boolean())) + { + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(j.type_name()))); + } + b = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::string_t& s) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()))); + } + s = *j.template get_ptr(); +} + +template < + typename BasicJsonType, typename ConstructibleStringType, + enable_if_t < + is_constructible_string_type::value&& + !std::is_same::value, + int > = 0 > +void from_json(const BasicJsonType& j, ConstructibleStringType& s) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_string())) + { + JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()))); + } + + s = *j.template get_ptr(); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_float_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_unsigned_t& val) +{ + get_arithmetic_value(j, val); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::number_integer_t& val) +{ + get_arithmetic_value(j, val); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, EnumType& e) +{ + typename std::underlying_type::type val; + get_arithmetic_value(j, val); + e = static_cast(val); +} + +// forward_list doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::forward_list& l) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + l.clear(); + std::transform(j.rbegin(), j.rend(), + std::front_inserter(l), [](const BasicJsonType & i) + { + return i.template get(); + }); +} + +// valarray doesn't have an insert method +template::value, int> = 0> +void from_json(const BasicJsonType& j, std::valarray& l) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + l.resize(j.size()); + std::transform(j.begin(), j.end(), std::begin(l), + [](const BasicJsonType & elem) + { + return elem.template get(); + }); +} + +template +auto from_json(const BasicJsonType& j, T (&arr)[N]) +-> decltype(j.template get(), void()) +{ + for (std::size_t i = 0; i < N; ++i) + { + arr[i] = j.at(i).template get(); + } +} + +template +void from_json_array_impl(const BasicJsonType& j, typename BasicJsonType::array_t& arr, priority_tag<3> /*unused*/) +{ + arr = *j.template get_ptr(); +} + +template +auto from_json_array_impl(const BasicJsonType& j, std::array& arr, + priority_tag<2> /*unused*/) +-> decltype(j.template get(), void()) +{ + for (std::size_t i = 0; i < N; ++i) + { + arr[i] = j.at(i).template get(); + } +} + +template +auto from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, priority_tag<1> /*unused*/) +-> decltype( + arr.reserve(std::declval()), + j.template get(), + void()) +{ + using std::end; + + ConstructibleArrayType ret; + ret.reserve(j.size()); + std::transform(j.begin(), j.end(), + std::inserter(ret, end(ret)), [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); + arr = std::move(ret); +} + +template +void from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, + priority_tag<0> /*unused*/) +{ + using std::end; + + ConstructibleArrayType ret; + std::transform( + j.begin(), j.end(), std::inserter(ret, end(ret)), + [](const BasicJsonType & i) + { + // get() returns *this, this won't call a from_json + // method when value_type is BasicJsonType + return i.template get(); + }); + arr = std::move(ret); +} + +template < typename BasicJsonType, typename ConstructibleArrayType, + enable_if_t < + is_constructible_array_type::value&& + !is_constructible_object_type::value&& + !is_constructible_string_type::value&& + !std::is_same::value&& + !is_basic_json::value, + int > = 0 > +auto from_json(const BasicJsonType& j, ConstructibleArrayType& arr) +-> decltype(from_json_array_impl(j, arr, priority_tag<3> {}), +j.template get(), +void()) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + + std::string(j.type_name()))); + } + + from_json_array_impl(j, arr, priority_tag<3> {}); +} + +template +void from_json(const BasicJsonType& j, typename BasicJsonType::binary_t& bin) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_binary())) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(j.type_name()))); + } + + bin = *j.template get_ptr(); +} + +template::value, int> = 0> +void from_json(const BasicJsonType& j, ConstructibleObjectType& obj) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_object())) + { + JSON_THROW(type_error::create(302, "type must be object, but is " + std::string(j.type_name()))); + } + + ConstructibleObjectType ret; + auto inner_object = j.template get_ptr(); + using value_type = typename ConstructibleObjectType::value_type; + std::transform( + inner_object->begin(), inner_object->end(), + std::inserter(ret, ret.begin()), + [](typename BasicJsonType::object_t::value_type const & p) + { + return value_type(p.first, p.second.template get()); + }); + obj = std::move(ret); +} + +// overload for arithmetic types, not chosen for basic_json template arguments +// (BooleanType, etc..); note: Is it really necessary to provide explicit +// overloads for boolean_t etc. in case of a custom BooleanType which is not +// an arithmetic type? +template < typename BasicJsonType, typename ArithmeticType, + enable_if_t < + std::is_arithmetic::value&& + !std::is_same::value&& + !std::is_same::value&& + !std::is_same::value&& + !std::is_same::value, + int > = 0 > +void from_json(const BasicJsonType& j, ArithmeticType& val) +{ + switch (static_cast(j)) + { + case value_t::number_unsigned: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_integer: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::number_float: + { + val = static_cast(*j.template get_ptr()); + break; + } + case value_t::boolean: + { + val = static_cast(*j.template get_ptr()); + break; + } + + default: + JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); + } +} + +template +void from_json(const BasicJsonType& j, std::pair& p) +{ + p = {j.at(0).template get(), j.at(1).template get()}; +} + +template +void from_json_tuple_impl(const BasicJsonType& j, Tuple& t, index_sequence /*unused*/) +{ + t = std::make_tuple(j.at(Idx).template get::type>()...); +} + +template +void from_json(const BasicJsonType& j, std::tuple& t) +{ + from_json_tuple_impl(j, t, index_sequence_for {}); +} + +template < typename BasicJsonType, typename Key, typename Value, typename Compare, typename Allocator, + typename = enable_if_t < !std::is_constructible < + typename BasicJsonType::string_t, Key >::value >> +void from_json(const BasicJsonType& j, std::map& m) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + m.clear(); + for (const auto& p : j) + { + if (JSON_HEDLEY_UNLIKELY(!p.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()))); + } + m.emplace(p.at(0).template get(), p.at(1).template get()); + } +} + +template < typename BasicJsonType, typename Key, typename Value, typename Hash, typename KeyEqual, typename Allocator, + typename = enable_if_t < !std::is_constructible < + typename BasicJsonType::string_t, Key >::value >> +void from_json(const BasicJsonType& j, std::unordered_map& m) +{ + if (JSON_HEDLEY_UNLIKELY(!j.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); + } + m.clear(); + for (const auto& p : j) + { + if (JSON_HEDLEY_UNLIKELY(!p.is_array())) + { + JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()))); + } + m.emplace(p.at(0).template get(), p.at(1).template get()); + } +} + +struct from_json_fn +{ + template + auto operator()(const BasicJsonType& j, T& val) const + noexcept(noexcept(from_json(j, val))) + -> decltype(from_json(j, val), void()) + { + return from_json(j, val); + } +}; +} // namespace detail + +/// namespace to hold default `from_json` function +/// to see why this is required: +/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html +namespace +{ +constexpr const auto& from_json = detail::static_const::value; +} // namespace +} // namespace nlohmann + +// #include + + +#include // copy +#include // begin, end +#include // string +#include // tuple, get +#include // is_same, is_constructible, is_floating_point, is_enum, underlying_type +#include // move, forward, declval, pair +#include // valarray +#include // vector + +// #include + + +#include // size_t +#include // input_iterator_tag +#include // string, to_string +#include // tuple_size, get, tuple_element + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +void int_to_string( string_type& target, std::size_t value ) +{ + // For ADL + using std::to_string; + target = to_string(value); +} +template class iteration_proxy_value +{ + public: + using difference_type = std::ptrdiff_t; + using value_type = iteration_proxy_value; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::input_iterator_tag; + using string_type = typename std::remove_cv< typename std::remove_reference().key() ) >::type >::type; + + private: + /// the iterator + IteratorType anchor; + /// an index for arrays (used to create key names) + std::size_t array_index = 0; + /// last stringified array index + mutable std::size_t array_index_last = 0; + /// a string representation of the array index + mutable string_type array_index_str = "0"; + /// an empty string (to return a reference for primitive values) + const string_type empty_str = ""; + + public: + explicit iteration_proxy_value(IteratorType it) noexcept : anchor(it) {} + + /// dereference operator (needed for range-based for) + iteration_proxy_value& operator*() + { + return *this; + } + + /// increment operator (needed for range-based for) + iteration_proxy_value& operator++() + { + ++anchor; + ++array_index; + + return *this; + } + + /// equality operator (needed for InputIterator) + bool operator==(const iteration_proxy_value& o) const + { + return anchor == o.anchor; + } + + /// inequality operator (needed for range-based for) + bool operator!=(const iteration_proxy_value& o) const + { + return anchor != o.anchor; + } + + /// return key of the iterator + const string_type& key() const + { + JSON_ASSERT(anchor.m_object != nullptr); + + switch (anchor.m_object->type()) + { + // use integer array index as key + case value_t::array: + { + if (array_index != array_index_last) + { + int_to_string( array_index_str, array_index ); + array_index_last = array_index; + } + return array_index_str; + } + + // use key from the object + case value_t::object: + return anchor.key(); + + // use an empty key for all primitive types + default: + return empty_str; + } + } + + /// return value of the iterator + typename IteratorType::reference value() const + { + return anchor.value(); + } +}; + +/// proxy class for the items() function +template class iteration_proxy +{ + private: + /// the container to iterate + typename IteratorType::reference container; + + public: + /// construct iteration proxy from a container + explicit iteration_proxy(typename IteratorType::reference cont) noexcept + : container(cont) {} + + /// return iterator begin (needed for range-based for) + iteration_proxy_value begin() noexcept + { + return iteration_proxy_value(container.begin()); + } + + /// return iterator end (needed for range-based for) + iteration_proxy_value end() noexcept + { + return iteration_proxy_value(container.end()); + } +}; +// Structured Bindings Support +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +template = 0> +auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.key()) +{ + return i.key(); +} +// Structured Bindings Support +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +template = 0> +auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.value()) +{ + return i.value(); +} +} // namespace detail +} // namespace nlohmann + +// The Addition to the STD Namespace is required to add +// Structured Bindings Support to the iteration_proxy_value class +// For further reference see https://blog.tartanllama.xyz/structured-bindings/ +// And see https://github.com/nlohmann/json/pull/1391 +namespace std +{ +#if defined(__clang__) + // Fix: https://github.com/nlohmann/json/issues/1401 + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wmismatched-tags" +#endif +template +class tuple_size<::nlohmann::detail::iteration_proxy_value> + : public std::integral_constant {}; + +template +class tuple_element> +{ + public: + using type = decltype( + get(std::declval < + ::nlohmann::detail::iteration_proxy_value> ())); +}; +#if defined(__clang__) + #pragma clang diagnostic pop +#endif +} // namespace std + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +////////////////// +// constructors // +////////////////// + +template struct external_constructor; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::boolean_t b) noexcept + { + j.m_type = value_t::boolean; + j.m_value = b; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::string_t& s) + { + j.m_type = value_t::string; + j.m_value = s; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::string_t&& s) + { + j.m_type = value_t::string; + j.m_value = std::move(s); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleStringType, + enable_if_t < !std::is_same::value, + int > = 0 > + static void construct(BasicJsonType& j, const CompatibleStringType& str) + { + j.m_type = value_t::string; + j.m_value.string = j.template create(str); + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::binary_t& b) + { + j.m_type = value_t::binary; + typename BasicJsonType::binary_t value{b}; + j.m_value = value; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::binary_t&& b) + { + j.m_type = value_t::binary; + typename BasicJsonType::binary_t value{std::move(b)}; + j.m_value = value; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_float_t val) noexcept + { + j.m_type = value_t::number_float; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_unsigned_t val) noexcept + { + j.m_type = value_t::number_unsigned; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, typename BasicJsonType::number_integer_t val) noexcept + { + j.m_type = value_t::number_integer; + j.m_value = val; + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::array_t& arr) + { + j.m_type = value_t::array; + j.m_value = arr; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::array_t&& arr) + { + j.m_type = value_t::array; + j.m_value = std::move(arr); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleArrayType, + enable_if_t < !std::is_same::value, + int > = 0 > + static void construct(BasicJsonType& j, const CompatibleArrayType& arr) + { + using std::begin; + using std::end; + j.m_type = value_t::array; + j.m_value.array = j.template create(begin(arr), end(arr)); + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, const std::vector& arr) + { + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array->reserve(arr.size()); + for (const bool x : arr) + { + j.m_value.array->push_back(x); + } + j.assert_invariant(); + } + + template::value, int> = 0> + static void construct(BasicJsonType& j, const std::valarray& arr) + { + j.m_type = value_t::array; + j.m_value = value_t::array; + j.m_value.array->resize(arr.size()); + if (arr.size() > 0) + { + std::copy(std::begin(arr), std::end(arr), j.m_value.array->begin()); + } + j.assert_invariant(); + } +}; + +template<> +struct external_constructor +{ + template + static void construct(BasicJsonType& j, const typename BasicJsonType::object_t& obj) + { + j.m_type = value_t::object; + j.m_value = obj; + j.assert_invariant(); + } + + template + static void construct(BasicJsonType& j, typename BasicJsonType::object_t&& obj) + { + j.m_type = value_t::object; + j.m_value = std::move(obj); + j.assert_invariant(); + } + + template < typename BasicJsonType, typename CompatibleObjectType, + enable_if_t < !std::is_same::value, int > = 0 > + static void construct(BasicJsonType& j, const CompatibleObjectType& obj) + { + using std::begin; + using std::end; + + j.m_type = value_t::object; + j.m_value.object = j.template create(begin(obj), end(obj)); + j.assert_invariant(); + } +}; + +///////////// +// to_json // +///////////// + +template::value, int> = 0> +void to_json(BasicJsonType& j, T b) noexcept +{ + external_constructor::construct(j, b); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, const CompatibleString& s) +{ + external_constructor::construct(j, s); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::string_t&& s) +{ + external_constructor::construct(j, std::move(s)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, FloatType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, CompatibleNumberUnsignedType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, CompatibleNumberIntegerType val) noexcept +{ + external_constructor::construct(j, static_cast(val)); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, EnumType e) noexcept +{ + using underlying_type = typename std::underlying_type::type; + external_constructor::construct(j, static_cast(e)); +} + +template +void to_json(BasicJsonType& j, const std::vector& e) +{ + external_constructor::construct(j, e); +} + +template < typename BasicJsonType, typename CompatibleArrayType, + enable_if_t < is_compatible_array_type::value&& + !is_compatible_object_type::value&& + !is_compatible_string_type::value&& + !std::is_same::value&& + !is_basic_json::value, + int > = 0 > +void to_json(BasicJsonType& j, const CompatibleArrayType& arr) +{ + external_constructor::construct(j, arr); +} + +template +void to_json(BasicJsonType& j, const typename BasicJsonType::binary_t& bin) +{ + external_constructor::construct(j, bin); +} + +template::value, int> = 0> +void to_json(BasicJsonType& j, const std::valarray& arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::array_t&& arr) +{ + external_constructor::construct(j, std::move(arr)); +} + +template < typename BasicJsonType, typename CompatibleObjectType, + enable_if_t < is_compatible_object_type::value&& !is_basic_json::value, int > = 0 > +void to_json(BasicJsonType& j, const CompatibleObjectType& obj) +{ + external_constructor::construct(j, obj); +} + +template +void to_json(BasicJsonType& j, typename BasicJsonType::object_t&& obj) +{ + external_constructor::construct(j, std::move(obj)); +} + +template < + typename BasicJsonType, typename T, std::size_t N, + enable_if_t < !std::is_constructible::value, + int > = 0 > +void to_json(BasicJsonType& j, const T(&arr)[N]) +{ + external_constructor::construct(j, arr); +} + +template < typename BasicJsonType, typename T1, typename T2, enable_if_t < std::is_constructible::value&& std::is_constructible::value, int > = 0 > +void to_json(BasicJsonType& j, const std::pair& p) +{ + j = { p.first, p.second }; +} + +// for https://github.com/nlohmann/json/pull/1134 +template>::value, int> = 0> +void to_json(BasicJsonType& j, const T& b) +{ + j = { {b.key(), b.value()} }; +} + +template +void to_json_tuple_impl(BasicJsonType& j, const Tuple& t, index_sequence /*unused*/) +{ + j = { std::get(t)... }; +} + +template::value, int > = 0> +void to_json(BasicJsonType& j, const T& t) +{ + to_json_tuple_impl(j, t, make_index_sequence::value> {}); +} + +struct to_json_fn +{ + template + auto operator()(BasicJsonType& j, T&& val) const noexcept(noexcept(to_json(j, std::forward(val)))) + -> decltype(to_json(j, std::forward(val)), void()) + { + return to_json(j, std::forward(val)); + } +}; +} // namespace detail + +/// namespace to hold default `to_json` function +namespace +{ +constexpr const auto& to_json = detail::static_const::value; +} // namespace +} // namespace nlohmann + + +namespace nlohmann +{ + +template +struct adl_serializer +{ + /*! + @brief convert a JSON value to any value type + + This function is usually called by the `get()` function of the + @ref basic_json class (either explicit or via conversion operators). + + @param[in] j JSON value to read from + @param[in,out] val value to write to + */ + template + static auto from_json(BasicJsonType&& j, ValueType& val) noexcept( + noexcept(::nlohmann::from_json(std::forward(j), val))) + -> decltype(::nlohmann::from_json(std::forward(j), val), void()) + { + ::nlohmann::from_json(std::forward(j), val); + } + + /*! + @brief convert any value type to a JSON value + + This function is usually called by the constructors of the @ref basic_json + class. + + @param[in,out] j JSON value to write to + @param[in] val value to read from + */ + template + static auto to_json(BasicJsonType& j, ValueType&& val) noexcept( + noexcept(::nlohmann::to_json(j, std::forward(val)))) + -> decltype(::nlohmann::to_json(j, std::forward(val)), void()) + { + ::nlohmann::to_json(j, std::forward(val)); + } +}; + +} // namespace nlohmann + +// #include + + +#include // uint8_t +#include // tie +#include // move + +namespace nlohmann +{ + +/*! +@brief an internal type for a backed binary type + +This type extends the template parameter @a BinaryType provided to `basic_json` +with a subtype used by BSON and MessagePack. This type exists so that the user +does not have to specify a type themselves with a specific naming scheme in +order to override the binary type. + +@tparam BinaryType container to store bytes (`std::vector` by + default) + +@since version 3.8.0 +*/ +template +class byte_container_with_subtype : public BinaryType +{ + public: + /// the type of the underlying container + using container_type = BinaryType; + + byte_container_with_subtype() noexcept(noexcept(container_type())) + : container_type() + {} + + byte_container_with_subtype(const container_type& b) noexcept(noexcept(container_type(b))) + : container_type(b) + {} + + byte_container_with_subtype(container_type&& b) noexcept(noexcept(container_type(std::move(b)))) + : container_type(std::move(b)) + {} + + byte_container_with_subtype(const container_type& b, std::uint8_t subtype) noexcept(noexcept(container_type(b))) + : container_type(b) + , m_subtype(subtype) + , m_has_subtype(true) + {} + + byte_container_with_subtype(container_type&& b, std::uint8_t subtype) noexcept(noexcept(container_type(std::move(b)))) + : container_type(std::move(b)) + , m_subtype(subtype) + , m_has_subtype(true) + {} + + bool operator==(const byte_container_with_subtype& rhs) const + { + return std::tie(static_cast(*this), m_subtype, m_has_subtype) == + std::tie(static_cast(rhs), rhs.m_subtype, rhs.m_has_subtype); + } + + bool operator!=(const byte_container_with_subtype& rhs) const + { + return !(rhs == *this); + } + + /*! + @brief sets the binary subtype + + Sets the binary subtype of the value, also flags a binary JSON value as + having a subtype, which has implications for serialization. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa @ref subtype() -- return the binary subtype + @sa @ref clear_subtype() -- clears the binary subtype + @sa @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0 + */ + void set_subtype(std::uint8_t subtype) noexcept + { + m_subtype = subtype; + m_has_subtype = true; + } + + /*! + @brief return the binary subtype + + Returns the numerical subtype of the value if it has a subtype. If it does + not have a subtype, this function will return size_t(-1) as a sentinel + value. + + @return the numerical subtype of the binary value + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa @ref set_subtype() -- sets the binary subtype + @sa @ref clear_subtype() -- clears the binary subtype + @sa @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0 + */ + constexpr std::uint8_t subtype() const noexcept + { + return m_subtype; + } + + /*! + @brief return whether the value has a subtype + + @return whether the value has a subtype + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa @ref subtype() -- return the binary subtype + @sa @ref set_subtype() -- sets the binary subtype + @sa @ref clear_subtype() -- clears the binary subtype + + @since version 3.8.0 + */ + constexpr bool has_subtype() const noexcept + { + return m_has_subtype; + } + + /*! + @brief clears the binary subtype + + Clears the binary subtype and flags the value as not having a subtype, which + has implications for serialization; for instance MessagePack will prefer the + bin family over the ext family. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @sa @ref subtype() -- return the binary subtype + @sa @ref set_subtype() -- sets the binary subtype + @sa @ref has_subtype() -- returns whether or not the binary value has a + subtype + + @since version 3.8.0 + */ + void clear_subtype() noexcept + { + m_subtype = 0; + m_has_subtype = false; + } + + private: + std::uint8_t m_subtype = 0; + bool m_has_subtype = false; +}; + +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + + +#include // size_t, uint8_t +#include // hash + +namespace nlohmann +{ +namespace detail +{ + +// boost::hash_combine +inline std::size_t combine(std::size_t seed, std::size_t h) noexcept +{ + seed ^= h + 0x9e3779b9 + (seed << 6U) + (seed >> 2U); + return seed; +} + +/*! +@brief hash a JSON value + +The hash function tries to rely on std::hash where possible. Furthermore, the +type of the JSON value is taken into account to have different hash values for +null, 0, 0U, and false, etc. + +@tparam BasicJsonType basic_json specialization +@param j JSON value to hash +@return hash value of j +*/ +template +std::size_t hash(const BasicJsonType& j) +{ + using string_t = typename BasicJsonType::string_t; + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + + const auto type = static_cast(j.type()); + switch (j.type()) + { + case BasicJsonType::value_t::null: + case BasicJsonType::value_t::discarded: + { + return combine(type, 0); + } + + case BasicJsonType::value_t::object: + { + auto seed = combine(type, j.size()); + for (const auto& element : j.items()) + { + const auto h = std::hash {}(element.key()); + seed = combine(seed, h); + seed = combine(seed, hash(element.value())); + } + return seed; + } + + case BasicJsonType::value_t::array: + { + auto seed = combine(type, j.size()); + for (const auto& element : j) + { + seed = combine(seed, hash(element)); + } + return seed; + } + + case BasicJsonType::value_t::string: + { + const auto h = std::hash {}(j.template get_ref()); + return combine(type, h); + } + + case BasicJsonType::value_t::boolean: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case BasicJsonType::value_t::number_integer: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case nlohmann::detail::value_t::number_unsigned: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case nlohmann::detail::value_t::number_float: + { + const auto h = std::hash {}(j.template get()); + return combine(type, h); + } + + case nlohmann::detail::value_t::binary: + { + auto seed = combine(type, j.get_binary().size()); + const auto h = std::hash {}(j.get_binary().has_subtype()); + seed = combine(seed, h); + seed = combine(seed, j.get_binary().subtype()); + for (const auto byte : j.get_binary()) + { + seed = combine(seed, std::hash {}(byte)); + } + return seed; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } +} + +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // generate_n +#include // array +#include // ldexp +#include // size_t +#include // uint8_t, uint16_t, uint32_t, uint64_t +#include // snprintf +#include // memcpy +#include // back_inserter +#include // numeric_limits +#include // char_traits, string +#include // make_pair, move + +// #include + +// #include + + +#include // array +#include // size_t +#include //FILE * +#include // strlen +#include // istream +#include // begin, end, iterator_traits, random_access_iterator_tag, distance, next +#include // shared_ptr, make_shared, addressof +#include // accumulate +#include // string, char_traits +#include // enable_if, is_base_of, is_pointer, is_integral, remove_pointer +#include // pair, declval + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/// the supported input formats +enum class input_format_t { json, cbor, msgpack, ubjson, bson }; + +//////////////////// +// input adapters // +//////////////////// + +/*! +Input adapter for stdio file access. This adapter read only 1 byte and do not use any + buffer. This adapter is a very low level adapter. +*/ +class file_input_adapter +{ + public: + using char_type = char; + + JSON_HEDLEY_NON_NULL(2) + explicit file_input_adapter(std::FILE* f) noexcept + : m_file(f) + {} + + // make class move-only + file_input_adapter(const file_input_adapter&) = delete; + file_input_adapter(file_input_adapter&&) = default; + file_input_adapter& operator=(const file_input_adapter&) = delete; + file_input_adapter& operator=(file_input_adapter&&) = delete; + + std::char_traits::int_type get_character() noexcept + { + return std::fgetc(m_file); + } + + private: + /// the file pointer to read from + std::FILE* m_file; +}; + + +/*! +Input adapter for a (caching) istream. Ignores a UFT Byte Order Mark at +beginning of input. Does not support changing the underlying std::streambuf +in mid-input. Maintains underlying std::istream and std::streambuf to support +subsequent use of standard std::istream operations to process any input +characters following those used in parsing the JSON input. Clears the +std::istream flags; any input errors (e.g., EOF) will be detected by the first +subsequent call for input from the std::istream. +*/ +class input_stream_adapter +{ + public: + using char_type = char; + + ~input_stream_adapter() + { + // clear stream flags; we use underlying streambuf I/O, do not + // maintain ifstream flags, except eof + if (is != nullptr) + { + is->clear(is->rdstate() & std::ios::eofbit); + } + } + + explicit input_stream_adapter(std::istream& i) + : is(&i), sb(i.rdbuf()) + {} + + // delete because of pointer members + input_stream_adapter(const input_stream_adapter&) = delete; + input_stream_adapter& operator=(input_stream_adapter&) = delete; + input_stream_adapter& operator=(input_stream_adapter&& rhs) = delete; + + input_stream_adapter(input_stream_adapter&& rhs) noexcept : is(rhs.is), sb(rhs.sb) + { + rhs.is = nullptr; + rhs.sb = nullptr; + } + + // std::istream/std::streambuf use std::char_traits::to_int_type, to + // ensure that std::char_traits::eof() and the character 0xFF do not + // end up as the same value, eg. 0xFFFFFFFF. + std::char_traits::int_type get_character() + { + auto res = sb->sbumpc(); + // set eof manually, as we don't use the istream interface. + if (JSON_HEDLEY_UNLIKELY(res == EOF)) + { + is->clear(is->rdstate() | std::ios::eofbit); + } + return res; + } + + private: + /// the associated input stream + std::istream* is = nullptr; + std::streambuf* sb = nullptr; +}; + +// General-purpose iterator-based adapter. It might not be as fast as +// theoretically possible for some containers, but it is extremely versatile. +template +class iterator_input_adapter +{ + public: + using char_type = typename std::iterator_traits::value_type; + + iterator_input_adapter(IteratorType first, IteratorType last) + : current(std::move(first)), end(std::move(last)) {} + + typename std::char_traits::int_type get_character() + { + if (JSON_HEDLEY_LIKELY(current != end)) + { + auto result = std::char_traits::to_int_type(*current); + std::advance(current, 1); + return result; + } + else + { + return std::char_traits::eof(); + } + } + + private: + IteratorType current; + IteratorType end; + + template + friend struct wide_string_input_helper; + + bool empty() const + { + return current == end; + } + +}; + + +template +struct wide_string_input_helper; + +template +struct wide_string_input_helper +{ + // UTF-32 + static void fill_buffer(BaseInputAdapter& input, + std::array::int_type, 4>& utf8_bytes, + size_t& utf8_bytes_index, + size_t& utf8_bytes_filled) + { + utf8_bytes_index = 0; + + if (JSON_HEDLEY_UNLIKELY(input.empty())) + { + utf8_bytes[0] = std::char_traits::eof(); + utf8_bytes_filled = 1; + } + else + { + // get the current character + const auto wc = input.get_character(); + + // UTF-32 to UTF-8 encoding + if (wc < 0x80) + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + else if (wc <= 0x7FF) + { + utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u) & 0x1Fu)); + utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 2; + } + else if (wc <= 0xFFFF) + { + utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u) & 0x0Fu)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 3; + } + else if (wc <= 0x10FFFF) + { + utf8_bytes[0] = static_cast::int_type>(0xF0u | ((static_cast(wc) >> 18u) & 0x07u)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 12u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[3] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 4; + } + else + { + // unknown character + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + } + } +}; + +template +struct wide_string_input_helper +{ + // UTF-16 + static void fill_buffer(BaseInputAdapter& input, + std::array::int_type, 4>& utf8_bytes, + size_t& utf8_bytes_index, + size_t& utf8_bytes_filled) + { + utf8_bytes_index = 0; + + if (JSON_HEDLEY_UNLIKELY(input.empty())) + { + utf8_bytes[0] = std::char_traits::eof(); + utf8_bytes_filled = 1; + } + else + { + // get the current character + const auto wc = input.get_character(); + + // UTF-16 to UTF-8 encoding + if (wc < 0x80) + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + else if (wc <= 0x7FF) + { + utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u))); + utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 2; + } + else if (0xD800 > wc || wc >= 0xE000) + { + utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u))); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); + utf8_bytes_filled = 3; + } + else + { + if (JSON_HEDLEY_UNLIKELY(!input.empty())) + { + const auto wc2 = static_cast(input.get_character()); + const auto charcode = 0x10000u + (((static_cast(wc) & 0x3FFu) << 10u) | (wc2 & 0x3FFu)); + utf8_bytes[0] = static_cast::int_type>(0xF0u | (charcode >> 18u)); + utf8_bytes[1] = static_cast::int_type>(0x80u | ((charcode >> 12u) & 0x3Fu)); + utf8_bytes[2] = static_cast::int_type>(0x80u | ((charcode >> 6u) & 0x3Fu)); + utf8_bytes[3] = static_cast::int_type>(0x80u | (charcode & 0x3Fu)); + utf8_bytes_filled = 4; + } + else + { + utf8_bytes[0] = static_cast::int_type>(wc); + utf8_bytes_filled = 1; + } + } + } + } +}; + +// Wraps another input apdater to convert wide character types into individual bytes. +template +class wide_string_input_adapter +{ + public: + using char_type = char; + + wide_string_input_adapter(BaseInputAdapter base) + : base_adapter(base) {} + + typename std::char_traits::int_type get_character() noexcept + { + // check if buffer needs to be filled + if (utf8_bytes_index == utf8_bytes_filled) + { + fill_buffer(); + + JSON_ASSERT(utf8_bytes_filled > 0); + JSON_ASSERT(utf8_bytes_index == 0); + } + + // use buffer + JSON_ASSERT(utf8_bytes_filled > 0); + JSON_ASSERT(utf8_bytes_index < utf8_bytes_filled); + return utf8_bytes[utf8_bytes_index++]; + } + + private: + BaseInputAdapter base_adapter; + + template + void fill_buffer() + { + wide_string_input_helper::fill_buffer(base_adapter, utf8_bytes, utf8_bytes_index, utf8_bytes_filled); + } + + /// a buffer for UTF-8 bytes + std::array::int_type, 4> utf8_bytes = {{0, 0, 0, 0}}; + + /// index to the utf8_codes array for the next valid byte + std::size_t utf8_bytes_index = 0; + /// number of valid bytes in the utf8_codes array + std::size_t utf8_bytes_filled = 0; +}; + + +template +struct iterator_input_adapter_factory +{ + using iterator_type = IteratorType; + using char_type = typename std::iterator_traits::value_type; + using adapter_type = iterator_input_adapter; + + static adapter_type create(IteratorType first, IteratorType last) + { + return adapter_type(std::move(first), std::move(last)); + } +}; + +template +struct is_iterator_of_multibyte +{ + using value_type = typename std::iterator_traits::value_type; + enum + { + value = sizeof(value_type) > 1 + }; +}; + +template +struct iterator_input_adapter_factory::value>> +{ + using iterator_type = IteratorType; + using char_type = typename std::iterator_traits::value_type; + using base_adapter_type = iterator_input_adapter; + using adapter_type = wide_string_input_adapter; + + static adapter_type create(IteratorType first, IteratorType last) + { + return adapter_type(base_adapter_type(std::move(first), std::move(last))); + } +}; + +// General purpose iterator-based input +template +typename iterator_input_adapter_factory::adapter_type input_adapter(IteratorType first, IteratorType last) +{ + using factory_type = iterator_input_adapter_factory; + return factory_type::create(first, last); +} + +// Convenience shorthand from container to iterator +template +auto input_adapter(const ContainerType& container) -> decltype(input_adapter(begin(container), end(container))) +{ + // Enable ADL + using std::begin; + using std::end; + + return input_adapter(begin(container), end(container)); +} + +// Special cases with fast paths +inline file_input_adapter input_adapter(std::FILE* file) +{ + return file_input_adapter(file); +} + +inline input_stream_adapter input_adapter(std::istream& stream) +{ + return input_stream_adapter(stream); +} + +inline input_stream_adapter input_adapter(std::istream&& stream) +{ + return input_stream_adapter(stream); +} + +using contiguous_bytes_input_adapter = decltype(input_adapter(std::declval(), std::declval())); + +// Null-delimited strings, and the like. +template < typename CharT, + typename std::enable_if < + std::is_pointer::value&& + !std::is_array::value&& + std::is_integral::type>::value&& + sizeof(typename std::remove_pointer::type) == 1, + int >::type = 0 > +contiguous_bytes_input_adapter input_adapter(CharT b) +{ + auto length = std::strlen(reinterpret_cast(b)); + const auto* ptr = reinterpret_cast(b); + return input_adapter(ptr, ptr + length); +} + +template +auto input_adapter(T (&array)[N]) -> decltype(input_adapter(array, array + N)) +{ + return input_adapter(array, array + N); +} + +// This class only handles inputs of input_buffer_adapter type. +// It's required so that expressions like {ptr, len} can be implicitely casted +// to the correct adapter. +class span_input_adapter +{ + public: + template < typename CharT, + typename std::enable_if < + std::is_pointer::value&& + std::is_integral::type>::value&& + sizeof(typename std::remove_pointer::type) == 1, + int >::type = 0 > + span_input_adapter(CharT b, std::size_t l) + : ia(reinterpret_cast(b), reinterpret_cast(b) + l) {} + + template::iterator_category, std::random_access_iterator_tag>::value, + int>::type = 0> + span_input_adapter(IteratorType first, IteratorType last) + : ia(input_adapter(first, last)) {} + + contiguous_bytes_input_adapter&& get() + { + return std::move(ia); + } + + private: + contiguous_bytes_input_adapter ia; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include +#include // string +#include // move +#include // vector + +// #include + +// #include + + +namespace nlohmann +{ + +/*! +@brief SAX interface + +This class describes the SAX interface used by @ref nlohmann::json::sax_parse. +Each function is called in different situations while the input is parsed. The +boolean return value informs the parser whether to continue processing the +input. +*/ +template +struct json_sax +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + /*! + @brief a null value was read + @return whether parsing should proceed + */ + virtual bool null() = 0; + + /*! + @brief a boolean value was read + @param[in] val boolean value + @return whether parsing should proceed + */ + virtual bool boolean(bool val) = 0; + + /*! + @brief an integer number was read + @param[in] val integer value + @return whether parsing should proceed + */ + virtual bool number_integer(number_integer_t val) = 0; + + /*! + @brief an unsigned integer number was read + @param[in] val unsigned integer value + @return whether parsing should proceed + */ + virtual bool number_unsigned(number_unsigned_t val) = 0; + + /*! + @brief an floating-point number was read + @param[in] val floating-point value + @param[in] s raw token value + @return whether parsing should proceed + */ + virtual bool number_float(number_float_t val, const string_t& s) = 0; + + /*! + @brief a string was read + @param[in] val string value + @return whether parsing should proceed + @note It is safe to move the passed string. + */ + virtual bool string(string_t& val) = 0; + + /*! + @brief a binary string was read + @param[in] val binary value + @return whether parsing should proceed + @note It is safe to move the passed binary. + */ + virtual bool binary(binary_t& val) = 0; + + /*! + @brief the beginning of an object was read + @param[in] elements number of object elements or -1 if unknown + @return whether parsing should proceed + @note binary formats may report the number of elements + */ + virtual bool start_object(std::size_t elements) = 0; + + /*! + @brief an object key was read + @param[in] val object key + @return whether parsing should proceed + @note It is safe to move the passed string. + */ + virtual bool key(string_t& val) = 0; + + /*! + @brief the end of an object was read + @return whether parsing should proceed + */ + virtual bool end_object() = 0; + + /*! + @brief the beginning of an array was read + @param[in] elements number of array elements or -1 if unknown + @return whether parsing should proceed + @note binary formats may report the number of elements + */ + virtual bool start_array(std::size_t elements) = 0; + + /*! + @brief the end of an array was read + @return whether parsing should proceed + */ + virtual bool end_array() = 0; + + /*! + @brief a parse error occurred + @param[in] position the position in the input where the error occurs + @param[in] last_token the last read token + @param[in] ex an exception object describing the error + @return whether parsing should proceed (must return false) + */ + virtual bool parse_error(std::size_t position, + const std::string& last_token, + const detail::exception& ex) = 0; + + virtual ~json_sax() = default; +}; + + +namespace detail +{ +/*! +@brief SAX implementation to create a JSON value from SAX events + +This class implements the @ref json_sax interface and processes the SAX events +to create a JSON value which makes it basically a DOM parser. The structure or +hierarchy of the JSON value is managed by the stack `ref_stack` which contains +a pointer to the respective array or object for each recursion depth. + +After successful parsing, the value that is passed by reference to the +constructor contains the parsed value. + +@tparam BasicJsonType the JSON type +*/ +template +class json_sax_dom_parser +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + /*! + @param[in, out] r reference to a JSON value that is manipulated while + parsing + @param[in] allow_exceptions_ whether parse errors yield exceptions + */ + explicit json_sax_dom_parser(BasicJsonType& r, const bool allow_exceptions_ = true) + : root(r), allow_exceptions(allow_exceptions_) + {} + + // make class move-only + json_sax_dom_parser(const json_sax_dom_parser&) = delete; + json_sax_dom_parser(json_sax_dom_parser&&) = default; + json_sax_dom_parser& operator=(const json_sax_dom_parser&) = delete; + json_sax_dom_parser& operator=(json_sax_dom_parser&&) = default; + ~json_sax_dom_parser() = default; + + bool null() + { + handle_value(nullptr); + return true; + } + + bool boolean(bool val) + { + handle_value(val); + return true; + } + + bool number_integer(number_integer_t val) + { + handle_value(val); + return true; + } + + bool number_unsigned(number_unsigned_t val) + { + handle_value(val); + return true; + } + + bool number_float(number_float_t val, const string_t& /*unused*/) + { + handle_value(val); + return true; + } + + bool string(string_t& val) + { + handle_value(val); + return true; + } + + bool binary(binary_t& val) + { + handle_value(std::move(val)); + return true; + } + + bool start_object(std::size_t len) + { + ref_stack.push_back(handle_value(BasicJsonType::value_t::object)); + + if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, + "excessive object size: " + std::to_string(len))); + } + + return true; + } + + bool key(string_t& val) + { + // add null at given key and store the reference for later + object_element = &(ref_stack.back()->m_value.object->operator[](val)); + return true; + } + + bool end_object() + { + ref_stack.pop_back(); + return true; + } + + bool start_array(std::size_t len) + { + ref_stack.push_back(handle_value(BasicJsonType::value_t::array)); + + if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, + "excessive array size: " + std::to_string(len))); + } + + return true; + } + + bool end_array() + { + ref_stack.pop_back(); + return true; + } + + template + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, + const Exception& ex) + { + errored = true; + static_cast(ex); + if (allow_exceptions) + { + JSON_THROW(ex); + } + return false; + } + + constexpr bool is_errored() const + { + return errored; + } + + private: + /*! + @invariant If the ref stack is empty, then the passed value will be the new + root. + @invariant If the ref stack contains a value, then it is an array or an + object to which we can add elements + */ + template + JSON_HEDLEY_RETURNS_NON_NULL + BasicJsonType* handle_value(Value&& v) + { + if (ref_stack.empty()) + { + root = BasicJsonType(std::forward(v)); + return &root; + } + + JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); + + if (ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->emplace_back(std::forward(v)); + return &(ref_stack.back()->m_value.array->back()); + } + + JSON_ASSERT(ref_stack.back()->is_object()); + JSON_ASSERT(object_element); + *object_element = BasicJsonType(std::forward(v)); + return object_element; + } + + /// the parsed JSON value + BasicJsonType& root; + /// stack to model hierarchy of values + std::vector ref_stack {}; + /// helper to hold the reference for the next object element + BasicJsonType* object_element = nullptr; + /// whether a syntax error occurred + bool errored = false; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; +}; + +template +class json_sax_dom_callback_parser +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using parser_callback_t = typename BasicJsonType::parser_callback_t; + using parse_event_t = typename BasicJsonType::parse_event_t; + + json_sax_dom_callback_parser(BasicJsonType& r, + const parser_callback_t cb, + const bool allow_exceptions_ = true) + : root(r), callback(cb), allow_exceptions(allow_exceptions_) + { + keep_stack.push_back(true); + } + + // make class move-only + json_sax_dom_callback_parser(const json_sax_dom_callback_parser&) = delete; + json_sax_dom_callback_parser(json_sax_dom_callback_parser&&) = default; + json_sax_dom_callback_parser& operator=(const json_sax_dom_callback_parser&) = delete; + json_sax_dom_callback_parser& operator=(json_sax_dom_callback_parser&&) = default; + ~json_sax_dom_callback_parser() = default; + + bool null() + { + handle_value(nullptr); + return true; + } + + bool boolean(bool val) + { + handle_value(val); + return true; + } + + bool number_integer(number_integer_t val) + { + handle_value(val); + return true; + } + + bool number_unsigned(number_unsigned_t val) + { + handle_value(val); + return true; + } + + bool number_float(number_float_t val, const string_t& /*unused*/) + { + handle_value(val); + return true; + } + + bool string(string_t& val) + { + handle_value(val); + return true; + } + + bool binary(binary_t& val) + { + handle_value(std::move(val)); + return true; + } + + bool start_object(std::size_t len) + { + // check callback for object start + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::object_start, discarded); + keep_stack.push_back(keep); + + auto val = handle_value(BasicJsonType::value_t::object, true); + ref_stack.push_back(val.second); + + // check object limit + if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive object size: " + std::to_string(len))); + } + + return true; + } + + bool key(string_t& val) + { + BasicJsonType k = BasicJsonType(val); + + // check callback for key + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::key, k); + key_keep_stack.push_back(keep); + + // add discarded value at given key and store the reference for later + if (keep && ref_stack.back()) + { + object_element = &(ref_stack.back()->m_value.object->operator[](val) = discarded); + } + + return true; + } + + bool end_object() + { + if (ref_stack.back() && !callback(static_cast(ref_stack.size()) - 1, parse_event_t::object_end, *ref_stack.back())) + { + // discard object + *ref_stack.back() = discarded; + } + + JSON_ASSERT(!ref_stack.empty()); + JSON_ASSERT(!keep_stack.empty()); + ref_stack.pop_back(); + keep_stack.pop_back(); + + if (!ref_stack.empty() && ref_stack.back() && ref_stack.back()->is_structured()) + { + // remove discarded value + for (auto it = ref_stack.back()->begin(); it != ref_stack.back()->end(); ++it) + { + if (it->is_discarded()) + { + ref_stack.back()->erase(it); + break; + } + } + } + + return true; + } + + bool start_array(std::size_t len) + { + const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::array_start, discarded); + keep_stack.push_back(keep); + + auto val = handle_value(BasicJsonType::value_t::array, true); + ref_stack.push_back(val.second); + + // check array limit + if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) + { + JSON_THROW(out_of_range::create(408, "excessive array size: " + std::to_string(len))); + } + + return true; + } + + bool end_array() + { + bool keep = true; + + if (ref_stack.back()) + { + keep = callback(static_cast(ref_stack.size()) - 1, parse_event_t::array_end, *ref_stack.back()); + if (!keep) + { + // discard array + *ref_stack.back() = discarded; + } + } + + JSON_ASSERT(!ref_stack.empty()); + JSON_ASSERT(!keep_stack.empty()); + ref_stack.pop_back(); + keep_stack.pop_back(); + + // remove discarded value + if (!keep && !ref_stack.empty() && ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->pop_back(); + } + + return true; + } + + template + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, + const Exception& ex) + { + errored = true; + static_cast(ex); + if (allow_exceptions) + { + JSON_THROW(ex); + } + return false; + } + + constexpr bool is_errored() const + { + return errored; + } + + private: + /*! + @param[in] v value to add to the JSON value we build during parsing + @param[in] skip_callback whether we should skip calling the callback + function; this is required after start_array() and + start_object() SAX events, because otherwise we would call the + callback function with an empty array or object, respectively. + + @invariant If the ref stack is empty, then the passed value will be the new + root. + @invariant If the ref stack contains a value, then it is an array or an + object to which we can add elements + + @return pair of boolean (whether value should be kept) and pointer (to the + passed value in the ref_stack hierarchy; nullptr if not kept) + */ + template + std::pair handle_value(Value&& v, const bool skip_callback = false) + { + JSON_ASSERT(!keep_stack.empty()); + + // do not handle this value if we know it would be added to a discarded + // container + if (!keep_stack.back()) + { + return {false, nullptr}; + } + + // create value + auto value = BasicJsonType(std::forward(v)); + + // check callback + const bool keep = skip_callback || callback(static_cast(ref_stack.size()), parse_event_t::value, value); + + // do not handle this value if we just learnt it shall be discarded + if (!keep) + { + return {false, nullptr}; + } + + if (ref_stack.empty()) + { + root = std::move(value); + return {true, &root}; + } + + // skip this value if we already decided to skip the parent + // (https://github.com/nlohmann/json/issues/971#issuecomment-413678360) + if (!ref_stack.back()) + { + return {false, nullptr}; + } + + // we now only expect arrays and objects + JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); + + // array + if (ref_stack.back()->is_array()) + { + ref_stack.back()->m_value.array->push_back(std::move(value)); + return {true, &(ref_stack.back()->m_value.array->back())}; + } + + // object + JSON_ASSERT(ref_stack.back()->is_object()); + // check if we should store an element for the current key + JSON_ASSERT(!key_keep_stack.empty()); + const bool store_element = key_keep_stack.back(); + key_keep_stack.pop_back(); + + if (!store_element) + { + return {false, nullptr}; + } + + JSON_ASSERT(object_element); + *object_element = std::move(value); + return {true, object_element}; + } + + /// the parsed JSON value + BasicJsonType& root; + /// stack to model hierarchy of values + std::vector ref_stack {}; + /// stack to manage which values to keep + std::vector keep_stack {}; + /// stack to manage which object keys to keep + std::vector key_keep_stack {}; + /// helper to hold the reference for the next object element + BasicJsonType* object_element = nullptr; + /// whether a syntax error occurred + bool errored = false; + /// callback function + const parser_callback_t callback = nullptr; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; + /// a discarded value for the callback + BasicJsonType discarded = BasicJsonType::value_t::discarded; +}; + +template +class json_sax_acceptor +{ + public: + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + + bool null() + { + return true; + } + + bool boolean(bool /*unused*/) + { + return true; + } + + bool number_integer(number_integer_t /*unused*/) + { + return true; + } + + bool number_unsigned(number_unsigned_t /*unused*/) + { + return true; + } + + bool number_float(number_float_t /*unused*/, const string_t& /*unused*/) + { + return true; + } + + bool string(string_t& /*unused*/) + { + return true; + } + + bool binary(binary_t& /*unused*/) + { + return true; + } + + bool start_object(std::size_t /*unused*/ = std::size_t(-1)) + { + return true; + } + + bool key(string_t& /*unused*/) + { + return true; + } + + bool end_object() + { + return true; + } + + bool start_array(std::size_t /*unused*/ = std::size_t(-1)) + { + return true; + } + + bool end_array() + { + return true; + } + + bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, const detail::exception& /*unused*/) + { + return false; + } +}; +} // namespace detail + +} // namespace nlohmann + +// #include + + +#include // array +#include // localeconv +#include // size_t +#include // snprintf +#include // strtof, strtod, strtold, strtoll, strtoull +#include // initializer_list +#include // char_traits, string +#include // move +#include // vector + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/////////// +// lexer // +/////////// + +template +class lexer_base +{ + public: + /// token types for the parser + enum class token_type + { + uninitialized, ///< indicating the scanner is uninitialized + literal_true, ///< the `true` literal + literal_false, ///< the `false` literal + literal_null, ///< the `null` literal + value_string, ///< a string -- use get_string() for actual value + value_unsigned, ///< an unsigned integer -- use get_number_unsigned() for actual value + value_integer, ///< a signed integer -- use get_number_integer() for actual value + value_float, ///< an floating point number -- use get_number_float() for actual value + begin_array, ///< the character for array begin `[` + begin_object, ///< the character for object begin `{` + end_array, ///< the character for array end `]` + end_object, ///< the character for object end `}` + name_separator, ///< the name separator `:` + value_separator, ///< the value separator `,` + parse_error, ///< indicating a parse error + end_of_input, ///< indicating the end of the input buffer + literal_or_value ///< a literal or the begin of a value (only for diagnostics) + }; + + /// return name of values of type token_type (only used for errors) + JSON_HEDLEY_RETURNS_NON_NULL + JSON_HEDLEY_CONST + static const char* token_type_name(const token_type t) noexcept + { + switch (t) + { + case token_type::uninitialized: + return ""; + case token_type::literal_true: + return "true literal"; + case token_type::literal_false: + return "false literal"; + case token_type::literal_null: + return "null literal"; + case token_type::value_string: + return "string literal"; + case token_type::value_unsigned: + case token_type::value_integer: + case token_type::value_float: + return "number literal"; + case token_type::begin_array: + return "'['"; + case token_type::begin_object: + return "'{'"; + case token_type::end_array: + return "']'"; + case token_type::end_object: + return "'}'"; + case token_type::name_separator: + return "':'"; + case token_type::value_separator: + return "','"; + case token_type::parse_error: + return ""; + case token_type::end_of_input: + return "end of input"; + case token_type::literal_or_value: + return "'[', '{', or a literal"; + // LCOV_EXCL_START + default: // catch non-enum values + return "unknown token"; + // LCOV_EXCL_STOP + } + } +}; +/*! +@brief lexical analysis + +This class organizes the lexical analysis during JSON deserialization. +*/ +template +class lexer : public lexer_base +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using char_type = typename InputAdapterType::char_type; + using char_int_type = typename std::char_traits::int_type; + + public: + using token_type = typename lexer_base::token_type; + + explicit lexer(InputAdapterType&& adapter, bool ignore_comments_ = false) + : ia(std::move(adapter)) + , ignore_comments(ignore_comments_) + , decimal_point_char(static_cast(get_decimal_point())) + {} + + // delete because of pointer members + lexer(const lexer&) = delete; + lexer(lexer&&) = default; + lexer& operator=(lexer&) = delete; + lexer& operator=(lexer&&) = default; + ~lexer() = default; + + private: + ///////////////////// + // locales + ///////////////////// + + /// return the locale-dependent decimal point + JSON_HEDLEY_PURE + static char get_decimal_point() noexcept + { + const auto* loc = localeconv(); + JSON_ASSERT(loc != nullptr); + return (loc->decimal_point == nullptr) ? '.' : *(loc->decimal_point); + } + + ///////////////////// + // scan functions + ///////////////////// + + /*! + @brief get codepoint from 4 hex characters following `\u` + + For input "\u c1 c2 c3 c4" the codepoint is: + (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4 + = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0) + + Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f' + must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The + conversion is done by subtracting the offset (0x30, 0x37, and 0x57) + between the ASCII value of the character and the desired integer value. + + @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or + non-hex character) + */ + int get_codepoint() + { + // this function only makes sense after reading `\u` + JSON_ASSERT(current == 'u'); + int codepoint = 0; + + const auto factors = { 12u, 8u, 4u, 0u }; + for (const auto factor : factors) + { + get(); + + if (current >= '0' && current <= '9') + { + codepoint += static_cast((static_cast(current) - 0x30u) << factor); + } + else if (current >= 'A' && current <= 'F') + { + codepoint += static_cast((static_cast(current) - 0x37u) << factor); + } + else if (current >= 'a' && current <= 'f') + { + codepoint += static_cast((static_cast(current) - 0x57u) << factor); + } + else + { + return -1; + } + } + + JSON_ASSERT(0x0000 <= codepoint && codepoint <= 0xFFFF); + return codepoint; + } + + /*! + @brief check if the next byte(s) are inside a given range + + Adds the current byte and, for each passed range, reads a new byte and + checks if it is inside the range. If a violation was detected, set up an + error message and return false. Otherwise, return true. + + @param[in] ranges list of integers; interpreted as list of pairs of + inclusive lower and upper bound, respectively + + @pre The passed list @a ranges must have 2, 4, or 6 elements; that is, + 1, 2, or 3 pairs. This precondition is enforced by an assertion. + + @return true if and only if no range violation was detected + */ + bool next_byte_in_range(std::initializer_list ranges) + { + JSON_ASSERT(ranges.size() == 2 || ranges.size() == 4 || ranges.size() == 6); + add(current); + + for (auto range = ranges.begin(); range != ranges.end(); ++range) + { + get(); + if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) + { + add(current); + } + else + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return false; + } + } + + return true; + } + + /*! + @brief scan a string literal + + This function scans a string according to Sect. 7 of RFC 7159. While + scanning, bytes are escaped and copied into buffer token_buffer. Then the + function returns successfully, token_buffer is *not* null-terminated (as it + may contain \0 bytes), and token_buffer.size() is the number of bytes in the + string. + + @return token_type::value_string if string could be successfully scanned, + token_type::parse_error otherwise + + @note In case of errors, variable error_message contains a textual + description. + */ + token_type scan_string() + { + // reset token_buffer (ignore opening quote) + reset(); + + // we entered the function by reading an open quote + JSON_ASSERT(current == '\"'); + + while (true) + { + // get next character + switch (get()) + { + // end of file while parsing string + case std::char_traits::eof(): + { + error_message = "invalid string: missing closing quote"; + return token_type::parse_error; + } + + // closing quote + case '\"': + { + return token_type::value_string; + } + + // escapes + case '\\': + { + switch (get()) + { + // quotation mark + case '\"': + add('\"'); + break; + // reverse solidus + case '\\': + add('\\'); + break; + // solidus + case '/': + add('/'); + break; + // backspace + case 'b': + add('\b'); + break; + // form feed + case 'f': + add('\f'); + break; + // line feed + case 'n': + add('\n'); + break; + // carriage return + case 'r': + add('\r'); + break; + // tab + case 't': + add('\t'); + break; + + // unicode escapes + case 'u': + { + const int codepoint1 = get_codepoint(); + int codepoint = codepoint1; // start with codepoint1 + + if (JSON_HEDLEY_UNLIKELY(codepoint1 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if code point is a high surrogate + if (0xD800 <= codepoint1 && codepoint1 <= 0xDBFF) + { + // expect next \uxxxx entry + if (JSON_HEDLEY_LIKELY(get() == '\\' && get() == 'u')) + { + const int codepoint2 = get_codepoint(); + + if (JSON_HEDLEY_UNLIKELY(codepoint2 == -1)) + { + error_message = "invalid string: '\\u' must be followed by 4 hex digits"; + return token_type::parse_error; + } + + // check if codepoint2 is a low surrogate + if (JSON_HEDLEY_LIKELY(0xDC00 <= codepoint2 && codepoint2 <= 0xDFFF)) + { + // overwrite codepoint + codepoint = static_cast( + // high surrogate occupies the most significant 22 bits + (static_cast(codepoint1) << 10u) + // low surrogate occupies the least significant 15 bits + + static_cast(codepoint2) + // there is still the 0xD800, 0xDC00 and 0x10000 noise + // in the result so we have to subtract with: + // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00 + - 0x35FDC00u); + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; + return token_type::parse_error; + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(0xDC00 <= codepoint1 && codepoint1 <= 0xDFFF)) + { + error_message = "invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF"; + return token_type::parse_error; + } + } + + // result of the above calculation yields a proper codepoint + JSON_ASSERT(0x00 <= codepoint && codepoint <= 0x10FFFF); + + // translate codepoint into bytes + if (codepoint < 0x80) + { + // 1-byte characters: 0xxxxxxx (ASCII) + add(static_cast(codepoint)); + } + else if (codepoint <= 0x7FF) + { + // 2-byte characters: 110xxxxx 10xxxxxx + add(static_cast(0xC0u | (static_cast(codepoint) >> 6u))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else if (codepoint <= 0xFFFF) + { + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + add(static_cast(0xE0u | (static_cast(codepoint) >> 12u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + else + { + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + add(static_cast(0xF0u | (static_cast(codepoint) >> 18u))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 12u) & 0x3Fu))); + add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); + add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); + } + + break; + } + + // other characters after escape + default: + error_message = "invalid string: forbidden character after backslash"; + return token_type::parse_error; + } + + break; + } + + // invalid control characters + case 0x00: + { + error_message = "invalid string: control character U+0000 (NUL) must be escaped to \\u0000"; + return token_type::parse_error; + } + + case 0x01: + { + error_message = "invalid string: control character U+0001 (SOH) must be escaped to \\u0001"; + return token_type::parse_error; + } + + case 0x02: + { + error_message = "invalid string: control character U+0002 (STX) must be escaped to \\u0002"; + return token_type::parse_error; + } + + case 0x03: + { + error_message = "invalid string: control character U+0003 (ETX) must be escaped to \\u0003"; + return token_type::parse_error; + } + + case 0x04: + { + error_message = "invalid string: control character U+0004 (EOT) must be escaped to \\u0004"; + return token_type::parse_error; + } + + case 0x05: + { + error_message = "invalid string: control character U+0005 (ENQ) must be escaped to \\u0005"; + return token_type::parse_error; + } + + case 0x06: + { + error_message = "invalid string: control character U+0006 (ACK) must be escaped to \\u0006"; + return token_type::parse_error; + } + + case 0x07: + { + error_message = "invalid string: control character U+0007 (BEL) must be escaped to \\u0007"; + return token_type::parse_error; + } + + case 0x08: + { + error_message = "invalid string: control character U+0008 (BS) must be escaped to \\u0008 or \\b"; + return token_type::parse_error; + } + + case 0x09: + { + error_message = "invalid string: control character U+0009 (HT) must be escaped to \\u0009 or \\t"; + return token_type::parse_error; + } + + case 0x0A: + { + error_message = "invalid string: control character U+000A (LF) must be escaped to \\u000A or \\n"; + return token_type::parse_error; + } + + case 0x0B: + { + error_message = "invalid string: control character U+000B (VT) must be escaped to \\u000B"; + return token_type::parse_error; + } + + case 0x0C: + { + error_message = "invalid string: control character U+000C (FF) must be escaped to \\u000C or \\f"; + return token_type::parse_error; + } + + case 0x0D: + { + error_message = "invalid string: control character U+000D (CR) must be escaped to \\u000D or \\r"; + return token_type::parse_error; + } + + case 0x0E: + { + error_message = "invalid string: control character U+000E (SO) must be escaped to \\u000E"; + return token_type::parse_error; + } + + case 0x0F: + { + error_message = "invalid string: control character U+000F (SI) must be escaped to \\u000F"; + return token_type::parse_error; + } + + case 0x10: + { + error_message = "invalid string: control character U+0010 (DLE) must be escaped to \\u0010"; + return token_type::parse_error; + } + + case 0x11: + { + error_message = "invalid string: control character U+0011 (DC1) must be escaped to \\u0011"; + return token_type::parse_error; + } + + case 0x12: + { + error_message = "invalid string: control character U+0012 (DC2) must be escaped to \\u0012"; + return token_type::parse_error; + } + + case 0x13: + { + error_message = "invalid string: control character U+0013 (DC3) must be escaped to \\u0013"; + return token_type::parse_error; + } + + case 0x14: + { + error_message = "invalid string: control character U+0014 (DC4) must be escaped to \\u0014"; + return token_type::parse_error; + } + + case 0x15: + { + error_message = "invalid string: control character U+0015 (NAK) must be escaped to \\u0015"; + return token_type::parse_error; + } + + case 0x16: + { + error_message = "invalid string: control character U+0016 (SYN) must be escaped to \\u0016"; + return token_type::parse_error; + } + + case 0x17: + { + error_message = "invalid string: control character U+0017 (ETB) must be escaped to \\u0017"; + return token_type::parse_error; + } + + case 0x18: + { + error_message = "invalid string: control character U+0018 (CAN) must be escaped to \\u0018"; + return token_type::parse_error; + } + + case 0x19: + { + error_message = "invalid string: control character U+0019 (EM) must be escaped to \\u0019"; + return token_type::parse_error; + } + + case 0x1A: + { + error_message = "invalid string: control character U+001A (SUB) must be escaped to \\u001A"; + return token_type::parse_error; + } + + case 0x1B: + { + error_message = "invalid string: control character U+001B (ESC) must be escaped to \\u001B"; + return token_type::parse_error; + } + + case 0x1C: + { + error_message = "invalid string: control character U+001C (FS) must be escaped to \\u001C"; + return token_type::parse_error; + } + + case 0x1D: + { + error_message = "invalid string: control character U+001D (GS) must be escaped to \\u001D"; + return token_type::parse_error; + } + + case 0x1E: + { + error_message = "invalid string: control character U+001E (RS) must be escaped to \\u001E"; + return token_type::parse_error; + } + + case 0x1F: + { + error_message = "invalid string: control character U+001F (US) must be escaped to \\u001F"; + return token_type::parse_error; + } + + // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace)) + case 0x20: + case 0x21: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3A: + case 0x3B: + case 0x3C: + case 0x3D: + case 0x3E: + case 0x3F: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5A: + case 0x5B: + case 0x5D: + case 0x5E: + case 0x5F: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7A: + case 0x7B: + case 0x7C: + case 0x7D: + case 0x7E: + case 0x7F: + { + add(current); + break; + } + + // U+0080..U+07FF: bytes C2..DF 80..BF + case 0xC2: + case 0xC3: + case 0xC4: + case 0xC5: + case 0xC6: + case 0xC7: + case 0xC8: + case 0xC9: + case 0xCA: + case 0xCB: + case 0xCC: + case 0xCD: + case 0xCE: + case 0xCF: + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + case 0xD8: + case 0xD9: + case 0xDA: + case 0xDB: + case 0xDC: + case 0xDD: + case 0xDE: + case 0xDF: + { + if (JSON_HEDLEY_UNLIKELY(!next_byte_in_range({0x80, 0xBF}))) + { + return token_type::parse_error; + } + break; + } + + // U+0800..U+0FFF: bytes E0 A0..BF 80..BF + case 0xE0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF + // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xEE: + case 0xEF: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+D000..U+D7FF: bytes ED 80..9F 80..BF + case 0xED: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x9F, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF + case 0xF0: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF + case 0xF1: + case 0xF2: + case 0xF3: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF + case 0xF4: + { + if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF})))) + { + return token_type::parse_error; + } + break; + } + + // remaining bytes (80..C1 and F5..FF) are ill-formed + default: + { + error_message = "invalid string: ill-formed UTF-8 byte"; + return token_type::parse_error; + } + } + } + } + + /*! + * @brief scan a comment + * @return whether comment could be scanned successfully + */ + bool scan_comment() + { + switch (get()) + { + // single-line comments skip input until a newline or EOF is read + case '/': + { + while (true) + { + switch (get()) + { + case '\n': + case '\r': + case std::char_traits::eof(): + case '\0': + return true; + + default: + break; + } + } + } + + // multi-line comments skip input until */ is read + case '*': + { + while (true) + { + switch (get()) + { + case std::char_traits::eof(): + case '\0': + { + error_message = "invalid comment; missing closing '*/'"; + return false; + } + + case '*': + { + switch (get()) + { + case '/': + return true; + + default: + { + unget(); + continue; + } + } + } + + default: + continue; + } + } + } + + // unexpected character after reading '/' + default: + { + error_message = "invalid comment; expecting '/' or '*' after '/'"; + return false; + } + } + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(float& f, const char* str, char** endptr) noexcept + { + f = std::strtof(str, endptr); + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(double& f, const char* str, char** endptr) noexcept + { + f = std::strtod(str, endptr); + } + + JSON_HEDLEY_NON_NULL(2) + static void strtof(long double& f, const char* str, char** endptr) noexcept + { + f = std::strtold(str, endptr); + } + + /*! + @brief scan a number literal + + This function scans a string according to Sect. 6 of RFC 7159. + + The function is realized with a deterministic finite state machine derived + from the grammar described in RFC 7159. Starting in state "init", the + input is read and used to determined the next state. Only state "done" + accepts the number. State "error" is a trap state to model errors. In the + table below, "anything" means any character but the ones listed before. + + state | 0 | 1-9 | e E | + | - | . | anything + ---------|----------|----------|----------|---------|---------|----------|----------- + init | zero | any1 | [error] | [error] | minus | [error] | [error] + minus | zero | any1 | [error] | [error] | [error] | [error] | [error] + zero | done | done | exponent | done | done | decimal1 | done + any1 | any1 | any1 | exponent | done | done | decimal1 | done + decimal1 | decimal2 | decimal2 | [error] | [error] | [error] | [error] | [error] + decimal2 | decimal2 | decimal2 | exponent | done | done | done | done + exponent | any2 | any2 | [error] | sign | sign | [error] | [error] + sign | any2 | any2 | [error] | [error] | [error] | [error] | [error] + any2 | any2 | any2 | done | done | done | done | done + + The state machine is realized with one label per state (prefixed with + "scan_number_") and `goto` statements between them. The state machine + contains cycles, but any cycle can be left when EOF is read. Therefore, + the function is guaranteed to terminate. + + During scanning, the read bytes are stored in token_buffer. This string is + then converted to a signed integer, an unsigned integer, or a + floating-point number. + + @return token_type::value_unsigned, token_type::value_integer, or + token_type::value_float if number could be successfully scanned, + token_type::parse_error otherwise + + @note The scanner is independent of the current locale. Internally, the + locale's decimal point is used instead of `.` to work with the + locale-dependent converters. + */ + token_type scan_number() // lgtm [cpp/use-of-goto] + { + // reset token_buffer to store the number's bytes + reset(); + + // the type of the parsed number; initially set to unsigned; will be + // changed if minus sign, decimal point or exponent is read + token_type number_type = token_type::value_unsigned; + + // state (init): we just found out we need to scan a number + switch (current) + { + case '-': + { + add(current); + goto scan_number_minus; + } + + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + // all other characters are rejected outside scan_number() + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + +scan_number_minus: + // state: we just parsed a leading minus sign + number_type = token_type::value_integer; + switch (get()) + { + case '0': + { + add(current); + goto scan_number_zero; + } + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + default: + { + error_message = "invalid number; expected digit after '-'"; + return token_type::parse_error; + } + } + +scan_number_zero: + // state: we just parse a zero (maybe with a leading minus sign) + switch (get()) + { + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_any1: + // state: we just parsed a number 0-9 (maybe with a leading minus sign) + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any1; + } + + case '.': + { + add(decimal_point_char); + goto scan_number_decimal1; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_decimal1: + // state: we just parsed a decimal point + number_type = token_type::value_float; + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + default: + { + error_message = "invalid number; expected digit after '.'"; + return token_type::parse_error; + } + } + +scan_number_decimal2: + // we just parsed at least one number after a decimal point + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_decimal2; + } + + case 'e': + case 'E': + { + add(current); + goto scan_number_exponent; + } + + default: + goto scan_number_done; + } + +scan_number_exponent: + // we just parsed an exponent + number_type = token_type::value_float; + switch (get()) + { + case '+': + case '-': + { + add(current); + goto scan_number_sign; + } + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = + "invalid number; expected '+', '-', or digit after exponent"; + return token_type::parse_error; + } + } + +scan_number_sign: + // we just parsed an exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + { + error_message = "invalid number; expected digit after exponent sign"; + return token_type::parse_error; + } + } + +scan_number_any2: + // we just parsed a number after the exponent or exponent sign + switch (get()) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + add(current); + goto scan_number_any2; + } + + default: + goto scan_number_done; + } + +scan_number_done: + // unget the character after the number (we only read it to know that + // we are done scanning a number) + unget(); + + char* endptr = nullptr; + errno = 0; + + // try to parse integers first and fall back to floats + if (number_type == token_type::value_unsigned) + { + const auto x = std::strtoull(token_buffer.data(), &endptr, 10); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + if (errno == 0) + { + value_unsigned = static_cast(x); + if (value_unsigned == x) + { + return token_type::value_unsigned; + } + } + } + else if (number_type == token_type::value_integer) + { + const auto x = std::strtoll(token_buffer.data(), &endptr, 10); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + if (errno == 0) + { + value_integer = static_cast(x); + if (value_integer == x) + { + return token_type::value_integer; + } + } + } + + // this code is reached if we parse a floating-point number or if an + // integer conversion above failed + strtof(value_float, token_buffer.data(), &endptr); + + // we checked the number format before + JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); + + return token_type::value_float; + } + + /*! + @param[in] literal_text the literal text to expect + @param[in] length the length of the passed literal text + @param[in] return_type the token type to return on success + */ + JSON_HEDLEY_NON_NULL(2) + token_type scan_literal(const char_type* literal_text, const std::size_t length, + token_type return_type) + { + JSON_ASSERT(std::char_traits::to_char_type(current) == literal_text[0]); + for (std::size_t i = 1; i < length; ++i) + { + if (JSON_HEDLEY_UNLIKELY(std::char_traits::to_char_type(get()) != literal_text[i])) + { + error_message = "invalid literal"; + return token_type::parse_error; + } + } + return return_type; + } + + ///////////////////// + // input management + ///////////////////// + + /// reset token_buffer; current character is beginning of token + void reset() noexcept + { + token_buffer.clear(); + token_string.clear(); + token_string.push_back(std::char_traits::to_char_type(current)); + } + + /* + @brief get next character from the input + + This function provides the interface to the used input adapter. It does + not throw in case the input reached EOF, but returns a + `std::char_traits::eof()` in that case. Stores the scanned characters + for use in error messages. + + @return character read from the input + */ + char_int_type get() + { + ++position.chars_read_total; + ++position.chars_read_current_line; + + if (next_unget) + { + // just reset the next_unget variable and work with current + next_unget = false; + } + else + { + current = ia.get_character(); + } + + if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + { + token_string.push_back(std::char_traits::to_char_type(current)); + } + + if (current == '\n') + { + ++position.lines_read; + position.chars_read_current_line = 0; + } + + return current; + } + + /*! + @brief unget current character (read it again on next get) + + We implement unget by setting variable next_unget to true. The input is not + changed - we just simulate ungetting by modifying chars_read_total, + chars_read_current_line, and token_string. The next call to get() will + behave as if the unget character is read again. + */ + void unget() + { + next_unget = true; + + --position.chars_read_total; + + // in case we "unget" a newline, we have to also decrement the lines_read + if (position.chars_read_current_line == 0) + { + if (position.lines_read > 0) + { + --position.lines_read; + } + } + else + { + --position.chars_read_current_line; + } + + if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + { + JSON_ASSERT(!token_string.empty()); + token_string.pop_back(); + } + } + + /// add a character to token_buffer + void add(char_int_type c) + { + token_buffer.push_back(static_cast(c)); + } + + public: + ///////////////////// + // value getters + ///////////////////// + + /// return integer value + constexpr number_integer_t get_number_integer() const noexcept + { + return value_integer; + } + + /// return unsigned integer value + constexpr number_unsigned_t get_number_unsigned() const noexcept + { + return value_unsigned; + } + + /// return floating-point value + constexpr number_float_t get_number_float() const noexcept + { + return value_float; + } + + /// return current string value (implicitly resets the token; useful only once) + string_t& get_string() + { + return token_buffer; + } + + ///////////////////// + // diagnostics + ///////////////////// + + /// return position of last read token + constexpr position_t get_position() const noexcept + { + return position; + } + + /// return the last read token (for errors only). Will never contain EOF + /// (an arbitrary value that is not a valid char value, often -1), because + /// 255 may legitimately occur. May contain NUL, which should be escaped. + std::string get_token_string() const + { + // escape control characters + std::string result; + for (const auto c : token_string) + { + if (static_cast(c) <= '\x1F') + { + // escape control characters + std::array cs{{}}; + (std::snprintf)(cs.data(), cs.size(), "", static_cast(c)); + result += cs.data(); + } + else + { + // add character as is + result.push_back(static_cast(c)); + } + } + + return result; + } + + /// return syntax error message + JSON_HEDLEY_RETURNS_NON_NULL + constexpr const char* get_error_message() const noexcept + { + return error_message; + } + + ///////////////////// + // actual scanner + ///////////////////// + + /*! + @brief skip the UTF-8 byte order mark + @return true iff there is no BOM or the correct BOM has been skipped + */ + bool skip_bom() + { + if (get() == 0xEF) + { + // check if we completely parse the BOM + return get() == 0xBB && get() == 0xBF; + } + + // the first character is not the beginning of the BOM; unget it to + // process is later + unget(); + return true; + } + + void skip_whitespace() + { + do + { + get(); + } + while (current == ' ' || current == '\t' || current == '\n' || current == '\r'); + } + + token_type scan() + { + // initially, skip the BOM + if (position.chars_read_total == 0 && !skip_bom()) + { + error_message = "invalid BOM; must be 0xEF 0xBB 0xBF if given"; + return token_type::parse_error; + } + + // read next character and ignore whitespace + skip_whitespace(); + + // ignore comments + while (ignore_comments && current == '/') + { + if (!scan_comment()) + { + return token_type::parse_error; + } + + // skip following whitespace + skip_whitespace(); + } + + switch (current) + { + // structural characters + case '[': + return token_type::begin_array; + case ']': + return token_type::end_array; + case '{': + return token_type::begin_object; + case '}': + return token_type::end_object; + case ':': + return token_type::name_separator; + case ',': + return token_type::value_separator; + + // literals + case 't': + { + std::array true_literal = {{'t', 'r', 'u', 'e'}}; + return scan_literal(true_literal.data(), true_literal.size(), token_type::literal_true); + } + case 'f': + { + std::array false_literal = {{'f', 'a', 'l', 's', 'e'}}; + return scan_literal(false_literal.data(), false_literal.size(), token_type::literal_false); + } + case 'n': + { + std::array null_literal = {{'n', 'u', 'l', 'l'}}; + return scan_literal(null_literal.data(), null_literal.size(), token_type::literal_null); + } + + // string + case '\"': + return scan_string(); + + // number + case '-': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return scan_number(); + + // end of input (the null byte is needed when parsing from + // string literals) + case '\0': + case std::char_traits::eof(): + return token_type::end_of_input; + + // error + default: + error_message = "invalid literal"; + return token_type::parse_error; + } + } + + private: + /// input adapter + InputAdapterType ia; + + /// whether comments should be ignored (true) or signaled as errors (false) + const bool ignore_comments = false; + + /// the current character + char_int_type current = std::char_traits::eof(); + + /// whether the next get() call should just return current + bool next_unget = false; + + /// the start position of the current token + position_t position {}; + + /// raw input token string (for error messages) + std::vector token_string {}; + + /// buffer for variable-length tokens (numbers, strings) + string_t token_buffer {}; + + /// a description of occurred lexer errors + const char* error_message = ""; + + // number values + number_integer_t value_integer = 0; + number_unsigned_t value_unsigned = 0; + number_float_t value_float = 0; + + /// the decimal point + const char_int_type decimal_point_char = '.'; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // size_t +#include // declval +#include // string + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +using null_function_t = decltype(std::declval().null()); + +template +using boolean_function_t = + decltype(std::declval().boolean(std::declval())); + +template +using number_integer_function_t = + decltype(std::declval().number_integer(std::declval())); + +template +using number_unsigned_function_t = + decltype(std::declval().number_unsigned(std::declval())); + +template +using number_float_function_t = decltype(std::declval().number_float( + std::declval(), std::declval())); + +template +using string_function_t = + decltype(std::declval().string(std::declval())); + +template +using binary_function_t = + decltype(std::declval().binary(std::declval())); + +template +using start_object_function_t = + decltype(std::declval().start_object(std::declval())); + +template +using key_function_t = + decltype(std::declval().key(std::declval())); + +template +using end_object_function_t = decltype(std::declval().end_object()); + +template +using start_array_function_t = + decltype(std::declval().start_array(std::declval())); + +template +using end_array_function_t = decltype(std::declval().end_array()); + +template +using parse_error_function_t = decltype(std::declval().parse_error( + std::declval(), std::declval(), + std::declval())); + +template +struct is_sax +{ + private: + static_assert(is_basic_json::value, + "BasicJsonType must be of type basic_json<...>"); + + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using exception_t = typename BasicJsonType::exception; + + public: + static constexpr bool value = + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value && + is_detected_exact::value; +}; + +template +struct is_sax_static_asserts +{ + private: + static_assert(is_basic_json::value, + "BasicJsonType must be of type basic_json<...>"); + + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using exception_t = typename BasicJsonType::exception; + + public: + static_assert(is_detected_exact::value, + "Missing/invalid function: bool null()"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool boolean(bool)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool boolean(bool)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool number_integer(number_integer_t)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool number_unsigned(number_unsigned_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool number_float(number_float_t, const string_t&)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool string(string_t&)"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool binary(binary_t&)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool start_object(std::size_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool key(string_t&)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool end_object()"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool start_array(std::size_t)"); + static_assert(is_detected_exact::value, + "Missing/invalid function: bool end_array()"); + static_assert( + is_detected_exact::value, + "Missing/invalid function: bool parse_error(std::size_t, const " + "std::string&, const exception&)"); +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +/// how to treat CBOR tags +enum class cbor_tag_handler_t +{ + error, ///< throw a parse_error exception in case of a tag + ignore ///< ignore tags +}; + +/*! +@brief determine system byte order + +@return true if and only if system's byte order is little endian + +@note from https://stackoverflow.com/a/1001328/266378 +*/ +static inline bool little_endianess(int num = 1) noexcept +{ + return *reinterpret_cast(&num) == 1; +} + + +/////////////////// +// binary reader // +/////////////////// + +/*! +@brief deserialization of CBOR, MessagePack, and UBJSON values +*/ +template> +class binary_reader +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using json_sax_t = SAX; + using char_type = typename InputAdapterType::char_type; + using char_int_type = typename std::char_traits::int_type; + + public: + /*! + @brief create a binary reader + + @param[in] adapter input adapter to read from + */ + explicit binary_reader(InputAdapterType&& adapter) : ia(std::move(adapter)) + { + (void)detail::is_sax_static_asserts {}; + } + + // make class move-only + binary_reader(const binary_reader&) = delete; + binary_reader(binary_reader&&) = default; + binary_reader& operator=(const binary_reader&) = delete; + binary_reader& operator=(binary_reader&&) = default; + ~binary_reader() = default; + + /*! + @param[in] format the binary format to parse + @param[in] sax_ a SAX event processor + @param[in] strict whether to expect the input to be consumed completed + @param[in] tag_handler how to treat CBOR tags + + @return + */ + JSON_HEDLEY_NON_NULL(3) + bool sax_parse(const input_format_t format, + json_sax_t* sax_, + const bool strict = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + sax = sax_; + bool result = false; + + switch (format) + { + case input_format_t::bson: + result = parse_bson_internal(); + break; + + case input_format_t::cbor: + result = parse_cbor_internal(true, tag_handler); + break; + + case input_format_t::msgpack: + result = parse_msgpack_internal(); + break; + + case input_format_t::ubjson: + result = parse_ubjson_internal(); + break; + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + + // strict mode: next byte must be EOF + if (result && strict) + { + if (format == input_format_t::ubjson) + { + get_ignore_noop(); + } + else + { + get(); + } + + if (JSON_HEDLEY_UNLIKELY(current != std::char_traits::eof())) + { + return sax->parse_error(chars_read, get_token_string(), + parse_error::create(110, chars_read, exception_message(format, "expected end of input; last byte: 0x" + get_token_string(), "value"))); + } + } + + return result; + } + + private: + ////////// + // BSON // + ////////// + + /*! + @brief Reads in a BSON-object and passes it to the SAX-parser. + @return whether a valid BSON-value was passed to the SAX parser + */ + bool parse_bson_internal() + { + std::int32_t document_size{}; + get_number(input_format_t::bson, document_size); + + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/false))) + { + return false; + } + + return sax->end_object(); + } + + /*! + @brief Parses a C-style string from the BSON input. + @param[in, out] result A reference to the string variable where the read + string is to be stored. + @return `true` if the \x00-byte indicating the end of the string was + encountered before the EOF; false` indicates an unexpected EOF. + */ + bool get_bson_cstr(string_t& result) + { + auto out = std::back_inserter(result); + while (true) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "cstring"))) + { + return false; + } + if (current == 0x00) + { + return true; + } + *out++ = static_cast(current); + } + } + + /*! + @brief Parses a zero-terminated string of length @a len from the BSON + input. + @param[in] len The length (including the zero-byte at the end) of the + string to be read. + @param[in, out] result A reference to the string variable where the read + string is to be stored. + @tparam NumberType The type of the length @a len + @pre len >= 1 + @return `true` if the string was successfully parsed + */ + template + bool get_bson_string(const NumberType len, string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(len < 1)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "string length must be at least 1, is " + std::to_string(len), "string"))); + } + + return get_string(input_format_t::bson, len - static_cast(1), result) && get() != std::char_traits::eof(); + } + + /*! + @brief Parses a byte array input of length @a len from the BSON input. + @param[in] len The length of the byte array to be read. + @param[in, out] result A reference to the binary variable where the read + array is to be stored. + @tparam NumberType The type of the length @a len + @pre len >= 0 + @return `true` if the byte array was successfully parsed + */ + template + bool get_bson_binary(const NumberType len, binary_t& result) + { + if (JSON_HEDLEY_UNLIKELY(len < 0)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "byte array length cannot be negative, is " + std::to_string(len), "binary"))); + } + + // All BSON binary values have a subtype + std::uint8_t subtype{}; + get_number(input_format_t::bson, subtype); + result.set_subtype(subtype); + + return get_binary(input_format_t::bson, len, result); + } + + /*! + @brief Read a BSON document element of the given @a element_type. + @param[in] element_type The BSON element type, c.f. http://bsonspec.org/spec.html + @param[in] element_type_parse_position The position in the input stream, + where the `element_type` was read. + @warning Not all BSON element types are supported yet. An unsupported + @a element_type will give rise to a parse_error.114: + Unsupported BSON record type 0x... + @return whether a valid BSON-object/array was passed to the SAX parser + */ + bool parse_bson_element_internal(const char_int_type element_type, + const std::size_t element_type_parse_position) + { + switch (element_type) + { + case 0x01: // double + { + double number{}; + return get_number(input_format_t::bson, number) && sax->number_float(static_cast(number), ""); + } + + case 0x02: // string + { + std::int32_t len{}; + string_t value; + return get_number(input_format_t::bson, len) && get_bson_string(len, value) && sax->string(value); + } + + case 0x03: // object + { + return parse_bson_internal(); + } + + case 0x04: // array + { + return parse_bson_array(); + } + + case 0x05: // binary + { + std::int32_t len{}; + binary_t value; + return get_number(input_format_t::bson, len) && get_bson_binary(len, value) && sax->binary(value); + } + + case 0x08: // boolean + { + return sax->boolean(get() != 0); + } + + case 0x0A: // null + { + return sax->null(); + } + + case 0x10: // int32 + { + std::int32_t value{}; + return get_number(input_format_t::bson, value) && sax->number_integer(value); + } + + case 0x12: // int64 + { + std::int64_t value{}; + return get_number(input_format_t::bson, value) && sax->number_integer(value); + } + + default: // anything else not supported (yet) + { + std::array cr{{}}; + (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(element_type)); + return sax->parse_error(element_type_parse_position, std::string(cr.data()), parse_error::create(114, element_type_parse_position, "Unsupported BSON record type 0x" + std::string(cr.data()))); + } + } + } + + /*! + @brief Read a BSON element list (as specified in the BSON-spec) + + The same binary layout is used for objects and arrays, hence it must be + indicated with the argument @a is_array which one is expected + (true --> array, false --> object). + + @param[in] is_array Determines if the element list being read is to be + treated as an object (@a is_array == false), or as an + array (@a is_array == true). + @return whether a valid BSON-object/array was passed to the SAX parser + */ + bool parse_bson_element_list(const bool is_array) + { + string_t key; + + while (auto element_type = get()) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "element list"))) + { + return false; + } + + const std::size_t element_type_parse_position = chars_read; + if (JSON_HEDLEY_UNLIKELY(!get_bson_cstr(key))) + { + return false; + } + + if (!is_array && !sax->key(key)) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_internal(element_type, element_type_parse_position))) + { + return false; + } + + // get_bson_cstr only appends + key.clear(); + } + + return true; + } + + /*! + @brief Reads an array from the BSON input and passes it to the SAX-parser. + @return whether a valid BSON-array was passed to the SAX parser + */ + bool parse_bson_array() + { + std::int32_t document_size{}; + get_number(input_format_t::bson, document_size); + + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/true))) + { + return false; + } + + return sax->end_array(); + } + + ////////// + // CBOR // + ////////// + + /*! + @param[in] get_char whether a new character should be retrieved from the + input (true) or whether the last read character should + be considered instead (false) + @param[in] tag_handler how CBOR tags should be treated + + @return whether a valid CBOR value was passed to the SAX parser + */ + bool parse_cbor_internal(const bool get_char, + const cbor_tag_handler_t tag_handler) + { + switch (get_char ? get() : current) + { + // EOF + case std::char_traits::eof(): + return unexpect_eof(input_format_t::cbor, "value"); + + // Integer 0x00..0x17 (0..23) + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0A: + case 0x0B: + case 0x0C: + case 0x0D: + case 0x0E: + case 0x0F: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + return sax->number_unsigned(static_cast(current)); + + case 0x18: // Unsigned integer (one-byte uint8_t follows) + { + std::uint8_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x19: // Unsigned integer (two-byte uint16_t follows) + { + std::uint16_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x1A: // Unsigned integer (four-byte uint32_t follows) + { + std::uint32_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + case 0x1B: // Unsigned integer (eight-byte uint64_t follows) + { + std::uint64_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); + } + + // Negative integer -1-0x00..-1-0x17 (-1..-24) + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + return sax->number_integer(static_cast(0x20 - 1 - current)); + + case 0x38: // Negative integer (one-byte uint8_t follows) + { + std::uint8_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x39: // Negative integer -1-n (two-byte uint16_t follows) + { + std::uint16_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x3A: // Negative integer -1-n (four-byte uint32_t follows) + { + std::uint32_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); + } + + case 0x3B: // Negative integer -1-n (eight-byte uint64_t follows) + { + std::uint64_t number{}; + return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) + - static_cast(number)); + } + + // Binary data (0x00..0x17 bytes follow) + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: // Binary data (one-byte uint8_t for n follows) + case 0x59: // Binary data (two-byte uint16_t for n follow) + case 0x5A: // Binary data (four-byte uint32_t for n follow) + case 0x5B: // Binary data (eight-byte uint64_t for n follow) + case 0x5F: // Binary data (indefinite length) + { + binary_t b; + return get_cbor_binary(b) && sax->binary(b); + } + + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) + case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) + case 0x7F: // UTF-8 string (indefinite length) + { + string_t s; + return get_cbor_string(s) && sax->string(s); + } + + // array (0x00..0x17 data items follow) + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8A: + case 0x8B: + case 0x8C: + case 0x8D: + case 0x8E: + case 0x8F: + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + return get_cbor_array(static_cast(static_cast(current) & 0x1Fu), tag_handler); + + case 0x98: // array (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x99: // array (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x9A: // array (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x9B: // array (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); + } + + case 0x9F: // array (indefinite length) + return get_cbor_array(std::size_t(-1), tag_handler); + + // map (0x00..0x17 pairs of data items follow) + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + return get_cbor_object(static_cast(static_cast(current) & 0x1Fu), tag_handler); + + case 0xB8: // map (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xB9: // map (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xBA: // map (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xBB: // map (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); + } + + case 0xBF: // map (indefinite length) + return get_cbor_object(std::size_t(-1), tag_handler); + + case 0xC6: // tagged item + case 0xC7: + case 0xC8: + case 0xC9: + case 0xCA: + case 0xCB: + case 0xCC: + case 0xCD: + case 0xCE: + case 0xCF: + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD8: // tagged item (1 bytes follow) + case 0xD9: // tagged item (2 bytes follow) + case 0xDA: // tagged item (4 bytes follow) + case 0xDB: // tagged item (8 bytes follow) + { + switch (tag_handler) + { + case cbor_tag_handler_t::error: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"))); + } + + case cbor_tag_handler_t::ignore: + { + switch (current) + { + case 0xD8: + { + std::uint8_t len{}; + get_number(input_format_t::cbor, len); + break; + } + case 0xD9: + { + std::uint16_t len{}; + get_number(input_format_t::cbor, len); + break; + } + case 0xDA: + { + std::uint32_t len{}; + get_number(input_format_t::cbor, len); + break; + } + case 0xDB: + { + std::uint64_t len{}; + get_number(input_format_t::cbor, len); + break; + } + default: + break; + } + return parse_cbor_internal(true, tag_handler); + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + } + + case 0xF4: // false + return sax->boolean(false); + + case 0xF5: // true + return sax->boolean(true); + + case 0xF6: // null + return sax->null(); + + case 0xF9: // Half-Precision Float (two-byte IEEE 754) + { + const auto byte1_raw = get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) + { + return false; + } + const auto byte2_raw = get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) + { + return false; + } + + const auto byte1 = static_cast(byte1_raw); + const auto byte2 = static_cast(byte2_raw); + + // code from RFC 7049, Appendix D, Figure 3: + // As half-precision floating-point numbers were only added + // to IEEE 754 in 2008, today's programming platforms often + // still only have limited support for them. It is very + // easy to include at least decoding support for them even + // without such support. An example of a small decoder for + // half-precision floating-point numbers in the C language + // is shown in Fig. 3. + const auto half = static_cast((byte1 << 8u) + byte2); + const double val = [&half] + { + const int exp = (half >> 10u) & 0x1Fu; + const unsigned int mant = half & 0x3FFu; + JSON_ASSERT(0 <= exp&& exp <= 32); + JSON_ASSERT(mant <= 1024); + switch (exp) + { + case 0: + return std::ldexp(mant, -24); + case 31: + return (mant == 0) + ? std::numeric_limits::infinity() + : std::numeric_limits::quiet_NaN(); + default: + return std::ldexp(mant + 1024, exp - 25); + } + }(); + return sax->number_float((half & 0x8000u) != 0 + ? static_cast(-val) + : static_cast(val), ""); + } + + case 0xFA: // Single-Precision Float (four-byte IEEE 754) + { + float number{}; + return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); + } + + case 0xFB: // Double-Precision Float (eight-byte IEEE 754) + { + double number{}; + return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); + } + + default: // anything else (0xFF is handled inside the other types) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"))); + } + } + } + + /*! + @brief reads a CBOR string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + Additionally, CBOR's strings with indefinite lengths are supported. + + @param[out] result created string + + @return whether string creation completed + */ + bool get_cbor_string(string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "string"))) + { + return false; + } + + switch (current) + { + // UTF-8 string (0x00..0x17 bytes follow) + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + { + return get_string(input_format_t::cbor, static_cast(current) & 0x1Fu, result); + } + + case 0x78: // UTF-8 string (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x79: // UTF-8 string (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); + } + + case 0x7F: // UTF-8 string (indefinite length) + { + while (get() != 0xFF) + { + string_t chunk; + if (!get_cbor_string(chunk)) + { + return false; + } + result.append(chunk); + } + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x60-0x7B) or indefinite string type (0x7F); last byte: 0x" + last_token, "string"))); + } + } + } + + /*! + @brief reads a CBOR byte array + + This function first reads starting bytes to determine the expected + byte array length and then copies this number of bytes into the byte array. + Additionally, CBOR's byte arrays with indefinite lengths are supported. + + @param[out] result created byte array + + @return whether byte array creation completed + */ + bool get_cbor_binary(binary_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "binary"))) + { + return false; + } + + switch (current) + { + // Binary data (0x00..0x17 bytes follow) + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + { + return get_binary(input_format_t::cbor, static_cast(current) & 0x1Fu, result); + } + + case 0x58: // Binary data (one-byte uint8_t for n follows) + { + std::uint8_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x59: // Binary data (two-byte uint16_t for n follow) + { + std::uint16_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5A: // Binary data (four-byte uint32_t for n follow) + { + std::uint32_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5B: // Binary data (eight-byte uint64_t for n follow) + { + std::uint64_t len{}; + return get_number(input_format_t::cbor, len) && + get_binary(input_format_t::cbor, len, result); + } + + case 0x5F: // Binary data (indefinite length) + { + while (get() != 0xFF) + { + binary_t chunk; + if (!get_cbor_binary(chunk)) + { + return false; + } + result.insert(result.end(), chunk.begin(), chunk.end()); + } + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x40-0x5B) or indefinite binary array type (0x5F); last byte: 0x" + last_token, "binary"))); + } + } + } + + /*! + @param[in] len the length of the array or std::size_t(-1) for an + array of indefinite size + @param[in] tag_handler how CBOR tags should be treated + @return whether array creation completed + */ + bool get_cbor_array(const std::size_t len, + const cbor_tag_handler_t tag_handler) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) + { + return false; + } + + if (len != std::size_t(-1)) + { + for (std::size_t i = 0; i < len; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + } + } + else + { + while (get() != 0xFF) + { + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(false, tag_handler))) + { + return false; + } + } + } + + return sax->end_array(); + } + + /*! + @param[in] len the length of the object or std::size_t(-1) for an + object of indefinite size + @param[in] tag_handler how CBOR tags should be treated + @return whether object creation completed + */ + bool get_cbor_object(const std::size_t len, + const cbor_tag_handler_t tag_handler) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) + { + return false; + } + + string_t key; + if (len != std::size_t(-1)) + { + for (std::size_t i = 0; i < len; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + key.clear(); + } + } + else + { + while (get() != 0xFF) + { + if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) + { + return false; + } + key.clear(); + } + } + + return sax->end_object(); + } + + ///////////// + // MsgPack // + ///////////// + + /*! + @return whether a valid MessagePack value was passed to the SAX parser + */ + bool parse_msgpack_internal() + { + switch (get()) + { + // EOF + case std::char_traits::eof(): + return unexpect_eof(input_format_t::msgpack, "value"); + + // positive fixint + case 0x00: + case 0x01: + case 0x02: + case 0x03: + case 0x04: + case 0x05: + case 0x06: + case 0x07: + case 0x08: + case 0x09: + case 0x0A: + case 0x0B: + case 0x0C: + case 0x0D: + case 0x0E: + case 0x0F: + case 0x10: + case 0x11: + case 0x12: + case 0x13: + case 0x14: + case 0x15: + case 0x16: + case 0x17: + case 0x18: + case 0x19: + case 0x1A: + case 0x1B: + case 0x1C: + case 0x1D: + case 0x1E: + case 0x1F: + case 0x20: + case 0x21: + case 0x22: + case 0x23: + case 0x24: + case 0x25: + case 0x26: + case 0x27: + case 0x28: + case 0x29: + case 0x2A: + case 0x2B: + case 0x2C: + case 0x2D: + case 0x2E: + case 0x2F: + case 0x30: + case 0x31: + case 0x32: + case 0x33: + case 0x34: + case 0x35: + case 0x36: + case 0x37: + case 0x38: + case 0x39: + case 0x3A: + case 0x3B: + case 0x3C: + case 0x3D: + case 0x3E: + case 0x3F: + case 0x40: + case 0x41: + case 0x42: + case 0x43: + case 0x44: + case 0x45: + case 0x46: + case 0x47: + case 0x48: + case 0x49: + case 0x4A: + case 0x4B: + case 0x4C: + case 0x4D: + case 0x4E: + case 0x4F: + case 0x50: + case 0x51: + case 0x52: + case 0x53: + case 0x54: + case 0x55: + case 0x56: + case 0x57: + case 0x58: + case 0x59: + case 0x5A: + case 0x5B: + case 0x5C: + case 0x5D: + case 0x5E: + case 0x5F: + case 0x60: + case 0x61: + case 0x62: + case 0x63: + case 0x64: + case 0x65: + case 0x66: + case 0x67: + case 0x68: + case 0x69: + case 0x6A: + case 0x6B: + case 0x6C: + case 0x6D: + case 0x6E: + case 0x6F: + case 0x70: + case 0x71: + case 0x72: + case 0x73: + case 0x74: + case 0x75: + case 0x76: + case 0x77: + case 0x78: + case 0x79: + case 0x7A: + case 0x7B: + case 0x7C: + case 0x7D: + case 0x7E: + case 0x7F: + return sax->number_unsigned(static_cast(current)); + + // fixmap + case 0x80: + case 0x81: + case 0x82: + case 0x83: + case 0x84: + case 0x85: + case 0x86: + case 0x87: + case 0x88: + case 0x89: + case 0x8A: + case 0x8B: + case 0x8C: + case 0x8D: + case 0x8E: + case 0x8F: + return get_msgpack_object(static_cast(static_cast(current) & 0x0Fu)); + + // fixarray + case 0x90: + case 0x91: + case 0x92: + case 0x93: + case 0x94: + case 0x95: + case 0x96: + case 0x97: + case 0x98: + case 0x99: + case 0x9A: + case 0x9B: + case 0x9C: + case 0x9D: + case 0x9E: + case 0x9F: + return get_msgpack_array(static_cast(static_cast(current) & 0x0Fu)); + + // fixstr + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + case 0xB8: + case 0xB9: + case 0xBA: + case 0xBB: + case 0xBC: + case 0xBD: + case 0xBE: + case 0xBF: + case 0xD9: // str 8 + case 0xDA: // str 16 + case 0xDB: // str 32 + { + string_t s; + return get_msgpack_string(s) && sax->string(s); + } + + case 0xC0: // nil + return sax->null(); + + case 0xC2: // false + return sax->boolean(false); + + case 0xC3: // true + return sax->boolean(true); + + case 0xC4: // bin 8 + case 0xC5: // bin 16 + case 0xC6: // bin 32 + case 0xC7: // ext 8 + case 0xC8: // ext 16 + case 0xC9: // ext 32 + case 0xD4: // fixext 1 + case 0xD5: // fixext 2 + case 0xD6: // fixext 4 + case 0xD7: // fixext 8 + case 0xD8: // fixext 16 + { + binary_t b; + return get_msgpack_binary(b) && sax->binary(b); + } + + case 0xCA: // float 32 + { + float number{}; + return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); + } + + case 0xCB: // float 64 + { + double number{}; + return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); + } + + case 0xCC: // uint 8 + { + std::uint8_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCD: // uint 16 + { + std::uint16_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCE: // uint 32 + { + std::uint32_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xCF: // uint 64 + { + std::uint64_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); + } + + case 0xD0: // int 8 + { + std::int8_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD1: // int 16 + { + std::int16_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD2: // int 32 + { + std::int32_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xD3: // int 64 + { + std::int64_t number{}; + return get_number(input_format_t::msgpack, number) && sax->number_integer(number); + } + + case 0xDC: // array 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); + } + + case 0xDD: // array 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); + } + + case 0xDE: // map 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); + } + + case 0xDF: // map 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); + } + + // negative fixint + case 0xE0: + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xED: + case 0xEE: + case 0xEF: + case 0xF0: + case 0xF1: + case 0xF2: + case 0xF3: + case 0xF4: + case 0xF5: + case 0xF6: + case 0xF7: + case 0xF8: + case 0xF9: + case 0xFA: + case 0xFB: + case 0xFC: + case 0xFD: + case 0xFE: + case 0xFF: + return sax->number_integer(static_cast(current)); + + default: // anything else + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::msgpack, "invalid byte: 0x" + last_token, "value"))); + } + } + } + + /*! + @brief reads a MessagePack string + + This function first reads starting bytes to determine the expected + string length and then copies this number of bytes into a string. + + @param[out] result created string + + @return whether string creation completed + */ + bool get_msgpack_string(string_t& result) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::msgpack, "string"))) + { + return false; + } + + switch (current) + { + // fixstr + case 0xA0: + case 0xA1: + case 0xA2: + case 0xA3: + case 0xA4: + case 0xA5: + case 0xA6: + case 0xA7: + case 0xA8: + case 0xA9: + case 0xAA: + case 0xAB: + case 0xAC: + case 0xAD: + case 0xAE: + case 0xAF: + case 0xB0: + case 0xB1: + case 0xB2: + case 0xB3: + case 0xB4: + case 0xB5: + case 0xB6: + case 0xB7: + case 0xB8: + case 0xB9: + case 0xBA: + case 0xBB: + case 0xBC: + case 0xBD: + case 0xBE: + case 0xBF: + { + return get_string(input_format_t::msgpack, static_cast(current) & 0x1Fu, result); + } + + case 0xD9: // str 8 + { + std::uint8_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + case 0xDA: // str 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + case 0xDB: // str 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::msgpack, "expected length specification (0xA0-0xBF, 0xD9-0xDB); last byte: 0x" + last_token, "string"))); + } + } + } + + /*! + @brief reads a MessagePack byte array + + This function first reads starting bytes to determine the expected + byte array length and then copies this number of bytes into a byte array. + + @param[out] result created byte array + + @return whether byte array creation completed + */ + bool get_msgpack_binary(binary_t& result) + { + // helper function to set the subtype + auto assign_and_return_true = [&result](std::int8_t subtype) + { + result.set_subtype(static_cast(subtype)); + return true; + }; + + switch (current) + { + case 0xC4: // bin 8 + { + std::uint8_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC5: // bin 16 + { + std::uint16_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC6: // bin 32 + { + std::uint32_t len{}; + return get_number(input_format_t::msgpack, len) && + get_binary(input_format_t::msgpack, len, result); + } + + case 0xC7: // ext 8 + { + std::uint8_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xC8: // ext 16 + { + std::uint16_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xC9: // ext 32 + { + std::uint32_t len{}; + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, len) && + get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, len, result) && + assign_and_return_true(subtype); + } + + case 0xD4: // fixext 1 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 1, result) && + assign_and_return_true(subtype); + } + + case 0xD5: // fixext 2 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 2, result) && + assign_and_return_true(subtype); + } + + case 0xD6: // fixext 4 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 4, result) && + assign_and_return_true(subtype); + } + + case 0xD7: // fixext 8 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 8, result) && + assign_and_return_true(subtype); + } + + case 0xD8: // fixext 16 + { + std::int8_t subtype{}; + return get_number(input_format_t::msgpack, subtype) && + get_binary(input_format_t::msgpack, 16, result) && + assign_and_return_true(subtype); + } + + default: // LCOV_EXCL_LINE + return false; // LCOV_EXCL_LINE + } + } + + /*! + @param[in] len the length of the array + @return whether array creation completed + */ + bool get_msgpack_array(const std::size_t len) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) + { + return false; + } + + for (std::size_t i = 0; i < len; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) + { + return false; + } + } + + return sax->end_array(); + } + + /*! + @param[in] len the length of the object + @return whether object creation completed + */ + bool get_msgpack_object(const std::size_t len) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) + { + return false; + } + + string_t key; + for (std::size_t i = 0; i < len; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!get_msgpack_string(key) || !sax->key(key))) + { + return false; + } + + if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) + { + return false; + } + key.clear(); + } + + return sax->end_object(); + } + + //////////// + // UBJSON // + //////////// + + /*! + @param[in] get_char whether a new character should be retrieved from the + input (true, default) or whether the last read + character should be considered instead + + @return whether a valid UBJSON value was passed to the SAX parser + */ + bool parse_ubjson_internal(const bool get_char = true) + { + return get_ubjson_value(get_char ? get_ignore_noop() : current); + } + + /*! + @brief reads a UBJSON string + + This function is either called after reading the 'S' byte explicitly + indicating a string, or in case of an object key where the 'S' byte can be + left out. + + @param[out] result created string + @param[in] get_char whether a new character should be retrieved from the + input (true, default) or whether the last read + character should be considered instead + + @return whether string creation completed + */ + bool get_ubjson_string(string_t& result, const bool get_char = true) + { + if (get_char) + { + get(); // TODO(niels): may we ignore N here? + } + + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) + { + return false; + } + + switch (current) + { + case 'U': + { + std::uint8_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'i': + { + std::int8_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'I': + { + std::int16_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'l': + { + std::int32_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + case 'L': + { + std::int64_t len{}; + return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); + } + + default: + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L); last byte: 0x" + last_token, "string"))); + } + } + + /*! + @param[out] result determined size + @return whether size determination completed + */ + bool get_ubjson_size_value(std::size_t& result) + { + switch (get_ignore_noop()) + { + case 'U': + { + std::uint8_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'i': + { + std::int8_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'I': + { + std::int16_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'l': + { + std::int32_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + case 'L': + { + std::int64_t number{}; + if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) + { + return false; + } + result = static_cast(number); + return true; + } + + default: + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L) after '#'; last byte: 0x" + last_token, "size"))); + } + } + } + + /*! + @brief determine the type and size for a container + + In the optimized UBJSON format, a type and a size can be provided to allow + for a more compact representation. + + @param[out] result pair of the size and the type + + @return whether pair creation completed + */ + bool get_ubjson_size_type(std::pair& result) + { + result.first = string_t::npos; // size + result.second = 0; // type + + get_ignore_noop(); + + if (current == '$') + { + result.second = get(); // must not ignore 'N', because 'N' maybe the type + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "type"))) + { + return false; + } + + get_ignore_noop(); + if (JSON_HEDLEY_UNLIKELY(current != '#')) + { + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) + { + return false; + } + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "expected '#' after type information; last byte: 0x" + last_token, "size"))); + } + + return get_ubjson_size_value(result.first); + } + + if (current == '#') + { + return get_ubjson_size_value(result.first); + } + + return true; + } + + /*! + @param prefix the previously read or set type prefix + @return whether value creation completed + */ + bool get_ubjson_value(const char_int_type prefix) + { + switch (prefix) + { + case std::char_traits::eof(): // EOF + return unexpect_eof(input_format_t::ubjson, "value"); + + case 'T': // true + return sax->boolean(true); + case 'F': // false + return sax->boolean(false); + + case 'Z': // null + return sax->null(); + + case 'U': + { + std::uint8_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_unsigned(number); + } + + case 'i': + { + std::int8_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'I': + { + std::int16_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'l': + { + std::int32_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'L': + { + std::int64_t number{}; + return get_number(input_format_t::ubjson, number) && sax->number_integer(number); + } + + case 'd': + { + float number{}; + return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); + } + + case 'D': + { + double number{}; + return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); + } + + case 'H': + { + return get_ubjson_high_precision_number(); + } + + case 'C': // char + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "char"))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(current > 127)) + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "byte after 'C' must be in range 0x00..0x7F; last byte: 0x" + last_token, "char"))); + } + string_t s(1, static_cast(current)); + return sax->string(s); + } + + case 'S': // string + { + string_t s; + return get_ubjson_string(s) && sax->string(s); + } + + case '[': // array + return get_ubjson_array(); + + case '{': // object + return get_ubjson_object(); + + default: // anything else + { + auto last_token = get_token_string(); + return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "invalid byte: 0x" + last_token, "value"))); + } + } + } + + /*! + @return whether array creation completed + */ + bool get_ubjson_array() + { + std::pair size_and_type; + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) + { + return false; + } + + if (size_and_type.first != string_t::npos) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(size_and_type.first))) + { + return false; + } + + if (size_and_type.second != 0) + { + if (size_and_type.second != 'N') + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) + { + return false; + } + } + } + } + else + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + } + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + while (current != ']') + { + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal(false))) + { + return false; + } + get_ignore_noop(); + } + } + + return sax->end_array(); + } + + /*! + @return whether object creation completed + */ + bool get_ubjson_object() + { + std::pair size_and_type; + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) + { + return false; + } + + string_t key; + if (size_and_type.first != string_t::npos) + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(size_and_type.first))) + { + return false; + } + + if (size_and_type.second != 0) + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) + { + return false; + } + key.clear(); + } + } + else + { + for (std::size_t i = 0; i < size_and_type.first; ++i) + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + key.clear(); + } + } + } + else + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + while (current != '}') + { + if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key, false) || !sax->key(key))) + { + return false; + } + if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) + { + return false; + } + get_ignore_noop(); + key.clear(); + } + } + + return sax->end_object(); + } + + // Note, no reader for UBJSON binary types is implemented because they do + // not exist + + bool get_ubjson_high_precision_number() + { + // get size of following number string + std::size_t size{}; + auto res = get_ubjson_size_value(size); + if (JSON_HEDLEY_UNLIKELY(!res)) + { + return res; + } + + // get number string + std::vector number_vector; + for (std::size_t i = 0; i < size; ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "number"))) + { + return false; + } + number_vector.push_back(static_cast(current)); + } + + // parse number string + auto number_ia = detail::input_adapter(std::forward(number_vector)); + auto number_lexer = detail::lexer(std::move(number_ia), false); + const auto result_number = number_lexer.scan(); + const auto number_string = number_lexer.get_token_string(); + const auto result_remainder = number_lexer.scan(); + + using token_type = typename detail::lexer_base::token_type; + + if (JSON_HEDLEY_UNLIKELY(result_remainder != token_type::end_of_input)) + { + return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"))); + } + + switch (result_number) + { + case token_type::value_integer: + return sax->number_integer(number_lexer.get_number_integer()); + case token_type::value_unsigned: + return sax->number_unsigned(number_lexer.get_number_unsigned()); + case token_type::value_float: + return sax->number_float(number_lexer.get_number_float(), std::move(number_string)); + default: + return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"))); + } + } + + /////////////////////// + // Utility functions // + /////////////////////// + + /*! + @brief get next character from the input + + This function provides the interface to the used input adapter. It does + not throw in case the input reached EOF, but returns a -'ve valued + `std::char_traits::eof()` in that case. + + @return character read from the input + */ + char_int_type get() + { + ++chars_read; + return current = ia.get_character(); + } + + /*! + @return character read from the input after ignoring all 'N' entries + */ + char_int_type get_ignore_noop() + { + do + { + get(); + } + while (current == 'N'); + + return current; + } + + /* + @brief read a number from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[out] result number of type @a NumberType + + @return whether conversion completed + + @note This function needs to respect the system's endianess, because + bytes in CBOR, MessagePack, and UBJSON are stored in network order + (big endian) and therefore need reordering on little endian systems. + */ + template + bool get_number(const input_format_t format, NumberType& result) + { + // step 1: read input into array with system's byte order + std::array vec; + for (std::size_t i = 0; i < sizeof(NumberType); ++i) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "number"))) + { + return false; + } + + // reverse byte order prior to conversion if necessary + if (is_little_endian != InputIsLittleEndian) + { + vec[sizeof(NumberType) - i - 1] = static_cast(current); + } + else + { + vec[i] = static_cast(current); // LCOV_EXCL_LINE + } + } + + // step 2: convert array into number of type T and return + std::memcpy(&result, vec.data(), sizeof(NumberType)); + return true; + } + + /*! + @brief create a string by reading characters from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[in] len number of characters to read + @param[out] result string created by reading @a len bytes + + @return whether string creation completed + + @note We can not reserve @a len bytes for the result, because @a len + may be too large. Usually, @ref unexpect_eof() detects the end of + the input before we run out of string memory. + */ + template + bool get_string(const input_format_t format, + const NumberType len, + string_t& result) + { + bool success = true; + for (NumberType i = 0; i < len; i++) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "string"))) + { + success = false; + break; + } + result.push_back(static_cast(current)); + }; + return success; + } + + /*! + @brief create a byte array by reading bytes from the input + + @tparam NumberType the type of the number + @param[in] format the current format (for diagnostics) + @param[in] len number of bytes to read + @param[out] result byte array created by reading @a len bytes + + @return whether byte array creation completed + + @note We can not reserve @a len bytes for the result, because @a len + may be too large. Usually, @ref unexpect_eof() detects the end of + the input before we run out of memory. + */ + template + bool get_binary(const input_format_t format, + const NumberType len, + binary_t& result) + { + bool success = true; + for (NumberType i = 0; i < len; i++) + { + get(); + if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "binary"))) + { + success = false; + break; + } + result.push_back(static_cast(current)); + } + return success; + } + + /*! + @param[in] format the current format (for diagnostics) + @param[in] context further context information (for diagnostics) + @return whether the last read character is not EOF + */ + JSON_HEDLEY_NON_NULL(3) + bool unexpect_eof(const input_format_t format, const char* context) const + { + if (JSON_HEDLEY_UNLIKELY(current == std::char_traits::eof())) + { + return sax->parse_error(chars_read, "", + parse_error::create(110, chars_read, exception_message(format, "unexpected end of input", context))); + } + return true; + } + + /*! + @return a string representation of the last read byte + */ + std::string get_token_string() const + { + std::array cr{{}}; + (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(current)); + return std::string{cr.data()}; + } + + /*! + @param[in] format the current format + @param[in] detail a detailed error message + @param[in] context further context information + @return a message string to use in the parse_error exceptions + */ + std::string exception_message(const input_format_t format, + const std::string& detail, + const std::string& context) const + { + std::string error_msg = "syntax error while parsing "; + + switch (format) + { + case input_format_t::cbor: + error_msg += "CBOR"; + break; + + case input_format_t::msgpack: + error_msg += "MessagePack"; + break; + + case input_format_t::ubjson: + error_msg += "UBJSON"; + break; + + case input_format_t::bson: + error_msg += "BSON"; + break; + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + + return error_msg + " " + context + ": " + detail; + } + + private: + /// input adapter + InputAdapterType ia; + + /// the current character + char_int_type current = std::char_traits::eof(); + + /// the number of characters read + std::size_t chars_read = 0; + + /// whether we can assume little endianess + const bool is_little_endian = little_endianess(); + + /// the SAX parser + json_sax_t* sax = nullptr; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include // isfinite +#include // uint8_t +#include // function +#include // string +#include // move +#include // vector + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +//////////// +// parser // +//////////// + +enum class parse_event_t : uint8_t +{ + /// the parser read `{` and started to process a JSON object + object_start, + /// the parser read `}` and finished processing a JSON object + object_end, + /// the parser read `[` and started to process a JSON array + array_start, + /// the parser read `]` and finished processing a JSON array + array_end, + /// the parser read a key of a value in an object + key, + /// the parser finished reading a JSON value + value +}; + +template +using parser_callback_t = + std::function; + +/*! +@brief syntax analysis + +This class implements a recursive descent parser. +*/ +template +class parser +{ + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using number_float_t = typename BasicJsonType::number_float_t; + using string_t = typename BasicJsonType::string_t; + using lexer_t = lexer; + using token_type = typename lexer_t::token_type; + + public: + /// a parser reading from an input adapter + explicit parser(InputAdapterType&& adapter, + const parser_callback_t cb = nullptr, + const bool allow_exceptions_ = true, + const bool skip_comments = false) + : callback(cb) + , m_lexer(std::move(adapter), skip_comments) + , allow_exceptions(allow_exceptions_) + { + // read first token + get_token(); + } + + /*! + @brief public parser interface + + @param[in] strict whether to expect the last token to be EOF + @param[in,out] result parsed JSON value + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + void parse(const bool strict, BasicJsonType& result) + { + if (callback) + { + json_sax_dom_callback_parser sdp(result, callback, allow_exceptions); + sax_parse_internal(&sdp); + result.assert_invariant(); + + // in strict mode, input must be completely read + if (strict && (get_token() != token_type::end_of_input)) + { + sdp.parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_of_input, "value"))); + } + + // in case of an error, return discarded value + if (sdp.is_errored()) + { + result = value_t::discarded; + return; + } + + // set top-level value to null if it was discarded by the callback + // function + if (result.is_discarded()) + { + result = nullptr; + } + } + else + { + json_sax_dom_parser sdp(result, allow_exceptions); + sax_parse_internal(&sdp); + result.assert_invariant(); + + // in strict mode, input must be completely read + if (strict && (get_token() != token_type::end_of_input)) + { + sdp.parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_of_input, "value"))); + } + + // in case of an error, return discarded value + if (sdp.is_errored()) + { + result = value_t::discarded; + return; + } + } + } + + /*! + @brief public accept interface + + @param[in] strict whether to expect the last token to be EOF + @return whether the input is a proper JSON text + */ + bool accept(const bool strict = true) + { + json_sax_acceptor sax_acceptor; + return sax_parse(&sax_acceptor, strict); + } + + template + JSON_HEDLEY_NON_NULL(2) + bool sax_parse(SAX* sax, const bool strict = true) + { + (void)detail::is_sax_static_asserts {}; + const bool result = sax_parse_internal(sax); + + // strict mode: next byte must be EOF + if (result && strict && (get_token() != token_type::end_of_input)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_of_input, "value"))); + } + + return result; + } + + private: + template + JSON_HEDLEY_NON_NULL(2) + bool sax_parse_internal(SAX* sax) + { + // stack to remember the hierarchy of structured values we are parsing + // true = array; false = object + std::vector states; + // value to avoid a goto (see comment where set to true) + bool skip_to_state_evaluation = false; + + while (true) + { + if (!skip_to_state_evaluation) + { + // invariant: get_token() was called before each iteration + switch (last_token) + { + case token_type::begin_object: + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) + { + return false; + } + + // closing } -> we are done + if (get_token() == token_type::end_object) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) + { + return false; + } + break; + } + + // parse key + if (JSON_HEDLEY_UNLIKELY(last_token != token_type::value_string)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::value_string, "object key"))); + } + if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) + { + return false; + } + + // parse separator (:) + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::name_separator, "object separator"))); + } + + // remember we are now inside an object + states.push_back(false); + + // parse values + get_token(); + continue; + } + + case token_type::begin_array: + { + if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) + { + return false; + } + + // closing ] -> we are done + if (get_token() == token_type::end_array) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) + { + return false; + } + break; + } + + // remember we are now inside an array + states.push_back(true); + + // parse values (no need to call get_token) + continue; + } + + case token_type::value_float: + { + const auto res = m_lexer.get_number_float(); + + if (JSON_HEDLEY_UNLIKELY(!std::isfinite(res))) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + out_of_range::create(406, "number overflow parsing '" + m_lexer.get_token_string() + "'")); + } + + if (JSON_HEDLEY_UNLIKELY(!sax->number_float(res, m_lexer.get_string()))) + { + return false; + } + + break; + } + + case token_type::literal_false: + { + if (JSON_HEDLEY_UNLIKELY(!sax->boolean(false))) + { + return false; + } + break; + } + + case token_type::literal_null: + { + if (JSON_HEDLEY_UNLIKELY(!sax->null())) + { + return false; + } + break; + } + + case token_type::literal_true: + { + if (JSON_HEDLEY_UNLIKELY(!sax->boolean(true))) + { + return false; + } + break; + } + + case token_type::value_integer: + { + if (JSON_HEDLEY_UNLIKELY(!sax->number_integer(m_lexer.get_number_integer()))) + { + return false; + } + break; + } + + case token_type::value_string: + { + if (JSON_HEDLEY_UNLIKELY(!sax->string(m_lexer.get_string()))) + { + return false; + } + break; + } + + case token_type::value_unsigned: + { + if (JSON_HEDLEY_UNLIKELY(!sax->number_unsigned(m_lexer.get_number_unsigned()))) + { + return false; + } + break; + } + + case token_type::parse_error: + { + // using "uninitialized" to avoid "expected" message + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::uninitialized, "value"))); + } + + default: // the last token was unexpected + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::literal_or_value, "value"))); + } + } + } + else + { + skip_to_state_evaluation = false; + } + + // we reached this line after we successfully parsed a value + if (states.empty()) + { + // empty stack: we reached the end of the hierarchy: done + return true; + } + + if (states.back()) // array + { + // comma -> next value + if (get_token() == token_type::value_separator) + { + // parse a new value + get_token(); + continue; + } + + // closing ] + if (JSON_HEDLEY_LIKELY(last_token == token_type::end_array)) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) + { + return false; + } + + // We are done with this array. Before we can parse a + // new value, we need to evaluate the new state first. + // By setting skip_to_state_evaluation to false, we + // are effectively jumping to the beginning of this if. + JSON_ASSERT(!states.empty()); + states.pop_back(); + skip_to_state_evaluation = true; + continue; + } + + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_array, "array"))); + } + else // object + { + // comma -> next value + if (get_token() == token_type::value_separator) + { + // parse key + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::value_string)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::value_string, "object key"))); + } + + if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) + { + return false; + } + + // parse separator (:) + if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::name_separator, "object separator"))); + } + + // parse values + get_token(); + continue; + } + + // closing } + if (JSON_HEDLEY_LIKELY(last_token == token_type::end_object)) + { + if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) + { + return false; + } + + // We are done with this object. Before we can parse a + // new value, we need to evaluate the new state first. + // By setting skip_to_state_evaluation to false, we + // are effectively jumping to the beginning of this if. + JSON_ASSERT(!states.empty()); + states.pop_back(); + skip_to_state_evaluation = true; + continue; + } + + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + exception_message(token_type::end_object, "object"))); + } + } + } + + /// get next token from lexer + token_type get_token() + { + return last_token = m_lexer.scan(); + } + + std::string exception_message(const token_type expected, const std::string& context) + { + std::string error_msg = "syntax error "; + + if (!context.empty()) + { + error_msg += "while parsing " + context + " "; + } + + error_msg += "- "; + + if (last_token == token_type::parse_error) + { + error_msg += std::string(m_lexer.get_error_message()) + "; last read: '" + + m_lexer.get_token_string() + "'"; + } + else + { + error_msg += "unexpected " + std::string(lexer_t::token_type_name(last_token)); + } + + if (expected != token_type::uninitialized) + { + error_msg += "; expected " + std::string(lexer_t::token_type_name(expected)); + } + + return error_msg; + } + + private: + /// callback function + const parser_callback_t callback = nullptr; + /// the type of the last read token + token_type last_token = token_type::uninitialized; + /// the lexer + lexer_t m_lexer; + /// whether to throw exceptions in case of errors + const bool allow_exceptions = true; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +// #include + + +#include // ptrdiff_t +#include // numeric_limits + +namespace nlohmann +{ +namespace detail +{ +/* +@brief an iterator for primitive JSON types + +This class models an iterator for primitive JSON types (boolean, number, +string). It's only purpose is to allow the iterator/const_iterator classes +to "iterate" over primitive values. Internally, the iterator is modeled by +a `difference_type` variable. Value begin_value (`0`) models the begin, +end_value (`1`) models past the end. +*/ +class primitive_iterator_t +{ + private: + using difference_type = std::ptrdiff_t; + static constexpr difference_type begin_value = 0; + static constexpr difference_type end_value = begin_value + 1; + + /// iterator as signed integer type + difference_type m_it = (std::numeric_limits::min)(); + + public: + constexpr difference_type get_value() const noexcept + { + return m_it; + } + + /// set iterator to a defined beginning + void set_begin() noexcept + { + m_it = begin_value; + } + + /// set iterator to a defined past the end + void set_end() noexcept + { + m_it = end_value; + } + + /// return whether the iterator can be dereferenced + constexpr bool is_begin() const noexcept + { + return m_it == begin_value; + } + + /// return whether the iterator is at end + constexpr bool is_end() const noexcept + { + return m_it == end_value; + } + + friend constexpr bool operator==(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it == rhs.m_it; + } + + friend constexpr bool operator<(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it < rhs.m_it; + } + + primitive_iterator_t operator+(difference_type n) noexcept + { + auto result = *this; + result += n; + return result; + } + + friend constexpr difference_type operator-(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept + { + return lhs.m_it - rhs.m_it; + } + + primitive_iterator_t& operator++() noexcept + { + ++m_it; + return *this; + } + + primitive_iterator_t const operator++(int) noexcept + { + auto result = *this; + ++m_it; + return result; + } + + primitive_iterator_t& operator--() noexcept + { + --m_it; + return *this; + } + + primitive_iterator_t const operator--(int) noexcept + { + auto result = *this; + --m_it; + return result; + } + + primitive_iterator_t& operator+=(difference_type n) noexcept + { + m_it += n; + return *this; + } + + primitive_iterator_t& operator-=(difference_type n) noexcept + { + m_it -= n; + return *this; + } +}; +} // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ +namespace detail +{ +/*! +@brief an iterator value + +@note This structure could easily be a union, but MSVC currently does not allow +unions members with complex constructors, see https://github.com/nlohmann/json/pull/105. +*/ +template struct internal_iterator +{ + /// iterator for JSON objects + typename BasicJsonType::object_t::iterator object_iterator {}; + /// iterator for JSON arrays + typename BasicJsonType::array_t::iterator array_iterator {}; + /// generic iterator for all other types + primitive_iterator_t primitive_iterator {}; +}; +} // namespace detail +} // namespace nlohmann + +// #include + + +#include // iterator, random_access_iterator_tag, bidirectional_iterator_tag, advance, next +#include // conditional, is_const, remove_const + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +// forward declare, to be able to friend it later on +template class iteration_proxy; +template class iteration_proxy_value; + +/*! +@brief a template for a bidirectional iterator for the @ref basic_json class +This class implements a both iterators (iterator and const_iterator) for the +@ref basic_json class. +@note An iterator is called *initialized* when a pointer to a JSON value has + been set (e.g., by a constructor or a copy assignment). If the iterator is + default-constructed, it is *uninitialized* and most methods are undefined. + **The library uses assertions to detect calls on uninitialized iterators.** +@requirement The class satisfies the following concept requirements: +- +[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): + The iterator that can be moved can be moved in both directions (i.e. + incremented and decremented). +@since version 1.0.0, simplified in version 2.0.9, change to bidirectional + iterators in version 3.0.0 (see https://github.com/nlohmann/json/issues/593) +*/ +template +class iter_impl +{ + /// allow basic_json to access private members + friend iter_impl::value, typename std::remove_const::type, const BasicJsonType>::type>; + friend BasicJsonType; + friend iteration_proxy; + friend iteration_proxy_value; + + using object_t = typename BasicJsonType::object_t; + using array_t = typename BasicJsonType::array_t; + // make sure BasicJsonType is basic_json or const basic_json + static_assert(is_basic_json::type>::value, + "iter_impl only accepts (const) basic_json"); + + public: + + /// The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17. + /// The C++ Standard has never required user-defined iterators to derive from std::iterator. + /// A user-defined iterator should provide publicly accessible typedefs named + /// iterator_category, value_type, difference_type, pointer, and reference. + /// Note that value_type is required to be non-const, even for constant iterators. + using iterator_category = std::bidirectional_iterator_tag; + + /// the type of the values when the iterator is dereferenced + using value_type = typename BasicJsonType::value_type; + /// a type to represent differences between iterators + using difference_type = typename BasicJsonType::difference_type; + /// defines a pointer to the type iterated over (value_type) + using pointer = typename std::conditional::value, + typename BasicJsonType::const_pointer, + typename BasicJsonType::pointer>::type; + /// defines a reference to the type iterated over (value_type) + using reference = + typename std::conditional::value, + typename BasicJsonType::const_reference, + typename BasicJsonType::reference>::type; + + /// default constructor + iter_impl() = default; + + /*! + @brief constructor for a given JSON instance + @param[in] object pointer to a JSON object for this iterator + @pre object != nullptr + @post The iterator is initialized; i.e. `m_object != nullptr`. + */ + explicit iter_impl(pointer object) noexcept : m_object(object) + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = typename object_t::iterator(); + break; + } + + case value_t::array: + { + m_it.array_iterator = typename array_t::iterator(); + break; + } + + default: + { + m_it.primitive_iterator = primitive_iterator_t(); + break; + } + } + } + + /*! + @note The conventional copy constructor and copy assignment are implicitly + defined. Combined with the following converting constructor and + assignment, they support: (1) copy from iterator to iterator, (2) + copy from const iterator to const iterator, and (3) conversion from + iterator to const iterator. However conversion from const iterator + to iterator is not defined. + */ + + /*! + @brief const copy constructor + @param[in] other const iterator to copy from + @note This copy constructor had to be defined explicitly to circumvent a bug + occurring on msvc v19.0 compiler (VS 2015) debug build. For more + information refer to: https://github.com/nlohmann/json/issues/1608 + */ + iter_impl(const iter_impl& other) noexcept + : m_object(other.m_object), m_it(other.m_it) + {} + + /*! + @brief converting assignment + @param[in] other const iterator to copy from + @return const/non-const iterator + @note It is not checked whether @a other is initialized. + */ + iter_impl& operator=(const iter_impl& other) noexcept + { + m_object = other.m_object; + m_it = other.m_it; + return *this; + } + + /*! + @brief converting constructor + @param[in] other non-const iterator to copy from + @note It is not checked whether @a other is initialized. + */ + iter_impl(const iter_impl::type>& other) noexcept + : m_object(other.m_object), m_it(other.m_it) + {} + + /*! + @brief converting assignment + @param[in] other non-const iterator to copy from + @return const/non-const iterator + @note It is not checked whether @a other is initialized. + */ + iter_impl& operator=(const iter_impl::type>& other) noexcept + { + m_object = other.m_object; + m_it = other.m_it; + return *this; + } + + private: + /*! + @brief set the iterator to the first value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_begin() noexcept + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->begin(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->begin(); + break; + } + + case value_t::null: + { + // set to end so begin()==end() is true: null is empty + m_it.primitive_iterator.set_end(); + break; + } + + default: + { + m_it.primitive_iterator.set_begin(); + break; + } + } + } + + /*! + @brief set the iterator past the last value + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + void set_end() noexcept + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + m_it.object_iterator = m_object->m_value.object->end(); + break; + } + + case value_t::array: + { + m_it.array_iterator = m_object->m_value.array->end(); + break; + } + + default: + { + m_it.primitive_iterator.set_end(); + break; + } + } + } + + public: + /*! + @brief return a reference to the value pointed to by the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator*() const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); + return m_it.object_iterator->second; + } + + case value_t::array: + { + JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); + return *m_it.array_iterator; + } + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief dereference the iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + pointer operator->() const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); + return &(m_it.object_iterator->second); + } + + case value_t::array: + { + JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); + return &*m_it.array_iterator; + } + + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) + { + return m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief post-increment (it++) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl const operator++(int) + { + auto result = *this; + ++(*this); + return result; + } + + /*! + @brief pre-increment (++it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator++() + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, 1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, 1); + break; + } + + default: + { + ++m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief post-decrement (it--) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl const operator--(int) + { + auto result = *this; + --(*this); + return result; + } + + /*! + @brief pre-decrement (--it) + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator--() + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + { + std::advance(m_it.object_iterator, -1); + break; + } + + case value_t::array: + { + std::advance(m_it.array_iterator, -1); + break; + } + + default: + { + --m_it.primitive_iterator; + break; + } + } + + return *this; + } + + /*! + @brief comparison: equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator==(const iter_impl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); + } + + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + return (m_it.object_iterator == other.m_it.object_iterator); + + case value_t::array: + return (m_it.array_iterator == other.m_it.array_iterator); + + default: + return (m_it.primitive_iterator == other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: not equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator!=(const iter_impl& other) const + { + return !operator==(other); + } + + /*! + @brief comparison: smaller + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<(const iter_impl& other) const + { + // if objects are not the same, the comparison is undefined + if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) + { + JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); + } + + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(213, "cannot compare order of object iterators")); + + case value_t::array: + return (m_it.array_iterator < other.m_it.array_iterator); + + default: + return (m_it.primitive_iterator < other.m_it.primitive_iterator); + } + } + + /*! + @brief comparison: less than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator<=(const iter_impl& other) const + { + return !other.operator < (*this); + } + + /*! + @brief comparison: greater than + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>(const iter_impl& other) const + { + return !operator<=(other); + } + + /*! + @brief comparison: greater than or equal + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + bool operator>=(const iter_impl& other) const + { + return !operator<(other); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator+=(difference_type i) + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); + + case value_t::array: + { + std::advance(m_it.array_iterator, i); + break; + } + + default: + { + m_it.primitive_iterator += i; + break; + } + } + + return *this; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl& operator-=(difference_type i) + { + return operator+=(-i); + } + + /*! + @brief add to iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator+(difference_type i) const + { + auto result = *this; + result += i; + return result; + } + + /*! + @brief addition of distance and iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + friend iter_impl operator+(difference_type i, const iter_impl& it) + { + auto result = it; + result += i; + return result; + } + + /*! + @brief subtract from iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + iter_impl operator-(difference_type i) const + { + auto result = *this; + result -= i; + return result; + } + + /*! + @brief return difference + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + difference_type operator-(const iter_impl& other) const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); + + case value_t::array: + return m_it.array_iterator - other.m_it.array_iterator; + + default: + return m_it.primitive_iterator - other.m_it.primitive_iterator; + } + } + + /*! + @brief access to successor + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference operator[](difference_type n) const + { + JSON_ASSERT(m_object != nullptr); + + switch (m_object->m_type) + { + case value_t::object: + JSON_THROW(invalid_iterator::create(208, "cannot use operator[] for object iterators")); + + case value_t::array: + return *std::next(m_it.array_iterator, n); + + case value_t::null: + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + + default: + { + if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.get_value() == -n)) + { + return *m_object; + } + + JSON_THROW(invalid_iterator::create(214, "cannot get value")); + } + } + } + + /*! + @brief return the key of an object iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + const typename object_t::key_type& key() const + { + JSON_ASSERT(m_object != nullptr); + + if (JSON_HEDLEY_LIKELY(m_object->is_object())) + { + return m_it.object_iterator->first; + } + + JSON_THROW(invalid_iterator::create(207, "cannot use key() for non-object iterators")); + } + + /*! + @brief return the value of an iterator + @pre The iterator is initialized; i.e. `m_object != nullptr`. + */ + reference value() const + { + return operator*(); + } + + private: + /// associated JSON instance + pointer m_object = nullptr; + /// the actual iterator of the associated instance + internal_iterator::type> m_it {}; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // ptrdiff_t +#include // reverse_iterator +#include // declval + +namespace nlohmann +{ +namespace detail +{ +////////////////////// +// reverse_iterator // +////////////////////// + +/*! +@brief a template for a reverse iterator class + +@tparam Base the base iterator type to reverse. Valid types are @ref +iterator (to create @ref reverse_iterator) and @ref const_iterator (to +create @ref const_reverse_iterator). + +@requirement The class satisfies the following concept requirements: +- +[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): + The iterator that can be moved can be moved in both directions (i.e. + incremented and decremented). +- [OutputIterator](https://en.cppreference.com/w/cpp/named_req/OutputIterator): + It is possible to write to the pointed-to element (only if @a Base is + @ref iterator). + +@since version 1.0.0 +*/ +template +class json_reverse_iterator : public std::reverse_iterator +{ + public: + using difference_type = std::ptrdiff_t; + /// shortcut to the reverse iterator adapter + using base_iterator = std::reverse_iterator; + /// the reference type for the pointed-to element + using reference = typename Base::reference; + + /// create reverse iterator from iterator + explicit json_reverse_iterator(const typename base_iterator::iterator_type& it) noexcept + : base_iterator(it) {} + + /// create reverse iterator from base class + explicit json_reverse_iterator(const base_iterator& it) noexcept : base_iterator(it) {} + + /// post-increment (it++) + json_reverse_iterator const operator++(int) + { + return static_cast(base_iterator::operator++(1)); + } + + /// pre-increment (++it) + json_reverse_iterator& operator++() + { + return static_cast(base_iterator::operator++()); + } + + /// post-decrement (it--) + json_reverse_iterator const operator--(int) + { + return static_cast(base_iterator::operator--(1)); + } + + /// pre-decrement (--it) + json_reverse_iterator& operator--() + { + return static_cast(base_iterator::operator--()); + } + + /// add to iterator + json_reverse_iterator& operator+=(difference_type i) + { + return static_cast(base_iterator::operator+=(i)); + } + + /// add to iterator + json_reverse_iterator operator+(difference_type i) const + { + return static_cast(base_iterator::operator+(i)); + } + + /// subtract from iterator + json_reverse_iterator operator-(difference_type i) const + { + return static_cast(base_iterator::operator-(i)); + } + + /// return difference + difference_type operator-(const json_reverse_iterator& other) const + { + return base_iterator(*this) - base_iterator(other); + } + + /// access to successor + reference operator[](difference_type n) const + { + return *(this->operator+(n)); + } + + /// return the key of an object iterator + auto key() const -> decltype(std::declval().key()) + { + auto it = --this->base(); + return it.key(); + } + + /// return the value of an iterator + reference value() const + { + auto it = --this->base(); + return it.operator * (); + } +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // all_of +#include // isdigit +#include // max +#include // accumulate +#include // string +#include // move +#include // vector + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +template +class json_pointer +{ + // allow basic_json to access private members + NLOHMANN_BASIC_JSON_TPL_DECLARATION + friend class basic_json; + + public: + /*! + @brief create JSON pointer + + Create a JSON pointer according to the syntax described in + [Section 3 of RFC6901](https://tools.ietf.org/html/rfc6901#section-3). + + @param[in] s string representing the JSON pointer; if omitted, the empty + string is assumed which references the whole JSON value + + @throw parse_error.107 if the given JSON pointer @a s is nonempty and does + not begin with a slash (`/`); see example below + + @throw parse_error.108 if a tilde (`~`) in the given JSON pointer @a s is + not followed by `0` (representing `~`) or `1` (representing `/`); see + example below + + @liveexample{The example shows the construction several valid JSON pointers + as well as the exceptional behavior.,json_pointer} + + @since version 2.0.0 + */ + explicit json_pointer(const std::string& s = "") + : reference_tokens(split(s)) + {} + + /*! + @brief return a string representation of the JSON pointer + + @invariant For each JSON pointer `ptr`, it holds: + @code {.cpp} + ptr == json_pointer(ptr.to_string()); + @endcode + + @return a string representation of the JSON pointer + + @liveexample{The example shows the result of `to_string`.,json_pointer__to_string} + + @since version 2.0.0 + */ + std::string to_string() const + { + return std::accumulate(reference_tokens.begin(), reference_tokens.end(), + std::string{}, + [](const std::string & a, const std::string & b) + { + return a + "/" + escape(b); + }); + } + + /// @copydoc to_string() + operator std::string() const + { + return to_string(); + } + + /*! + @brief append another JSON pointer at the end of this JSON pointer + + @param[in] ptr JSON pointer to append + @return JSON pointer with @a ptr appended + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Linear in the length of @a ptr. + + @sa @ref operator/=(std::string) to append a reference token + @sa @ref operator/=(std::size_t) to append an array index + @sa @ref operator/(const json_pointer&, const json_pointer&) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(const json_pointer& ptr) + { + reference_tokens.insert(reference_tokens.end(), + ptr.reference_tokens.begin(), + ptr.reference_tokens.end()); + return *this; + } + + /*! + @brief append an unescaped reference token at the end of this JSON pointer + + @param[in] token reference token to append + @return JSON pointer with @a token appended without escaping @a token + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Amortized constant. + + @sa @ref operator/=(const json_pointer&) to append a JSON pointer + @sa @ref operator/=(std::size_t) to append an array index + @sa @ref operator/(const json_pointer&, std::size_t) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(std::string token) + { + push_back(std::move(token)); + return *this; + } + + /*! + @brief append an array index at the end of this JSON pointer + + @param[in] array_idx array index to append + @return JSON pointer with @a array_idx appended + + @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} + + @complexity Amortized constant. + + @sa @ref operator/=(const json_pointer&) to append a JSON pointer + @sa @ref operator/=(std::string) to append a reference token + @sa @ref operator/(const json_pointer&, std::string) for a binary operator + + @since version 3.6.0 + */ + json_pointer& operator/=(std::size_t array_idx) + { + return *this /= std::to_string(array_idx); + } + + /*! + @brief create a new JSON pointer by appending the right JSON pointer at the end of the left JSON pointer + + @param[in] lhs JSON pointer + @param[in] rhs JSON pointer + @return a new JSON pointer with @a rhs appended to @a lhs + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a lhs and @a rhs. + + @sa @ref operator/=(const json_pointer&) to append a JSON pointer + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& lhs, + const json_pointer& rhs) + { + return json_pointer(lhs) /= rhs; + } + + /*! + @brief create a new JSON pointer by appending the unescaped token at the end of the JSON pointer + + @param[in] ptr JSON pointer + @param[in] token reference token + @return a new JSON pointer with unescaped @a token appended to @a ptr + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a ptr. + + @sa @ref operator/=(std::string) to append a reference token + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& ptr, std::string token) + { + return json_pointer(ptr) /= std::move(token); + } + + /*! + @brief create a new JSON pointer by appending the array-index-token at the end of the JSON pointer + + @param[in] ptr JSON pointer + @param[in] array_idx array index + @return a new JSON pointer with @a array_idx appended to @a ptr + + @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} + + @complexity Linear in the length of @a ptr. + + @sa @ref operator/=(std::size_t) to append an array index + + @since version 3.6.0 + */ + friend json_pointer operator/(const json_pointer& ptr, std::size_t array_idx) + { + return json_pointer(ptr) /= array_idx; + } + + /*! + @brief returns the parent of this JSON pointer + + @return parent of this JSON pointer; in case this JSON pointer is the root, + the root itself is returned + + @complexity Linear in the length of the JSON pointer. + + @liveexample{The example shows the result of `parent_pointer` for different + JSON Pointers.,json_pointer__parent_pointer} + + @since version 3.6.0 + */ + json_pointer parent_pointer() const + { + if (empty()) + { + return *this; + } + + json_pointer res = *this; + res.pop_back(); + return res; + } + + /*! + @brief remove last reference token + + @pre not `empty()` + + @liveexample{The example shows the usage of `pop_back`.,json_pointer__pop_back} + + @complexity Constant. + + @throw out_of_range.405 if JSON pointer has no parent + + @since version 3.6.0 + */ + void pop_back() + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); + } + + reference_tokens.pop_back(); + } + + /*! + @brief return last reference token + + @pre not `empty()` + @return last reference token + + @liveexample{The example shows the usage of `back`.,json_pointer__back} + + @complexity Constant. + + @throw out_of_range.405 if JSON pointer has no parent + + @since version 3.6.0 + */ + const std::string& back() const + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); + } + + return reference_tokens.back(); + } + + /*! + @brief append an unescaped token at the end of the reference pointer + + @param[in] token token to add + + @complexity Amortized constant. + + @liveexample{The example shows the result of `push_back` for different + JSON Pointers.,json_pointer__push_back} + + @since version 3.6.0 + */ + void push_back(const std::string& token) + { + reference_tokens.push_back(token); + } + + /// @copydoc push_back(const std::string&) + void push_back(std::string&& token) + { + reference_tokens.push_back(std::move(token)); + } + + /*! + @brief return whether pointer points to the root document + + @return true iff the JSON pointer points to the root document + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example shows the result of `empty` for different JSON + Pointers.,json_pointer__empty} + + @since version 3.6.0 + */ + bool empty() const noexcept + { + return reference_tokens.empty(); + } + + private: + /*! + @param[in] s reference token to be converted into an array index + + @return integer representation of @a s + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index begins not with a digit + @throw out_of_range.404 if string @a s could not be converted to an integer + @throw out_of_range.410 if an array index exceeds size_type + */ + static typename BasicJsonType::size_type array_index(const std::string& s) + { + using size_type = typename BasicJsonType::size_type; + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && s[0] == '0')) + { + JSON_THROW(detail::parse_error::create(106, 0, + "array index '" + s + + "' must not begin with '0'")); + } + + // error condition (cf. RFC 6901, Sect. 4) + if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && !(s[0] >= '1' && s[0] <= '9'))) + { + JSON_THROW(detail::parse_error::create(109, 0, "array index '" + s + "' is not a number")); + } + + std::size_t processed_chars = 0; + unsigned long long res = 0; + JSON_TRY + { + res = std::stoull(s, &processed_chars); + } + JSON_CATCH(std::out_of_range&) + { + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'")); + } + + // check if the string was completely read + if (JSON_HEDLEY_UNLIKELY(processed_chars != s.size())) + { + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'")); + } + + // only triggered on special platforms (like 32bit), see also + // https://github.com/nlohmann/json/pull/2203 + if (res >= static_cast((std::numeric_limits::max)())) + { + JSON_THROW(detail::out_of_range::create(410, "array index " + s + " exceeds size_type")); // LCOV_EXCL_LINE + } + + return static_cast(res); + } + + json_pointer top() const + { + if (JSON_HEDLEY_UNLIKELY(empty())) + { + JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); + } + + json_pointer result = *this; + result.reference_tokens = {reference_tokens[0]}; + return result; + } + + /*! + @brief create and return a reference to the pointed to value + + @complexity Linear in the number of reference tokens. + + @throw parse_error.109 if array index is not a number + @throw type_error.313 if value cannot be unflattened + */ + BasicJsonType& get_and_create(BasicJsonType& j) const + { + auto result = &j; + + // in case no reference tokens exist, return a reference to the JSON value + // j which will be overwritten by a primitive value + for (const auto& reference_token : reference_tokens) + { + switch (result->type()) + { + case detail::value_t::null: + { + if (reference_token == "0") + { + // start a new array if reference token is 0 + result = &result->operator[](0); + } + else + { + // start a new object otherwise + result = &result->operator[](reference_token); + } + break; + } + + case detail::value_t::object: + { + // create an entry in the object + result = &result->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + // create an entry in the array + result = &result->operator[](array_index(reference_token)); + break; + } + + /* + The following code is only reached if there exists a reference + token _and_ the current value is primitive. In this case, we have + an error situation, because primitive values may only occur as + single value; that is, with an empty list of reference tokens. + */ + default: + JSON_THROW(detail::type_error::create(313, "invalid value to unflatten")); + } + } + + return *result; + } + + /*! + @brief return a reference to the pointed to value + + @note This version does not throw if a value is not present, but tries to + create nested values instead. For instance, calling this function + with pointer `"/this/that"` on a null value is equivalent to calling + `operator[]("this").operator[]("that")` on that value, effectively + changing the null value to an object. + + @param[in] ptr a JSON value + + @return reference to the JSON value pointed to by the JSON pointer + + @complexity Linear in the length of the JSON pointer. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + BasicJsonType& get_unchecked(BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + // convert null values to arrays or objects before continuing + if (ptr->is_null()) + { + // check if reference token is a number + const bool nums = + std::all_of(reference_token.begin(), reference_token.end(), + [](const unsigned char x) + { + return std::isdigit(x); + }); + + // change value to array for numbers or "-" or to object otherwise + *ptr = (nums || reference_token == "-") + ? detail::value_t::array + : detail::value_t::object; + } + + switch (ptr->type()) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (reference_token == "-") + { + // explicitly treat "-" as index beyond the end + ptr = &ptr->operator[](ptr->m_value.array->size()); + } + else + { + // convert array index to number; unchecked access + ptr = &ptr->operator[](array_index(reference_token)); + } + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + BasicJsonType& get_checked(BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // note: at performs range check + ptr = &ptr->at(array_index(reference_token)); + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; + } + + /*! + @brief return a const reference to the pointed to value + + @param[in] ptr a JSON value + + @return const reference to the JSON value pointed to by the JSON + pointer + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + const BasicJsonType& get_unchecked(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // use unchecked object access + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" cannot be used for const access + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // use unchecked array access + ptr = &ptr->operator[](array_index(reference_token)); + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + */ + const BasicJsonType& get_checked(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + // note: at performs range check + ptr = &ptr->at(reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + JSON_THROW(detail::out_of_range::create(402, + "array index '-' (" + std::to_string(ptr->m_value.array->size()) + + ") is out of range")); + } + + // note: at performs range check + ptr = &ptr->at(array_index(reference_token)); + break; + } + + default: + JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); + } + } + + return *ptr; + } + + /*! + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + */ + bool contains(const BasicJsonType* ptr) const + { + for (const auto& reference_token : reference_tokens) + { + switch (ptr->type()) + { + case detail::value_t::object: + { + if (!ptr->contains(reference_token)) + { + // we did not find the key in the object + return false; + } + + ptr = &ptr->operator[](reference_token); + break; + } + + case detail::value_t::array: + { + if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) + { + // "-" always fails the range check + return false; + } + if (JSON_HEDLEY_UNLIKELY(reference_token.size() == 1 && !("0" <= reference_token && reference_token <= "9"))) + { + // invalid char + return false; + } + if (JSON_HEDLEY_UNLIKELY(reference_token.size() > 1)) + { + if (JSON_HEDLEY_UNLIKELY(!('1' <= reference_token[0] && reference_token[0] <= '9'))) + { + // first char should be between '1' and '9' + return false; + } + for (std::size_t i = 1; i < reference_token.size(); i++) + { + if (JSON_HEDLEY_UNLIKELY(!('0' <= reference_token[i] && reference_token[i] <= '9'))) + { + // other char should be between '0' and '9' + return false; + } + } + } + + const auto idx = array_index(reference_token); + if (idx >= ptr->size()) + { + // index out of range + return false; + } + + ptr = &ptr->operator[](idx); + break; + } + + default: + { + // we do not expect primitive values if there is still a + // reference token to process + return false; + } + } + } + + // no reference token left means we found a primitive value + return true; + } + + /*! + @brief split the string input to reference tokens + + @note This function is only called by the json_pointer constructor. + All exceptions below are documented there. + + @throw parse_error.107 if the pointer is not empty or begins with '/' + @throw parse_error.108 if character '~' is not followed by '0' or '1' + */ + static std::vector split(const std::string& reference_string) + { + std::vector result; + + // special case: empty reference string -> no reference tokens + if (reference_string.empty()) + { + return result; + } + + // check if nonempty reference string begins with slash + if (JSON_HEDLEY_UNLIKELY(reference_string[0] != '/')) + { + JSON_THROW(detail::parse_error::create(107, 1, + "JSON pointer must be empty or begin with '/' - was: '" + + reference_string + "'")); + } + + // extract the reference tokens: + // - slash: position of the last read slash (or end of string) + // - start: position after the previous slash + for ( + // search for the first slash after the first character + std::size_t slash = reference_string.find_first_of('/', 1), + // set the beginning of the first reference token + start = 1; + // we can stop if start == 0 (if slash == std::string::npos) + start != 0; + // set the beginning of the next reference token + // (will eventually be 0 if slash == std::string::npos) + start = (slash == std::string::npos) ? 0 : slash + 1, + // find next slash + slash = reference_string.find_first_of('/', start)) + { + // use the text between the beginning of the reference token + // (start) and the last slash (slash). + auto reference_token = reference_string.substr(start, slash - start); + + // check reference tokens are properly escaped + for (std::size_t pos = reference_token.find_first_of('~'); + pos != std::string::npos; + pos = reference_token.find_first_of('~', pos + 1)) + { + JSON_ASSERT(reference_token[pos] == '~'); + + // ~ must be followed by 0 or 1 + if (JSON_HEDLEY_UNLIKELY(pos == reference_token.size() - 1 || + (reference_token[pos + 1] != '0' && + reference_token[pos + 1] != '1'))) + { + JSON_THROW(detail::parse_error::create(108, 0, "escape character '~' must be followed with '0' or '1'")); + } + } + + // finally, store the reference token + unescape(reference_token); + result.push_back(reference_token); + } + + return result; + } + + /*! + @brief replace all occurrences of a substring by another string + + @param[in,out] s the string to manipulate; changed so that all + occurrences of @a f are replaced with @a t + @param[in] f the substring to replace with @a t + @param[in] t the string to replace @a f + + @pre The search string @a f must not be empty. **This precondition is + enforced with an assertion.** + + @since version 2.0.0 + */ + static void replace_substring(std::string& s, const std::string& f, + const std::string& t) + { + JSON_ASSERT(!f.empty()); + for (auto pos = s.find(f); // find first occurrence of f + pos != std::string::npos; // make sure f was found + s.replace(pos, f.size(), t), // replace with t, and + pos = s.find(f, pos + t.size())) // find next occurrence of f + {} + } + + /// escape "~" to "~0" and "/" to "~1" + static std::string escape(std::string s) + { + replace_substring(s, "~", "~0"); + replace_substring(s, "/", "~1"); + return s; + } + + /// unescape "~1" to tilde and "~0" to slash (order is important!) + static void unescape(std::string& s) + { + replace_substring(s, "~1", "/"); + replace_substring(s, "~0", "~"); + } + + /*! + @param[in] reference_string the reference string to the current value + @param[in] value the value to consider + @param[in,out] result the result object to insert values to + + @note Empty objects or arrays are flattened to `null`. + */ + static void flatten(const std::string& reference_string, + const BasicJsonType& value, + BasicJsonType& result) + { + switch (value.type()) + { + case detail::value_t::array: + { + if (value.m_value.array->empty()) + { + // flatten empty array as null + result[reference_string] = nullptr; + } + else + { + // iterate array and use index as reference string + for (std::size_t i = 0; i < value.m_value.array->size(); ++i) + { + flatten(reference_string + "/" + std::to_string(i), + value.m_value.array->operator[](i), result); + } + } + break; + } + + case detail::value_t::object: + { + if (value.m_value.object->empty()) + { + // flatten empty object as null + result[reference_string] = nullptr; + } + else + { + // iterate object and use keys as reference string + for (const auto& element : *value.m_value.object) + { + flatten(reference_string + "/" + escape(element.first), element.second, result); + } + } + break; + } + + default: + { + // add primitive value with its reference string + result[reference_string] = value; + break; + } + } + } + + /*! + @param[in] value flattened JSON + + @return unflattened JSON + + @throw parse_error.109 if array index is not a number + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + @throw type_error.313 if value cannot be unflattened + */ + static BasicJsonType + unflatten(const BasicJsonType& value) + { + if (JSON_HEDLEY_UNLIKELY(!value.is_object())) + { + JSON_THROW(detail::type_error::create(314, "only objects can be unflattened")); + } + + BasicJsonType result; + + // iterate the JSON object values + for (const auto& element : *value.m_value.object) + { + if (JSON_HEDLEY_UNLIKELY(!element.second.is_primitive())) + { + JSON_THROW(detail::type_error::create(315, "values in object must be primitive")); + } + + // assign value to reference pointed to by JSON pointer; Note that if + // the JSON pointer is "" (i.e., points to the whole value), function + // get_and_create returns a reference to result itself. An assignment + // will then create a primitive value. + json_pointer(element.first).get_and_create(result) = element.second; + } + + return result; + } + + /*! + @brief compares two JSON pointers for equality + + @param[in] lhs JSON pointer to compare + @param[in] rhs JSON pointer to compare + @return whether @a lhs is equal to @a rhs + + @complexity Linear in the length of the JSON pointer + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + */ + friend bool operator==(json_pointer const& lhs, + json_pointer const& rhs) noexcept + { + return lhs.reference_tokens == rhs.reference_tokens; + } + + /*! + @brief compares two JSON pointers for inequality + + @param[in] lhs JSON pointer to compare + @param[in] rhs JSON pointer to compare + @return whether @a lhs is not equal @a rhs + + @complexity Linear in the length of the JSON pointer + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + */ + friend bool operator!=(json_pointer const& lhs, + json_pointer const& rhs) noexcept + { + return !(lhs == rhs); + } + + /// the reference tokens + std::vector reference_tokens; +}; +} // namespace nlohmann + +// #include + + +#include +#include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +template +class json_ref +{ + public: + using value_type = BasicJsonType; + + json_ref(value_type&& value) + : owned_value(std::move(value)) + , value_ref(&owned_value) + , is_rvalue(true) + {} + + json_ref(const value_type& value) + : value_ref(const_cast(&value)) + , is_rvalue(false) + {} + + json_ref(std::initializer_list init) + : owned_value(init) + , value_ref(&owned_value) + , is_rvalue(true) + {} + + template < + class... Args, + enable_if_t::value, int> = 0 > + json_ref(Args && ... args) + : owned_value(std::forward(args)...) + , value_ref(&owned_value) + , is_rvalue(true) + {} + + // class should be movable only + json_ref(json_ref&&) = default; + json_ref(const json_ref&) = delete; + json_ref& operator=(const json_ref&) = delete; + json_ref& operator=(json_ref&&) = delete; + ~json_ref() = default; + + value_type moved_or_copied() const + { + if (is_rvalue) + { + return std::move(*value_ref); + } + return *value_ref; + } + + value_type const& operator*() const + { + return *static_cast(value_ref); + } + + value_type const* operator->() const + { + return static_cast(value_ref); + } + + private: + mutable value_type owned_value = nullptr; + value_type* value_ref = nullptr; + const bool is_rvalue = true; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + + +#include // reverse +#include // array +#include // uint8_t, uint16_t, uint32_t, uint64_t +#include // memcpy +#include // numeric_limits +#include // string +#include // isnan, isinf + +// #include + +// #include + +// #include + + +#include // copy +#include // size_t +#include // streamsize +#include // back_inserter +#include // shared_ptr, make_shared +#include // basic_ostream +#include // basic_string +#include // vector +// #include + + +namespace nlohmann +{ +namespace detail +{ +/// abstract output adapter interface +template struct output_adapter_protocol +{ + virtual void write_character(CharType c) = 0; + virtual void write_characters(const CharType* s, std::size_t length) = 0; + virtual ~output_adapter_protocol() = default; +}; + +/// a type to simplify interfaces +template +using output_adapter_t = std::shared_ptr>; + +/// output adapter for byte vectors +template +class output_vector_adapter : public output_adapter_protocol +{ + public: + explicit output_vector_adapter(std::vector& vec) noexcept + : v(vec) + {} + + void write_character(CharType c) override + { + v.push_back(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + std::copy(s, s + length, std::back_inserter(v)); + } + + private: + std::vector& v; +}; + +/// output adapter for output streams +template +class output_stream_adapter : public output_adapter_protocol +{ + public: + explicit output_stream_adapter(std::basic_ostream& s) noexcept + : stream(s) + {} + + void write_character(CharType c) override + { + stream.put(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + stream.write(s, static_cast(length)); + } + + private: + std::basic_ostream& stream; +}; + +/// output adapter for basic_string +template> +class output_string_adapter : public output_adapter_protocol +{ + public: + explicit output_string_adapter(StringType& s) noexcept + : str(s) + {} + + void write_character(CharType c) override + { + str.push_back(c); + } + + JSON_HEDLEY_NON_NULL(2) + void write_characters(const CharType* s, std::size_t length) override + { + str.append(s, length); + } + + private: + StringType& str; +}; + +template> +class output_adapter +{ + public: + output_adapter(std::vector& vec) + : oa(std::make_shared>(vec)) {} + + output_adapter(std::basic_ostream& s) + : oa(std::make_shared>(s)) {} + + output_adapter(StringType& s) + : oa(std::make_shared>(s)) {} + + operator output_adapter_t() + { + return oa; + } + + private: + output_adapter_t oa = nullptr; +}; +} // namespace detail +} // namespace nlohmann + + +namespace nlohmann +{ +namespace detail +{ +/////////////////// +// binary writer // +/////////////////// + +/*! +@brief serialization to CBOR and MessagePack values +*/ +template +class binary_writer +{ + using string_t = typename BasicJsonType::string_t; + using binary_t = typename BasicJsonType::binary_t; + using number_float_t = typename BasicJsonType::number_float_t; + + public: + /*! + @brief create a binary writer + + @param[in] adapter output adapter to write to + */ + explicit binary_writer(output_adapter_t adapter) : oa(adapter) + { + JSON_ASSERT(oa); + } + + /*! + @param[in] j JSON value to serialize + @pre j.type() == value_t::object + */ + void write_bson(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::object: + { + write_bson_object(*j.m_value.object); + break; + } + + default: + { + JSON_THROW(type_error::create(317, "to serialize to BSON, top-level type must be object, but is " + std::string(j.type_name()))); + } + } + } + + /*! + @param[in] j JSON value to serialize + */ + void write_cbor(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: + { + oa->write_character(to_char_type(0xF6)); + break; + } + + case value_t::boolean: + { + oa->write_character(j.m_value.boolean + ? to_char_type(0xF5) + : to_char_type(0xF4)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // CBOR does not differentiate between positive signed + // integers and unsigned integers. Therefore, we used the + // code from the value_t::number_unsigned case here. + if (j.m_value.number_integer <= 0x17) + { + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x18)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x19)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x1A)); + write_number(static_cast(j.m_value.number_integer)); + } + else + { + oa->write_character(to_char_type(0x1B)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + // The conversions below encode the sign in the first + // byte, and the value is converted to a positive number. + const auto positive_number = -1 - j.m_value.number_integer; + if (j.m_value.number_integer >= -24) + { + write_number(static_cast(0x20 + positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x38)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x39)); + write_number(static_cast(positive_number)); + } + else if (positive_number <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x3A)); + write_number(static_cast(positive_number)); + } + else + { + oa->write_character(to_char_type(0x3B)); + write_number(static_cast(positive_number)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned <= 0x17) + { + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x18)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x19)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x1A)); + write_number(static_cast(j.m_value.number_unsigned)); + } + else + { + oa->write_character(to_char_type(0x1B)); + write_number(static_cast(j.m_value.number_unsigned)); + } + break; + } + + case value_t::number_float: + { + if (std::isnan(j.m_value.number_float)) + { + // NaN is 0xf97e00 in CBOR + oa->write_character(to_char_type(0xF9)); + oa->write_character(to_char_type(0x7E)); + oa->write_character(to_char_type(0x00)); + } + else if (std::isinf(j.m_value.number_float)) + { + // Infinity is 0xf97c00, -Infinity is 0xf9fc00 + oa->write_character(to_char_type(0xf9)); + oa->write_character(j.m_value.number_float > 0 ? to_char_type(0x7C) : to_char_type(0xFC)); + oa->write_character(to_char_type(0x00)); + } + else + { + write_compact_float(j.m_value.number_float, detail::input_format_t::cbor); + } + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 0x17) + { + write_number(static_cast(0x60 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x78)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x79)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x7A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x7B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 0x17) + { + write_number(static_cast(0x80 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x98)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x99)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x9A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x9B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_cbor(el); + } + break; + } + + case value_t::binary: + { + if (j.m_value.binary->has_subtype()) + { + write_number(static_cast(0xd8)); + write_number(j.m_value.binary->subtype()); + } + + // step 1: write control byte and the binary array size + const auto N = j.m_value.binary->size(); + if (N <= 0x17) + { + write_number(static_cast(0x40 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x58)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x59)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x5A)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0x5B)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + N); + + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 0x17) + { + write_number(static_cast(0xA0 + N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xB8)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xB9)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xBA)); + write_number(static_cast(N)); + } + // LCOV_EXCL_START + else if (N <= (std::numeric_limits::max)()) + { + oa->write_character(to_char_type(0xBB)); + write_number(static_cast(N)); + } + // LCOV_EXCL_STOP + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_cbor(el.first); + write_cbor(el.second); + } + break; + } + + default: + break; + } + } + + /*! + @param[in] j JSON value to serialize + */ + void write_msgpack(const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::null: // nil + { + oa->write_character(to_char_type(0xC0)); + break; + } + + case value_t::boolean: // true and false + { + oa->write_character(j.m_value.boolean + ? to_char_type(0xC3) + : to_char_type(0xC2)); + break; + } + + case value_t::number_integer: + { + if (j.m_value.number_integer >= 0) + { + // MessagePack does not differentiate between positive + // signed integers and unsigned integers. Therefore, we used + // the code from the value_t::number_unsigned case here. + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(to_char_type(0xCC)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(to_char_type(0xCD)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(to_char_type(0xCE)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(to_char_type(0xCF)); + write_number(static_cast(j.m_value.number_integer)); + } + } + else + { + if (j.m_value.number_integer >= -32) + { + // negative fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 8 + oa->write_character(to_char_type(0xD0)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 16 + oa->write_character(to_char_type(0xD1)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 32 + oa->write_character(to_char_type(0xD2)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_integer >= (std::numeric_limits::min)() && + j.m_value.number_integer <= (std::numeric_limits::max)()) + { + // int 64 + oa->write_character(to_char_type(0xD3)); + write_number(static_cast(j.m_value.number_integer)); + } + } + break; + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned < 128) + { + // positive fixnum + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 8 + oa->write_character(to_char_type(0xCC)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 16 + oa->write_character(to_char_type(0xCD)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 32 + oa->write_character(to_char_type(0xCE)); + write_number(static_cast(j.m_value.number_integer)); + } + else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) + { + // uint 64 + oa->write_character(to_char_type(0xCF)); + write_number(static_cast(j.m_value.number_integer)); + } + break; + } + + case value_t::number_float: + { + write_compact_float(j.m_value.number_float, detail::input_format_t::msgpack); + break; + } + + case value_t::string: + { + // step 1: write control byte and the string length + const auto N = j.m_value.string->size(); + if (N <= 31) + { + // fixstr + write_number(static_cast(0xA0 | N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 8 + oa->write_character(to_char_type(0xD9)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 16 + oa->write_character(to_char_type(0xDA)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // str 32 + oa->write_character(to_char_type(0xDB)); + write_number(static_cast(N)); + } + + // step 2: write the string + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + // step 1: write control byte and the array size + const auto N = j.m_value.array->size(); + if (N <= 15) + { + // fixarray + write_number(static_cast(0x90 | N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // array 16 + oa->write_character(to_char_type(0xDC)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // array 32 + oa->write_character(to_char_type(0xDD)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.array) + { + write_msgpack(el); + } + break; + } + + case value_t::binary: + { + // step 0: determine if the binary type has a set subtype to + // determine whether or not to use the ext or fixext types + const bool use_ext = j.m_value.binary->has_subtype(); + + // step 1: write control byte and the byte string length + const auto N = j.m_value.binary->size(); + if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type{}; + bool fixed = true; + if (use_ext) + { + switch (N) + { + case 1: + output_type = 0xD4; // fixext 1 + break; + case 2: + output_type = 0xD5; // fixext 2 + break; + case 4: + output_type = 0xD6; // fixext 4 + break; + case 8: + output_type = 0xD7; // fixext 8 + break; + case 16: + output_type = 0xD8; // fixext 16 + break; + default: + output_type = 0xC7; // ext 8 + fixed = false; + break; + } + + } + else + { + output_type = 0xC4; // bin 8 + fixed = false; + } + + oa->write_character(to_char_type(output_type)); + if (!fixed) + { + write_number(static_cast(N)); + } + } + else if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type = use_ext + ? 0xC8 // ext 16 + : 0xC5; // bin 16 + + oa->write_character(to_char_type(output_type)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + std::uint8_t output_type = use_ext + ? 0xC9 // ext 32 + : 0xC6; // bin 32 + + oa->write_character(to_char_type(output_type)); + write_number(static_cast(N)); + } + + // step 1.5: if this is an ext type, write the subtype + if (use_ext) + { + write_number(static_cast(j.m_value.binary->subtype())); + } + + // step 2: write the byte string + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + N); + + break; + } + + case value_t::object: + { + // step 1: write control byte and the object size + const auto N = j.m_value.object->size(); + if (N <= 15) + { + // fixmap + write_number(static_cast(0x80 | (N & 0xF))); + } + else if (N <= (std::numeric_limits::max)()) + { + // map 16 + oa->write_character(to_char_type(0xDE)); + write_number(static_cast(N)); + } + else if (N <= (std::numeric_limits::max)()) + { + // map 32 + oa->write_character(to_char_type(0xDF)); + write_number(static_cast(N)); + } + + // step 2: write each element + for (const auto& el : *j.m_value.object) + { + write_msgpack(el.first); + write_msgpack(el.second); + } + break; + } + + default: + break; + } + } + + /*! + @param[in] j JSON value to serialize + @param[in] use_count whether to use '#' prefixes (optimized format) + @param[in] use_type whether to use '$' prefixes (optimized format) + @param[in] add_prefix whether prefixes need to be used for this value + */ + void write_ubjson(const BasicJsonType& j, const bool use_count, + const bool use_type, const bool add_prefix = true) + { + switch (j.type()) + { + case value_t::null: + { + if (add_prefix) + { + oa->write_character(to_char_type('Z')); + } + break; + } + + case value_t::boolean: + { + if (add_prefix) + { + oa->write_character(j.m_value.boolean + ? to_char_type('T') + : to_char_type('F')); + } + break; + } + + case value_t::number_integer: + { + write_number_with_ubjson_prefix(j.m_value.number_integer, add_prefix); + break; + } + + case value_t::number_unsigned: + { + write_number_with_ubjson_prefix(j.m_value.number_unsigned, add_prefix); + break; + } + + case value_t::number_float: + { + write_number_with_ubjson_prefix(j.m_value.number_float, add_prefix); + break; + } + + case value_t::string: + { + if (add_prefix) + { + oa->write_character(to_char_type('S')); + } + write_number_with_ubjson_prefix(j.m_value.string->size(), true); + oa->write_characters( + reinterpret_cast(j.m_value.string->c_str()), + j.m_value.string->size()); + break; + } + + case value_t::array: + { + if (add_prefix) + { + oa->write_character(to_char_type('[')); + } + + bool prefix_required = true; + if (use_type && !j.m_value.array->empty()) + { + JSON_ASSERT(use_count); + const CharType first_prefix = ubjson_prefix(j.front()); + const bool same_prefix = std::all_of(j.begin() + 1, j.end(), + [this, first_prefix](const BasicJsonType & v) + { + return ubjson_prefix(v) == first_prefix; + }); + + if (same_prefix) + { + prefix_required = false; + oa->write_character(to_char_type('$')); + oa->write_character(first_prefix); + } + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.array->size(), true); + } + + for (const auto& el : *j.m_value.array) + { + write_ubjson(el, use_count, use_type, prefix_required); + } + + if (!use_count) + { + oa->write_character(to_char_type(']')); + } + + break; + } + + case value_t::binary: + { + if (add_prefix) + { + oa->write_character(to_char_type('[')); + } + + if (use_type && !j.m_value.binary->empty()) + { + JSON_ASSERT(use_count); + oa->write_character(to_char_type('$')); + oa->write_character('U'); + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.binary->size(), true); + } + + if (use_type) + { + oa->write_characters( + reinterpret_cast(j.m_value.binary->data()), + j.m_value.binary->size()); + } + else + { + for (size_t i = 0; i < j.m_value.binary->size(); ++i) + { + oa->write_character(to_char_type('U')); + oa->write_character(j.m_value.binary->data()[i]); + } + } + + if (!use_count) + { + oa->write_character(to_char_type(']')); + } + + break; + } + + case value_t::object: + { + if (add_prefix) + { + oa->write_character(to_char_type('{')); + } + + bool prefix_required = true; + if (use_type && !j.m_value.object->empty()) + { + JSON_ASSERT(use_count); + const CharType first_prefix = ubjson_prefix(j.front()); + const bool same_prefix = std::all_of(j.begin(), j.end(), + [this, first_prefix](const BasicJsonType & v) + { + return ubjson_prefix(v) == first_prefix; + }); + + if (same_prefix) + { + prefix_required = false; + oa->write_character(to_char_type('$')); + oa->write_character(first_prefix); + } + } + + if (use_count) + { + oa->write_character(to_char_type('#')); + write_number_with_ubjson_prefix(j.m_value.object->size(), true); + } + + for (const auto& el : *j.m_value.object) + { + write_number_with_ubjson_prefix(el.first.size(), true); + oa->write_characters( + reinterpret_cast(el.first.c_str()), + el.first.size()); + write_ubjson(el.second, use_count, use_type, prefix_required); + } + + if (!use_count) + { + oa->write_character(to_char_type('}')); + } + + break; + } + + default: + break; + } + } + + private: + ////////// + // BSON // + ////////// + + /*! + @return The size of a BSON document entry header, including the id marker + and the entry name size (and its null-terminator). + */ + static std::size_t calc_bson_entry_header_size(const string_t& name) + { + const auto it = name.find(static_cast(0)); + if (JSON_HEDLEY_UNLIKELY(it != BasicJsonType::string_t::npos)) + { + JSON_THROW(out_of_range::create(409, + "BSON key cannot contain code point U+0000 (at byte " + std::to_string(it) + ")")); + } + + return /*id*/ 1ul + name.size() + /*zero-terminator*/1u; + } + + /*! + @brief Writes the given @a element_type and @a name to the output adapter + */ + void write_bson_entry_header(const string_t& name, + const std::uint8_t element_type) + { + oa->write_character(to_char_type(element_type)); // boolean + oa->write_characters( + reinterpret_cast(name.c_str()), + name.size() + 1u); + } + + /*! + @brief Writes a BSON element with key @a name and boolean value @a value + */ + void write_bson_boolean(const string_t& name, + const bool value) + { + write_bson_entry_header(name, 0x08); + oa->write_character(value ? to_char_type(0x01) : to_char_type(0x00)); + } + + /*! + @brief Writes a BSON element with key @a name and double value @a value + */ + void write_bson_double(const string_t& name, + const double value) + { + write_bson_entry_header(name, 0x01); + write_number(value); + } + + /*! + @return The size of the BSON-encoded string in @a value + */ + static std::size_t calc_bson_string_size(const string_t& value) + { + return sizeof(std::int32_t) + value.size() + 1ul; + } + + /*! + @brief Writes a BSON element with key @a name and string value @a value + */ + void write_bson_string(const string_t& name, + const string_t& value) + { + write_bson_entry_header(name, 0x02); + + write_number(static_cast(value.size() + 1ul)); + oa->write_characters( + reinterpret_cast(value.c_str()), + value.size() + 1); + } + + /*! + @brief Writes a BSON element with key @a name and null value + */ + void write_bson_null(const string_t& name) + { + write_bson_entry_header(name, 0x0A); + } + + /*! + @return The size of the BSON-encoded integer @a value + */ + static std::size_t calc_bson_integer_size(const std::int64_t value) + { + return (std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)() + ? sizeof(std::int32_t) + : sizeof(std::int64_t); + } + + /*! + @brief Writes a BSON element with key @a name and integer @a value + */ + void write_bson_integer(const string_t& name, + const std::int64_t value) + { + if ((std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)()) + { + write_bson_entry_header(name, 0x10); // int32 + write_number(static_cast(value)); + } + else + { + write_bson_entry_header(name, 0x12); // int64 + write_number(static_cast(value)); + } + } + + /*! + @return The size of the BSON-encoded unsigned integer in @a j + */ + static constexpr std::size_t calc_bson_unsigned_size(const std::uint64_t value) noexcept + { + return (value <= static_cast((std::numeric_limits::max)())) + ? sizeof(std::int32_t) + : sizeof(std::int64_t); + } + + /*! + @brief Writes a BSON element with key @a name and unsigned @a value + */ + void write_bson_unsigned(const string_t& name, + const std::uint64_t value) + { + if (value <= static_cast((std::numeric_limits::max)())) + { + write_bson_entry_header(name, 0x10 /* int32 */); + write_number(static_cast(value)); + } + else if (value <= static_cast((std::numeric_limits::max)())) + { + write_bson_entry_header(name, 0x12 /* int64 */); + write_number(static_cast(value)); + } + else + { + JSON_THROW(out_of_range::create(407, "integer number " + std::to_string(value) + " cannot be represented by BSON as it does not fit int64")); + } + } + + /*! + @brief Writes a BSON element with key @a name and object @a value + */ + void write_bson_object_entry(const string_t& name, + const typename BasicJsonType::object_t& value) + { + write_bson_entry_header(name, 0x03); // object + write_bson_object(value); + } + + /*! + @return The size of the BSON-encoded array @a value + */ + static std::size_t calc_bson_array_size(const typename BasicJsonType::array_t& value) + { + std::size_t array_index = 0ul; + + const std::size_t embedded_document_size = std::accumulate(std::begin(value), std::end(value), std::size_t(0), [&array_index](std::size_t result, const typename BasicJsonType::array_t::value_type & el) + { + return result + calc_bson_element_size(std::to_string(array_index++), el); + }); + + return sizeof(std::int32_t) + embedded_document_size + 1ul; + } + + /*! + @return The size of the BSON-encoded binary array @a value + */ + static std::size_t calc_bson_binary_size(const typename BasicJsonType::binary_t& value) + { + return sizeof(std::int32_t) + value.size() + 1ul; + } + + /*! + @brief Writes a BSON element with key @a name and array @a value + */ + void write_bson_array(const string_t& name, + const typename BasicJsonType::array_t& value) + { + write_bson_entry_header(name, 0x04); // array + write_number(static_cast(calc_bson_array_size(value))); + + std::size_t array_index = 0ul; + + for (const auto& el : value) + { + write_bson_element(std::to_string(array_index++), el); + } + + oa->write_character(to_char_type(0x00)); + } + + /*! + @brief Writes a BSON element with key @a name and binary value @a value + */ + void write_bson_binary(const string_t& name, + const binary_t& value) + { + write_bson_entry_header(name, 0x05); + + write_number(static_cast(value.size())); + write_number(value.has_subtype() ? value.subtype() : std::uint8_t(0x00)); + + oa->write_characters(reinterpret_cast(value.data()), value.size()); + } + + /*! + @brief Calculates the size necessary to serialize the JSON value @a j with its @a name + @return The calculated size for the BSON document entry for @a j with the given @a name. + */ + static std::size_t calc_bson_element_size(const string_t& name, + const BasicJsonType& j) + { + const auto header_size = calc_bson_entry_header_size(name); + switch (j.type()) + { + case value_t::object: + return header_size + calc_bson_object_size(*j.m_value.object); + + case value_t::array: + return header_size + calc_bson_array_size(*j.m_value.array); + + case value_t::binary: + return header_size + calc_bson_binary_size(*j.m_value.binary); + + case value_t::boolean: + return header_size + 1ul; + + case value_t::number_float: + return header_size + 8ul; + + case value_t::number_integer: + return header_size + calc_bson_integer_size(j.m_value.number_integer); + + case value_t::number_unsigned: + return header_size + calc_bson_unsigned_size(j.m_value.number_unsigned); + + case value_t::string: + return header_size + calc_bson_string_size(*j.m_value.string); + + case value_t::null: + return header_size + 0ul; + + // LCOV_EXCL_START + default: + JSON_ASSERT(false); + return 0ul; + // LCOV_EXCL_STOP + } + } + + /*! + @brief Serializes the JSON value @a j to BSON and associates it with the + key @a name. + @param name The name to associate with the JSON entity @a j within the + current BSON document + @return The size of the BSON entry + */ + void write_bson_element(const string_t& name, + const BasicJsonType& j) + { + switch (j.type()) + { + case value_t::object: + return write_bson_object_entry(name, *j.m_value.object); + + case value_t::array: + return write_bson_array(name, *j.m_value.array); + + case value_t::binary: + return write_bson_binary(name, *j.m_value.binary); + + case value_t::boolean: + return write_bson_boolean(name, j.m_value.boolean); + + case value_t::number_float: + return write_bson_double(name, j.m_value.number_float); + + case value_t::number_integer: + return write_bson_integer(name, j.m_value.number_integer); + + case value_t::number_unsigned: + return write_bson_unsigned(name, j.m_value.number_unsigned); + + case value_t::string: + return write_bson_string(name, *j.m_value.string); + + case value_t::null: + return write_bson_null(name); + + // LCOV_EXCL_START + default: + JSON_ASSERT(false); + return; + // LCOV_EXCL_STOP + } + } + + /*! + @brief Calculates the size of the BSON serialization of the given + JSON-object @a j. + @param[in] j JSON value to serialize + @pre j.type() == value_t::object + */ + static std::size_t calc_bson_object_size(const typename BasicJsonType::object_t& value) + { + std::size_t document_size = std::accumulate(value.begin(), value.end(), std::size_t(0), + [](size_t result, const typename BasicJsonType::object_t::value_type & el) + { + return result += calc_bson_element_size(el.first, el.second); + }); + + return sizeof(std::int32_t) + document_size + 1ul; + } + + /*! + @param[in] j JSON value to serialize + @pre j.type() == value_t::object + */ + void write_bson_object(const typename BasicJsonType::object_t& value) + { + write_number(static_cast(calc_bson_object_size(value))); + + for (const auto& el : value) + { + write_bson_element(el.first, el.second); + } + + oa->write_character(to_char_type(0x00)); + } + + ////////// + // CBOR // + ////////// + + static constexpr CharType get_cbor_float_prefix(float /*unused*/) + { + return to_char_type(0xFA); // Single-Precision Float + } + + static constexpr CharType get_cbor_float_prefix(double /*unused*/) + { + return to_char_type(0xFB); // Double-Precision Float + } + + ///////////// + // MsgPack // + ///////////// + + static constexpr CharType get_msgpack_float_prefix(float /*unused*/) + { + return to_char_type(0xCA); // float 32 + } + + static constexpr CharType get_msgpack_float_prefix(double /*unused*/) + { + return to_char_type(0xCB); // float 64 + } + + //////////// + // UBJSON // + //////////// + + // UBJSON: write number (floating point) + template::value, int>::type = 0> + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if (add_prefix) + { + oa->write_character(get_ubjson_float_prefix(n)); + } + write_number(n); + } + + // UBJSON: write number (unsigned integer) + template::value, int>::type = 0> + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('i')); // int8 + } + write_number(static_cast(n)); + } + else if (n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('U')); // uint8 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('I')); // int16 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('l')); // int32 + } + write_number(static_cast(n)); + } + else if (n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('L')); // int64 + } + write_number(static_cast(n)); + } + else + { + if (add_prefix) + { + oa->write_character(to_char_type('H')); // high-precision number + } + + const auto number = BasicJsonType(n).dump(); + write_number_with_ubjson_prefix(number.size(), true); + for (std::size_t i = 0; i < number.size(); ++i) + { + oa->write_character(to_char_type(static_cast(number[i]))); + } + } + } + + // UBJSON: write number (signed integer) + template < typename NumberType, typename std::enable_if < + std::is_signed::value&& + !std::is_floating_point::value, int >::type = 0 > + void write_number_with_ubjson_prefix(const NumberType n, + const bool add_prefix) + { + if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('i')); // int8 + } + write_number(static_cast(n)); + } + else if (static_cast((std::numeric_limits::min)()) <= n && n <= static_cast((std::numeric_limits::max)())) + { + if (add_prefix) + { + oa->write_character(to_char_type('U')); // uint8 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('I')); // int16 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('l')); // int32 + } + write_number(static_cast(n)); + } + else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) + { + if (add_prefix) + { + oa->write_character(to_char_type('L')); // int64 + } + write_number(static_cast(n)); + } + // LCOV_EXCL_START + else + { + if (add_prefix) + { + oa->write_character(to_char_type('H')); // high-precision number + } + + const auto number = BasicJsonType(n).dump(); + write_number_with_ubjson_prefix(number.size(), true); + for (std::size_t i = 0; i < number.size(); ++i) + { + oa->write_character(to_char_type(static_cast(number[i]))); + } + } + // LCOV_EXCL_STOP + } + + /*! + @brief determine the type prefix of container values + */ + CharType ubjson_prefix(const BasicJsonType& j) const noexcept + { + switch (j.type()) + { + case value_t::null: + return 'Z'; + + case value_t::boolean: + return j.m_value.boolean ? 'T' : 'F'; + + case value_t::number_integer: + { + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'i'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'U'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'I'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'l'; + } + if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) + { + return 'L'; + } + // anything else is treated as high-precision number + return 'H'; // LCOV_EXCL_LINE + } + + case value_t::number_unsigned: + { + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'i'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'U'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'I'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'l'; + } + if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) + { + return 'L'; + } + // anything else is treated as high-precision number + return 'H'; // LCOV_EXCL_LINE + } + + case value_t::number_float: + return get_ubjson_float_prefix(j.m_value.number_float); + + case value_t::string: + return 'S'; + + case value_t::array: // fallthrough + case value_t::binary: + return '['; + + case value_t::object: + return '{'; + + default: // discarded values + return 'N'; + } + } + + static constexpr CharType get_ubjson_float_prefix(float /*unused*/) + { + return 'd'; // float 32 + } + + static constexpr CharType get_ubjson_float_prefix(double /*unused*/) + { + return 'D'; // float 64 + } + + /////////////////////// + // Utility functions // + /////////////////////// + + /* + @brief write a number to output input + @param[in] n number of type @a NumberType + @tparam NumberType the type of the number + @tparam OutputIsLittleEndian Set to true if output data is + required to be little endian + + @note This function needs to respect the system's endianess, because bytes + in CBOR, MessagePack, and UBJSON are stored in network order (big + endian) and therefore need reordering on little endian systems. + */ + template + void write_number(const NumberType n) + { + // step 1: write number to array of length NumberType + std::array vec; + std::memcpy(vec.data(), &n, sizeof(NumberType)); + + // step 2: write array to output (with possible reordering) + if (is_little_endian != OutputIsLittleEndian) + { + // reverse byte order prior to conversion if necessary + std::reverse(vec.begin(), vec.end()); + } + + oa->write_characters(vec.data(), sizeof(NumberType)); + } + + void write_compact_float(const number_float_t n, detail::input_format_t format) + { + if (static_cast(n) >= static_cast(std::numeric_limits::lowest()) && + static_cast(n) <= static_cast((std::numeric_limits::max)()) && + static_cast(static_cast(n)) == static_cast(n)) + { + oa->write_character(format == detail::input_format_t::cbor + ? get_cbor_float_prefix(static_cast(n)) + : get_msgpack_float_prefix(static_cast(n))); + write_number(static_cast(n)); + } + else + { + oa->write_character(format == detail::input_format_t::cbor + ? get_cbor_float_prefix(n) + : get_msgpack_float_prefix(n)); + write_number(n); + } + } + + public: + // The following to_char_type functions are implement the conversion + // between uint8_t and CharType. In case CharType is not unsigned, + // such a conversion is required to allow values greater than 128. + // See for a discussion. + template < typename C = CharType, + enable_if_t < std::is_signed::value && std::is_signed::value > * = nullptr > + static constexpr CharType to_char_type(std::uint8_t x) noexcept + { + return *reinterpret_cast(&x); + } + + template < typename C = CharType, + enable_if_t < std::is_signed::value && std::is_unsigned::value > * = nullptr > + static CharType to_char_type(std::uint8_t x) noexcept + { + static_assert(sizeof(std::uint8_t) == sizeof(CharType), "size of CharType must be equal to std::uint8_t"); + static_assert(std::is_trivial::value, "CharType must be trivial"); + CharType result; + std::memcpy(&result, &x, sizeof(x)); + return result; + } + + template::value>* = nullptr> + static constexpr CharType to_char_type(std::uint8_t x) noexcept + { + return x; + } + + template < typename InputCharType, typename C = CharType, + enable_if_t < + std::is_signed::value && + std::is_signed::value && + std::is_same::type>::value + > * = nullptr > + static constexpr CharType to_char_type(InputCharType x) noexcept + { + return x; + } + + private: + /// whether we can assume little endianess + const bool is_little_endian = little_endianess(); + + /// the output + output_adapter_t oa = nullptr; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + + +#include // reverse, remove, fill, find, none_of +#include // array +#include // localeconv, lconv +#include // labs, isfinite, isnan, signbit +#include // size_t, ptrdiff_t +#include // uint8_t +#include // snprintf +#include // numeric_limits +#include // string, char_traits +#include // is_same +#include // move + +// #include + + +#include // array +#include // signbit, isfinite +#include // intN_t, uintN_t +#include // memcpy, memmove +#include // numeric_limits +#include // conditional + +// #include + + +namespace nlohmann +{ +namespace detail +{ + +/*! +@brief implements the Grisu2 algorithm for binary to decimal floating-point +conversion. + +This implementation is a slightly modified version of the reference +implementation which may be obtained from +http://florian.loitsch.com/publications (bench.tar.gz). + +The code is distributed under the MIT license, Copyright (c) 2009 Florian Loitsch. + +For a detailed description of the algorithm see: + +[1] Loitsch, "Printing Floating-Point Numbers Quickly and Accurately with + Integers", Proceedings of the ACM SIGPLAN 2010 Conference on Programming + Language Design and Implementation, PLDI 2010 +[2] Burger, Dybvig, "Printing Floating-Point Numbers Quickly and Accurately", + Proceedings of the ACM SIGPLAN 1996 Conference on Programming Language + Design and Implementation, PLDI 1996 +*/ +namespace dtoa_impl +{ + +template +Target reinterpret_bits(const Source source) +{ + static_assert(sizeof(Target) == sizeof(Source), "size mismatch"); + + Target target; + std::memcpy(&target, &source, sizeof(Source)); + return target; +} + +struct diyfp // f * 2^e +{ + static constexpr int kPrecision = 64; // = q + + std::uint64_t f = 0; + int e = 0; + + constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {} + + /*! + @brief returns x - y + @pre x.e == y.e and x.f >= y.f + */ + static diyfp sub(const diyfp& x, const diyfp& y) noexcept + { + JSON_ASSERT(x.e == y.e); + JSON_ASSERT(x.f >= y.f); + + return {x.f - y.f, x.e}; + } + + /*! + @brief returns x * y + @note The result is rounded. (Only the upper q bits are returned.) + */ + static diyfp mul(const diyfp& x, const diyfp& y) noexcept + { + static_assert(kPrecision == 64, "internal error"); + + // Computes: + // f = round((x.f * y.f) / 2^q) + // e = x.e + y.e + q + + // Emulate the 64-bit * 64-bit multiplication: + // + // p = u * v + // = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi) + // = (u_lo v_lo ) + 2^32 ((u_lo v_hi ) + (u_hi v_lo )) + 2^64 (u_hi v_hi ) + // = (p0 ) + 2^32 ((p1 ) + (p2 )) + 2^64 (p3 ) + // = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3 ) + // = (p0_lo ) + 2^32 (p0_hi + p1_lo + p2_lo ) + 2^64 (p1_hi + p2_hi + p3) + // = (p0_lo ) + 2^32 (Q ) + 2^64 (H ) + // = (p0_lo ) + 2^32 (Q_lo + 2^32 Q_hi ) + 2^64 (H ) + // + // (Since Q might be larger than 2^32 - 1) + // + // = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H) + // + // (Q_hi + H does not overflow a 64-bit int) + // + // = p_lo + 2^64 p_hi + + const std::uint64_t u_lo = x.f & 0xFFFFFFFFu; + const std::uint64_t u_hi = x.f >> 32u; + const std::uint64_t v_lo = y.f & 0xFFFFFFFFu; + const std::uint64_t v_hi = y.f >> 32u; + + const std::uint64_t p0 = u_lo * v_lo; + const std::uint64_t p1 = u_lo * v_hi; + const std::uint64_t p2 = u_hi * v_lo; + const std::uint64_t p3 = u_hi * v_hi; + + const std::uint64_t p0_hi = p0 >> 32u; + const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu; + const std::uint64_t p1_hi = p1 >> 32u; + const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu; + const std::uint64_t p2_hi = p2 >> 32u; + + std::uint64_t Q = p0_hi + p1_lo + p2_lo; + + // The full product might now be computed as + // + // p_hi = p3 + p2_hi + p1_hi + (Q >> 32) + // p_lo = p0_lo + (Q << 32) + // + // But in this particular case here, the full p_lo is not required. + // Effectively we only need to add the highest bit in p_lo to p_hi (and + // Q_hi + 1 does not overflow). + + Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up + + const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u); + + return {h, x.e + y.e + 64}; + } + + /*! + @brief normalize x such that the significand is >= 2^(q-1) + @pre x.f != 0 + */ + static diyfp normalize(diyfp x) noexcept + { + JSON_ASSERT(x.f != 0); + + while ((x.f >> 63u) == 0) + { + x.f <<= 1u; + x.e--; + } + + return x; + } + + /*! + @brief normalize x such that the result has the exponent E + @pre e >= x.e and the upper e - x.e bits of x.f must be zero. + */ + static diyfp normalize_to(const diyfp& x, const int target_exponent) noexcept + { + const int delta = x.e - target_exponent; + + JSON_ASSERT(delta >= 0); + JSON_ASSERT(((x.f << delta) >> delta) == x.f); + + return {x.f << delta, target_exponent}; + } +}; + +struct boundaries +{ + diyfp w; + diyfp minus; + diyfp plus; +}; + +/*! +Compute the (normalized) diyfp representing the input number 'value' and its +boundaries. + +@pre value must be finite and positive +*/ +template +boundaries compute_boundaries(FloatType value) +{ + JSON_ASSERT(std::isfinite(value)); + JSON_ASSERT(value > 0); + + // Convert the IEEE representation into a diyfp. + // + // If v is denormal: + // value = 0.F * 2^(1 - bias) = ( F) * 2^(1 - bias - (p-1)) + // If v is normalized: + // value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1)) + + static_assert(std::numeric_limits::is_iec559, + "internal error: dtoa_short requires an IEEE-754 floating-point implementation"); + + constexpr int kPrecision = std::numeric_limits::digits; // = p (includes the hidden bit) + constexpr int kBias = std::numeric_limits::max_exponent - 1 + (kPrecision - 1); + constexpr int kMinExp = 1 - kBias; + constexpr std::uint64_t kHiddenBit = std::uint64_t{1} << (kPrecision - 1); // = 2^(p-1) + + using bits_type = typename std::conditional::type; + + const std::uint64_t bits = reinterpret_bits(value); + const std::uint64_t E = bits >> (kPrecision - 1); + const std::uint64_t F = bits & (kHiddenBit - 1); + + const bool is_denormal = E == 0; + const diyfp v = is_denormal + ? diyfp(F, kMinExp) + : diyfp(F + kHiddenBit, static_cast(E) - kBias); + + // Compute the boundaries m- and m+ of the floating-point value + // v = f * 2^e. + // + // Determine v- and v+, the floating-point predecessor and successor if v, + // respectively. + // + // v- = v - 2^e if f != 2^(p-1) or e == e_min (A) + // = v - 2^(e-1) if f == 2^(p-1) and e > e_min (B) + // + // v+ = v + 2^e + // + // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_ + // between m- and m+ round to v, regardless of how the input rounding + // algorithm breaks ties. + // + // ---+-------------+-------------+-------------+-------------+--- (A) + // v- m- v m+ v+ + // + // -----------------+------+------+-------------+-------------+--- (B) + // v- m- v m+ v+ + + const bool lower_boundary_is_closer = F == 0 && E > 1; + const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1); + const diyfp m_minus = lower_boundary_is_closer + ? diyfp(4 * v.f - 1, v.e - 2) // (B) + : diyfp(2 * v.f - 1, v.e - 1); // (A) + + // Determine the normalized w+ = m+. + const diyfp w_plus = diyfp::normalize(m_plus); + + // Determine w- = m- such that e_(w-) = e_(w+). + const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e); + + return {diyfp::normalize(v), w_minus, w_plus}; +} + +// Given normalized diyfp w, Grisu needs to find a (normalized) cached +// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies +// within a certain range [alpha, gamma] (Definition 3.2 from [1]) +// +// alpha <= e = e_c + e_w + q <= gamma +// +// or +// +// f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q +// <= f_c * f_w * 2^gamma +// +// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies +// +// 2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma +// +// or +// +// 2^(q - 2 + alpha) <= c * w < 2^(q + gamma) +// +// The choice of (alpha,gamma) determines the size of the table and the form of +// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well +// in practice: +// +// The idea is to cut the number c * w = f * 2^e into two parts, which can be +// processed independently: An integral part p1, and a fractional part p2: +// +// f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e +// = (f div 2^-e) + (f mod 2^-e) * 2^e +// = p1 + p2 * 2^e +// +// The conversion of p1 into decimal form requires a series of divisions and +// modulos by (a power of) 10. These operations are faster for 32-bit than for +// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be +// achieved by choosing +// +// -e >= 32 or e <= -32 := gamma +// +// In order to convert the fractional part +// +// p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ... +// +// into decimal form, the fraction is repeatedly multiplied by 10 and the digits +// d[-i] are extracted in order: +// +// (10 * p2) div 2^-e = d[-1] +// (10 * p2) mod 2^-e = d[-2] / 10^1 + ... +// +// The multiplication by 10 must not overflow. It is sufficient to choose +// +// 10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64. +// +// Since p2 = f mod 2^-e < 2^-e, +// +// -e <= 60 or e >= -60 := alpha + +constexpr int kAlpha = -60; +constexpr int kGamma = -32; + +struct cached_power // c = f * 2^e ~= 10^k +{ + std::uint64_t f; + int e; + int k; +}; + +/*! +For a normalized diyfp w = f * 2^e, this function returns a (normalized) cached +power-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c +satisfies (Definition 3.2 from [1]) + + alpha <= e_c + e + q <= gamma. +*/ +inline cached_power get_cached_power_for_binary_exponent(int e) +{ + // Now + // + // alpha <= e_c + e + q <= gamma (1) + // ==> f_c * 2^alpha <= c * 2^e * 2^q + // + // and since the c's are normalized, 2^(q-1) <= f_c, + // + // ==> 2^(q - 1 + alpha) <= c * 2^(e + q) + // ==> 2^(alpha - e - 1) <= c + // + // If c were an exact power of ten, i.e. c = 10^k, one may determine k as + // + // k = ceil( log_10( 2^(alpha - e - 1) ) ) + // = ceil( (alpha - e - 1) * log_10(2) ) + // + // From the paper: + // "In theory the result of the procedure could be wrong since c is rounded, + // and the computation itself is approximated [...]. In practice, however, + // this simple function is sufficient." + // + // For IEEE double precision floating-point numbers converted into + // normalized diyfp's w = f * 2^e, with q = 64, + // + // e >= -1022 (min IEEE exponent) + // -52 (p - 1) + // -52 (p - 1, possibly normalize denormal IEEE numbers) + // -11 (normalize the diyfp) + // = -1137 + // + // and + // + // e <= +1023 (max IEEE exponent) + // -52 (p - 1) + // -11 (normalize the diyfp) + // = 960 + // + // This binary exponent range [-1137,960] results in a decimal exponent + // range [-307,324]. One does not need to store a cached power for each + // k in this range. For each such k it suffices to find a cached power + // such that the exponent of the product lies in [alpha,gamma]. + // This implies that the difference of the decimal exponents of adjacent + // table entries must be less than or equal to + // + // floor( (gamma - alpha) * log_10(2) ) = 8. + // + // (A smaller distance gamma-alpha would require a larger table.) + + // NB: + // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34. + + constexpr int kCachedPowersMinDecExp = -300; + constexpr int kCachedPowersDecStep = 8; + + static constexpr std::array kCachedPowers = + { + { + { 0xAB70FE17C79AC6CA, -1060, -300 }, + { 0xFF77B1FCBEBCDC4F, -1034, -292 }, + { 0xBE5691EF416BD60C, -1007, -284 }, + { 0x8DD01FAD907FFC3C, -980, -276 }, + { 0xD3515C2831559A83, -954, -268 }, + { 0x9D71AC8FADA6C9B5, -927, -260 }, + { 0xEA9C227723EE8BCB, -901, -252 }, + { 0xAECC49914078536D, -874, -244 }, + { 0x823C12795DB6CE57, -847, -236 }, + { 0xC21094364DFB5637, -821, -228 }, + { 0x9096EA6F3848984F, -794, -220 }, + { 0xD77485CB25823AC7, -768, -212 }, + { 0xA086CFCD97BF97F4, -741, -204 }, + { 0xEF340A98172AACE5, -715, -196 }, + { 0xB23867FB2A35B28E, -688, -188 }, + { 0x84C8D4DFD2C63F3B, -661, -180 }, + { 0xC5DD44271AD3CDBA, -635, -172 }, + { 0x936B9FCEBB25C996, -608, -164 }, + { 0xDBAC6C247D62A584, -582, -156 }, + { 0xA3AB66580D5FDAF6, -555, -148 }, + { 0xF3E2F893DEC3F126, -529, -140 }, + { 0xB5B5ADA8AAFF80B8, -502, -132 }, + { 0x87625F056C7C4A8B, -475, -124 }, + { 0xC9BCFF6034C13053, -449, -116 }, + { 0x964E858C91BA2655, -422, -108 }, + { 0xDFF9772470297EBD, -396, -100 }, + { 0xA6DFBD9FB8E5B88F, -369, -92 }, + { 0xF8A95FCF88747D94, -343, -84 }, + { 0xB94470938FA89BCF, -316, -76 }, + { 0x8A08F0F8BF0F156B, -289, -68 }, + { 0xCDB02555653131B6, -263, -60 }, + { 0x993FE2C6D07B7FAC, -236, -52 }, + { 0xE45C10C42A2B3B06, -210, -44 }, + { 0xAA242499697392D3, -183, -36 }, + { 0xFD87B5F28300CA0E, -157, -28 }, + { 0xBCE5086492111AEB, -130, -20 }, + { 0x8CBCCC096F5088CC, -103, -12 }, + { 0xD1B71758E219652C, -77, -4 }, + { 0x9C40000000000000, -50, 4 }, + { 0xE8D4A51000000000, -24, 12 }, + { 0xAD78EBC5AC620000, 3, 20 }, + { 0x813F3978F8940984, 30, 28 }, + { 0xC097CE7BC90715B3, 56, 36 }, + { 0x8F7E32CE7BEA5C70, 83, 44 }, + { 0xD5D238A4ABE98068, 109, 52 }, + { 0x9F4F2726179A2245, 136, 60 }, + { 0xED63A231D4C4FB27, 162, 68 }, + { 0xB0DE65388CC8ADA8, 189, 76 }, + { 0x83C7088E1AAB65DB, 216, 84 }, + { 0xC45D1DF942711D9A, 242, 92 }, + { 0x924D692CA61BE758, 269, 100 }, + { 0xDA01EE641A708DEA, 295, 108 }, + { 0xA26DA3999AEF774A, 322, 116 }, + { 0xF209787BB47D6B85, 348, 124 }, + { 0xB454E4A179DD1877, 375, 132 }, + { 0x865B86925B9BC5C2, 402, 140 }, + { 0xC83553C5C8965D3D, 428, 148 }, + { 0x952AB45CFA97A0B3, 455, 156 }, + { 0xDE469FBD99A05FE3, 481, 164 }, + { 0xA59BC234DB398C25, 508, 172 }, + { 0xF6C69A72A3989F5C, 534, 180 }, + { 0xB7DCBF5354E9BECE, 561, 188 }, + { 0x88FCF317F22241E2, 588, 196 }, + { 0xCC20CE9BD35C78A5, 614, 204 }, + { 0x98165AF37B2153DF, 641, 212 }, + { 0xE2A0B5DC971F303A, 667, 220 }, + { 0xA8D9D1535CE3B396, 694, 228 }, + { 0xFB9B7CD9A4A7443C, 720, 236 }, + { 0xBB764C4CA7A44410, 747, 244 }, + { 0x8BAB8EEFB6409C1A, 774, 252 }, + { 0xD01FEF10A657842C, 800, 260 }, + { 0x9B10A4E5E9913129, 827, 268 }, + { 0xE7109BFBA19C0C9D, 853, 276 }, + { 0xAC2820D9623BF429, 880, 284 }, + { 0x80444B5E7AA7CF85, 907, 292 }, + { 0xBF21E44003ACDD2D, 933, 300 }, + { 0x8E679C2F5E44FF8F, 960, 308 }, + { 0xD433179D9C8CB841, 986, 316 }, + { 0x9E19DB92B4E31BA9, 1013, 324 }, + } + }; + + // This computation gives exactly the same results for k as + // k = ceil((kAlpha - e - 1) * 0.30102999566398114) + // for |e| <= 1500, but doesn't require floating-point operations. + // NB: log_10(2) ~= 78913 / 2^18 + JSON_ASSERT(e >= -1500); + JSON_ASSERT(e <= 1500); + const int f = kAlpha - e - 1; + const int k = (f * 78913) / (1 << 18) + static_cast(f > 0); + + const int index = (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / kCachedPowersDecStep; + JSON_ASSERT(index >= 0); + JSON_ASSERT(static_cast(index) < kCachedPowers.size()); + + const cached_power cached = kCachedPowers[static_cast(index)]; + JSON_ASSERT(kAlpha <= cached.e + e + 64); + JSON_ASSERT(kGamma >= cached.e + e + 64); + + return cached; +} + +/*! +For n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k. +For n == 0, returns 1 and sets pow10 := 1. +*/ +inline int find_largest_pow10(const std::uint32_t n, std::uint32_t& pow10) +{ + // LCOV_EXCL_START + if (n >= 1000000000) + { + pow10 = 1000000000; + return 10; + } + // LCOV_EXCL_STOP + else if (n >= 100000000) + { + pow10 = 100000000; + return 9; + } + else if (n >= 10000000) + { + pow10 = 10000000; + return 8; + } + else if (n >= 1000000) + { + pow10 = 1000000; + return 7; + } + else if (n >= 100000) + { + pow10 = 100000; + return 6; + } + else if (n >= 10000) + { + pow10 = 10000; + return 5; + } + else if (n >= 1000) + { + pow10 = 1000; + return 4; + } + else if (n >= 100) + { + pow10 = 100; + return 3; + } + else if (n >= 10) + { + pow10 = 10; + return 2; + } + else + { + pow10 = 1; + return 1; + } +} + +inline void grisu2_round(char* buf, int len, std::uint64_t dist, std::uint64_t delta, + std::uint64_t rest, std::uint64_t ten_k) +{ + JSON_ASSERT(len >= 1); + JSON_ASSERT(dist <= delta); + JSON_ASSERT(rest <= delta); + JSON_ASSERT(ten_k > 0); + + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // ten_k + // <------> + // <---- rest ----> + // --------------[------------------+----+--------------]-------------- + // w V + // = buf * 10^k + // + // ten_k represents a unit-in-the-last-place in the decimal representation + // stored in buf. + // Decrement buf by ten_k while this takes buf closer to w. + + // The tests are written in this order to avoid overflow in unsigned + // integer arithmetic. + + while (rest < dist + && delta - rest >= ten_k + && (rest + ten_k < dist || dist - rest > rest + ten_k - dist)) + { + JSON_ASSERT(buf[len - 1] != '0'); + buf[len - 1]--; + rest += ten_k; + } +} + +/*! +Generates V = buffer * 10^decimal_exponent, such that M- <= V <= M+. +M- and M+ must be normalized and share the same exponent -60 <= e <= -32. +*/ +inline void grisu2_digit_gen(char* buffer, int& length, int& decimal_exponent, + diyfp M_minus, diyfp w, diyfp M_plus) +{ + static_assert(kAlpha >= -60, "internal error"); + static_assert(kGamma <= -32, "internal error"); + + // Generates the digits (and the exponent) of a decimal floating-point + // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The diyfp's + // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= gamma. + // + // <--------------------------- delta ----> + // <---- dist ---------> + // --------------[------------------+-------------------]-------------- + // M- w M+ + // + // Grisu2 generates the digits of M+ from left to right and stops as soon as + // V is in [M-,M+]. + + JSON_ASSERT(M_plus.e >= kAlpha); + JSON_ASSERT(M_plus.e <= kGamma); + + std::uint64_t delta = diyfp::sub(M_plus, M_minus).f; // (significand of (M+ - M-), implicit exponent is e) + std::uint64_t dist = diyfp::sub(M_plus, w ).f; // (significand of (M+ - w ), implicit exponent is e) + + // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0): + // + // M+ = f * 2^e + // = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e + // = ((p1 ) * 2^-e + (p2 )) * 2^e + // = p1 + p2 * 2^e + + const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e); + + auto p1 = static_cast(M_plus.f >> -one.e); // p1 = f div 2^-e (Since -e >= 32, p1 fits into a 32-bit int.) + std::uint64_t p2 = M_plus.f & (one.f - 1); // p2 = f mod 2^-e + + // 1) + // + // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0] + + JSON_ASSERT(p1 > 0); + + std::uint32_t pow10; + const int k = find_largest_pow10(p1, pow10); + + // 10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1) + // + // p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1)) + // = (d[k-1] ) * 10^(k-1) + (p1 mod 10^(k-1)) + // + // M+ = p1 + p2 * 2^e + // = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1)) + p2 * 2^e + // = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e + // = d[k-1] * 10^(k-1) + ( rest) * 2^e + // + // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0) + // + // p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0] + // + // but stop as soon as + // + // rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e + + int n = k; + while (n > 0) + { + // Invariants: + // M+ = buffer * 10^n + (p1 + p2 * 2^e) (buffer = 0 for n = k) + // pow10 = 10^(n-1) <= p1 < 10^n + // + const std::uint32_t d = p1 / pow10; // d = p1 div 10^(n-1) + const std::uint32_t r = p1 % pow10; // r = p1 mod 10^(n-1) + // + // M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e + // = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e) + // + JSON_ASSERT(d <= 9); + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(n-1) + (r + p2 * 2^e) + // + p1 = r; + n--; + // + // M+ = buffer * 10^n + (p1 + p2 * 2^e) + // pow10 = 10^n + // + + // Now check if enough digits have been generated. + // Compute + // + // p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e + // + // Note: + // Since rest and delta share the same exponent e, it suffices to + // compare the significands. + const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2; + if (rest <= delta) + { + // V = buffer * 10^n, with M- <= V <= M+. + + decimal_exponent += n; + + // We may now just stop. But instead look if the buffer could be + // decremented to bring V closer to w. + // + // pow10 = 10^n is now 1 ulp in the decimal representation V. + // The rounding procedure works with diyfp's with an implicit + // exponent of e. + // + // 10^n = (10^n * 2^-e) * 2^e = ulp * 2^e + // + const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e; + grisu2_round(buffer, length, dist, delta, rest, ten_n); + + return; + } + + pow10 /= 10; + // + // pow10 = 10^(n-1) <= p1 < 10^n + // Invariants restored. + } + + // 2) + // + // The digits of the integral part have been generated: + // + // M+ = d[k-1]...d[1]d[0] + p2 * 2^e + // = buffer + p2 * 2^e + // + // Now generate the digits of the fractional part p2 * 2^e. + // + // Note: + // No decimal point is generated: the exponent is adjusted instead. + // + // p2 actually represents the fraction + // + // p2 * 2^e + // = p2 / 2^-e + // = d[-1] / 10^1 + d[-2] / 10^2 + ... + // + // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...) + // + // p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m + // + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...) + // + // using + // + // 10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e) + // = ( d) * 2^-e + ( r) + // + // or + // 10^m * p2 * 2^e = d + r * 2^e + // + // i.e. + // + // M+ = buffer + p2 * 2^e + // = buffer + 10^-m * (d + r * 2^e) + // = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e + // + // and stop as soon as 10^-m * r * 2^e <= delta * 2^e + + JSON_ASSERT(p2 > delta); + + int m = 0; + for (;;) + { + // Invariant: + // M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + ...) * 2^e + // = buffer * 10^-m + 10^-m * (p2 ) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * (10 * p2) ) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + (10*p2 mod 2^-e)) * 2^e + // + JSON_ASSERT(p2 <= (std::numeric_limits::max)() / 10); + p2 *= 10; + const std::uint64_t d = p2 >> -one.e; // d = (10 * p2) div 2^-e + const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e + // + // M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e + // = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e)) + // = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + JSON_ASSERT(d <= 9); + buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d + // + // M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e + // + p2 = r; + m++; + // + // M+ = buffer * 10^-m + 10^-m * p2 * 2^e + // Invariant restored. + + // Check if enough digits have been generated. + // + // 10^-m * p2 * 2^e <= delta * 2^e + // p2 * 2^e <= 10^m * delta * 2^e + // p2 <= 10^m * delta + delta *= 10; + dist *= 10; + if (p2 <= delta) + { + break; + } + } + + // V = buffer * 10^-m, with M- <= V <= M+. + + decimal_exponent -= m; + + // 1 ulp in the decimal representation is now 10^-m. + // Since delta and dist are now scaled by 10^m, we need to do the + // same with ulp in order to keep the units in sync. + // + // 10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e + // + const std::uint64_t ten_m = one.f; + grisu2_round(buffer, length, dist, delta, p2, ten_m); + + // By construction this algorithm generates the shortest possible decimal + // number (Loitsch, Theorem 6.2) which rounds back to w. + // For an input number of precision p, at least + // + // N = 1 + ceil(p * log_10(2)) + // + // decimal digits are sufficient to identify all binary floating-point + // numbers (Matula, "In-and-Out conversions"). + // This implies that the algorithm does not produce more than N decimal + // digits. + // + // N = 17 for p = 53 (IEEE double precision) + // N = 9 for p = 24 (IEEE single precision) +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +JSON_HEDLEY_NON_NULL(1) +inline void grisu2(char* buf, int& len, int& decimal_exponent, + diyfp m_minus, diyfp v, diyfp m_plus) +{ + JSON_ASSERT(m_plus.e == m_minus.e); + JSON_ASSERT(m_plus.e == v.e); + + // --------(-----------------------+-----------------------)-------- (A) + // m- v m+ + // + // --------------------(-----------+-----------------------)-------- (B) + // m- v m+ + // + // First scale v (and m- and m+) such that the exponent is in the range + // [alpha, gamma]. + + const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e); + + const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k + + // The exponent of the products is = v.e + c_minus_k.e + q and is in the range [alpha,gamma] + const diyfp w = diyfp::mul(v, c_minus_k); + const diyfp w_minus = diyfp::mul(m_minus, c_minus_k); + const diyfp w_plus = diyfp::mul(m_plus, c_minus_k); + + // ----(---+---)---------------(---+---)---------------(---+---)---- + // w- w w+ + // = c*m- = c*v = c*m+ + // + // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and + // w+ are now off by a small amount. + // In fact: + // + // w - v * 10^k < 1 ulp + // + // To account for this inaccuracy, add resp. subtract 1 ulp. + // + // --------+---[---------------(---+---)---------------]---+-------- + // w- M- w M+ w+ + // + // Now any number in [M-, M+] (bounds included) will round to w when input, + // regardless of how the input rounding algorithm breaks ties. + // + // And digit_gen generates the shortest possible such number in [M-, M+]. + // Note that this does not mean that Grisu2 always generates the shortest + // possible number in the interval (m-, m+). + const diyfp M_minus(w_minus.f + 1, w_minus.e); + const diyfp M_plus (w_plus.f - 1, w_plus.e ); + + decimal_exponent = -cached.k; // = -(-k) = k + + grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus); +} + +/*! +v = buf * 10^decimal_exponent +len is the length of the buffer (number of decimal digits) +The buffer must be large enough, i.e. >= max_digits10. +*/ +template +JSON_HEDLEY_NON_NULL(1) +void grisu2(char* buf, int& len, int& decimal_exponent, FloatType value) +{ + static_assert(diyfp::kPrecision >= std::numeric_limits::digits + 3, + "internal error: not enough precision"); + + JSON_ASSERT(std::isfinite(value)); + JSON_ASSERT(value > 0); + + // If the neighbors (and boundaries) of 'value' are always computed for double-precision + // numbers, all float's can be recovered using strtod (and strtof). However, the resulting + // decimal representations are not exactly "short". + // + // The documentation for 'std::to_chars' (https://en.cppreference.com/w/cpp/utility/to_chars) + // says "value is converted to a string as if by std::sprintf in the default ("C") locale" + // and since sprintf promotes float's to double's, I think this is exactly what 'std::to_chars' + // does. + // On the other hand, the documentation for 'std::to_chars' requires that "parsing the + // representation using the corresponding std::from_chars function recovers value exactly". That + // indicates that single precision floating-point numbers should be recovered using + // 'std::strtof'. + // + // NB: If the neighbors are computed for single-precision numbers, there is a single float + // (7.0385307e-26f) which can't be recovered using strtod. The resulting double precision + // value is off by 1 ulp. +#if 0 + const boundaries w = compute_boundaries(static_cast(value)); +#else + const boundaries w = compute_boundaries(value); +#endif + + grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus); +} + +/*! +@brief appends a decimal representation of e to buf +@return a pointer to the element following the exponent. +@pre -1000 < e < 1000 +*/ +JSON_HEDLEY_NON_NULL(1) +JSON_HEDLEY_RETURNS_NON_NULL +inline char* append_exponent(char* buf, int e) +{ + JSON_ASSERT(e > -1000); + JSON_ASSERT(e < 1000); + + if (e < 0) + { + e = -e; + *buf++ = '-'; + } + else + { + *buf++ = '+'; + } + + auto k = static_cast(e); + if (k < 10) + { + // Always print at least two digits in the exponent. + // This is for compatibility with printf("%g"). + *buf++ = '0'; + *buf++ = static_cast('0' + k); + } + else if (k < 100) + { + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } + else + { + *buf++ = static_cast('0' + k / 100); + k %= 100; + *buf++ = static_cast('0' + k / 10); + k %= 10; + *buf++ = static_cast('0' + k); + } + + return buf; +} + +/*! +@brief prettify v = buf * 10^decimal_exponent + +If v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point +notation. Otherwise it will be printed in exponential notation. + +@pre min_exp < 0 +@pre max_exp > 0 +*/ +JSON_HEDLEY_NON_NULL(1) +JSON_HEDLEY_RETURNS_NON_NULL +inline char* format_buffer(char* buf, int len, int decimal_exponent, + int min_exp, int max_exp) +{ + JSON_ASSERT(min_exp < 0); + JSON_ASSERT(max_exp > 0); + + const int k = len; + const int n = len + decimal_exponent; + + // v = buf * 10^(n-k) + // k is the length of the buffer (number of decimal digits) + // n is the position of the decimal point relative to the start of the buffer. + + if (k <= n && n <= max_exp) + { + // digits[000] + // len <= max_exp + 2 + + std::memset(buf + k, '0', static_cast(n) - static_cast(k)); + // Make it look like a floating-point number (#362, #378) + buf[n + 0] = '.'; + buf[n + 1] = '0'; + return buf + (static_cast(n) + 2); + } + + if (0 < n && n <= max_exp) + { + // dig.its + // len <= max_digits10 + 1 + + JSON_ASSERT(k > n); + + std::memmove(buf + (static_cast(n) + 1), buf + n, static_cast(k) - static_cast(n)); + buf[n] = '.'; + return buf + (static_cast(k) + 1U); + } + + if (min_exp < n && n <= 0) + { + // 0.[000]digits + // len <= 2 + (-min_exp - 1) + max_digits10 + + std::memmove(buf + (2 + static_cast(-n)), buf, static_cast(k)); + buf[0] = '0'; + buf[1] = '.'; + std::memset(buf + 2, '0', static_cast(-n)); + return buf + (2U + static_cast(-n) + static_cast(k)); + } + + if (k == 1) + { + // dE+123 + // len <= 1 + 5 + + buf += 1; + } + else + { + // d.igitsE+123 + // len <= max_digits10 + 1 + 5 + + std::memmove(buf + 2, buf + 1, static_cast(k) - 1); + buf[1] = '.'; + buf += 1 + static_cast(k); + } + + *buf++ = 'e'; + return append_exponent(buf, n - 1); +} + +} // namespace dtoa_impl + +/*! +@brief generates a decimal representation of the floating-point number value in [first, last). + +The format of the resulting decimal representation is similar to printf's %g +format. Returns an iterator pointing past-the-end of the decimal representation. + +@note The input number must be finite, i.e. NaN's and Inf's are not supported. +@note The buffer must be large enough. +@note The result is NOT null-terminated. +*/ +template +JSON_HEDLEY_NON_NULL(1, 2) +JSON_HEDLEY_RETURNS_NON_NULL +char* to_chars(char* first, const char* last, FloatType value) +{ + static_cast(last); // maybe unused - fix warning + JSON_ASSERT(std::isfinite(value)); + + // Use signbit(value) instead of (value < 0) since signbit works for -0. + if (std::signbit(value)) + { + value = -value; + *first++ = '-'; + } + + if (value == 0) // +-0 + { + *first++ = '0'; + // Make it look like a floating-point number (#362, #378) + *first++ = '.'; + *first++ = '0'; + return first; + } + + JSON_ASSERT(last - first >= std::numeric_limits::max_digits10); + + // Compute v = buffer * 10^decimal_exponent. + // The decimal digits are stored in the buffer, which needs to be interpreted + // as an unsigned decimal integer. + // len is the length of the buffer, i.e. the number of decimal digits. + int len = 0; + int decimal_exponent = 0; + dtoa_impl::grisu2(first, len, decimal_exponent, value); + + JSON_ASSERT(len <= std::numeric_limits::max_digits10); + + // Format the buffer like printf("%.*g", prec, value) + constexpr int kMinExp = -4; + // Use digits10 here to increase compatibility with version 2. + constexpr int kMaxExp = std::numeric_limits::digits10; + + JSON_ASSERT(last - first >= kMaxExp + 2); + JSON_ASSERT(last - first >= 2 + (-kMinExp - 1) + std::numeric_limits::max_digits10); + JSON_ASSERT(last - first >= std::numeric_limits::max_digits10 + 6); + + return dtoa_impl::format_buffer(first, len, decimal_exponent, kMinExp, kMaxExp); +} + +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + +// #include + +// #include + +// #include + + +namespace nlohmann +{ +namespace detail +{ +/////////////////// +// serialization // +/////////////////// + +/// how to treat decoding errors +enum class error_handler_t +{ + strict, ///< throw a type_error exception in case of invalid UTF-8 + replace, ///< replace invalid UTF-8 sequences with U+FFFD + ignore ///< ignore invalid UTF-8 sequences +}; + +template +class serializer +{ + using string_t = typename BasicJsonType::string_t; + using number_float_t = typename BasicJsonType::number_float_t; + using number_integer_t = typename BasicJsonType::number_integer_t; + using number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using binary_char_t = typename BasicJsonType::binary_t::value_type; + static constexpr std::uint8_t UTF8_ACCEPT = 0; + static constexpr std::uint8_t UTF8_REJECT = 1; + + public: + /*! + @param[in] s output stream to serialize to + @param[in] ichar indentation character to use + @param[in] error_handler_ how to react on decoding errors + */ + serializer(output_adapter_t s, const char ichar, + error_handler_t error_handler_ = error_handler_t::strict) + : o(std::move(s)) + , loc(std::localeconv()) + , thousands_sep(loc->thousands_sep == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->thousands_sep))) + , decimal_point(loc->decimal_point == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->decimal_point))) + , indent_char(ichar) + , indent_string(512, indent_char) + , error_handler(error_handler_) + {} + + // delete because of pointer members + serializer(const serializer&) = delete; + serializer& operator=(const serializer&) = delete; + serializer(serializer&&) = delete; + serializer& operator=(serializer&&) = delete; + ~serializer() = default; + + /*! + @brief internal implementation of the serialization function + + This function is called by the public member function dump and organizes + the serialization internally. The indentation level is propagated as + additional parameter. In case of arrays and objects, the function is + called recursively. + + - strings and object keys are escaped using `escape_string()` + - integer numbers are converted implicitly via `operator<<` + - floating-point numbers are converted to a string using `"%g"` format + - binary values are serialized as objects containing the subtype and the + byte array + + @param[in] val value to serialize + @param[in] pretty_print whether the output shall be pretty-printed + @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters + in the output are escaped with `\uXXXX` sequences, and the result consists + of ASCII characters only. + @param[in] indent_step the indent level + @param[in] current_indent the current indent level (only used internally) + */ + void dump(const BasicJsonType& val, + const bool pretty_print, + const bool ensure_ascii, + const unsigned int indent_step, + const unsigned int current_indent = 0) + { + switch (val.m_type) + { + case value_t::object: + { + if (val.m_value.object->empty()) + { + o->write_characters("{}", 2); + return; + } + + if (pretty_print) + { + o->write_characters("{\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + JSON_ASSERT(i != val.m_value.object->cend()); + JSON_ASSERT(std::next(i) == val.m_value.object->cend()); + o->write_characters(indent_string.c_str(), new_indent); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\": ", 3); + dump(i->second, true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character('}'); + } + else + { + o->write_character('{'); + + // first n-1 elements + auto i = val.m_value.object->cbegin(); + for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) + { + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + JSON_ASSERT(i != val.m_value.object->cend()); + JSON_ASSERT(std::next(i) == val.m_value.object->cend()); + o->write_character('\"'); + dump_escaped(i->first, ensure_ascii); + o->write_characters("\":", 2); + dump(i->second, false, ensure_ascii, indent_step, current_indent); + + o->write_character('}'); + } + + return; + } + + case value_t::array: + { + if (val.m_value.array->empty()) + { + o->write_characters("[]", 2); + return; + } + + if (pretty_print) + { + o->write_characters("[\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + o->write_characters(indent_string.c_str(), new_indent); + dump(*i, true, ensure_ascii, indent_step, new_indent); + o->write_characters(",\n", 2); + } + + // last element + JSON_ASSERT(!val.m_value.array->empty()); + o->write_characters(indent_string.c_str(), new_indent); + dump(val.m_value.array->back(), true, ensure_ascii, indent_step, new_indent); + + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character(']'); + } + else + { + o->write_character('['); + + // first n-1 elements + for (auto i = val.m_value.array->cbegin(); + i != val.m_value.array->cend() - 1; ++i) + { + dump(*i, false, ensure_ascii, indent_step, current_indent); + o->write_character(','); + } + + // last element + JSON_ASSERT(!val.m_value.array->empty()); + dump(val.m_value.array->back(), false, ensure_ascii, indent_step, current_indent); + + o->write_character(']'); + } + + return; + } + + case value_t::string: + { + o->write_character('\"'); + dump_escaped(*val.m_value.string, ensure_ascii); + o->write_character('\"'); + return; + } + + case value_t::binary: + { + if (pretty_print) + { + o->write_characters("{\n", 2); + + // variable to hold indentation for recursive calls + const auto new_indent = current_indent + indent_step; + if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) + { + indent_string.resize(indent_string.size() * 2, ' '); + } + + o->write_characters(indent_string.c_str(), new_indent); + + o->write_characters("\"bytes\": [", 10); + + if (!val.m_value.binary->empty()) + { + for (auto i = val.m_value.binary->cbegin(); + i != val.m_value.binary->cend() - 1; ++i) + { + dump_integer(*i); + o->write_characters(", ", 2); + } + dump_integer(val.m_value.binary->back()); + } + + o->write_characters("],\n", 3); + o->write_characters(indent_string.c_str(), new_indent); + + o->write_characters("\"subtype\": ", 11); + if (val.m_value.binary->has_subtype()) + { + dump_integer(val.m_value.binary->subtype()); + } + else + { + o->write_characters("null", 4); + } + o->write_character('\n'); + o->write_characters(indent_string.c_str(), current_indent); + o->write_character('}'); + } + else + { + o->write_characters("{\"bytes\":[", 10); + + if (!val.m_value.binary->empty()) + { + for (auto i = val.m_value.binary->cbegin(); + i != val.m_value.binary->cend() - 1; ++i) + { + dump_integer(*i); + o->write_character(','); + } + dump_integer(val.m_value.binary->back()); + } + + o->write_characters("],\"subtype\":", 12); + if (val.m_value.binary->has_subtype()) + { + dump_integer(val.m_value.binary->subtype()); + o->write_character('}'); + } + else + { + o->write_characters("null}", 5); + } + } + return; + } + + case value_t::boolean: + { + if (val.m_value.boolean) + { + o->write_characters("true", 4); + } + else + { + o->write_characters("false", 5); + } + return; + } + + case value_t::number_integer: + { + dump_integer(val.m_value.number_integer); + return; + } + + case value_t::number_unsigned: + { + dump_integer(val.m_value.number_unsigned); + return; + } + + case value_t::number_float: + { + dump_float(val.m_value.number_float); + return; + } + + case value_t::discarded: + { + o->write_characters("", 11); + return; + } + + case value_t::null: + { + o->write_characters("null", 4); + return; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + } + + private: + /*! + @brief dump escaped string + + Escape a string by replacing certain special characters by a sequence of an + escape character (backslash) and another character and other control + characters by a sequence of "\u" followed by a four-digit hex + representation. The escaped string is written to output stream @a o. + + @param[in] s the string to escape + @param[in] ensure_ascii whether to escape non-ASCII characters with + \uXXXX sequences + + @complexity Linear in the length of string @a s. + */ + void dump_escaped(const string_t& s, const bool ensure_ascii) + { + std::uint32_t codepoint; + std::uint8_t state = UTF8_ACCEPT; + std::size_t bytes = 0; // number of bytes written to string_buffer + + // number of bytes written at the point of the last valid byte + std::size_t bytes_after_last_accept = 0; + std::size_t undumped_chars = 0; + + for (std::size_t i = 0; i < s.size(); ++i) + { + const auto byte = static_cast(s[i]); + + switch (decode(state, codepoint, byte)) + { + case UTF8_ACCEPT: // decode found a new code point + { + switch (codepoint) + { + case 0x08: // backspace + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'b'; + break; + } + + case 0x09: // horizontal tab + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 't'; + break; + } + + case 0x0A: // newline + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'n'; + break; + } + + case 0x0C: // formfeed + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'f'; + break; + } + + case 0x0D: // carriage return + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'r'; + break; + } + + case 0x22: // quotation mark + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = '\"'; + break; + } + + case 0x5C: // reverse solidus + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = '\\'; + break; + } + + default: + { + // escape control characters (0x00..0x1F) or, if + // ensure_ascii parameter is used, non-ASCII characters + if ((codepoint <= 0x1F) || (ensure_ascii && (codepoint >= 0x7F))) + { + if (codepoint <= 0xFFFF) + { + (std::snprintf)(string_buffer.data() + bytes, 7, "\\u%04x", + static_cast(codepoint)); + bytes += 6; + } + else + { + (std::snprintf)(string_buffer.data() + bytes, 13, "\\u%04x\\u%04x", + static_cast(0xD7C0u + (codepoint >> 10u)), + static_cast(0xDC00u + (codepoint & 0x3FFu))); + bytes += 12; + } + } + else + { + // copy byte to buffer (all previous bytes + // been copied have in default case above) + string_buffer[bytes++] = s[i]; + } + break; + } + } + + // write buffer and reset index; there must be 13 bytes + // left, as this is the maximal number of bytes to be + // written ("\uxxxx\uxxxx\0") for one code point + if (string_buffer.size() - bytes < 13) + { + o->write_characters(string_buffer.data(), bytes); + bytes = 0; + } + + // remember the byte position of this accept + bytes_after_last_accept = bytes; + undumped_chars = 0; + break; + } + + case UTF8_REJECT: // decode found invalid UTF-8 byte + { + switch (error_handler) + { + case error_handler_t::strict: + { + std::string sn(3, '\0'); + (std::snprintf)(&sn[0], sn.size(), "%.2X", byte); + JSON_THROW(type_error::create(316, "invalid UTF-8 byte at index " + std::to_string(i) + ": 0x" + sn)); + } + + case error_handler_t::ignore: + case error_handler_t::replace: + { + // in case we saw this character the first time, we + // would like to read it again, because the byte + // may be OK for itself, but just not OK for the + // previous sequence + if (undumped_chars > 0) + { + --i; + } + + // reset length buffer to the last accepted index; + // thus removing/ignoring the invalid characters + bytes = bytes_after_last_accept; + + if (error_handler == error_handler_t::replace) + { + // add a replacement character + if (ensure_ascii) + { + string_buffer[bytes++] = '\\'; + string_buffer[bytes++] = 'u'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'f'; + string_buffer[bytes++] = 'd'; + } + else + { + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xEF'); + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBF'); + string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBD'); + } + + // write buffer and reset index; there must be 13 bytes + // left, as this is the maximal number of bytes to be + // written ("\uxxxx\uxxxx\0") for one code point + if (string_buffer.size() - bytes < 13) + { + o->write_characters(string_buffer.data(), bytes); + bytes = 0; + } + + bytes_after_last_accept = bytes; + } + + undumped_chars = 0; + + // continue processing the string + state = UTF8_ACCEPT; + break; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + break; + } + + default: // decode found yet incomplete multi-byte code point + { + if (!ensure_ascii) + { + // code point will not be escaped - copy byte to buffer + string_buffer[bytes++] = s[i]; + } + ++undumped_chars; + break; + } + } + } + + // we finished processing the string + if (JSON_HEDLEY_LIKELY(state == UTF8_ACCEPT)) + { + // write buffer + if (bytes > 0) + { + o->write_characters(string_buffer.data(), bytes); + } + } + else + { + // we finish reading, but do not accept: string was incomplete + switch (error_handler) + { + case error_handler_t::strict: + { + std::string sn(3, '\0'); + (std::snprintf)(&sn[0], sn.size(), "%.2X", static_cast(s.back())); + JSON_THROW(type_error::create(316, "incomplete UTF-8 string; last byte: 0x" + sn)); + } + + case error_handler_t::ignore: + { + // write all accepted bytes + o->write_characters(string_buffer.data(), bytes_after_last_accept); + break; + } + + case error_handler_t::replace: + { + // write all accepted bytes + o->write_characters(string_buffer.data(), bytes_after_last_accept); + // add a replacement character + if (ensure_ascii) + { + o->write_characters("\\ufffd", 6); + } + else + { + o->write_characters("\xEF\xBF\xBD", 3); + } + break; + } + + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + } + } + + /*! + @brief count digits + + Count the number of decimal (base 10) digits for an input unsigned integer. + + @param[in] x unsigned integer number to count its digits + @return number of decimal digits + */ + inline unsigned int count_digits(number_unsigned_t x) noexcept + { + unsigned int n_digits = 1; + for (;;) + { + if (x < 10) + { + return n_digits; + } + if (x < 100) + { + return n_digits + 1; + } + if (x < 1000) + { + return n_digits + 2; + } + if (x < 10000) + { + return n_digits + 3; + } + x = x / 10000u; + n_digits += 4; + } + } + + /*! + @brief dump an integer + + Dump a given integer to output stream @a o. Works internally with + @a number_buffer. + + @param[in] x integer number (signed or unsigned) to dump + @tparam NumberType either @a number_integer_t or @a number_unsigned_t + */ + template < typename NumberType, detail::enable_if_t < + std::is_same::value || + std::is_same::value || + std::is_same::value, + int > = 0 > + void dump_integer(NumberType x) + { + static constexpr std::array, 100> digits_to_99 + { + { + {{'0', '0'}}, {{'0', '1'}}, {{'0', '2'}}, {{'0', '3'}}, {{'0', '4'}}, {{'0', '5'}}, {{'0', '6'}}, {{'0', '7'}}, {{'0', '8'}}, {{'0', '9'}}, + {{'1', '0'}}, {{'1', '1'}}, {{'1', '2'}}, {{'1', '3'}}, {{'1', '4'}}, {{'1', '5'}}, {{'1', '6'}}, {{'1', '7'}}, {{'1', '8'}}, {{'1', '9'}}, + {{'2', '0'}}, {{'2', '1'}}, {{'2', '2'}}, {{'2', '3'}}, {{'2', '4'}}, {{'2', '5'}}, {{'2', '6'}}, {{'2', '7'}}, {{'2', '8'}}, {{'2', '9'}}, + {{'3', '0'}}, {{'3', '1'}}, {{'3', '2'}}, {{'3', '3'}}, {{'3', '4'}}, {{'3', '5'}}, {{'3', '6'}}, {{'3', '7'}}, {{'3', '8'}}, {{'3', '9'}}, + {{'4', '0'}}, {{'4', '1'}}, {{'4', '2'}}, {{'4', '3'}}, {{'4', '4'}}, {{'4', '5'}}, {{'4', '6'}}, {{'4', '7'}}, {{'4', '8'}}, {{'4', '9'}}, + {{'5', '0'}}, {{'5', '1'}}, {{'5', '2'}}, {{'5', '3'}}, {{'5', '4'}}, {{'5', '5'}}, {{'5', '6'}}, {{'5', '7'}}, {{'5', '8'}}, {{'5', '9'}}, + {{'6', '0'}}, {{'6', '1'}}, {{'6', '2'}}, {{'6', '3'}}, {{'6', '4'}}, {{'6', '5'}}, {{'6', '6'}}, {{'6', '7'}}, {{'6', '8'}}, {{'6', '9'}}, + {{'7', '0'}}, {{'7', '1'}}, {{'7', '2'}}, {{'7', '3'}}, {{'7', '4'}}, {{'7', '5'}}, {{'7', '6'}}, {{'7', '7'}}, {{'7', '8'}}, {{'7', '9'}}, + {{'8', '0'}}, {{'8', '1'}}, {{'8', '2'}}, {{'8', '3'}}, {{'8', '4'}}, {{'8', '5'}}, {{'8', '6'}}, {{'8', '7'}}, {{'8', '8'}}, {{'8', '9'}}, + {{'9', '0'}}, {{'9', '1'}}, {{'9', '2'}}, {{'9', '3'}}, {{'9', '4'}}, {{'9', '5'}}, {{'9', '6'}}, {{'9', '7'}}, {{'9', '8'}}, {{'9', '9'}}, + } + }; + + // special case for "0" + if (x == 0) + { + o->write_character('0'); + return; + } + + // use a pointer to fill the buffer + auto buffer_ptr = number_buffer.begin(); + + const bool is_negative = std::is_same::value && !(x >= 0); // see issue #755 + number_unsigned_t abs_value; + + unsigned int n_chars; + + if (is_negative) + { + *buffer_ptr = '-'; + abs_value = remove_sign(static_cast(x)); + + // account one more byte for the minus sign + n_chars = 1 + count_digits(abs_value); + } + else + { + abs_value = static_cast(x); + n_chars = count_digits(abs_value); + } + + // spare 1 byte for '\0' + JSON_ASSERT(n_chars < number_buffer.size() - 1); + + // jump to the end to generate the string from backward + // so we later avoid reversing the result + buffer_ptr += n_chars; + + // Fast int2ascii implementation inspired by "Fastware" talk by Andrei Alexandrescu + // See: https://www.youtube.com/watch?v=o4-CwDo2zpg + while (abs_value >= 100) + { + const auto digits_index = static_cast((abs_value % 100)); + abs_value /= 100; + *(--buffer_ptr) = digits_to_99[digits_index][1]; + *(--buffer_ptr) = digits_to_99[digits_index][0]; + } + + if (abs_value >= 10) + { + const auto digits_index = static_cast(abs_value); + *(--buffer_ptr) = digits_to_99[digits_index][1]; + *(--buffer_ptr) = digits_to_99[digits_index][0]; + } + else + { + *(--buffer_ptr) = static_cast('0' + abs_value); + } + + o->write_characters(number_buffer.data(), n_chars); + } + + /*! + @brief dump a floating-point number + + Dump a given floating-point number to output stream @a o. Works internally + with @a number_buffer. + + @param[in] x floating-point number to dump + */ + void dump_float(number_float_t x) + { + // NaN / inf + if (!std::isfinite(x)) + { + o->write_characters("null", 4); + return; + } + + // If number_float_t is an IEEE-754 single or double precision number, + // use the Grisu2 algorithm to produce short numbers which are + // guaranteed to round-trip, using strtof and strtod, resp. + // + // NB: The test below works if == . + static constexpr bool is_ieee_single_or_double + = (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 24 && std::numeric_limits::max_exponent == 128) || + (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 53 && std::numeric_limits::max_exponent == 1024); + + dump_float(x, std::integral_constant()); + } + + void dump_float(number_float_t x, std::true_type /*is_ieee_single_or_double*/) + { + char* begin = number_buffer.data(); + char* end = ::nlohmann::detail::to_chars(begin, begin + number_buffer.size(), x); + + o->write_characters(begin, static_cast(end - begin)); + } + + void dump_float(number_float_t x, std::false_type /*is_ieee_single_or_double*/) + { + // get number of digits for a float -> text -> float round-trip + static constexpr auto d = std::numeric_limits::max_digits10; + + // the actual conversion + std::ptrdiff_t len = (std::snprintf)(number_buffer.data(), number_buffer.size(), "%.*g", d, x); + + // negative value indicates an error + JSON_ASSERT(len > 0); + // check if buffer was large enough + JSON_ASSERT(static_cast(len) < number_buffer.size()); + + // erase thousands separator + if (thousands_sep != '\0') + { + const auto end = std::remove(number_buffer.begin(), + number_buffer.begin() + len, thousands_sep); + std::fill(end, number_buffer.end(), '\0'); + JSON_ASSERT((end - number_buffer.begin()) <= len); + len = (end - number_buffer.begin()); + } + + // convert decimal point to '.' + if (decimal_point != '\0' && decimal_point != '.') + { + const auto dec_pos = std::find(number_buffer.begin(), number_buffer.end(), decimal_point); + if (dec_pos != number_buffer.end()) + { + *dec_pos = '.'; + } + } + + o->write_characters(number_buffer.data(), static_cast(len)); + + // determine if need to append ".0" + const bool value_is_int_like = + std::none_of(number_buffer.begin(), number_buffer.begin() + len + 1, + [](char c) + { + return c == '.' || c == 'e'; + }); + + if (value_is_int_like) + { + o->write_characters(".0", 2); + } + } + + /*! + @brief check whether a string is UTF-8 encoded + + The function checks each byte of a string whether it is UTF-8 encoded. The + result of the check is stored in the @a state parameter. The function must + be called initially with state 0 (accept). State 1 means the string must + be rejected, because the current byte is not allowed. If the string is + completely processed, but the state is non-zero, the string ended + prematurely; that is, the last byte indicated more bytes should have + followed. + + @param[in,out] state the state of the decoding + @param[in,out] codep codepoint (valid only if resulting state is UTF8_ACCEPT) + @param[in] byte next byte to decode + @return new state + + @note The function has been edited: a std::array is used. + + @copyright Copyright (c) 2008-2009 Bjoern Hoehrmann + @sa http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + */ + static std::uint8_t decode(std::uint8_t& state, std::uint32_t& codep, const std::uint8_t byte) noexcept + { + static const std::array utf8d = + { + { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9F + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // A0..BF + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C0..DF + 0xA, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // E0..EF + 0xB, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // F0..FF + 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 + 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 + 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 + 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 // s7..s8 + } + }; + + const std::uint8_t type = utf8d[byte]; + + codep = (state != UTF8_ACCEPT) + ? (byte & 0x3fu) | (codep << 6u) + : (0xFFu >> type) & (byte); + + std::size_t index = 256u + static_cast(state) * 16u + static_cast(type); + JSON_ASSERT(index < 400); + state = utf8d[index]; + return state; + } + + /* + * Overload to make the compiler happy while it is instantiating + * dump_integer for number_unsigned_t. + * Must never be called. + */ + number_unsigned_t remove_sign(number_unsigned_t x) + { + JSON_ASSERT(false); // LCOV_EXCL_LINE + return x; // LCOV_EXCL_LINE + } + + /* + * Helper function for dump_integer + * + * This function takes a negative signed integer and returns its absolute + * value as unsigned integer. The plus/minus shuffling is necessary as we can + * not directly remove the sign of an arbitrary signed integer as the + * absolute values of INT_MIN and INT_MAX are usually not the same. See + * #1708 for details. + */ + inline number_unsigned_t remove_sign(number_integer_t x) noexcept + { + JSON_ASSERT(x < 0 && x < (std::numeric_limits::max)()); + return static_cast(-(x + 1)) + 1; + } + + private: + /// the output of the serializer + output_adapter_t o = nullptr; + + /// a (hopefully) large enough character buffer + std::array number_buffer{{}}; + + /// the locale + const std::lconv* loc = nullptr; + /// the locale's thousand separator character + const char thousands_sep = '\0'; + /// the locale's decimal point character + const char decimal_point = '\0'; + + /// string buffer + std::array string_buffer{{}}; + + /// the indentation character + const char indent_char; + /// the indentation string + string_t indent_string; + + /// error_handler how to react on decoding errors + const error_handler_t error_handler; +}; +} // namespace detail +} // namespace nlohmann + +// #include + +// #include + +// #include + + +#include // less +#include // allocator +#include // pair +#include // vector + +namespace nlohmann +{ + +/// ordered_map: a minimal map-like container that preserves insertion order +/// for use within nlohmann::basic_json +template , + class Allocator = std::allocator>> + struct ordered_map : std::vector, Allocator> +{ + using key_type = Key; + using mapped_type = T; + using Container = std::vector, Allocator>; + using typename Container::iterator; + using typename Container::const_iterator; + using typename Container::size_type; + using typename Container::value_type; + + // Explicit constructors instead of `using Container::Container` + // otherwise older compilers choke on it (GCC <= 5.5, xcode <= 9.4) + ordered_map(const Allocator& alloc = Allocator()) : Container{alloc} {} + template + ordered_map(It first, It last, const Allocator& alloc = Allocator()) + : Container{first, last, alloc} {} + ordered_map(std::initializer_list init, const Allocator& alloc = Allocator() ) + : Container{init, alloc} {} + + std::pair emplace(const key_type& key, T&& t) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return {it, false}; + } + } + Container::emplace_back(key, t); + return {--this->end(), true}; + } + + T& operator[](const Key& key) + { + return emplace(key, T{}).first->second; + } + + const T& operator[](const Key& key) const + { + return at(key); + } + + T& at(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it->second; + } + } + + throw std::out_of_range("key not found"); + } + + const T& at(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it->second; + } + } + + throw std::out_of_range("key not found"); + } + + size_type erase(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + // Since we cannot move const Keys, re-construct them in place + for (auto next = it; ++next != this->end(); ++it) + { + it->~value_type(); // Destroy but keep allocation + new (&*it) value_type{std::move(*next)}; + } + Container::pop_back(); + return 1; + } + } + return 0; + } + + iterator erase(iterator pos) + { + auto it = pos; + + // Since we cannot move const Keys, re-construct them in place + for (auto next = it; ++next != this->end(); ++it) + { + it->~value_type(); // Destroy but keep allocation + new (&*it) value_type{std::move(*next)}; + } + Container::pop_back(); + return pos; + } + + size_type count(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return 1; + } + } + return 0; + } + + iterator find(const Key& key) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it; + } + } + return Container::end(); + } + + const_iterator find(const Key& key) const + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == key) + { + return it; + } + } + return Container::end(); + } + + std::pair insert( value_type&& value ) + { + return emplace(value.first, std::move(value.second)); + } + + std::pair insert( const value_type& value ) + { + for (auto it = this->begin(); it != this->end(); ++it) + { + if (it->first == value.first) + { + return {it, false}; + } + } + Container::push_back(value); + return {--this->end(), true}; + } +}; + +} // namespace nlohmann + + +/*! +@brief namespace for Niels Lohmann +@see https://github.com/nlohmann +@since version 1.0.0 +*/ +namespace nlohmann +{ + +/*! +@brief a class to store JSON values + +@tparam ObjectType type for JSON objects (`std::map` by default; will be used +in @ref object_t) +@tparam ArrayType type for JSON arrays (`std::vector` by default; will be used +in @ref array_t) +@tparam StringType type for JSON strings and object keys (`std::string` by +default; will be used in @ref string_t) +@tparam BooleanType type for JSON booleans (`bool` by default; will be used +in @ref boolean_t) +@tparam NumberIntegerType type for JSON integer numbers (`int64_t` by +default; will be used in @ref number_integer_t) +@tparam NumberUnsignedType type for JSON unsigned integer numbers (@c +`uint64_t` by default; will be used in @ref number_unsigned_t) +@tparam NumberFloatType type for JSON floating-point numbers (`double` by +default; will be used in @ref number_float_t) +@tparam BinaryType type for packed binary data for compatibility with binary +serialization formats (`std::vector` by default; will be used in +@ref binary_t) +@tparam AllocatorType type of the allocator to use (`std::allocator` by +default) +@tparam JSONSerializer the serializer to resolve internal calls to `to_json()` +and `from_json()` (@ref adl_serializer by default) + +@requirement The class satisfies the following concept requirements: +- Basic + - [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible): + JSON values can be default constructed. The result will be a JSON null + value. + - [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible): + A JSON value can be constructed from an rvalue argument. + - [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible): + A JSON value can be copy-constructed from an lvalue expression. + - [MoveAssignable](https://en.cppreference.com/w/cpp/named_req/MoveAssignable): + A JSON value van be assigned from an rvalue argument. + - [CopyAssignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable): + A JSON value can be copy-assigned from an lvalue expression. + - [Destructible](https://en.cppreference.com/w/cpp/named_req/Destructible): + JSON values can be destructed. +- Layout + - [StandardLayoutType](https://en.cppreference.com/w/cpp/named_req/StandardLayoutType): + JSON values have + [standard layout](https://en.cppreference.com/w/cpp/language/data_members#Standard_layout): + All non-static data members are private and standard layout types, the + class has no virtual functions or (virtual) base classes. +- Library-wide + - [EqualityComparable](https://en.cppreference.com/w/cpp/named_req/EqualityComparable): + JSON values can be compared with `==`, see @ref + operator==(const_reference,const_reference). + - [LessThanComparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable): + JSON values can be compared with `<`, see @ref + operator<(const_reference,const_reference). + - [Swappable](https://en.cppreference.com/w/cpp/named_req/Swappable): + Any JSON lvalue or rvalue of can be swapped with any lvalue or rvalue of + other compatible types, using unqualified function call @ref swap(). + - [NullablePointer](https://en.cppreference.com/w/cpp/named_req/NullablePointer): + JSON values can be compared against `std::nullptr_t` objects which are used + to model the `null` value. +- Container + - [Container](https://en.cppreference.com/w/cpp/named_req/Container): + JSON values can be used like STL containers and provide iterator access. + - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer); + JSON values can be used like STL containers and provide reverse iterator + access. + +@invariant The member variables @a m_value and @a m_type have the following +relationship: +- If `m_type == value_t::object`, then `m_value.object != nullptr`. +- If `m_type == value_t::array`, then `m_value.array != nullptr`. +- If `m_type == value_t::string`, then `m_value.string != nullptr`. +The invariants are checked by member function assert_invariant(). + +@internal +@note ObjectType trick from https://stackoverflow.com/a/9860911 +@endinternal + +@see [RFC 7159: The JavaScript Object Notation (JSON) Data Interchange +Format](http://rfc7159.net/rfc7159) + +@since version 1.0.0 + +@nosubgrouping +*/ +NLOHMANN_BASIC_JSON_TPL_DECLARATION +class basic_json +{ + private: + template friend struct detail::external_constructor; + friend ::nlohmann::json_pointer; + + template + friend class ::nlohmann::detail::parser; + friend ::nlohmann::detail::serializer; + template + friend class ::nlohmann::detail::iter_impl; + template + friend class ::nlohmann::detail::binary_writer; + template + friend class ::nlohmann::detail::binary_reader; + template + friend class ::nlohmann::detail::json_sax_dom_parser; + template + friend class ::nlohmann::detail::json_sax_dom_callback_parser; + + /// workaround type for MSVC + using basic_json_t = NLOHMANN_BASIC_JSON_TPL; + + // convenience aliases for types residing in namespace detail; + using lexer = ::nlohmann::detail::lexer_base; + + template + static ::nlohmann::detail::parser parser( + InputAdapterType adapter, + detail::parser_callback_tcb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false + ) + { + return ::nlohmann::detail::parser(std::move(adapter), + std::move(cb), allow_exceptions, ignore_comments); + } + + using primitive_iterator_t = ::nlohmann::detail::primitive_iterator_t; + template + using internal_iterator = ::nlohmann::detail::internal_iterator; + template + using iter_impl = ::nlohmann::detail::iter_impl; + template + using iteration_proxy = ::nlohmann::detail::iteration_proxy; + template using json_reverse_iterator = ::nlohmann::detail::json_reverse_iterator; + + template + using output_adapter_t = ::nlohmann::detail::output_adapter_t; + + template + using binary_reader = ::nlohmann::detail::binary_reader; + template using binary_writer = ::nlohmann::detail::binary_writer; + + using serializer = ::nlohmann::detail::serializer; + + public: + using value_t = detail::value_t; + /// JSON Pointer, see @ref nlohmann::json_pointer + using json_pointer = ::nlohmann::json_pointer; + template + using json_serializer = JSONSerializer; + /// how to treat decoding errors + using error_handler_t = detail::error_handler_t; + /// how to treat CBOR tags + using cbor_tag_handler_t = detail::cbor_tag_handler_t; + /// helper type for initializer lists of basic_json values + using initializer_list_t = std::initializer_list>; + + using input_format_t = detail::input_format_t; + /// SAX interface type, see @ref nlohmann::json_sax + using json_sax_t = json_sax; + + //////////////// + // exceptions // + //////////////// + + /// @name exceptions + /// Classes to implement user-defined exceptions. + /// @{ + + /// @copydoc detail::exception + using exception = detail::exception; + /// @copydoc detail::parse_error + using parse_error = detail::parse_error; + /// @copydoc detail::invalid_iterator + using invalid_iterator = detail::invalid_iterator; + /// @copydoc detail::type_error + using type_error = detail::type_error; + /// @copydoc detail::out_of_range + using out_of_range = detail::out_of_range; + /// @copydoc detail::other_error + using other_error = detail::other_error; + + /// @} + + + ///////////////////// + // container types // + ///////////////////// + + /// @name container types + /// The canonic container types to use @ref basic_json like any other STL + /// container. + /// @{ + + /// the type of elements in a basic_json container + using value_type = basic_json; + + /// the type of an element reference + using reference = value_type&; + /// the type of an element const reference + using const_reference = const value_type&; + + /// a type to represent differences between iterators + using difference_type = std::ptrdiff_t; + /// a type to represent container sizes + using size_type = std::size_t; + + /// the allocator type + using allocator_type = AllocatorType; + + /// the type of an element pointer + using pointer = typename std::allocator_traits::pointer; + /// the type of an element const pointer + using const_pointer = typename std::allocator_traits::const_pointer; + + /// an iterator for a basic_json container + using iterator = iter_impl; + /// a const iterator for a basic_json container + using const_iterator = iter_impl; + /// a reverse iterator for a basic_json container + using reverse_iterator = json_reverse_iterator; + /// a const reverse iterator for a basic_json container + using const_reverse_iterator = json_reverse_iterator; + + /// @} + + + /*! + @brief returns the allocator associated with the container + */ + static allocator_type get_allocator() + { + return allocator_type(); + } + + /*! + @brief returns version information on the library + + This function returns a JSON object with information about the library, + including the version number and information on the platform and compiler. + + @return JSON object holding version information + key | description + ----------- | --------------- + `compiler` | Information on the used compiler. It is an object with the following keys: `c++` (the used C++ standard), `family` (the compiler family; possible values are `clang`, `icc`, `gcc`, `ilecpp`, `msvc`, `pgcpp`, `sunpro`, and `unknown`), and `version` (the compiler version). + `copyright` | The copyright line for the library as string. + `name` | The name of the library as string. + `platform` | The used platform as string. Possible values are `win32`, `linux`, `apple`, `unix`, and `unknown`. + `url` | The URL of the project as string. + `version` | The version of the library. It is an object with the following keys: `major`, `minor`, and `patch` as defined by [Semantic Versioning](http://semver.org), and `string` (the version string). + + @liveexample{The following code shows an example output of the `meta()` + function.,meta} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @complexity Constant. + + @since 2.1.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json meta() + { + basic_json result; + + result["copyright"] = "(C) 2013-2020 Niels Lohmann"; + result["name"] = "JSON for Modern C++"; + result["url"] = "https://github.com/nlohmann/json"; + result["version"]["string"] = + std::to_string(NLOHMANN_JSON_VERSION_MAJOR) + "." + + std::to_string(NLOHMANN_JSON_VERSION_MINOR) + "." + + std::to_string(NLOHMANN_JSON_VERSION_PATCH); + result["version"]["major"] = NLOHMANN_JSON_VERSION_MAJOR; + result["version"]["minor"] = NLOHMANN_JSON_VERSION_MINOR; + result["version"]["patch"] = NLOHMANN_JSON_VERSION_PATCH; + +#ifdef _WIN32 + result["platform"] = "win32"; +#elif defined __linux__ + result["platform"] = "linux"; +#elif defined __APPLE__ + result["platform"] = "apple"; +#elif defined __unix__ + result["platform"] = "unix"; +#else + result["platform"] = "unknown"; +#endif + +#if defined(__ICC) || defined(__INTEL_COMPILER) + result["compiler"] = {{"family", "icc"}, {"version", __INTEL_COMPILER}}; +#elif defined(__clang__) + result["compiler"] = {{"family", "clang"}, {"version", __clang_version__}}; +#elif defined(__GNUC__) || defined(__GNUG__) + result["compiler"] = {{"family", "gcc"}, {"version", std::to_string(__GNUC__) + "." + std::to_string(__GNUC_MINOR__) + "." + std::to_string(__GNUC_PATCHLEVEL__)}}; +#elif defined(__HP_cc) || defined(__HP_aCC) + result["compiler"] = "hp" +#elif defined(__IBMCPP__) + result["compiler"] = {{"family", "ilecpp"}, {"version", __IBMCPP__}}; +#elif defined(_MSC_VER) + result["compiler"] = {{"family", "msvc"}, {"version", _MSC_VER}}; +#elif defined(__PGI) + result["compiler"] = {{"family", "pgcpp"}, {"version", __PGI}}; +#elif defined(__SUNPRO_CC) + result["compiler"] = {{"family", "sunpro"}, {"version", __SUNPRO_CC}}; +#else + result["compiler"] = {{"family", "unknown"}, {"version", "unknown"}}; +#endif + +#ifdef __cplusplus + result["compiler"]["c++"] = std::to_string(__cplusplus); +#else + result["compiler"]["c++"] = "unknown"; +#endif + return result; + } + + + /////////////////////////// + // JSON value data types // + /////////////////////////// + + /// @name JSON value data types + /// The data types to store a JSON value. These types are derived from + /// the template arguments passed to class @ref basic_json. + /// @{ + +#if defined(JSON_HAS_CPP_14) + // Use transparent comparator if possible, combined with perfect forwarding + // on find() and count() calls prevents unnecessary string construction. + using object_comparator_t = std::less<>; +#else + using object_comparator_t = std::less; +#endif + + /*! + @brief a type for an object + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON objects as follows: + > An object is an unordered collection of zero or more name/value pairs, + > where a name is a string and a value is a string, number, boolean, null, + > object, or array. + + To store objects in C++, a type is defined by the template parameters + described below. + + @tparam ObjectType the container to store objects (e.g., `std::map` or + `std::unordered_map`) + @tparam StringType the type of the keys or names (e.g., `std::string`). + The comparison function `std::less` is used to order elements + inside the container. + @tparam AllocatorType the allocator to use for objects (e.g., + `std::allocator`) + + #### Default type + + With the default values for @a ObjectType (`std::map`), @a StringType + (`std::string`), and @a AllocatorType (`std::allocator`), the default + value for @a object_t is: + + @code {.cpp} + std::map< + std::string, // key_type + basic_json, // value_type + std::less, // key_compare + std::allocator> // allocator_type + > + @endcode + + #### Behavior + + The choice of @a object_t influences the behavior of the JSON class. With + the default type, objects have the following behavior: + + - When all names are unique, objects will be interoperable in the sense + that all software implementations receiving that object will agree on + the name-value mappings. + - When the names within an object are not unique, it is unspecified which + one of the values for a given key will be chosen. For instance, + `{"key": 2, "key": 1}` could be equal to either `{"key": 1}` or + `{"key": 2}`. + - Internally, name/value pairs are stored in lexicographical order of the + names. Objects will also be serialized (see @ref dump) in this order. + For instance, `{"b": 1, "a": 2}` and `{"a": 2, "b": 1}` will be stored + and serialized as `{"a": 2, "b": 1}`. + - When comparing objects, the order of the name/value pairs is irrelevant. + This makes objects interoperable in the sense that they will not be + affected by these differences. For instance, `{"b": 1, "a": 2}` and + `{"a": 2, "b": 1}` will be treated as equal. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the object's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON object. + + #### Storage + + Objects are stored as pointers in a @ref basic_json type. That is, for any + access to object values, a pointer of type `object_t*` must be + dereferenced. + + @sa @ref array_t -- type for an array value + + @since version 1.0.0 + + @note The order name/value pairs are added to the object is *not* + preserved by the library. Therefore, iterating an object may return + name/value pairs in a different order than they were originally stored. In + fact, keys will be traversed in alphabetical order as `std::map` with + `std::less` is used by default. Please note this behavior conforms to [RFC + 7159](http://rfc7159.net/rfc7159), because any order implements the + specified "unordered" nature of JSON objects. + */ + using object_t = ObjectType>>; + + /*! + @brief a type for an array + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON arrays as follows: + > An array is an ordered sequence of zero or more values. + + To store objects in C++, a type is defined by the template parameters + explained below. + + @tparam ArrayType container type to store arrays (e.g., `std::vector` or + `std::list`) + @tparam AllocatorType allocator to use for arrays (e.g., `std::allocator`) + + #### Default type + + With the default values for @a ArrayType (`std::vector`) and @a + AllocatorType (`std::allocator`), the default value for @a array_t is: + + @code {.cpp} + std::vector< + basic_json, // value_type + std::allocator // allocator_type + > + @endcode + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the maximum depth of nesting. + + In this class, the array's limit of nesting is not explicitly constrained. + However, a maximum depth of nesting may be introduced by the compiler or + runtime environment. A theoretical limit can be queried by calling the + @ref max_size function of a JSON array. + + #### Storage + + Arrays are stored as pointers in a @ref basic_json type. That is, for any + access to array values, a pointer of type `array_t*` must be dereferenced. + + @sa @ref object_t -- type for an object value + + @since version 1.0.0 + */ + using array_t = ArrayType>; + + /*! + @brief a type for a string + + [RFC 7159](http://rfc7159.net/rfc7159) describes JSON strings as follows: + > A string is a sequence of zero or more Unicode characters. + + To store objects in C++, a type is defined by the template parameter + described below. Unicode values are split by the JSON class into + byte-sized characters during deserialization. + + @tparam StringType the container to store strings (e.g., `std::string`). + Note this container is used for keys/names in objects, see @ref object_t. + + #### Default type + + With the default values for @a StringType (`std::string`), the default + value for @a string_t is: + + @code {.cpp} + std::string + @endcode + + #### Encoding + + Strings are stored in UTF-8 encoding. Therefore, functions like + `std::string::size()` or `std::string::length()` return the number of + bytes in the string rather than the number of characters or glyphs. + + #### String comparison + + [RFC 7159](http://rfc7159.net/rfc7159) states: + > Software implementations are typically required to test names of object + > members for equality. Implementations that transform the textual + > representation into sequences of Unicode code units and then perform the + > comparison numerically, code unit by code unit, are interoperable in the + > sense that implementations will agree in all cases on equality or + > inequality of two strings. For example, implementations that compare + > strings with escaped characters unconverted may incorrectly find that + > `"a\\b"` and `"a\u005Cb"` are not equal. + + This implementation is interoperable as it does compare strings code unit + by code unit. + + #### Storage + + String values are stored as pointers in a @ref basic_json type. That is, + for any access to string values, a pointer of type `string_t*` must be + dereferenced. + + @since version 1.0.0 + */ + using string_t = StringType; + + /*! + @brief a type for a boolean + + [RFC 7159](http://rfc7159.net/rfc7159) implicitly describes a boolean as a + type which differentiates the two literals `true` and `false`. + + To store objects in C++, a type is defined by the template parameter @a + BooleanType which chooses the type to use. + + #### Default type + + With the default values for @a BooleanType (`bool`), the default value for + @a boolean_t is: + + @code {.cpp} + bool + @endcode + + #### Storage + + Boolean values are stored directly inside a @ref basic_json type. + + @since version 1.0.0 + */ + using boolean_t = BooleanType; + + /*! + @brief a type for a number (integer) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store integer numbers in C++, a type is defined by the template + parameter @a NumberIntegerType which chooses the type to use. + + #### Default type + + With the default values for @a NumberIntegerType (`int64_t`), the default + value for @a number_integer_t is: + + @code {.cpp} + int64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `9223372036854775807` (INT64_MAX) and the minimal integer number + that can be stored is `-9223372036854775808` (INT64_MIN). Integer numbers + that are out of range will yield over/underflow when used in a + constructor. During deserialization, too large or small integer numbers + will be automatically be stored as @ref number_unsigned_t or @ref + number_float_t. + + [RFC 7159](http://rfc7159.net/rfc7159) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange of the exactly supported range [INT64_MIN, + INT64_MAX], this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa @ref number_float_t -- type for number values (floating-point) + + @sa @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_integer_t = NumberIntegerType; + + /*! + @brief a type for a number (unsigned) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store unsigned integer numbers in C++, a type is defined by the + template parameter @a NumberUnsignedType which chooses the type to use. + + #### Default type + + With the default values for @a NumberUnsignedType (`uint64_t`), the + default value for @a number_unsigned_t is: + + @code {.cpp} + uint64_t + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in integer literals lead to an interpretation as octal + number. Internally, the value will be stored as decimal number. For + instance, the C++ integer literal `010` will be serialized to `8`. + During deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) specifies: + > An implementation may set limits on the range and precision of numbers. + + When the default type is used, the maximal integer number that can be + stored is `18446744073709551615` (UINT64_MAX) and the minimal integer + number that can be stored is `0`. Integer numbers that are out of range + will yield over/underflow when used in a constructor. During + deserialization, too large or small integer numbers will be automatically + be stored as @ref number_integer_t or @ref number_float_t. + + [RFC 7159](http://rfc7159.net/rfc7159) further states: + > Note that when such software is used, numbers that are integers and are + > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense + > that implementations will agree exactly on their numeric values. + + As this range is a subrange (when considered in conjunction with the + number_integer_t type) of the exactly supported range [0, UINT64_MAX], + this class's integer type is interoperable. + + #### Storage + + Integer number values are stored directly inside a @ref basic_json type. + + @sa @ref number_float_t -- type for number values (floating-point) + @sa @ref number_integer_t -- type for number values (integer) + + @since version 2.0.0 + */ + using number_unsigned_t = NumberUnsignedType; + + /*! + @brief a type for a number (floating-point) + + [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: + > The representation of numbers is similar to that used in most + > programming languages. A number is represented in base 10 using decimal + > digits. It contains an integer component that may be prefixed with an + > optional minus sign, which may be followed by a fraction part and/or an + > exponent part. Leading zeros are not allowed. (...) Numeric values that + > cannot be represented in the grammar below (such as Infinity and NaN) + > are not permitted. + + This description includes both integer and floating-point numbers. + However, C++ allows more precise storage if it is known whether the number + is a signed integer, an unsigned integer or a floating-point number. + Therefore, three different types, @ref number_integer_t, @ref + number_unsigned_t and @ref number_float_t are used. + + To store floating-point numbers in C++, a type is defined by the template + parameter @a NumberFloatType which chooses the type to use. + + #### Default type + + With the default values for @a NumberFloatType (`double`), the default + value for @a number_float_t is: + + @code {.cpp} + double + @endcode + + #### Default behavior + + - The restrictions about leading zeros is not enforced in C++. Instead, + leading zeros in floating-point literals will be ignored. Internally, + the value will be stored as decimal number. For instance, the C++ + floating-point literal `01.2` will be serialized to `1.2`. During + deserialization, leading zeros yield an error. + - Not-a-number (NaN) values will be serialized to `null`. + + #### Limits + + [RFC 7159](http://rfc7159.net/rfc7159) states: + > This specification allows implementations to set limits on the range and + > precision of numbers accepted. Since software that implements IEEE + > 754-2008 binary64 (double precision) numbers is generally available and + > widely used, good interoperability can be achieved by implementations + > that expect no more precision or range than these provide, in the sense + > that implementations will approximate JSON numbers within the expected + > precision. + + This implementation does exactly follow this approach, as it uses double + precision floating-point numbers. Note values smaller than + `-1.79769313486232e+308` and values greater than `1.79769313486232e+308` + will be stored as NaN internally and be serialized to `null`. + + #### Storage + + Floating-point number values are stored directly inside a @ref basic_json + type. + + @sa @ref number_integer_t -- type for number values (integer) + + @sa @ref number_unsigned_t -- type for number values (unsigned integer) + + @since version 1.0.0 + */ + using number_float_t = NumberFloatType; + + /*! + @brief a type for a packed binary type + + This type is a type designed to carry binary data that appears in various + serialized formats, such as CBOR's Major Type 2, MessagePack's bin, and + BSON's generic binary subtype. This type is NOT a part of standard JSON and + exists solely for compatibility with these binary types. As such, it is + simply defined as an ordered sequence of zero or more byte values. + + Additionally, as an implementation detail, the subtype of the binary data is + carried around as a `std::uint8_t`, which is compatible with both of the + binary data formats that use binary subtyping, (though the specific + numbering is incompatible with each other, and it is up to the user to + translate between them). + + [CBOR's RFC 7049](https://tools.ietf.org/html/rfc7049) describes this type + as: + > Major type 2: a byte string. The string's length in bytes is represented + > following the rules for positive integers (major type 0). + + [MessagePack's documentation on the bin type + family](https://github.com/msgpack/msgpack/blob/master/spec.md#bin-format-family) + describes this type as: + > Bin format family stores an byte array in 2, 3, or 5 bytes of extra bytes + > in addition to the size of the byte array. + + [BSON's specifications](http://bsonspec.org/spec.html) describe several + binary types; however, this type is intended to represent the generic binary + type which has the description: + > Generic binary subtype - This is the most commonly used binary subtype and + > should be the 'default' for drivers and tools. + + None of these impose any limitations on the internal representation other + than the basic unit of storage be some type of array whose parts are + decomposable into bytes. + + The default representation of this binary format is a + `std::vector`, which is a very common way to represent a byte + array in modern C++. + + #### Default type + + The default values for @a BinaryType is `std::vector` + + #### Storage + + Binary Arrays are stored as pointers in a @ref basic_json type. That is, + for any access to array values, a pointer of the type `binary_t*` must be + dereferenced. + + #### Notes on subtypes + + - CBOR + - Binary values are represented as byte strings. No subtypes are + supported and will be ignored when CBOR is written. + - MessagePack + - If a subtype is given and the binary array contains exactly 1, 2, 4, 8, + or 16 elements, the fixext family (fixext1, fixext2, fixext4, fixext8) + is used. For other sizes, the ext family (ext8, ext16, ext32) is used. + The subtype is then added as singed 8-bit integer. + - If no subtype is given, the bin family (bin8, bin16, bin32) is used. + - BSON + - If a subtype is given, it is used and added as unsigned 8-bit integer. + - If no subtype is given, the generic binary subtype 0x00 is used. + + @sa @ref binary -- create a binary array + + @since version 3.8.0 + */ + using binary_t = nlohmann::byte_container_with_subtype; + /// @} + + private: + + /// helper for exception-safe object creation + template + JSON_HEDLEY_RETURNS_NON_NULL + static T* create(Args&& ... args) + { + AllocatorType alloc; + using AllocatorTraits = std::allocator_traits>; + + auto deleter = [&](T * object) + { + AllocatorTraits::deallocate(alloc, object, 1); + }; + std::unique_ptr object(AllocatorTraits::allocate(alloc, 1), deleter); + AllocatorTraits::construct(alloc, object.get(), std::forward(args)...); + JSON_ASSERT(object != nullptr); + return object.release(); + } + + //////////////////////// + // JSON value storage // + //////////////////////// + + /*! + @brief a JSON value + + The actual storage for a JSON value of the @ref basic_json class. This + union combines the different storage types for the JSON value types + defined in @ref value_t. + + JSON type | value_t type | used type + --------- | --------------- | ------------------------ + object | object | pointer to @ref object_t + array | array | pointer to @ref array_t + string | string | pointer to @ref string_t + boolean | boolean | @ref boolean_t + number | number_integer | @ref number_integer_t + number | number_unsigned | @ref number_unsigned_t + number | number_float | @ref number_float_t + binary | binary | pointer to @ref binary_t + null | null | *no value is stored* + + @note Variable-length types (objects, arrays, and strings) are stored as + pointers. The size of the union should not exceed 64 bits if the default + value types are used. + + @since version 1.0.0 + */ + union json_value + { + /// object (stored with pointer to save storage) + object_t* object; + /// array (stored with pointer to save storage) + array_t* array; + /// string (stored with pointer to save storage) + string_t* string; + /// binary (stored with pointer to save storage) + binary_t* binary; + /// boolean + boolean_t boolean; + /// number (integer) + number_integer_t number_integer; + /// number (unsigned integer) + number_unsigned_t number_unsigned; + /// number (floating-point) + number_float_t number_float; + + /// default constructor (for null values) + json_value() = default; + /// constructor for booleans + json_value(boolean_t v) noexcept : boolean(v) {} + /// constructor for numbers (integer) + json_value(number_integer_t v) noexcept : number_integer(v) {} + /// constructor for numbers (unsigned) + json_value(number_unsigned_t v) noexcept : number_unsigned(v) {} + /// constructor for numbers (floating-point) + json_value(number_float_t v) noexcept : number_float(v) {} + /// constructor for empty values of a given type + json_value(value_t t) + { + switch (t) + { + case value_t::object: + { + object = create(); + break; + } + + case value_t::array: + { + array = create(); + break; + } + + case value_t::string: + { + string = create(""); + break; + } + + case value_t::binary: + { + binary = create(); + break; + } + + case value_t::boolean: + { + boolean = boolean_t(false); + break; + } + + case value_t::number_integer: + { + number_integer = number_integer_t(0); + break; + } + + case value_t::number_unsigned: + { + number_unsigned = number_unsigned_t(0); + break; + } + + case value_t::number_float: + { + number_float = number_float_t(0.0); + break; + } + + case value_t::null: + { + object = nullptr; // silence warning, see #821 + break; + } + + default: + { + object = nullptr; // silence warning, see #821 + if (JSON_HEDLEY_UNLIKELY(t == value_t::null)) + { + JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 3.9.1")); // LCOV_EXCL_LINE + } + break; + } + } + } + + /// constructor for strings + json_value(const string_t& value) + { + string = create(value); + } + + /// constructor for rvalue strings + json_value(string_t&& value) + { + string = create(std::move(value)); + } + + /// constructor for objects + json_value(const object_t& value) + { + object = create(value); + } + + /// constructor for rvalue objects + json_value(object_t&& value) + { + object = create(std::move(value)); + } + + /// constructor for arrays + json_value(const array_t& value) + { + array = create(value); + } + + /// constructor for rvalue arrays + json_value(array_t&& value) + { + array = create(std::move(value)); + } + + /// constructor for binary arrays + json_value(const typename binary_t::container_type& value) + { + binary = create(value); + } + + /// constructor for rvalue binary arrays + json_value(typename binary_t::container_type&& value) + { + binary = create(std::move(value)); + } + + /// constructor for binary arrays (internal type) + json_value(const binary_t& value) + { + binary = create(value); + } + + /// constructor for rvalue binary arrays (internal type) + json_value(binary_t&& value) + { + binary = create(std::move(value)); + } + + void destroy(value_t t) noexcept + { + // flatten the current json_value to a heap-allocated stack + std::vector stack; + + // move the top-level items to stack + if (t == value_t::array) + { + stack.reserve(array->size()); + std::move(array->begin(), array->end(), std::back_inserter(stack)); + } + else if (t == value_t::object) + { + stack.reserve(object->size()); + for (auto&& it : *object) + { + stack.push_back(std::move(it.second)); + } + } + + while (!stack.empty()) + { + // move the last item to local variable to be processed + basic_json current_item(std::move(stack.back())); + stack.pop_back(); + + // if current_item is array/object, move + // its children to the stack to be processed later + if (current_item.is_array()) + { + std::move(current_item.m_value.array->begin(), current_item.m_value.array->end(), + std::back_inserter(stack)); + + current_item.m_value.array->clear(); + } + else if (current_item.is_object()) + { + for (auto&& it : *current_item.m_value.object) + { + stack.push_back(std::move(it.second)); + } + + current_item.m_value.object->clear(); + } + + // it's now safe that current_item get destructed + // since it doesn't have any children + } + + switch (t) + { + case value_t::object: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, object); + std::allocator_traits::deallocate(alloc, object, 1); + break; + } + + case value_t::array: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, array); + std::allocator_traits::deallocate(alloc, array, 1); + break; + } + + case value_t::string: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, string); + std::allocator_traits::deallocate(alloc, string, 1); + break; + } + + case value_t::binary: + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, binary); + std::allocator_traits::deallocate(alloc, binary, 1); + break; + } + + default: + { + break; + } + } + } + }; + + /*! + @brief checks the class invariants + + This function asserts the class invariants. It needs to be called at the + end of every constructor to make sure that created objects respect the + invariant. Furthermore, it has to be called each time the type of a JSON + value is changed, because the invariant expresses a relationship between + @a m_type and @a m_value. + */ + void assert_invariant() const noexcept + { + JSON_ASSERT(m_type != value_t::object || m_value.object != nullptr); + JSON_ASSERT(m_type != value_t::array || m_value.array != nullptr); + JSON_ASSERT(m_type != value_t::string || m_value.string != nullptr); + JSON_ASSERT(m_type != value_t::binary || m_value.binary != nullptr); + } + + public: + ////////////////////////// + // JSON parser callback // + ////////////////////////// + + /*! + @brief parser event types + + The parser callback distinguishes the following events: + - `object_start`: the parser read `{` and started to process a JSON object + - `key`: the parser read a key of a value in an object + - `object_end`: the parser read `}` and finished processing a JSON object + - `array_start`: the parser read `[` and started to process a JSON array + - `array_end`: the parser read `]` and finished processing a JSON array + - `value`: the parser finished reading a JSON value + + @image html callback_events.png "Example when certain parse events are triggered" + + @sa @ref parser_callback_t for more information and examples + */ + using parse_event_t = detail::parse_event_t; + + /*! + @brief per-element parser callback type + + With a parser callback function, the result of parsing a JSON text can be + influenced. When passed to @ref parse, it is called on certain events + (passed as @ref parse_event_t via parameter @a event) with a set recursion + depth @a depth and context JSON value @a parsed. The return value of the + callback function is a boolean indicating whether the element that emitted + the callback shall be kept or not. + + We distinguish six scenarios (determined by the event type) in which the + callback function can be called. The following table describes the values + of the parameters @a depth, @a event, and @a parsed. + + parameter @a event | description | parameter @a depth | parameter @a parsed + ------------------ | ----------- | ------------------ | ------------------- + parse_event_t::object_start | the parser read `{` and started to process a JSON object | depth of the parent of the JSON object | a JSON value with type discarded + parse_event_t::key | the parser read a key of a value in an object | depth of the currently parsed JSON object | a JSON string containing the key + parse_event_t::object_end | the parser read `}` and finished processing a JSON object | depth of the parent of the JSON object | the parsed JSON object + parse_event_t::array_start | the parser read `[` and started to process a JSON array | depth of the parent of the JSON array | a JSON value with type discarded + parse_event_t::array_end | the parser read `]` and finished processing a JSON array | depth of the parent of the JSON array | the parsed JSON array + parse_event_t::value | the parser finished reading a JSON value | depth of the value | the parsed JSON value + + @image html callback_events.png "Example when certain parse events are triggered" + + Discarding a value (i.e., returning `false`) has different effects + depending on the context in which function was called: + + - Discarded values in structured types are skipped. That is, the parser + will behave as if the discarded value was never read. + - In case a value outside a structured type is skipped, it is replaced + with `null`. This case happens if the top-level element is skipped. + + @param[in] depth the depth of the recursion during parsing + + @param[in] event an event of type parse_event_t indicating the context in + the callback function has been called + + @param[in,out] parsed the current intermediate parse result; note that + writing to this value has no effect for parse_event_t::key events + + @return Whether the JSON value which called the function during parsing + should be kept (`true`) or not (`false`). In the latter case, it is either + skipped completely or replaced by an empty discarded object. + + @sa @ref parse for examples + + @since version 1.0.0 + */ + using parser_callback_t = detail::parser_callback_t; + + ////////////////// + // constructors // + ////////////////// + + /// @name constructors and destructors + /// Constructors of class @ref basic_json, copy/move constructor, copy + /// assignment, static functions creating objects, and the destructor. + /// @{ + + /*! + @brief create an empty value with a given type + + Create an empty JSON value with a given type. The value will be default + initialized with an empty value which depends on the type: + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + object | `{}` + array | `[]` + binary | empty array + + @param[in] v the type of the value to create + + @complexity Constant. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows the constructor for different @ref + value_t values,basic_json__value_t} + + @sa @ref clear() -- restores the postcondition of this constructor + + @since version 1.0.0 + */ + basic_json(const value_t v) + : m_type(v), m_value(v) + { + assert_invariant(); + } + + /*! + @brief create a null object + + Create a `null` JSON value. It either takes a null pointer as parameter + (explicitly creating `null`) or no parameter (implicitly creating `null`). + The passed null pointer itself is not read -- it is only used to choose + the right constructor. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @liveexample{The following code shows the constructor with and without a + null pointer parameter.,basic_json__nullptr_t} + + @since version 1.0.0 + */ + basic_json(std::nullptr_t = nullptr) noexcept + : basic_json(value_t::null) + { + assert_invariant(); + } + + /*! + @brief create a JSON value + + This is a "catch all" constructor for all compatible JSON types; that is, + types for which a `to_json()` method exists. The constructor forwards the + parameter @a val to that method (to `json_serializer::to_json` method + with `U = uncvref_t`, to be exact). + + Template type @a CompatibleType includes, but is not limited to, the + following types: + - **arrays**: @ref array_t and all kinds of compatible containers such as + `std::vector`, `std::deque`, `std::list`, `std::forward_list`, + `std::array`, `std::valarray`, `std::set`, `std::unordered_set`, + `std::multiset`, and `std::unordered_multiset` with a `value_type` from + which a @ref basic_json value can be constructed. + - **objects**: @ref object_t and all kinds of compatible associative + containers such as `std::map`, `std::unordered_map`, `std::multimap`, + and `std::unordered_multimap` with a `key_type` compatible to + @ref string_t and a `value_type` from which a @ref basic_json value can + be constructed. + - **strings**: @ref string_t, string literals, and all compatible string + containers can be used. + - **numbers**: @ref number_integer_t, @ref number_unsigned_t, + @ref number_float_t, and all convertible number types such as `int`, + `size_t`, `int64_t`, `float` or `double` can be used. + - **boolean**: @ref boolean_t / `bool` can be used. + - **binary**: @ref binary_t / `std::vector` may be used, + unfortunately because string literals cannot be distinguished from binary + character arrays by the C++ type system, all types compatible with `const + char*` will be directed to the string constructor instead. This is both + for backwards compatibility, and due to the fact that a binary type is not + a standard JSON type. + + See the examples below. + + @tparam CompatibleType a type such that: + - @a CompatibleType is not derived from `std::istream`, + - @a CompatibleType is not @ref basic_json (to avoid hijacking copy/move + constructors), + - @a CompatibleType is not a different @ref basic_json type (i.e. with different template arguments) + - @a CompatibleType is not a @ref basic_json nested type (e.g., + @ref json_pointer, @ref iterator, etc ...) + - @ref @ref json_serializer has a + `to_json(basic_json_t&, CompatibleType&&)` method + + @tparam U = `uncvref_t` + + @param[in] val the value to be forwarded to the respective constructor + + @complexity Usually linear in the size of the passed @a val, also + depending on the implementation of the called `to_json()` + method. + + @exceptionsafety Depends on the called constructor. For types directly + supported by the library (i.e., all types for which no `to_json()` function + was provided), strong guarantee holds: if an exception is thrown, there are + no changes to any JSON value. + + @liveexample{The following code shows the constructor with several + compatible types.,basic_json__CompatibleType} + + @since version 2.1.0 + */ + template < typename CompatibleType, + typename U = detail::uncvref_t, + detail::enable_if_t < + !detail::is_basic_json::value && detail::is_compatible_type::value, int > = 0 > + basic_json(CompatibleType && val) noexcept(noexcept( + JSONSerializer::to_json(std::declval(), + std::forward(val)))) + { + JSONSerializer::to_json(*this, std::forward(val)); + assert_invariant(); + } + + /*! + @brief create a JSON value from an existing one + + This is a constructor for existing @ref basic_json types. + It does not hijack copy/move constructors, since the parameter has different + template arguments than the current ones. + + The constructor tries to convert the internal @ref m_value of the parameter. + + @tparam BasicJsonType a type such that: + - @a BasicJsonType is a @ref basic_json type. + - @a BasicJsonType has different template arguments than @ref basic_json_t. + + @param[in] val the @ref basic_json value to be converted. + + @complexity Usually linear in the size of the passed @a val, also + depending on the implementation of the called `to_json()` + method. + + @exceptionsafety Depends on the called constructor. For types directly + supported by the library (i.e., all types for which no `to_json()` function + was provided), strong guarantee holds: if an exception is thrown, there are + no changes to any JSON value. + + @since version 3.2.0 + */ + template < typename BasicJsonType, + detail::enable_if_t < + detail::is_basic_json::value&& !std::is_same::value, int > = 0 > + basic_json(const BasicJsonType& val) + { + using other_boolean_t = typename BasicJsonType::boolean_t; + using other_number_float_t = typename BasicJsonType::number_float_t; + using other_number_integer_t = typename BasicJsonType::number_integer_t; + using other_number_unsigned_t = typename BasicJsonType::number_unsigned_t; + using other_string_t = typename BasicJsonType::string_t; + using other_object_t = typename BasicJsonType::object_t; + using other_array_t = typename BasicJsonType::array_t; + using other_binary_t = typename BasicJsonType::binary_t; + + switch (val.type()) + { + case value_t::boolean: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_float: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_integer: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::number_unsigned: + JSONSerializer::to_json(*this, val.template get()); + break; + case value_t::string: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::object: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::array: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::binary: + JSONSerializer::to_json(*this, val.template get_ref()); + break; + case value_t::null: + *this = nullptr; + break; + case value_t::discarded: + m_type = value_t::discarded; + break; + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + assert_invariant(); + } + + /*! + @brief create a container (array or object) from an initializer list + + Creates a JSON value of type array or object from the passed initializer + list @a init. In case @a type_deduction is `true` (default), the type of + the JSON value to be created is deducted from the initializer list @a init + according to the following rules: + + 1. If the list is empty, an empty JSON object value `{}` is created. + 2. If the list consists of pairs whose first element is a string, a JSON + object value is created where the first elements of the pairs are + treated as keys and the second elements are as values. + 3. In all other cases, an array is created. + + The rules aim to create the best fit between a C++ initializer list and + JSON values. The rationale is as follows: + + 1. The empty initializer list is written as `{}` which is exactly an empty + JSON object. + 2. C++ has no way of describing mapped types other than to list a list of + pairs. As JSON requires that keys must be of type string, rule 2 is the + weakest constraint one can pose on initializer lists to interpret them + as an object. + 3. In all other cases, the initializer list could not be interpreted as + JSON object type, so interpreting it as JSON array type is safe. + + With the rules described above, the following JSON values cannot be + expressed by an initializer list: + + - the empty array (`[]`): use @ref array(initializer_list_t) + with an empty initializer list in this case + - arrays whose elements satisfy rule 2: use @ref + array(initializer_list_t) with the same initializer list + in this case + + @note When used without parentheses around an empty initializer list, @ref + basic_json() is called instead of this function, yielding the JSON null + value. + + @param[in] init initializer list with JSON values + + @param[in] type_deduction internal parameter; when set to `true`, the type + of the JSON value is deducted from the initializer list @a init; when set + to `false`, the type provided via @a manual_type is forced. This mode is + used by the functions @ref array(initializer_list_t) and + @ref object(initializer_list_t). + + @param[in] manual_type internal parameter; when @a type_deduction is set + to `false`, the created JSON value will use the provided type (only @ref + value_t::array and @ref value_t::object are valid); when @a type_deduction + is set to `true`, this parameter has no effect + + @throw type_error.301 if @a type_deduction is `false`, @a manual_type is + `value_t::object`, but @a init contains an element which is not a pair + whose first element is a string. In this case, the constructor could not + create an object. If @a type_deduction would have be `true`, an array + would have been created. See @ref object(initializer_list_t) + for an example. + + @complexity Linear in the size of the initializer list @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows how JSON values are created from + initializer lists.,basic_json__list_init_t} + + @sa @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + @sa @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + basic_json(initializer_list_t init, + bool type_deduction = true, + value_t manual_type = value_t::array) + { + // check if each element is an array with two elements whose first + // element is a string + bool is_an_object = std::all_of(init.begin(), init.end(), + [](const detail::json_ref& element_ref) + { + return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[0].is_string(); + }); + + // adjust type if type deduction is not wanted + if (!type_deduction) + { + // if array is wanted, do not create an object though possible + if (manual_type == value_t::array) + { + is_an_object = false; + } + + // if object is wanted but impossible, throw an exception + if (JSON_HEDLEY_UNLIKELY(manual_type == value_t::object && !is_an_object)) + { + JSON_THROW(type_error::create(301, "cannot create object from initializer list")); + } + } + + if (is_an_object) + { + // the initializer list is a list of pairs -> create object + m_type = value_t::object; + m_value = value_t::object; + + std::for_each(init.begin(), init.end(), [this](const detail::json_ref& element_ref) + { + auto element = element_ref.moved_or_copied(); + m_value.object->emplace( + std::move(*((*element.m_value.array)[0].m_value.string)), + std::move((*element.m_value.array)[1])); + }); + } + else + { + // the initializer list describes an array -> create array + m_type = value_t::array; + m_value.array = create(init.begin(), init.end()); + } + + assert_invariant(); + } + + /*! + @brief explicitly create a binary array (without subtype) + + Creates a JSON binary array value from a given binary container. Binary + values are part of various binary formats, such as CBOR, MessagePack, and + BSON. This constructor is used to create a value for serialization to those + formats. + + @note Note, this function exists because of the difficulty in correctly + specifying the correct template overload in the standard value ctor, as both + JSON arrays and JSON binary arrays are backed with some form of a + `std::vector`. Because JSON binary arrays are a non-standard extension it + was decided that it would be best to prevent automatic initialization of a + binary array type, for backwards compatibility and so it does not happen on + accident. + + @param[in] init container containing bytes to use as binary type + + @return JSON binary array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @since version 3.8.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(const typename binary_t::container_type& init) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = init; + return res; + } + + /*! + @brief explicitly create a binary array (with subtype) + + Creates a JSON binary array value from a given binary container. Binary + values are part of various binary formats, such as CBOR, MessagePack, and + BSON. This constructor is used to create a value for serialization to those + formats. + + @note Note, this function exists because of the difficulty in correctly + specifying the correct template overload in the standard value ctor, as both + JSON arrays and JSON binary arrays are backed with some form of a + `std::vector`. Because JSON binary arrays are a non-standard extension it + was decided that it would be best to prevent automatic initialization of a + binary array type, for backwards compatibility and so it does not happen on + accident. + + @param[in] init container containing bytes to use as binary type + @param[in] subtype subtype to use in MessagePack and BSON + + @return JSON binary array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @since version 3.8.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(const typename binary_t::container_type& init, std::uint8_t subtype) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = binary_t(init, subtype); + return res; + } + + /// @copydoc binary(const typename binary_t::container_type&) + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(typename binary_t::container_type&& init) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = std::move(init); + return res; + } + + /// @copydoc binary(const typename binary_t::container_type&, std::uint8_t) + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json binary(typename binary_t::container_type&& init, std::uint8_t subtype) + { + auto res = basic_json(); + res.m_type = value_t::binary; + res.m_value = binary_t(std::move(init), subtype); + return res; + } + + /*! + @brief explicitly create an array from an initializer list + + Creates a JSON array value from a given initializer list. That is, given a + list of values `a, b, c`, creates the JSON value `[a, b, c]`. If the + initializer list is empty, the empty array `[]` is created. + + @note This function is only needed to express two edge cases that cannot + be realized with the initializer list constructor (@ref + basic_json(initializer_list_t, bool, value_t)). These cases + are: + 1. creating an array whose elements are all pairs whose first element is a + string -- in this case, the initializer list constructor would create an + object, taking the first elements as keys + 2. creating an empty array -- passing the empty initializer list to the + initializer list constructor yields an empty object + + @param[in] init initializer list with JSON values to create an array from + (optional) + + @return JSON array value + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `array` + function.,array} + + @sa @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa @ref object(initializer_list_t) -- create a JSON object + value from an initializer list + + @since version 1.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json array(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::array); + } + + /*! + @brief explicitly create an object from an initializer list + + Creates a JSON object value from a given initializer list. The initializer + lists elements must be pairs, and their first elements must be strings. If + the initializer list is empty, the empty object `{}` is created. + + @note This function is only added for symmetry reasons. In contrast to the + related function @ref array(initializer_list_t), there are + no cases which can only be expressed by this function. That is, any + initializer list @a init can also be passed to the initializer list + constructor @ref basic_json(initializer_list_t, bool, value_t). + + @param[in] init initializer list to create an object from (optional) + + @return JSON object value + + @throw type_error.301 if @a init is not a list of pairs whose first + elements are strings. In this case, no object can be created. When such a + value is passed to @ref basic_json(initializer_list_t, bool, value_t), + an array would have been created from the passed initializer list @a init. + See example below. + + @complexity Linear in the size of @a init. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows an example for the `object` + function.,object} + + @sa @ref basic_json(initializer_list_t, bool, value_t) -- + create a JSON value from an initializer list + @sa @ref array(initializer_list_t) -- create a JSON array + value from an initializer list + + @since version 1.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json object(initializer_list_t init = {}) + { + return basic_json(init, false, value_t::object); + } + + /*! + @brief construct an array with count copies of given value + + Constructs a JSON array value by creating @a cnt copies of a passed value. + In case @a cnt is `0`, an empty array is created. + + @param[in] cnt the number of JSON copies of @a val to create + @param[in] val the JSON value to copy + + @post `std::distance(begin(),end()) == cnt` holds. + + @complexity Linear in @a cnt. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The following code shows examples for the @ref + basic_json(size_type\, const basic_json&) + constructor.,basic_json__size_type_basic_json} + + @since version 1.0.0 + */ + basic_json(size_type cnt, const basic_json& val) + : m_type(value_t::array) + { + m_value.array = create(cnt, val); + assert_invariant(); + } + + /*! + @brief construct a JSON container given an iterator range + + Constructs the JSON value with the contents of the range `[first, last)`. + The semantics depends on the different types a JSON value can have: + - In case of a null type, invalid_iterator.206 is thrown. + - In case of other primitive types (number, boolean, or string), @a first + must be `begin()` and @a last must be `end()`. In this case, the value is + copied. Otherwise, invalid_iterator.204 is thrown. + - In case of structured types (array, object), the constructor behaves as + similar versions for `std::vector` or `std::map`; that is, a JSON array + or object is constructed from the values in the range. + + @tparam InputIT an input iterator type (@ref iterator or @ref + const_iterator) + + @param[in] first begin of the range to copy from (included) + @param[in] last end of the range to copy from (excluded) + + @pre Iterators @a first and @a last must be initialized. **This + precondition is enforced with an assertion (see warning).** If + assertions are switched off, a violation of this precondition yields + undefined behavior. + + @pre Range `[first, last)` is valid. Usually, this precondition cannot be + checked efficiently. Only certain edge cases are detected; see the + description of the exceptions below. A violation of this precondition + yields undefined behavior. + + @warning A precondition is enforced with a runtime assertion that will + result in calling `std::abort` if this precondition is not met. + Assertions can be disabled by defining `NDEBUG` at compile time. + See https://en.cppreference.com/w/cpp/error/assert for more + information. + + @throw invalid_iterator.201 if iterators @a first and @a last are not + compatible (i.e., do not belong to the same JSON value). In this case, + the range `[first, last)` is undefined. + @throw invalid_iterator.204 if iterators @a first and @a last belong to a + primitive type (number, boolean, or string), but @a first does not point + to the first element any more. In this case, the range `[first, last)` is + undefined. See example code below. + @throw invalid_iterator.206 if iterators @a first and @a last belong to a + null value. In this case, the range `[first, last)` is undefined. + + @complexity Linear in distance between @a first and @a last. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @liveexample{The example below shows several ways to create JSON values by + specifying a subrange with iterators.,basic_json__InputIt_InputIt} + + @since version 1.0.0 + */ + template < class InputIT, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type = 0 > + basic_json(InputIT first, InputIT last) + { + JSON_ASSERT(first.m_object != nullptr); + JSON_ASSERT(last.m_object != nullptr); + + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(201, "iterators are not compatible")); + } + + // copy type from first iterator + m_type = first.m_object->m_type; + + // check if iterator range is complete for primitive values + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + { + if (JSON_HEDLEY_UNLIKELY(!first.m_it.primitive_iterator.is_begin() + || !last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range")); + } + break; + } + + default: + break; + } + + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = first.m_object->m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = first.m_object->m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value.number_float = first.m_object->m_value.number_float; + break; + } + + case value_t::boolean: + { + m_value.boolean = first.m_object->m_value.boolean; + break; + } + + case value_t::string: + { + m_value = *first.m_object->m_value.string; + break; + } + + case value_t::object: + { + m_value.object = create(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + m_value.array = create(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + case value_t::binary: + { + m_value = *first.m_object->m_value.binary; + break; + } + + default: + JSON_THROW(invalid_iterator::create(206, "cannot construct with iterators from " + + std::string(first.m_object->type_name()))); + } + + assert_invariant(); + } + + + /////////////////////////////////////// + // other constructors and destructor // + /////////////////////////////////////// + + template, + std::is_same>::value, int> = 0 > + basic_json(const JsonRef& ref) : basic_json(ref.moved_or_copied()) {} + + /*! + @brief copy constructor + + Creates a copy of a given JSON value. + + @param[in] other the JSON value to copy + + @post `*this == other` + + @complexity Linear in the size of @a other. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes to any JSON value. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + - As postcondition, it holds: `other == basic_json(other)`. + + @liveexample{The following code shows an example for the copy + constructor.,basic_json__basic_json} + + @since version 1.0.0 + */ + basic_json(const basic_json& other) + : m_type(other.m_type) + { + // check of passed value is valid + other.assert_invariant(); + + switch (m_type) + { + case value_t::object: + { + m_value = *other.m_value.object; + break; + } + + case value_t::array: + { + m_value = *other.m_value.array; + break; + } + + case value_t::string: + { + m_value = *other.m_value.string; + break; + } + + case value_t::boolean: + { + m_value = other.m_value.boolean; + break; + } + + case value_t::number_integer: + { + m_value = other.m_value.number_integer; + break; + } + + case value_t::number_unsigned: + { + m_value = other.m_value.number_unsigned; + break; + } + + case value_t::number_float: + { + m_value = other.m_value.number_float; + break; + } + + case value_t::binary: + { + m_value = *other.m_value.binary; + break; + } + + default: + break; + } + + assert_invariant(); + } + + /*! + @brief move constructor + + Move constructor. Constructs a JSON value with the contents of the given + value @a other using move semantics. It "steals" the resources from @a + other and leaves it as JSON null value. + + @param[in,out] other value to move to this object + + @post `*this` has the same value as @a other before the call. + @post @a other is a JSON null value. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this constructor never throws + exceptions. + + @requirement This function helps `basic_json` satisfying the + [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible) + requirements. + + @liveexample{The code below shows the move constructor explicitly called + via std::move.,basic_json__moveconstructor} + + @since version 1.0.0 + */ + basic_json(basic_json&& other) noexcept + : m_type(std::move(other.m_type)), + m_value(std::move(other.m_value)) + { + // check that passed value is valid + other.assert_invariant(); + + // invalidate payload + other.m_type = value_t::null; + other.m_value = {}; + + assert_invariant(); + } + + /*! + @brief copy assignment + + Copy assignment operator. Copies a JSON value via the "copy and swap" + strategy: It is expressed in terms of the copy constructor, destructor, + and the `swap()` member function. + + @param[in] other value to copy from + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + + @liveexample{The code below shows and example for the copy assignment. It + creates a copy of value `a` which is then swapped with `b`. Finally\, the + copy of `a` (which is the null value after the swap) is + destroyed.,basic_json__copyassignment} + + @since version 1.0.0 + */ + basic_json& operator=(basic_json other) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + // check that passed value is valid + other.assert_invariant(); + + using std::swap; + swap(m_type, other.m_type); + swap(m_value, other.m_value); + + assert_invariant(); + return *this; + } + + /*! + @brief destructor + + Destroys the JSON value and frees all allocated memory. + + @complexity Linear. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is linear. + - All stored elements are destroyed and all memory is freed. + + @since version 1.0.0 + */ + ~basic_json() noexcept + { + assert_invariant(); + m_value.destroy(m_type); + } + + /// @} + + public: + /////////////////////// + // object inspection // + /////////////////////// + + /// @name object inspection + /// Functions to inspect the type of a JSON value. + /// @{ + + /*! + @brief serialization + + Serialization function for JSON values. The function tries to mimic + Python's `json.dumps()` function, and currently supports its @a indent + and @a ensure_ascii parameters. + + @param[in] indent If indent is nonnegative, then array elements and object + members will be pretty-printed with that indent level. An indent level of + `0` will only insert newlines. `-1` (the default) selects the most compact + representation. + @param[in] indent_char The character to use for indentation if @a indent is + greater than `0`. The default is ` ` (space). + @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters + in the output are escaped with `\uXXXX` sequences, and the result consists + of ASCII characters only. + @param[in] error_handler how to react on decoding errors; there are three + possible values: `strict` (throws and exception in case a decoding error + occurs; default), `replace` (replace invalid UTF-8 sequences with U+FFFD), + and `ignore` (ignore invalid UTF-8 sequences during serialization; all + bytes are copied to the output unchanged). + + @return string containing the serialization of the JSON value + + @throw type_error.316 if a string stored inside the JSON value is not + UTF-8 encoded and @a error_handler is set to strict + + @note Binary values are serialized as object containing two keys: + - "bytes": an array of bytes as integers + - "subtype": the subtype as integer or "null" if the binary has no subtype + + @complexity Linear. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @liveexample{The following example shows the effect of different @a indent\, + @a indent_char\, and @a ensure_ascii parameters to the result of the + serialization.,dump} + + @see https://docs.python.org/2/library/json.html#json.dump + + @since version 1.0.0; indentation character @a indent_char, option + @a ensure_ascii and exceptions added in version 3.0.0; error + handlers added in version 3.4.0; serialization of binary values added + in version 3.8.0. + */ + string_t dump(const int indent = -1, + const char indent_char = ' ', + const bool ensure_ascii = false, + const error_handler_t error_handler = error_handler_t::strict) const + { + string_t result; + serializer s(detail::output_adapter(result), indent_char, error_handler); + + if (indent >= 0) + { + s.dump(*this, true, ensure_ascii, static_cast(indent)); + } + else + { + s.dump(*this, false, ensure_ascii, 0); + } + + return result; + } + + /*! + @brief return the type of the JSON value (explicit) + + Return the type of the JSON value as a value from the @ref value_t + enumeration. + + @return the type of the JSON value + Value type | return value + ------------------------- | ------------------------- + null | value_t::null + boolean | value_t::boolean + string | value_t::string + number (integer) | value_t::number_integer + number (unsigned integer) | value_t::number_unsigned + number (floating-point) | value_t::number_float + object | value_t::object + array | value_t::array + binary | value_t::binary + discarded | value_t::discarded + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `type()` for all JSON + types.,type} + + @sa @ref operator value_t() -- return the type of the JSON value (implicit) + @sa @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr value_t type() const noexcept + { + return m_type; + } + + /*! + @brief return whether type is primitive + + This function returns true if and only if the JSON type is primitive + (string, number, boolean, or null). + + @return `true` if type is primitive (string, number, boolean, or null), + `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_primitive()` for all JSON + types.,is_primitive} + + @sa @ref is_structured() -- returns whether JSON value is structured + @sa @ref is_null() -- returns whether JSON value is `null` + @sa @ref is_string() -- returns whether JSON value is a string + @sa @ref is_boolean() -- returns whether JSON value is a boolean + @sa @ref is_number() -- returns whether JSON value is a number + @sa @ref is_binary() -- returns whether JSON value is a binary array + + @since version 1.0.0 + */ + constexpr bool is_primitive() const noexcept + { + return is_null() || is_string() || is_boolean() || is_number() || is_binary(); + } + + /*! + @brief return whether type is structured + + This function returns true if and only if the JSON type is structured + (array or object). + + @return `true` if type is structured (array or object), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_structured()` for all JSON + types.,is_structured} + + @sa @ref is_primitive() -- returns whether value is primitive + @sa @ref is_array() -- returns whether value is an array + @sa @ref is_object() -- returns whether value is an object + + @since version 1.0.0 + */ + constexpr bool is_structured() const noexcept + { + return is_array() || is_object(); + } + + /*! + @brief return whether value is null + + This function returns true if and only if the JSON value is null. + + @return `true` if type is null, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_null()` for all JSON + types.,is_null} + + @since version 1.0.0 + */ + constexpr bool is_null() const noexcept + { + return m_type == value_t::null; + } + + /*! + @brief return whether value is a boolean + + This function returns true if and only if the JSON value is a boolean. + + @return `true` if type is boolean, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_boolean()` for all JSON + types.,is_boolean} + + @since version 1.0.0 + */ + constexpr bool is_boolean() const noexcept + { + return m_type == value_t::boolean; + } + + /*! + @brief return whether value is a number + + This function returns true if and only if the JSON value is a number. This + includes both integer (signed and unsigned) and floating-point values. + + @return `true` if type is number (regardless whether integer, unsigned + integer or floating-type), `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number()` for all JSON + types.,is_number} + + @sa @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number() const noexcept + { + return is_number_integer() || is_number_float(); + } + + /*! + @brief return whether value is an integer number + + This function returns true if and only if the JSON value is a signed or + unsigned integer number. This excludes floating-point values. + + @return `true` if type is an integer or unsigned integer number, `false` + otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_integer()` for all + JSON types.,is_number_integer} + + @sa @ref is_number() -- check if value is a number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 1.0.0 + */ + constexpr bool is_number_integer() const noexcept + { + return m_type == value_t::number_integer || m_type == value_t::number_unsigned; + } + + /*! + @brief return whether value is an unsigned integer number + + This function returns true if and only if the JSON value is an unsigned + integer number. This excludes floating-point and signed integer values. + + @return `true` if type is an unsigned integer number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_unsigned()` for all + JSON types.,is_number_unsigned} + + @sa @ref is_number() -- check if value is a number + @sa @ref is_number_integer() -- check if value is an integer or unsigned + integer number + @sa @ref is_number_float() -- check if value is a floating-point number + + @since version 2.0.0 + */ + constexpr bool is_number_unsigned() const noexcept + { + return m_type == value_t::number_unsigned; + } + + /*! + @brief return whether value is a floating-point number + + This function returns true if and only if the JSON value is a + floating-point number. This excludes signed and unsigned integer values. + + @return `true` if type is a floating-point number, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_number_float()` for all + JSON types.,is_number_float} + + @sa @ref is_number() -- check if value is number + @sa @ref is_number_integer() -- check if value is an integer number + @sa @ref is_number_unsigned() -- check if value is an unsigned integer + number + + @since version 1.0.0 + */ + constexpr bool is_number_float() const noexcept + { + return m_type == value_t::number_float; + } + + /*! + @brief return whether value is an object + + This function returns true if and only if the JSON value is an object. + + @return `true` if type is object, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_object()` for all JSON + types.,is_object} + + @since version 1.0.0 + */ + constexpr bool is_object() const noexcept + { + return m_type == value_t::object; + } + + /*! + @brief return whether value is an array + + This function returns true if and only if the JSON value is an array. + + @return `true` if type is array, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_array()` for all JSON + types.,is_array} + + @since version 1.0.0 + */ + constexpr bool is_array() const noexcept + { + return m_type == value_t::array; + } + + /*! + @brief return whether value is a string + + This function returns true if and only if the JSON value is a string. + + @return `true` if type is string, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_string()` for all JSON + types.,is_string} + + @since version 1.0.0 + */ + constexpr bool is_string() const noexcept + { + return m_type == value_t::string; + } + + /*! + @brief return whether value is a binary array + + This function returns true if and only if the JSON value is a binary array. + + @return `true` if type is binary array, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_binary()` for all JSON + types.,is_binary} + + @since version 3.8.0 + */ + constexpr bool is_binary() const noexcept + { + return m_type == value_t::binary; + } + + /*! + @brief return whether value is discarded + + This function returns true if and only if the JSON value was discarded + during parsing with a callback function (see @ref parser_callback_t). + + @note This function will always be `false` for JSON values after parsing. + That is, discarded values can only occur during parsing, but will be + removed when inside a structured value or replaced by null in other cases. + + @return `true` if type is discarded, `false` otherwise. + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies `is_discarded()` for all JSON + types.,is_discarded} + + @since version 1.0.0 + */ + constexpr bool is_discarded() const noexcept + { + return m_type == value_t::discarded; + } + + /*! + @brief return the type of the JSON value (implicit) + + Implicitly return the type of the JSON value as a value from the @ref + value_t enumeration. + + @return the type of the JSON value + + @complexity Constant. + + @exceptionsafety No-throw guarantee: this member function never throws + exceptions. + + @liveexample{The following code exemplifies the @ref value_t operator for + all JSON types.,operator__value_t} + + @sa @ref type() -- return the type of the JSON value (explicit) + @sa @ref type_name() -- return the type as string + + @since version 1.0.0 + */ + constexpr operator value_t() const noexcept + { + return m_type; + } + + /// @} + + private: + ////////////////// + // value access // + ////////////////// + + /// get a boolean (explicit) + boolean_t get_impl(boolean_t* /*unused*/) const + { + if (JSON_HEDLEY_LIKELY(is_boolean())) + { + return m_value.boolean; + } + + JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(type_name()))); + } + + /// get a pointer to the value (object) + object_t* get_impl_ptr(object_t* /*unused*/) noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (object) + constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const noexcept + { + return is_object() ? m_value.object : nullptr; + } + + /// get a pointer to the value (array) + array_t* get_impl_ptr(array_t* /*unused*/) noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (array) + constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const noexcept + { + return is_array() ? m_value.array : nullptr; + } + + /// get a pointer to the value (string) + string_t* get_impl_ptr(string_t* /*unused*/) noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (string) + constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const noexcept + { + return is_string() ? m_value.string : nullptr; + } + + /// get a pointer to the value (boolean) + boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (boolean) + constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) const noexcept + { + return is_boolean() ? &m_value.boolean : nullptr; + } + + /// get a pointer to the value (integer number) + number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (integer number) + constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /*unused*/) const noexcept + { + return is_number_integer() ? &m_value.number_integer : nullptr; + } + + /// get a pointer to the value (unsigned number) + number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (unsigned number) + constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t* /*unused*/) const noexcept + { + return is_number_unsigned() ? &m_value.number_unsigned : nullptr; + } + + /// get a pointer to the value (floating-point number) + number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /// get a pointer to the value (floating-point number) + constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unused*/) const noexcept + { + return is_number_float() ? &m_value.number_float : nullptr; + } + + /// get a pointer to the value (binary) + binary_t* get_impl_ptr(binary_t* /*unused*/) noexcept + { + return is_binary() ? m_value.binary : nullptr; + } + + /// get a pointer to the value (binary) + constexpr const binary_t* get_impl_ptr(const binary_t* /*unused*/) const noexcept + { + return is_binary() ? m_value.binary : nullptr; + } + + /*! + @brief helper function to implement get_ref() + + This function helps to implement get_ref() without code duplication for + const and non-const overloads + + @tparam ThisType will be deduced as `basic_json` or `const basic_json` + + @throw type_error.303 if ReferenceType does not match underlying value + type of the current JSON + */ + template + static ReferenceType get_ref_impl(ThisType& obj) + { + // delegate the call to get_ptr<>() + auto ptr = obj.template get_ptr::type>(); + + if (JSON_HEDLEY_LIKELY(ptr != nullptr)) + { + return *ptr; + } + + JSON_THROW(type_error::create(303, "incompatible ReferenceType for get_ref, actual type is " + std::string(obj.type_name()))); + } + + public: + /// @name value access + /// Direct access to the stored value of a JSON value. + /// @{ + + /*! + @brief get special-case overload + + This overloads avoids a lot of template boilerplate, it can be seen as the + identity method + + @tparam BasicJsonType == @ref basic_json + + @return a copy of *this + + @complexity Constant. + + @since version 2.1.0 + */ + template::type, basic_json_t>::value, + int> = 0> + basic_json get() const + { + return *this; + } + + /*! + @brief get special-case overload + + This overloads converts the current @ref basic_json in a different + @ref basic_json type + + @tparam BasicJsonType == @ref basic_json + + @return a copy of *this, converted into @tparam BasicJsonType + + @complexity Depending on the implementation of the called `from_json()` + method. + + @since version 3.2.0 + */ + template < typename BasicJsonType, detail::enable_if_t < + !std::is_same::value&& + detail::is_basic_json::value, int > = 0 > + BasicJsonType get() const + { + return *this; + } + + /*! + @brief get a value (explicit) + + Explicit type conversion between the JSON value and a compatible value + which is [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) + and [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + ValueType ret; + JSONSerializer::from_json(*this, ret); + return ret; + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json, + - @ref json_serializer has a `from_json()` method of the form + `void from_json(const basic_json&, ValueType&)`, and + - @ref json_serializer does not have a `from_json()` method of + the form `ValueType from_json(const basic_json&)` + + @tparam ValueTypeCV the provided value type + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,get__ValueType_const} + + @since version 2.1.0 + */ + template < typename ValueTypeCV, typename ValueType = detail::uncvref_t, + detail::enable_if_t < + !detail::is_basic_json::value && + detail::has_from_json::value && + !detail::has_non_default_from_json::value, + int > = 0 > + ValueType get() const noexcept(noexcept( + JSONSerializer::from_json(std::declval(), std::declval()))) + { + // we cannot static_assert on ValueTypeCV being non-const, because + // there is support for get(), which is why we + // still need the uncvref + static_assert(!std::is_reference::value, + "get() cannot be used with reference types, you might want to use get_ref()"); + static_assert(std::is_default_constructible::value, + "types must be DefaultConstructible when used with get()"); + + ValueType ret; + JSONSerializer::from_json(*this, ret); + return ret; + } + + /*! + @brief get a value (explicit); special case + + Explicit type conversion between the JSON value and a compatible value + which is **not** [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) + and **not** [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). + The value is converted by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + return JSONSerializer::from_json(*this); + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json and + - @ref json_serializer has a `from_json()` method of the form + `ValueType from_json(const basic_json&)` + + @note If @ref json_serializer has both overloads of + `from_json()`, this one is chosen. + + @tparam ValueTypeCV the provided value type + @tparam ValueType the returned value type + + @return copy of the JSON value, converted to @a ValueType + + @throw what @ref json_serializer `from_json()` method throws + + @since version 2.1.0 + */ + template < typename ValueTypeCV, typename ValueType = detail::uncvref_t, + detail::enable_if_t < !std::is_same::value && + detail::has_non_default_from_json::value, + int > = 0 > + ValueType get() const noexcept(noexcept( + JSONSerializer::from_json(std::declval()))) + { + static_assert(!std::is_reference::value, + "get() cannot be used with reference types, you might want to use get_ref()"); + return JSONSerializer::from_json(*this); + } + + /*! + @brief get a value (explicit) + + Explicit type conversion between the JSON value and a compatible value. + The value is filled into the input parameter by calling the @ref json_serializer + `from_json()` method. + + The function is equivalent to executing + @code {.cpp} + ValueType v; + JSONSerializer::from_json(*this, v); + @endcode + + This overloads is chosen if: + - @a ValueType is not @ref basic_json, + - @ref json_serializer has a `from_json()` method of the form + `void from_json(const basic_json&, ValueType&)`, and + + @tparam ValueType the input parameter type. + + @return the input parameter, allowing chaining calls. + + @throw what @ref json_serializer `from_json()` method throws + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,get_to} + + @since version 3.3.0 + */ + template < typename ValueType, + detail::enable_if_t < + !detail::is_basic_json::value&& + detail::has_from_json::value, + int > = 0 > + ValueType & get_to(ValueType& v) const noexcept(noexcept( + JSONSerializer::from_json(std::declval(), v))) + { + JSONSerializer::from_json(*this, v); + return v; + } + + // specialization to allow to call get_to with a basic_json value + // see https://github.com/nlohmann/json/issues/2175 + template::value, + int> = 0> + ValueType & get_to(ValueType& v) const + { + v = *this; + return v; + } + + template < + typename T, std::size_t N, + typename Array = T (&)[N], + detail::enable_if_t < + detail::has_from_json::value, int > = 0 > + Array get_to(T (&v)[N]) const + noexcept(noexcept(JSONSerializer::from_json( + std::declval(), v))) + { + JSONSerializer::from_json(*this, v); + return v; + } + + + /*! + @brief get a pointer value (implicit) + + Implicit pointer access to the internally stored JSON value. No copies are + made. + + @warning Writing data to the pointee of the result yields an undefined + state. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. Enforced by a static + assertion. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get_ptr} + + @since version 1.0.0 + */ + template::value, int>::type = 0> + auto get_ptr() noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) + { + // delegate the call to get_impl_ptr<>() + return get_impl_ptr(static_cast(nullptr)); + } + + /*! + @brief get a pointer value (implicit) + @copydoc get_ptr() + */ + template < typename PointerType, typename std::enable_if < + std::is_pointer::value&& + std::is_const::type>::value, int >::type = 0 > + constexpr auto get_ptr() const noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) + { + // delegate the call to get_impl_ptr<>() const + return get_impl_ptr(static_cast(nullptr)); + } + + /*! + @brief get a pointer value (explicit) + + Explicit pointer access to the internally stored JSON value. No copies are + made. + + @warning The pointer becomes invalid if the underlying JSON object + changes. + + @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref + object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, + @ref number_unsigned_t, or @ref number_float_t. + + @return pointer to the internally stored JSON value if the requested + pointer type @a PointerType fits to the JSON value; `nullptr` otherwise + + @complexity Constant. + + @liveexample{The example below shows how pointers to internal values of a + JSON value can be requested. Note that no type conversions are made and a + `nullptr` is returned if the value and the requested pointer type does not + match.,get__PointerType} + + @sa @ref get_ptr() for explicit pointer-member access + + @since version 1.0.0 + */ + template::value, int>::type = 0> + auto get() noexcept -> decltype(std::declval().template get_ptr()) + { + // delegate the call to get_ptr + return get_ptr(); + } + + /*! + @brief get a pointer value (explicit) + @copydoc get() + */ + template::value, int>::type = 0> + constexpr auto get() const noexcept -> decltype(std::declval().template get_ptr()) + { + // delegate the call to get_ptr + return get_ptr(); + } + + /*! + @brief get a reference value (implicit) + + Implicit reference access to the internally stored JSON value. No copies + are made. + + @warning Writing data to the referee of the result yields an undefined + state. + + @tparam ReferenceType reference type; must be a reference to @ref array_t, + @ref object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, or + @ref number_float_t. Enforced by static assertion. + + @return reference to the internally stored JSON value if the requested + reference type @a ReferenceType fits to the JSON value; throws + type_error.303 otherwise + + @throw type_error.303 in case passed type @a ReferenceType is incompatible + with the stored JSON value; see example below + + @complexity Constant. + + @liveexample{The example shows several calls to `get_ref()`.,get_ref} + + @since version 1.1.0 + */ + template::value, int>::type = 0> + ReferenceType get_ref() + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a reference value (implicit) + @copydoc get_ref() + */ + template < typename ReferenceType, typename std::enable_if < + std::is_reference::value&& + std::is_const::type>::value, int >::type = 0 > + ReferenceType get_ref() const + { + // delegate call to get_ref_impl + return get_ref_impl(*this); + } + + /*! + @brief get a value (implicit) + + Implicit type conversion between the JSON value and a compatible value. + The call is realized by calling @ref get() const. + + @tparam ValueType non-pointer type compatible to the JSON value, for + instance `int` for JSON integer numbers, `bool` for JSON booleans, or + `std::vector` types for JSON arrays. The character type of @ref string_t + as well as an initializer list of this type is excluded to avoid + ambiguities as these types implicitly convert to `std::string`. + + @return copy of the JSON value, converted to type @a ValueType + + @throw type_error.302 in case passed type @a ValueType is incompatible + to the JSON value type (e.g., the JSON value is of type boolean, but a + string is requested); see example below + + @complexity Linear in the size of the JSON value. + + @liveexample{The example below shows several conversions from JSON values + to other types. There a few things to note: (1) Floating-point numbers can + be converted to integers\, (2) A JSON array can be converted to a standard + `std::vector`\, (3) A JSON object can be converted to C++ + associative containers such as `std::unordered_map`.,operator__ValueType} + + @since version 1.0.0 + */ + template < typename ValueType, typename std::enable_if < + !std::is_pointer::value&& + !std::is_same>::value&& + !std::is_same::value&& + !detail::is_basic_json::value + && !std::is_same>::value +#if defined(JSON_HAS_CPP_17) && (defined(__GNUC__) || (defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER <= 1914)) + && !std::is_same::value +#endif + && detail::is_detected::value + , int >::type = 0 > + JSON_EXPLICIT operator ValueType() const + { + // delegate the call to get<>() const + return get(); + } + + /*! + @return reference to the binary value + + @throw type_error.302 if the value is not binary + + @sa @ref is_binary() to check if the value is binary + + @since version 3.8.0 + */ + binary_t& get_binary() + { + if (!is_binary()) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()))); + } + + return *get_ptr(); + } + + /// @copydoc get_binary() + const binary_t& get_binary() const + { + if (!is_binary()) + { + JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()))); + } + + return *get_ptr(); + } + + /// @} + + + //////////////////// + // element access // + //////////////////// + + /// @name element access + /// Access to the JSON value. + /// @{ + + /*! + @brief access specified array element with bounds checking + + Returns a reference to the element at specified location @a idx, with + bounds checking. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__size_type} + */ + reference at(size_type idx) + { + // at only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + JSON_TRY + { + return m_value.array->at(idx); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified array element with bounds checking + + Returns a const reference to the element at specified location @a idx, + with bounds checking. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.304 if the JSON value is not an array; in this case, + calling `at` with an index makes no sense. See example below. + @throw out_of_range.401 if the index @a idx is out of range of the array; + that is, `idx >= size()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 1.0.0 + + @liveexample{The example below shows how array elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__size_type_const} + */ + const_reference at(size_type idx) const + { + // at only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + JSON_TRY + { + return m_value.array->at(idx); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a reference to the element at with specified key @a key, with + bounds checking. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read and + written using `at()`. It also demonstrates the different exceptions that + can be thrown.,at__object_t_key_type} + */ + reference at(const typename object_t::key_type& key) + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_TRY + { + return m_value.object->at(key); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified object element with bounds checking + + Returns a const reference to the element at with specified key @a key, + with bounds checking. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @throw type_error.304 if the JSON value is not an object; in this case, + calling `at` with a key makes no sense. See example below. + @throw out_of_range.403 if the key @a key is is not stored in the object; + that is, `find(key) == end()`. See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Logarithmic in the size of the container. + + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + + @liveexample{The example below shows how object elements can be read using + `at()`. It also demonstrates the different exceptions that can be thrown., + at__object_t_key_type_const} + */ + const_reference at(const typename object_t::key_type& key) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_TRY + { + return m_value.object->at(key); + } + JSON_CATCH (std::out_of_range&) + { + // create better exception explanation + JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); + } + } + else + { + JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); + } + } + + /*! + @brief access specified array element + + Returns a reference to the element at specified location @a idx. + + @note If @a idx is beyond the range of the array (i.e., `idx >= size()`), + then the array is silently filled up with `null` values to make `idx` a + valid reference to the last stored element. + + @param[in] idx index of the element to access + + @return reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array or null; in that + cases, using the [] operator with an index makes no sense. + + @complexity Constant if @a idx is in the range of the array. Otherwise + linear in `idx - size()`. + + @liveexample{The example below shows how array elements can be read and + written using `[]` operator. Note the addition of `null` + values.,operatorarray__size_type} + + @since version 1.0.0 + */ + reference operator[](size_type idx) + { + // implicitly convert null value to an empty array + if (is_null()) + { + m_type = value_t::array; + m_value.array = create(); + assert_invariant(); + } + + // operator[] only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // fill up array with null values if given idx is outside range + if (idx >= m_value.array->size()) + { + m_value.array->insert(m_value.array->end(), + idx - m_value.array->size() + 1, + basic_json()); + } + + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()))); + } + + /*! + @brief access specified array element + + Returns a const reference to the element at specified location @a idx. + + @param[in] idx index of the element to access + + @return const reference to the element at index @a idx + + @throw type_error.305 if the JSON value is not an array; in that case, + using the [] operator with an index makes no sense. + + @complexity Constant. + + @liveexample{The example below shows how array elements can be read using + the `[]` operator.,operatorarray__size_type_const} + + @since version 1.0.0 + */ + const_reference operator[](size_type idx) const + { + // const operator[] only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + return m_value.array->operator[](idx); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()))); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + */ + reference operator[](const typename object_t::key_type& key) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + // operator[] only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return m_value.object->operator[](key); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that case, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.0.0 + */ + const_reference operator[](const typename object_t::key_type& key) const + { + // const operator[] only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); + } + + /*! + @brief access specified object element + + Returns a reference to the element at with specified key @a key. + + @note If @a key is not found in the object, then it is silently added to + the object and filled with a `null` value to make `key` a valid reference. + In case the value was `null` before, it is converted to an object. + + @param[in] key key of the element to access + + @return reference to the element at key @a key + + @throw type_error.305 if the JSON value is not an object or null; in that + cases, using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read and + written using the `[]` operator.,operatorarray__key_type} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + reference operator[](T* key) + { + // implicitly convert null to object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return m_value.object->operator[](key); + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); + } + + /*! + @brief read-only access specified object element + + Returns a const reference to the element at with specified key @a key. No + bounds checking is performed. + + @warning If the element with key @a key does not exist, the behavior is + undefined. + + @param[in] key key of the element to access + + @return const reference to the element at key @a key + + @pre The element with key @a key must exist. **This precondition is + enforced with an assertion.** + + @throw type_error.305 if the JSON value is not an object; in that case, + using the [] operator with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be read using + the `[]` operator.,operatorarray__key_type_const} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref value() for access by value with a default value + + @since version 1.1.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + const_reference operator[](T* key) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); + return m_value.object->find(key)->second; + } + + JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); + } + + /*! + @brief access specified object element with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(key); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const typename object_t::key_type&), this function + does not throw if the given key @a key was not found. + + @note Unlike @ref operator[](const typename object_t::key_type& key), this + function does not implicitly add an element to the position defined by @a + key. This function is furthermore also applicable to const objects. + + @param[in] key key of the element to access + @param[in] default_value the value to return if @a key is not found + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.302 if @a default_value does not match the type of the + value at @a key + @throw type_error.306 if the JSON value is not an object; in that case, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value} + + @sa @ref at(const typename object_t::key_type&) for access by reference + with range checking + @sa @ref operator[](const typename object_t::key_type&) for unchecked + access by reference + + @since version 1.0.0 + */ + // using std::is_convertible in a std::enable_if will fail when using explicit conversions + template < class ValueType, typename std::enable_if < + detail::is_getable::value + && !std::is_same::value, int >::type = 0 > + ValueType value(const typename object_t::key_type& key, const ValueType& default_value) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + // if key is found, return value and given default value otherwise + const auto it = find(key); + if (it != end()) + { + return it->template get(); + } + + return default_value; + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const typename object_t::key_type&, const ValueType&) const + */ + string_t value(const typename object_t::key_type& key, const char* default_value) const + { + return value(key, string_t(default_value)); + } + + /*! + @brief access specified object element via JSON Pointer with default value + + Returns either a copy of an object's element at the specified key @a key + or a given default value if no element with key @a key exists. + + The function is basically equivalent to executing + @code {.cpp} + try { + return at(ptr); + } catch(out_of_range) { + return default_value; + } + @endcode + + @note Unlike @ref at(const json_pointer&), this function does not throw + if the given key @a key was not found. + + @param[in] ptr a JSON pointer to the element to access + @param[in] default_value the value to return if @a ptr found no value + + @tparam ValueType type compatible to JSON values, for instance `int` for + JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for + JSON arrays. Note the type of the expected value at @a key and the default + value @a default_value must be compatible. + + @return copy of the element at key @a key or @a default_value if @a key + is not found + + @throw type_error.302 if @a default_value does not match the type of the + value at @a ptr + @throw type_error.306 if the JSON value is not an object; in that case, + using `value()` with a key makes no sense. + + @complexity Logarithmic in the size of the container. + + @liveexample{The example below shows how object elements can be queried + with a default value.,basic_json__value_ptr} + + @sa @ref operator[](const json_pointer&) for unchecked access by reference + + @since version 2.0.2 + */ + template::value, int>::type = 0> + ValueType value(const json_pointer& ptr, const ValueType& default_value) const + { + // at only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + // if pointer resolves a value, return it or use default value + JSON_TRY + { + return ptr.get_checked(this).template get(); + } + JSON_INTERNAL_CATCH (out_of_range&) + { + return default_value; + } + } + + JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); + } + + /*! + @brief overload for a default value of type const char* + @copydoc basic_json::value(const json_pointer&, ValueType) const + */ + JSON_HEDLEY_NON_NULL(3) + string_t value(const json_pointer& ptr, const char* default_value) const + { + return value(ptr, string_t(default_value)); + } + + /*! + @brief access the first element + + Returns a reference to the first element in the container. For a JSON + container `c`, the expression `c.front()` is equivalent to `*c.begin()`. + + @return In case of a structured type (array or object), a reference to the + first element is returned. In case of number, string, boolean, or binary + values, a reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on `null` value + + @liveexample{The following code shows an example for `front()`.,front} + + @sa @ref back() -- access the last element + + @since version 1.0.0 + */ + reference front() + { + return *begin(); + } + + /*! + @copydoc basic_json::front() + */ + const_reference front() const + { + return *cbegin(); + } + + /*! + @brief access the last element + + Returns a reference to the last element in the container. For a JSON + container `c`, the expression `c.back()` is equivalent to + @code {.cpp} + auto tmp = c.end(); + --tmp; + return *tmp; + @endcode + + @return In case of a structured type (array or object), a reference to the + last element is returned. In case of number, string, boolean, or binary + values, a reference to the value is returned. + + @complexity Constant. + + @pre The JSON value must not be `null` (would throw `std::out_of_range`) + or an empty array or object (undefined behavior, **guarded by + assertions**). + @post The JSON value remains unchanged. + + @throw invalid_iterator.214 when called on a `null` value. See example + below. + + @liveexample{The following code shows an example for `back()`.,back} + + @sa @ref front() -- access the first element + + @since version 1.0.0 + */ + reference back() + { + auto tmp = end(); + --tmp; + return *tmp; + } + + /*! + @copydoc basic_json::back() + */ + const_reference back() const + { + auto tmp = cend(); + --tmp; + return *tmp; + } + + /*! + @brief remove element given an iterator + + Removes the element specified by iterator @a pos. The iterator @a pos must + be valid and dereferenceable. Thus the `end()` iterator (which is valid, + but is not dereferenceable) cannot be used as a value for @a pos. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] pos iterator to the element to remove + @return Iterator following the last removed element. If the iterator @a + pos refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.202 if called on an iterator which does not belong + to the current JSON value; example: `"iterator does not fit current + value"` + @throw invalid_iterator.205 if called on a primitive type with invalid + iterator (i.e., any iterator which is not `begin()`); example: `"iterator + out of range"` + + @complexity The complexity depends on the type: + - objects: amortized constant + - arrays: linear in distance between @a pos and the end of the container + - strings and binary: linear in the length of the member + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType} + + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template < class IteratorType, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type + = 0 > + IteratorType erase(IteratorType pos) + { + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(this != pos.m_object)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + case value_t::binary: + { + if (JSON_HEDLEY_UNLIKELY(!pos.m_it.primitive_iterator.is_begin())) + { + JSON_THROW(invalid_iterator::create(205, "iterator out of range")); + } + + if (is_string()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.string); + std::allocator_traits::deallocate(alloc, m_value.string, 1); + m_value.string = nullptr; + } + else if (is_binary()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.binary); + std::allocator_traits::deallocate(alloc, m_value.binary, 1); + m_value.binary = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(pos.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(pos.m_it.array_iterator); + break; + } + + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + return result; + } + + /*! + @brief remove elements given an iterator range + + Removes the element specified by the range `[first; last)`. The iterator + @a first does not need to be dereferenceable if `first == last`: erasing + an empty range is a no-op. + + If called on a primitive type other than `null`, the resulting JSON value + will be `null`. + + @param[in] first iterator to the beginning of the range to remove + @param[in] last iterator past the end of the range to remove + @return Iterator following the last removed element. If the iterator @a + second refers to the last element, the `end()` iterator is returned. + + @tparam IteratorType an @ref iterator or @ref const_iterator + + @post Invalidates iterators and references at or after the point of the + erase, including the `end()` iterator. + + @throw type_error.307 if called on a `null` value; example: `"cannot use + erase() with null"` + @throw invalid_iterator.203 if called on iterators which does not belong + to the current JSON value; example: `"iterators do not fit current value"` + @throw invalid_iterator.204 if called on a primitive type with invalid + iterators (i.e., if `first != begin()` and `last != end()`); example: + `"iterators out of range"` + + @complexity The complexity depends on the type: + - objects: `log(size()) + std::distance(first, last)` + - arrays: linear in the distance between @a first and @a last, plus linear + in the distance between @a last and end of the container + - strings and binary: linear in the length of the member + - other types: constant + + @liveexample{The example shows the result of `erase()` for different JSON + types.,erase__IteratorType_IteratorType} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + template < class IteratorType, typename std::enable_if < + std::is_same::value || + std::is_same::value, int >::type + = 0 > + IteratorType erase(IteratorType first, IteratorType last) + { + // make sure iterator fits the current value + if (JSON_HEDLEY_UNLIKELY(this != first.m_object || this != last.m_object)) + { + JSON_THROW(invalid_iterator::create(203, "iterators do not fit current value")); + } + + IteratorType result = end(); + + switch (m_type) + { + case value_t::boolean: + case value_t::number_float: + case value_t::number_integer: + case value_t::number_unsigned: + case value_t::string: + case value_t::binary: + { + if (JSON_HEDLEY_LIKELY(!first.m_it.primitive_iterator.is_begin() + || !last.m_it.primitive_iterator.is_end())) + { + JSON_THROW(invalid_iterator::create(204, "iterators out of range")); + } + + if (is_string()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.string); + std::allocator_traits::deallocate(alloc, m_value.string, 1); + m_value.string = nullptr; + } + else if (is_binary()) + { + AllocatorType alloc; + std::allocator_traits::destroy(alloc, m_value.binary); + std::allocator_traits::deallocate(alloc, m_value.binary, 1); + m_value.binary = nullptr; + } + + m_type = value_t::null; + assert_invariant(); + break; + } + + case value_t::object: + { + result.m_it.object_iterator = m_value.object->erase(first.m_it.object_iterator, + last.m_it.object_iterator); + break; + } + + case value_t::array: + { + result.m_it.array_iterator = m_value.array->erase(first.m_it.array_iterator, + last.m_it.array_iterator); + break; + } + + default: + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + return result; + } + + /*! + @brief remove element from a JSON object given a key + + Removes elements from a JSON object with the key value @a key. + + @param[in] key value of the elements to remove + + @return Number of elements removed. If @a ObjectType is the default + `std::map` type, the return value will always be `0` (@a key was not + found) or `1` (@a key was found). + + @post References and iterators to the erased elements are invalidated. + Other references and iterators are not affected. + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + + @complexity `log(size()) + count(key)` + + @liveexample{The example shows the effect of `erase()`.,erase__key_type} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const size_type) -- removes the element from an array at + the given index + + @since version 1.0.0 + */ + size_type erase(const typename object_t::key_type& key) + { + // this erase only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + return m_value.object->erase(key); + } + + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + + /*! + @brief remove element from a JSON array given an index + + Removes element from a JSON array at the index @a idx. + + @param[in] idx index of the element to remove + + @throw type_error.307 when called on a type other than JSON object; + example: `"cannot use erase() with null"` + @throw out_of_range.401 when `idx >= size()`; example: `"array index 17 + is out of range"` + + @complexity Linear in distance between @a idx and the end of the container. + + @liveexample{The example shows the effect of `erase()`.,erase__size_type} + + @sa @ref erase(IteratorType) -- removes the element at a given position + @sa @ref erase(IteratorType, IteratorType) -- removes the elements in + the given range + @sa @ref erase(const typename object_t::key_type&) -- removes the element + from an object at the given key + + @since version 1.0.0 + */ + void erase(const size_type idx) + { + // this erase only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + if (JSON_HEDLEY_UNLIKELY(idx >= size())) + { + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + + m_value.array->erase(m_value.array->begin() + static_cast(idx)); + } + else + { + JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); + } + } + + /// @} + + + //////////// + // lookup // + //////////// + + /// @name lookup + /// @{ + + /*! + @brief find an element in a JSON object + + Finds an element in a JSON object with key equivalent to @a key. If the + element is not found or the JSON value is not an object, end() is + returned. + + @note This method always returns @ref end() when executed on a JSON type + that is not an object. + + @param[in] key key value of the element to search for. + + @return Iterator to an element with key equivalent to @a key. If no such + element is found or the JSON value is not an object, past-the-end (see + @ref end()) iterator is returned. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `find()` is used.,find__key_type} + + @sa @ref contains(KeyT&&) const -- checks whether a key exists + + @since version 1.0.0 + */ + template + iterator find(KeyT&& key) + { + auto result = end(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(std::forward(key)); + } + + return result; + } + + /*! + @brief find an element in a JSON object + @copydoc find(KeyT&&) + */ + template + const_iterator find(KeyT&& key) const + { + auto result = cend(); + + if (is_object()) + { + result.m_it.object_iterator = m_value.object->find(std::forward(key)); + } + + return result; + } + + /*! + @brief returns the number of occurrences of a key in a JSON object + + Returns the number of elements with key @a key. If ObjectType is the + default `std::map` type, the return value will always be `0` (@a key was + not found) or `1` (@a key was found). + + @note This method always returns `0` when executed on a JSON type that is + not an object. + + @param[in] key key value of the element to count + + @return Number of elements with key @a key. If the JSON value is not an + object, the return value will be `0`. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The example shows how `count()` is used.,count} + + @since version 1.0.0 + */ + template + size_type count(KeyT&& key) const + { + // return 0 for all nonobject types + return is_object() ? m_value.object->count(std::forward(key)) : 0; + } + + /*! + @brief check the existence of an element in a JSON object + + Check whether an element exists in a JSON object with key equivalent to + @a key. If the element is not found or the JSON value is not an object, + false is returned. + + @note This method always returns false when executed on a JSON type + that is not an object. + + @param[in] key key value to check its existence. + + @return true if an element with specified @a key exists. If no such + element with such key is found or the JSON value is not an object, + false is returned. + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The following code shows an example for `contains()`.,contains} + + @sa @ref find(KeyT&&) -- returns an iterator to an object element + @sa @ref contains(const json_pointer&) const -- checks the existence for a JSON pointer + + @since version 3.6.0 + */ + template < typename KeyT, typename std::enable_if < + !std::is_same::type, json_pointer>::value, int >::type = 0 > + bool contains(KeyT && key) const + { + return is_object() && m_value.object->find(std::forward(key)) != m_value.object->end(); + } + + /*! + @brief check the existence of an element in a JSON object given a JSON pointer + + Check whether the given JSON pointer @a ptr can be resolved in the current + JSON value. + + @note This method can be executed on any JSON value type. + + @param[in] ptr JSON pointer to check its existence. + + @return true if the JSON pointer can be resolved to a stored value, false + otherwise. + + @post If `j.contains(ptr)` returns true, it is safe to call `j[ptr]`. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + + @complexity Logarithmic in the size of the JSON object. + + @liveexample{The following code shows an example for `contains()`.,contains_json_pointer} + + @sa @ref contains(KeyT &&) const -- checks the existence of a key + + @since version 3.7.0 + */ + bool contains(const json_pointer& ptr) const + { + return ptr.contains(this); + } + + /// @} + + + /////////////// + // iterators // + /////////////// + + /// @name iterators + /// @{ + + /*! + @brief returns an iterator to the first element + + Returns an iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `begin()`.,begin} + + @sa @ref cbegin() -- returns a const iterator to the beginning + @sa @ref end() -- returns an iterator to the end + @sa @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + iterator begin() noexcept + { + iterator result(this); + result.set_begin(); + return result; + } + + /*! + @copydoc basic_json::cbegin() + */ + const_iterator begin() const noexcept + { + return cbegin(); + } + + /*! + @brief returns a const iterator to the first element + + Returns a const iterator to the first element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator to the first element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).begin()`. + + @liveexample{The following code shows an example for `cbegin()`.,cbegin} + + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref end() -- returns an iterator to the end + @sa @ref cend() -- returns a const iterator to the end + + @since version 1.0.0 + */ + const_iterator cbegin() const noexcept + { + const_iterator result(this); + result.set_begin(); + return result; + } + + /*! + @brief returns an iterator to one past the last element + + Returns an iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + + @liveexample{The following code shows an example for `end()`.,end} + + @sa @ref cend() -- returns a const iterator to the end + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + iterator end() noexcept + { + iterator result(this); + result.set_end(); + return result; + } + + /*! + @copydoc basic_json::cend() + */ + const_iterator end() const noexcept + { + return cend(); + } + + /*! + @brief returns a const iterator to one past the last element + + Returns a const iterator to one past the last element. + + @image html range-begin-end.svg "Illustration from cppreference.com" + + @return const iterator one past the last element + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).end()`. + + @liveexample{The following code shows an example for `cend()`.,cend} + + @sa @ref end() -- returns an iterator to the end + @sa @ref begin() -- returns an iterator to the beginning + @sa @ref cbegin() -- returns a const iterator to the beginning + + @since version 1.0.0 + */ + const_iterator cend() const noexcept + { + const_iterator result(this); + result.set_end(); + return result; + } + + /*! + @brief returns an iterator to the reverse-beginning + + Returns an iterator to the reverse-beginning; that is, the last element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(end())`. + + @liveexample{The following code shows an example for `rbegin()`.,rbegin} + + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + reverse_iterator rbegin() noexcept + { + return reverse_iterator(end()); + } + + /*! + @copydoc basic_json::crbegin() + */ + const_reverse_iterator rbegin() const noexcept + { + return crbegin(); + } + + /*! + @brief returns an iterator to the reverse-end + + Returns an iterator to the reverse-end; that is, one before the first + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `reverse_iterator(begin())`. + + @liveexample{The following code shows an example for `rend()`.,rend} + + @sa @ref crend() -- returns a const reverse iterator to the end + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + reverse_iterator rend() noexcept + { + return reverse_iterator(begin()); + } + + /*! + @copydoc basic_json::crend() + */ + const_reverse_iterator rend() const noexcept + { + return crend(); + } + + /*! + @brief returns a const reverse iterator to the last element + + Returns a const iterator to the reverse-beginning; that is, the last + element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rbegin()`. + + @liveexample{The following code shows an example for `crbegin()`.,crbegin} + + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref crend() -- returns a const reverse iterator to the end + + @since version 1.0.0 + */ + const_reverse_iterator crbegin() const noexcept + { + return const_reverse_iterator(cend()); + } + + /*! + @brief returns a const reverse iterator to one before the first + + Returns a const reverse iterator to the reverse-end; that is, one before + the first element. + + @image html range-rbegin-rend.svg "Illustration from cppreference.com" + + @complexity Constant. + + @requirement This function helps `basic_json` satisfying the + [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) + requirements: + - The complexity is constant. + - Has the semantics of `const_cast(*this).rend()`. + + @liveexample{The following code shows an example for `crend()`.,crend} + + @sa @ref rend() -- returns a reverse iterator to the end + @sa @ref rbegin() -- returns a reverse iterator to the beginning + @sa @ref crbegin() -- returns a const reverse iterator to the beginning + + @since version 1.0.0 + */ + const_reverse_iterator crend() const noexcept + { + return const_reverse_iterator(cbegin()); + } + + public: + /*! + @brief wrapper to access iterator member functions in range-based for + + This function allows to access @ref iterator::key() and @ref + iterator::value() during range-based for loops. In these loops, a + reference to the JSON values is returned, so there is no access to the + underlying iterator. + + For loop without iterator_wrapper: + + @code{cpp} + for (auto it = j_object.begin(); it != j_object.end(); ++it) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + Range-based for loop without iterator proxy: + + @code{cpp} + for (auto it : j_object) + { + // "it" is of type json::reference and has no key() member + std::cout << "value: " << it << '\n'; + } + @endcode + + Range-based for loop with iterator proxy: + + @code{cpp} + for (auto it : json::iterator_wrapper(j_object)) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + @note When iterating over an array, `key()` will return the index of the + element as string (see example). + + @param[in] ref reference to a JSON value + @return iteration proxy object wrapping @a ref with an interface to use in + range-based for loops + + @liveexample{The following code shows how the wrapper is used,iterator_wrapper} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @note The name of this function is not yet final and may change in the + future. + + @deprecated This stream operator is deprecated and will be removed in + future 4.0.0 of the library. Please use @ref items() instead; + that is, replace `json::iterator_wrapper(j)` with `j.items()`. + */ + JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) + static iteration_proxy iterator_wrapper(reference ref) noexcept + { + return ref.items(); + } + + /*! + @copydoc iterator_wrapper(reference) + */ + JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) + static iteration_proxy iterator_wrapper(const_reference ref) noexcept + { + return ref.items(); + } + + /*! + @brief helper to access iterator member functions in range-based for + + This function allows to access @ref iterator::key() and @ref + iterator::value() during range-based for loops. In these loops, a + reference to the JSON values is returned, so there is no access to the + underlying iterator. + + For loop without `items()` function: + + @code{cpp} + for (auto it = j_object.begin(); it != j_object.end(); ++it) + { + std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; + } + @endcode + + Range-based for loop without `items()` function: + + @code{cpp} + for (auto it : j_object) + { + // "it" is of type json::reference and has no key() member + std::cout << "value: " << it << '\n'; + } + @endcode + + Range-based for loop with `items()` function: + + @code{cpp} + for (auto& el : j_object.items()) + { + std::cout << "key: " << el.key() << ", value:" << el.value() << '\n'; + } + @endcode + + The `items()` function also allows to use + [structured bindings](https://en.cppreference.com/w/cpp/language/structured_binding) + (C++17): + + @code{cpp} + for (auto& [key, val] : j_object.items()) + { + std::cout << "key: " << key << ", value:" << val << '\n'; + } + @endcode + + @note When iterating over an array, `key()` will return the index of the + element as string (see example). For primitive types (e.g., numbers), + `key()` returns an empty string. + + @warning Using `items()` on temporary objects is dangerous. Make sure the + object's lifetime exeeds the iteration. See + for more + information. + + @return iteration proxy object wrapping @a ref with an interface to use in + range-based for loops + + @liveexample{The following code shows how the function is used.,items} + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 3.1.0, structured bindings support since 3.5.0. + */ + iteration_proxy items() noexcept + { + return iteration_proxy(*this); + } + + /*! + @copydoc items() + */ + iteration_proxy items() const noexcept + { + return iteration_proxy(*this); + } + + /// @} + + + ////////////// + // capacity // + ////////////// + + /// @name capacity + /// @{ + + /*! + @brief checks whether the container is empty. + + Checks if a JSON value has no elements (i.e. whether its @ref size is `0`). + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `true` + boolean | `false` + string | `false` + number | `false` + binary | `false` + object | result of function `object_t::empty()` + array | result of function `array_t::empty()` + + @liveexample{The following code uses `empty()` to check if a JSON + object contains any elements.,empty} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `empty()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return whether a string stored as JSON value + is empty - it returns whether the JSON container itself is empty which is + false in the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `begin() == end()`. + + @sa @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + bool empty() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return true; + } + + case value_t::array: + { + // delegate call to array_t::empty() + return m_value.array->empty(); + } + + case value_t::object: + { + // delegate call to object_t::empty() + return m_value.object->empty(); + } + + default: + { + // all other types are nonempty + return false; + } + } + } + + /*! + @brief returns the number of elements + + Returns the number of elements in a JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` + boolean | `1` + string | `1` + number | `1` + binary | `1` + object | result of function object_t::size() + array | result of function array_t::size() + + @liveexample{The following code calls `size()` on the different value + types.,size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their size() functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @note This function does not return the length of a string stored as JSON + value - it returns the number of elements in the JSON value which is 1 in + the case of a string. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of `std::distance(begin(), end())`. + + @sa @ref empty() -- checks whether the container is empty + @sa @ref max_size() -- returns the maximal number of elements + + @since version 1.0.0 + */ + size_type size() const noexcept + { + switch (m_type) + { + case value_t::null: + { + // null values are empty + return 0; + } + + case value_t::array: + { + // delegate call to array_t::size() + return m_value.array->size(); + } + + case value_t::object: + { + // delegate call to object_t::size() + return m_value.object->size(); + } + + default: + { + // all other types have size 1 + return 1; + } + } + } + + /*! + @brief returns the maximum possible number of elements + + Returns the maximum number of elements a JSON value is able to hold due to + system or library implementation limitations, i.e. `std::distance(begin(), + end())` for the JSON value. + + @return The return value depends on the different types and is + defined as follows: + Value type | return value + ----------- | ------------- + null | `0` (same as `size()`) + boolean | `1` (same as `size()`) + string | `1` (same as `size()`) + number | `1` (same as `size()`) + binary | `1` (same as `size()`) + object | result of function `object_t::max_size()` + array | result of function `array_t::max_size()` + + @liveexample{The following code calls `max_size()` on the different value + types. Note the output is implementation specific.,max_size} + + @complexity Constant, as long as @ref array_t and @ref object_t satisfy + the Container concept; that is, their `max_size()` functions have constant + complexity. + + @iterators No changes. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @requirement This function helps `basic_json` satisfying the + [Container](https://en.cppreference.com/w/cpp/named_req/Container) + requirements: + - The complexity is constant. + - Has the semantics of returning `b.size()` where `b` is the largest + possible JSON value. + + @sa @ref size() -- returns the number of elements + + @since version 1.0.0 + */ + size_type max_size() const noexcept + { + switch (m_type) + { + case value_t::array: + { + // delegate call to array_t::max_size() + return m_value.array->max_size(); + } + + case value_t::object: + { + // delegate call to object_t::max_size() + return m_value.object->max_size(); + } + + default: + { + // all other types have max_size() == size() + return size(); + } + } + } + + /// @} + + + /////////////// + // modifiers // + /////////////// + + /// @name modifiers + /// @{ + + /*! + @brief clears the contents + + Clears the content of a JSON value and resets it to the default value as + if @ref basic_json(value_t) would have been called with the current value + type from @ref type(): + + Value type | initial value + ----------- | ------------- + null | `null` + boolean | `false` + string | `""` + number | `0` + binary | An empty byte vector + object | `{}` + array | `[]` + + @post Has the same effect as calling + @code {.cpp} + *this = basic_json(type()); + @endcode + + @liveexample{The example below shows the effect of `clear()` to different + JSON types.,clear} + + @complexity Linear in the size of the JSON value. + + @iterators All iterators, pointers and references related to this container + are invalidated. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @sa @ref basic_json(value_t) -- constructor that creates an object with the + same value than calling `clear()` + + @since version 1.0.0 + */ + void clear() noexcept + { + switch (m_type) + { + case value_t::number_integer: + { + m_value.number_integer = 0; + break; + } + + case value_t::number_unsigned: + { + m_value.number_unsigned = 0; + break; + } + + case value_t::number_float: + { + m_value.number_float = 0.0; + break; + } + + case value_t::boolean: + { + m_value.boolean = false; + break; + } + + case value_t::string: + { + m_value.string->clear(); + break; + } + + case value_t::binary: + { + m_value.binary->clear(); + break; + } + + case value_t::array: + { + m_value.array->clear(); + break; + } + + case value_t::object: + { + m_value.object->clear(); + break; + } + + default: + break; + } + } + + /*! + @brief add an object to an array + + Appends the given element @a val to the end of the JSON value. If the + function is called on a JSON null value, an empty array is created before + appending @a val. + + @param[in] val the value to add to the JSON array + + @throw type_error.308 when called on a type other than JSON array or + null; example: `"cannot use push_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON array. Note how the `null` value was silently + converted to a JSON array.,push_back} + + @since version 1.0.0 + */ + void push_back(basic_json&& val) + { + // push_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (move semantics) + m_value.array->push_back(std::move(val)); + // if val is moved from, basic_json move constructor marks it null so we do not call the destructor + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(basic_json&& val) + { + push_back(std::move(val)); + return *this; + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + void push_back(const basic_json& val) + { + // push_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array + m_value.array->push_back(val); + } + + /*! + @brief add an object to an array + @copydoc push_back(basic_json&&) + */ + reference operator+=(const basic_json& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + Inserts the given element @a val to the JSON object. If the function is + called on a JSON null value, an empty object is created before inserting + @a val. + + @param[in] val the value to add to the JSON object + + @throw type_error.308 when called on a type other than JSON object or + null; example: `"cannot use push_back() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `push_back()` and `+=` can be used to + add elements to a JSON object. Note how the `null` value was silently + converted to a JSON object.,push_back__object_t__value} + + @since version 1.0.0 + */ + void push_back(const typename object_t::value_type& val) + { + // push_back only works for null objects or objects + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) + { + JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to array + m_value.object->insert(val); + } + + /*! + @brief add an object to an object + @copydoc push_back(const typename object_t::value_type&) + */ + reference operator+=(const typename object_t::value_type& val) + { + push_back(val); + return *this; + } + + /*! + @brief add an object to an object + + This function allows to use `push_back` with an initializer list. In case + + 1. the current value is an object, + 2. the initializer list @a init contains only two elements, and + 3. the first element of @a init is a string, + + @a init is converted into an object element and added using + @ref push_back(const typename object_t::value_type&). Otherwise, @a init + is converted to a JSON value and added using @ref push_back(basic_json&&). + + @param[in] init an initializer list + + @complexity Linear in the size of the initializer list @a init. + + @note This function is required to resolve an ambiguous overload error, + because pairs like `{"key", "value"}` can be both interpreted as + `object_t::value_type` or `std::initializer_list`, see + https://github.com/nlohmann/json/issues/235 for more information. + + @liveexample{The example shows how initializer lists are treated as + objects when possible.,push_back__initializer_list} + */ + void push_back(initializer_list_t init) + { + if (is_object() && init.size() == 2 && (*init.begin())->is_string()) + { + basic_json&& key = init.begin()->moved_or_copied(); + push_back(typename object_t::value_type( + std::move(key.get_ref()), (init.begin() + 1)->moved_or_copied())); + } + else + { + push_back(basic_json(init)); + } + } + + /*! + @brief add an object to an object + @copydoc push_back(initializer_list_t) + */ + reference operator+=(initializer_list_t init) + { + push_back(init); + return *this; + } + + /*! + @brief add an object to an array + + Creates a JSON value from the passed parameters @a args to the end of the + JSON value. If the function is called on a JSON null value, an empty array + is created before appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @return reference to the inserted element + + @throw type_error.311 when called on a type other than JSON array or + null; example: `"cannot use emplace_back() with number"` + + @complexity Amortized constant. + + @liveexample{The example shows how `push_back()` can be used to add + elements to a JSON array. Note how the `null` value was silently converted + to a JSON array.,emplace_back} + + @since version 2.0.8, returns reference since 3.7.0 + */ + template + reference emplace_back(Args&& ... args) + { + // emplace_back only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace_back() with " + std::string(type_name()))); + } + + // transform null object into an array + if (is_null()) + { + m_type = value_t::array; + m_value = value_t::array; + assert_invariant(); + } + + // add element to array (perfect forwarding) +#ifdef JSON_HAS_CPP_17 + return m_value.array->emplace_back(std::forward(args)...); +#else + m_value.array->emplace_back(std::forward(args)...); + return m_value.array->back(); +#endif + } + + /*! + @brief add an object to an object if key does not exist + + Inserts a new element into a JSON object constructed in-place with the + given @a args if there is no element with the key in the container. If the + function is called on a JSON null value, an empty object is created before + appending the value created from @a args. + + @param[in] args arguments to forward to a constructor of @ref basic_json + @tparam Args compatible types to create a @ref basic_json object + + @return a pair consisting of an iterator to the inserted element, or the + already-existing element if no insertion happened, and a bool + denoting whether the insertion took place. + + @throw type_error.311 when called on a type other than JSON object or + null; example: `"cannot use emplace() with number"` + + @complexity Logarithmic in the size of the container, O(log(`size()`)). + + @liveexample{The example shows how `emplace()` can be used to add elements + to a JSON object. Note how the `null` value was silently converted to a + JSON object. Further note how no value is added if there was already one + value stored with the same key.,emplace} + + @since version 2.0.8 + */ + template + std::pair emplace(Args&& ... args) + { + // emplace only works for null objects or arrays + if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) + { + JSON_THROW(type_error::create(311, "cannot use emplace() with " + std::string(type_name()))); + } + + // transform null object into an object + if (is_null()) + { + m_type = value_t::object; + m_value = value_t::object; + assert_invariant(); + } + + // add element to array (perfect forwarding) + auto res = m_value.object->emplace(std::forward(args)...); + // create result iterator and set iterator to the result of emplace + auto it = begin(); + it.m_it.object_iterator = res.first; + + // return pair of iterator and boolean + return {it, res.second}; + } + + /// Helper for insertion of an iterator + /// @note: This uses std::distance to support GCC 4.8, + /// see https://github.com/nlohmann/json/pull/1257 + template + iterator insert_iterator(const_iterator pos, Args&& ... args) + { + iterator result(this); + JSON_ASSERT(m_value.array != nullptr); + + auto insert_pos = std::distance(m_value.array->begin(), pos.m_it.array_iterator); + m_value.array->insert(pos.m_it.array_iterator, std::forward(args)...); + result.m_it.array_iterator = m_value.array->begin() + insert_pos; + + // This could have been written as: + // result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, cnt, val); + // but the return value of insert is missing in GCC 4.8, so it is written this way instead. + + return result; + } + + /*! + @brief inserts element + + Inserts element @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] val element to insert + @return iterator pointing to the inserted @a val. + + @throw type_error.309 if called on JSON values other than arrays; + example: `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Constant plus linear in the distance between @a pos and end of + the container. + + @liveexample{The example shows how `insert()` is used.,insert} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const basic_json& val) + { + // insert only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + return insert_iterator(pos, val); + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + /*! + @brief inserts element + @copydoc insert(const_iterator, const basic_json&) + */ + iterator insert(const_iterator pos, basic_json&& val) + { + return insert(pos, val); + } + + /*! + @brief inserts elements + + Inserts @a cnt copies of @a val before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] cnt number of copies of @a val to insert + @param[in] val element to insert + @return iterator pointing to the first element inserted, or @a pos if + `cnt==0` + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @complexity Linear in @a cnt plus linear in the distance between @a pos + and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__count} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, size_type cnt, const basic_json& val) + { + // insert only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + return insert_iterator(pos, cnt, val); + } + + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)` before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + @throw invalid_iterator.211 if @a first or @a last are iterators into + container for which insert is called; example: `"passed iterators may not + belong to container"` + + @return iterator pointing to the first element inserted, or @a pos if + `first==last` + + @complexity Linear in `std::distance(first, last)` plus linear in the + distance between @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__range} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, const_iterator first, const_iterator last) + { + // insert only works for arrays + if (JSON_HEDLEY_UNLIKELY(!is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + if (JSON_HEDLEY_UNLIKELY(first.m_object == this)) + { + JSON_THROW(invalid_iterator::create(211, "passed iterators may not belong to container")); + } + + // insert to array and return iterator + return insert_iterator(pos, first.m_it.array_iterator, last.m_it.array_iterator); + } + + /*! + @brief inserts elements + + Inserts elements from initializer list @a ilist before iterator @a pos. + + @param[in] pos iterator before which the content will be inserted; may be + the end() iterator + @param[in] ilist initializer list to insert the values from + + @throw type_error.309 if called on JSON values other than arrays; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if @a pos is not an iterator of *this; + example: `"iterator does not fit current value"` + + @return iterator pointing to the first element inserted, or @a pos if + `ilist` is empty + + @complexity Linear in `ilist.size()` plus linear in the distance between + @a pos and end of the container. + + @liveexample{The example shows how `insert()` is used.,insert__ilist} + + @since version 1.0.0 + */ + iterator insert(const_iterator pos, initializer_list_t ilist) + { + // insert only works for arrays + if (JSON_HEDLEY_UNLIKELY(!is_array())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if iterator pos fits to this JSON value + if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) + { + JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); + } + + // insert to array and return iterator + return insert_iterator(pos, ilist.begin(), ilist.end()); + } + + /*! + @brief inserts elements + + Inserts elements from range `[first, last)`. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.309 if called on JSON values other than objects; example: + `"cannot use insert() with string"` + @throw invalid_iterator.202 if iterator @a first or @a last does does not + point to an object; example: `"iterators first and last must point to + objects"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity Logarithmic: `O(N*log(size() + N))`, where `N` is the number + of elements to insert. + + @liveexample{The example shows how `insert()` is used.,insert__range_object} + + @since version 3.0.0 + */ + void insert(const_iterator first, const_iterator last) + { + // insert only works for objects + if (JSON_HEDLEY_UNLIKELY(!is_object())) + { + JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + // passed iterators must belong to objects + if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object())) + { + JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); + } + + m_value.object->insert(first.m_it.object_iterator, last.m_it.object_iterator); + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from JSON object @a j and overwrites existing keys. + + @param[in] j JSON object to read values from + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0 + */ + void update(const_reference j) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + if (JSON_HEDLEY_UNLIKELY(!is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); + } + if (JSON_HEDLEY_UNLIKELY(!j.is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(j.type_name()))); + } + + for (auto it = j.cbegin(); it != j.cend(); ++it) + { + m_value.object->operator[](it.key()) = it.value(); + } + } + + /*! + @brief updates a JSON object from another object, overwriting existing keys + + Inserts all values from from range `[first, last)` and overwrites existing + keys. + + @param[in] first begin of the range of elements to insert + @param[in] last end of the range of elements to insert + + @throw type_error.312 if called on JSON values other than objects; example: + `"cannot use update() with string"` + @throw invalid_iterator.202 if iterator @a first or @a last does does not + point to an object; example: `"iterators first and last must point to + objects"` + @throw invalid_iterator.210 if @a first and @a last do not belong to the + same JSON value; example: `"iterators do not fit"` + + @complexity O(N*log(size() + N)), where N is the number of elements to + insert. + + @liveexample{The example shows how `update()` is used__range.,update} + + @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update + + @since version 3.0.0 + */ + void update(const_iterator first, const_iterator last) + { + // implicitly convert null value to an empty object + if (is_null()) + { + m_type = value_t::object; + m_value.object = create(); + assert_invariant(); + } + + if (JSON_HEDLEY_UNLIKELY(!is_object())) + { + JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); + } + + // check if range iterators belong to the same JSON object + if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) + { + JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); + } + + // passed iterators must belong to objects + if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object() + || !last.m_object->is_object())) + { + JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); + } + + for (auto it = first; it != last; ++it) + { + m_value.object->operator[](it.key()) = it.value(); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of the JSON value with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other JSON value to exchange the contents with + + @complexity Constant. + + @liveexample{The example below shows how JSON values can be swapped with + `swap()`.,swap__reference} + + @since version 1.0.0 + */ + void swap(reference other) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + std::swap(m_type, other.m_type); + std::swap(m_value, other.m_value); + assert_invariant(); + } + + /*! + @brief exchanges the values + + Exchanges the contents of the JSON value from @a left with those of @a right. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. implemented as a friend function callable via ADL. + + @param[in,out] left JSON value to exchange the contents with + @param[in,out] right JSON value to exchange the contents with + + @complexity Constant. + + @liveexample{The example below shows how JSON values can be swapped with + `swap()`.,swap__reference} + + @since version 1.0.0 + */ + friend void swap(reference left, reference right) noexcept ( + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value&& + std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_assignable::value + ) + { + left.swap(right); + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON array with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other array to exchange the contents with + + @throw type_error.310 when JSON value is not an array; example: `"cannot + use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how arrays can be swapped with + `swap()`.,swap__array_t} + + @since version 1.0.0 + */ + void swap(array_t& other) + { + // swap only works for arrays + if (JSON_HEDLEY_LIKELY(is_array())) + { + std::swap(*(m_value.array), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON object with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other object to exchange the contents with + + @throw type_error.310 when JSON value is not an object; example: + `"cannot use swap() with string"` + + @complexity Constant. + + @liveexample{The example below shows how objects can be swapped with + `swap()`.,swap__object_t} + + @since version 1.0.0 + */ + void swap(object_t& other) + { + // swap only works for objects + if (JSON_HEDLEY_LIKELY(is_object())) + { + std::swap(*(m_value.object), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON string with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other string to exchange the contents with + + @throw type_error.310 when JSON value is not a string; example: `"cannot + use swap() with boolean"` + + @complexity Constant. + + @liveexample{The example below shows how strings can be swapped with + `swap()`.,swap__string_t} + + @since version 1.0.0 + */ + void swap(string_t& other) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_string())) + { + std::swap(*(m_value.string), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /*! + @brief exchanges the values + + Exchanges the contents of a JSON string with those of @a other. Does not + invoke any move, copy, or swap operations on individual elements. All + iterators and references remain valid. The past-the-end iterator is + invalidated. + + @param[in,out] other binary to exchange the contents with + + @throw type_error.310 when JSON value is not a string; example: `"cannot + use swap() with boolean"` + + @complexity Constant. + + @liveexample{The example below shows how strings can be swapped with + `swap()`.,swap__binary_t} + + @since version 3.8.0 + */ + void swap(binary_t& other) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_binary())) + { + std::swap(*(m_value.binary), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /// @copydoc swap(binary_t) + void swap(typename binary_t::container_type& other) + { + // swap only works for strings + if (JSON_HEDLEY_LIKELY(is_binary())) + { + std::swap(*(m_value.binary), other); + } + else + { + JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); + } + } + + /// @} + + public: + ////////////////////////////////////////// + // lexicographical comparison operators // + ////////////////////////////////////////// + + /// @name lexicographical comparison operators + /// @{ + + /*! + @brief comparison: equal + + Compares two JSON values for equality according to the following rules: + - Two JSON values are equal if (1) they are from the same type and (2) + their stored values are the same according to their respective + `operator==`. + - Integer and floating-point numbers are automatically converted before + comparison. Note that two NaN values are always treated as unequal. + - Two JSON null values are equal. + + @note Floating-point inside JSON values numbers are compared with + `json::number_float_t::operator==` which is `double::operator==` by + default. To compare floating-point while respecting an epsilon, an alternative + [comparison function](https://github.com/mariokonrad/marnav/blob/master/include/marnav/math/floatingpoint.hpp#L34-#L39) + could be used, for instance + @code {.cpp} + template::value, T>::type> + inline bool is_same(T a, T b, T epsilon = std::numeric_limits::epsilon()) noexcept + { + return std::abs(a - b) <= epsilon; + } + @endcode + Or you can self-defined operator equal function like this: + @code {.cpp} + bool my_equal(const_reference lhs, const_reference rhs) { + const auto lhs_type lhs.type(); + const auto rhs_type rhs.type(); + if (lhs_type == rhs_type) { + switch(lhs_type) + // self_defined case + case value_t::number_float: + return std::abs(lhs - rhs) <= std::numeric_limits::epsilon(); + // other cases remain the same with the original + ... + } + ... + } + @endcode + + @note NaN values never compare equal to themselves or to other NaN values. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are equal + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Linear. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__equal} + + @since version 1.0.0 + */ + friend bool operator==(const_reference lhs, const_reference rhs) noexcept + { + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + return *lhs.m_value.array == *rhs.m_value.array; + + case value_t::object: + return *lhs.m_value.object == *rhs.m_value.object; + + case value_t::null: + return true; + + case value_t::string: + return *lhs.m_value.string == *rhs.m_value.string; + + case value_t::boolean: + return lhs.m_value.boolean == rhs.m_value.boolean; + + case value_t::number_integer: + return lhs.m_value.number_integer == rhs.m_value.number_integer; + + case value_t::number_unsigned: + return lhs.m_value.number_unsigned == rhs.m_value.number_unsigned; + + case value_t::number_float: + return lhs.m_value.number_float == rhs.m_value.number_float; + + case value_t::binary: + return *lhs.m_value.binary == *rhs.m_value.binary; + + default: + return false; + } + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_integer) == rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) + { + return lhs.m_value.number_float == static_cast(rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_float == static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) + { + return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_integer; + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_integer == static_cast(rhs.m_value.number_unsigned); + } + + return false; + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs == basic_json(rhs); + } + + /*! + @brief comparison: equal + @copydoc operator==(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator==(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) == rhs; + } + + /*! + @brief comparison: not equal + + Compares two JSON values for inequality by calculating `not (lhs == rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether the values @a lhs and @a rhs are not equal + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__notequal} + + @since version 1.0.0 + */ + friend bool operator!=(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs == rhs); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs != basic_json(rhs); + } + + /*! + @brief comparison: not equal + @copydoc operator!=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator!=(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) != rhs; + } + + /*! + @brief comparison: less than + + Compares whether one JSON value @a lhs is less than another JSON value @a + rhs according to the following rules: + - If @a lhs and @a rhs have the same type, the values are compared using + the default `<` operator. + - Integer and floating-point numbers are automatically converted before + comparison + - In case @a lhs and @a rhs have different types, the values are ignored + and the order of the types is considered, see + @ref operator<(const value_t, const value_t). + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__less} + + @since version 1.0.0 + */ + friend bool operator<(const_reference lhs, const_reference rhs) noexcept + { + const auto lhs_type = lhs.type(); + const auto rhs_type = rhs.type(); + + if (lhs_type == rhs_type) + { + switch (lhs_type) + { + case value_t::array: + // note parentheses are necessary, see + // https://github.com/nlohmann/json/issues/1530 + return (*lhs.m_value.array) < (*rhs.m_value.array); + + case value_t::object: + return (*lhs.m_value.object) < (*rhs.m_value.object); + + case value_t::null: + return false; + + case value_t::string: + return (*lhs.m_value.string) < (*rhs.m_value.string); + + case value_t::boolean: + return (lhs.m_value.boolean) < (rhs.m_value.boolean); + + case value_t::number_integer: + return (lhs.m_value.number_integer) < (rhs.m_value.number_integer); + + case value_t::number_unsigned: + return (lhs.m_value.number_unsigned) < (rhs.m_value.number_unsigned); + + case value_t::number_float: + return (lhs.m_value.number_float) < (rhs.m_value.number_float); + + case value_t::binary: + return (*lhs.m_value.binary) < (*rhs.m_value.binary); + + default: + return false; + } + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_integer) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_integer); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_float; + } + else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_float < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) + { + return lhs.m_value.number_integer < static_cast(rhs.m_value.number_unsigned); + } + else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) + { + return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_integer; + } + + // We only reach this line if we cannot compare values. In that case, + // we compare types. Note we have to call the operator explicitly, + // because MSVC has problems otherwise. + return operator<(lhs_type, rhs_type); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs < basic_json(rhs); + } + + /*! + @brief comparison: less than + @copydoc operator<(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) < rhs; + } + + /*! + @brief comparison: less than or equal + + Compares whether one JSON value @a lhs is less than or equal to another + JSON value by calculating `not (rhs < lhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is less than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greater} + + @since version 1.0.0 + */ + friend bool operator<=(const_reference lhs, const_reference rhs) noexcept + { + return !(rhs < lhs); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs <= basic_json(rhs); + } + + /*! + @brief comparison: less than or equal + @copydoc operator<=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator<=(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) <= rhs; + } + + /*! + @brief comparison: greater than + + Compares whether one JSON value @a lhs is greater than another + JSON value by calculating `not (lhs <= rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__lessequal} + + @since version 1.0.0 + */ + friend bool operator>(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs <= rhs); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs > basic_json(rhs); + } + + /*! + @brief comparison: greater than + @copydoc operator>(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) > rhs; + } + + /*! + @brief comparison: greater than or equal + + Compares whether one JSON value @a lhs is greater than or equal to another + JSON value by calculating `not (lhs < rhs)`. + + @param[in] lhs first JSON value to consider + @param[in] rhs second JSON value to consider + @return whether @a lhs is greater than or equal to @a rhs + + @complexity Linear. + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @liveexample{The example demonstrates comparing several JSON + types.,operator__greaterequal} + + @since version 1.0.0 + */ + friend bool operator>=(const_reference lhs, const_reference rhs) noexcept + { + return !(lhs < rhs); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(const_reference lhs, const ScalarType rhs) noexcept + { + return lhs >= basic_json(rhs); + } + + /*! + @brief comparison: greater than or equal + @copydoc operator>=(const_reference, const_reference) + */ + template::value, int>::type = 0> + friend bool operator>=(const ScalarType lhs, const_reference rhs) noexcept + { + return basic_json(lhs) >= rhs; + } + + /// @} + + /////////////////// + // serialization // + /////////////////// + + /// @name serialization + /// @{ + + /*! + @brief serialize to stream + + Serialize the given JSON value @a j to the output stream @a o. The JSON + value will be serialized using the @ref dump member function. + + - The indentation of the output can be controlled with the member variable + `width` of the output stream @a o. For instance, using the manipulator + `std::setw(4)` on @a o sets the indentation level to `4` and the + serialization result is the same as calling `dump(4)`. + + - The indentation character can be controlled with the member variable + `fill` of the output stream @a o. For instance, the manipulator + `std::setfill('\\t')` sets indentation to use a tab character rather than + the default space character. + + @param[in,out] o stream to serialize to + @param[in] j JSON value to serialize + + @return the stream @a o + + @throw type_error.316 if a string stored inside the JSON value is not + UTF-8 encoded + + @complexity Linear. + + @liveexample{The example below shows the serialization with different + parameters to `width` to adjust the indentation level.,operator_serialize} + + @since version 1.0.0; indentation character added in version 3.0.0 + */ + friend std::ostream& operator<<(std::ostream& o, const basic_json& j) + { + // read width member and use it as indentation parameter if nonzero + const bool pretty_print = o.width() > 0; + const auto indentation = pretty_print ? o.width() : 0; + + // reset width to 0 for subsequent calls to this stream + o.width(0); + + // do the actual serialization + serializer s(detail::output_adapter(o), o.fill()); + s.dump(j, pretty_print, false, static_cast(indentation)); + return o; + } + + /*! + @brief serialize to stream + @deprecated This stream operator is deprecated and will be removed in + future 4.0.0 of the library. Please use + @ref operator<<(std::ostream&, const basic_json&) + instead; that is, replace calls like `j >> o;` with `o << j;`. + @since version 1.0.0; deprecated since version 3.0.0 + */ + JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator<<(std::ostream&, const basic_json&)) + friend std::ostream& operator>>(const basic_json& j, std::ostream& o) + { + return o << j; + } + + /// @} + + + ///////////////////// + // deserialization // + ///////////////////// + + /// @name deserialization + /// @{ + + /*! + @brief deserialize from a compatible input + + @tparam InputType A compatible input, for instance + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the parser callback function + @a cb or reading from the input @a i has a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `parse()` function reading + from an array.,parse__array__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__string__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function with + and without callback function.,parse__istream__parser_callback_t} + + @liveexample{The example below demonstrates the `parse()` function reading + from a contiguous container.,parse__contiguouscontainer__parser_callback_t} + + @since version 2.0.3 (contiguous containers); version 3.9.0 allowed to + ignore comments. + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json parse(InputType&& i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(detail::input_adapter(std::forward(i)), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + /*! + @brief deserialize from a pair of character iterators + + The value_type of the iterator must be a integral type with size of 1, 2 or + 4 bytes, which will be interpreted respectively as UTF-8, UTF-16 and UTF-32. + + @param[in] first iterator to start of character range + @param[in] last iterator to end of character range + @param[in] cb a parser callback function of type @ref parser_callback_t + which is used to control the deserialization by filtering unwanted values + (optional) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json parse(IteratorType first, + IteratorType last, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(detail::input_adapter(std::move(first), std::move(last)), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, parse(ptr, ptr + len)) + static basic_json parse(detail::span_input_adapter&& i, + const parser_callback_t cb = nullptr, + const bool allow_exceptions = true, + const bool ignore_comments = false) + { + basic_json result; + parser(i.get(), cb, allow_exceptions, ignore_comments).parse(true, result); + return result; + } + + /*! + @brief check if the input is valid JSON + + Unlike the @ref parse(InputType&&, const parser_callback_t,const bool) + function, this function neither throws an exception in case of invalid JSON + input (i.e., a parse error) nor creates diagnostic information. + + @tparam InputType A compatible input, for instance + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default) + + @return Whether the input read from @a i is valid JSON. + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `accept()` function reading + from a string.,accept__string} + */ + template + static bool accept(InputType&& i, + const bool ignore_comments = false) + { + return parser(detail::input_adapter(std::forward(i)), nullptr, false, ignore_comments).accept(true); + } + + template + static bool accept(IteratorType first, IteratorType last, + const bool ignore_comments = false) + { + return parser(detail::input_adapter(std::move(first), std::move(last)), nullptr, false, ignore_comments).accept(true); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, accept(ptr, ptr + len)) + static bool accept(detail::span_input_adapter&& i, + const bool ignore_comments = false) + { + return parser(i.get(), nullptr, false, ignore_comments).accept(true); + } + + /*! + @brief generate SAX events + + The SAX event lister must follow the interface of @ref json_sax. + + This function reads from a compatible input. Examples are: + - an std::istream object + - a FILE pointer + - a C-style array of characters + - a pointer to a null-terminated string of single byte characters + - an object obj for which begin(obj) and end(obj) produces a valid pair of + iterators. + + @param[in] i input to read from + @param[in,out] sax SAX event listener + @param[in] format the format to parse (JSON, CBOR, MessagePack, or UBJSON) + @param[in] strict whether the input has to be consumed completely + @param[in] ignore_comments whether comments should be ignored and treated + like whitespace (true) or yield a parse error (true); (optional, false by + default); only applies to the JSON file format. + + @return return value of the last processed SAX event + + @throw parse_error.101 if a parse error occurs; example: `""unexpected end + of input; expected string literal""` + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. The complexity can be higher if the SAX consumer @a sax has + a super-linear complexity. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below demonstrates the `sax_parse()` function + reading from string and processing the events with a user-defined SAX + event consumer.,sax_parse} + + @since version 3.2.0 + */ + template + JSON_HEDLEY_NON_NULL(2) + static bool sax_parse(InputType&& i, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = detail::input_adapter(std::forward(i)); + return format == input_format_t::json + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } + + template + JSON_HEDLEY_NON_NULL(3) + static bool sax_parse(IteratorType first, IteratorType last, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = detail::input_adapter(std::move(first), std::move(last)); + return format == input_format_t::json + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } + + template + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, sax_parse(ptr, ptr + len, ...)) + JSON_HEDLEY_NON_NULL(2) + static bool sax_parse(detail::span_input_adapter&& i, SAX* sax, + input_format_t format = input_format_t::json, + const bool strict = true, + const bool ignore_comments = false) + { + auto ia = i.get(); + return format == input_format_t::json + ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) + : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); + } + + /*! + @brief deserialize from stream + @deprecated This stream operator is deprecated and will be removed in + version 4.0.0 of the library. Please use + @ref operator>>(std::istream&, basic_json&) + instead; that is, replace calls like `j << i;` with `i >> j;`. + @since version 1.0.0; deprecated since version 3.0.0 + */ + JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator>>(std::istream&, basic_json&)) + friend std::istream& operator<<(basic_json& j, std::istream& i) + { + return operator>>(i, j); + } + + /*! + @brief deserialize from stream + + Deserializes an input stream to a JSON value. + + @param[in,out] i input stream to read a serialized JSON value from + @param[in,out] j JSON value to write the deserialized input to + + @throw parse_error.101 in case of an unexpected token + @throw parse_error.102 if to_unicode fails or surrogate error + @throw parse_error.103 if to_unicode fails + + @complexity Linear in the length of the input. The parser is a predictive + LL(1) parser. + + @note A UTF-8 byte order mark is silently ignored. + + @liveexample{The example below shows how a JSON value is constructed by + reading a serialization from a stream.,operator_deserialize} + + @sa parse(std::istream&, const parser_callback_t) for a variant with a + parser callback function to filter values while parsing + + @since version 1.0.0 + */ + friend std::istream& operator>>(std::istream& i, basic_json& j) + { + parser(detail::input_adapter(i)).parse(false, j); + return i; + } + + /// @} + + /////////////////////////// + // convenience functions // + /////////////////////////// + + /*! + @brief return the type as string + + Returns the type name as string to be used in error messages - usually to + indicate that a function was called on a wrong JSON type. + + @return a string representation of a the @a m_type member: + Value type | return value + ----------- | ------------- + null | `"null"` + boolean | `"boolean"` + string | `"string"` + number | `"number"` (for all number types) + object | `"object"` + array | `"array"` + binary | `"binary"` + discarded | `"discarded"` + + @exceptionsafety No-throw guarantee: this function never throws exceptions. + + @complexity Constant. + + @liveexample{The following code exemplifies `type_name()` for all JSON + types.,type_name} + + @sa @ref type() -- return the type of the JSON value + @sa @ref operator value_t() -- return the type of the JSON value (implicit) + + @since version 1.0.0, public since 2.1.0, `const char*` and `noexcept` + since 3.0.0 + */ + JSON_HEDLEY_RETURNS_NON_NULL + const char* type_name() const noexcept + { + { + switch (m_type) + { + case value_t::null: + return "null"; + case value_t::object: + return "object"; + case value_t::array: + return "array"; + case value_t::string: + return "string"; + case value_t::boolean: + return "boolean"; + case value_t::binary: + return "binary"; + case value_t::discarded: + return "discarded"; + default: + return "number"; + } + } + } + + + private: + ////////////////////// + // member variables // + ////////////////////// + + /// the type of the current element + value_t m_type = value_t::null; + + /// the value of the current element + json_value m_value = {}; + + ////////////////////////////////////////// + // binary serialization/deserialization // + ////////////////////////////////////////// + + /// @name binary serialization/deserialization support + /// @{ + + public: + /*! + @brief create a CBOR serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the CBOR (Concise + Binary Object Representation) serialization format. CBOR is a binary + serialization format which aims to be more compact than JSON itself, yet + more efficient to parse. + + The library uses the following mapping from JSON values types to + CBOR types according to the CBOR specification (RFC 7049): + + JSON value type | value/range | CBOR type | first byte + --------------- | ------------------------------------------ | ---------------------------------- | --------------- + null | `null` | Null | 0xF6 + boolean | `true` | True | 0xF5 + boolean | `false` | False | 0xF4 + number_integer | -9223372036854775808..-2147483649 | Negative integer (8 bytes follow) | 0x3B + number_integer | -2147483648..-32769 | Negative integer (4 bytes follow) | 0x3A + number_integer | -32768..-129 | Negative integer (2 bytes follow) | 0x39 + number_integer | -128..-25 | Negative integer (1 byte follow) | 0x38 + number_integer | -24..-1 | Negative integer | 0x20..0x37 + number_integer | 0..23 | Integer | 0x00..0x17 + number_integer | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_integer | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_integer | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A + number_integer | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B + number_unsigned | 0..23 | Integer | 0x00..0x17 + number_unsigned | 24..255 | Unsigned integer (1 byte follow) | 0x18 + number_unsigned | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 + number_unsigned | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A + number_unsigned | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B + number_float | *any value representable by a float* | Single-Precision Float | 0xFA + number_float | *any value NOT representable by a float* | Double-Precision Float | 0xFB + string | *length*: 0..23 | UTF-8 string | 0x60..0x77 + string | *length*: 23..255 | UTF-8 string (1 byte follow) | 0x78 + string | *length*: 256..65535 | UTF-8 string (2 bytes follow) | 0x79 + string | *length*: 65536..4294967295 | UTF-8 string (4 bytes follow) | 0x7A + string | *length*: 4294967296..18446744073709551615 | UTF-8 string (8 bytes follow) | 0x7B + array | *size*: 0..23 | array | 0x80..0x97 + array | *size*: 23..255 | array (1 byte follow) | 0x98 + array | *size*: 256..65535 | array (2 bytes follow) | 0x99 + array | *size*: 65536..4294967295 | array (4 bytes follow) | 0x9A + array | *size*: 4294967296..18446744073709551615 | array (8 bytes follow) | 0x9B + object | *size*: 0..23 | map | 0xA0..0xB7 + object | *size*: 23..255 | map (1 byte follow) | 0xB8 + object | *size*: 256..65535 | map (2 bytes follow) | 0xB9 + object | *size*: 65536..4294967295 | map (4 bytes follow) | 0xBA + object | *size*: 4294967296..18446744073709551615 | map (8 bytes follow) | 0xBB + binary | *size*: 0..23 | byte string | 0x40..0x57 + binary | *size*: 23..255 | byte string (1 byte follow) | 0x58 + binary | *size*: 256..65535 | byte string (2 bytes follow) | 0x59 + binary | *size*: 65536..4294967295 | byte string (4 bytes follow) | 0x5A + binary | *size*: 4294967296..18446744073709551615 | byte string (8 bytes follow) | 0x5B + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a CBOR value. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @note The following CBOR types are not used in the conversion: + - UTF-8 strings terminated by "break" (0x7F) + - arrays terminated by "break" (0x9F) + - maps terminated by "break" (0xBF) + - byte strings terminated by "break" (0x5F) + - date/time (0xC0..0xC1) + - bignum (0xC2..0xC3) + - decimal fraction (0xC4) + - bigfloat (0xC5) + - expected conversions (0xD5..0xD7) + - simple values (0xE0..0xF3, 0xF8) + - undefined (0xF7) + - half-precision floats (0xF9) + - break (0xFF) + + @param[in] j JSON value to serialize + @return CBOR serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in CBOR format.,to_cbor} + + @sa http://cbor.io + @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the + analogous deserialization + @sa @ref to_msgpack(const basic_json&) for the related MessagePack format + @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9; compact representation of floating-point numbers + since version 3.8.0 + */ + static std::vector to_cbor(const basic_json& j) + { + std::vector result; + to_cbor(j, result); + return result; + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + static void to_cbor(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_cbor(j); + } + + /*! + @brief create a MessagePack serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the MessagePack + serialization format. MessagePack is a binary serialization format which + aims to be more compact than JSON itself, yet more efficient to parse. + + The library uses the following mapping from JSON values types to + MessagePack types according to the MessagePack specification: + + JSON value type | value/range | MessagePack type | first byte + --------------- | --------------------------------- | ---------------- | ---------- + null | `null` | nil | 0xC0 + boolean | `true` | true | 0xC3 + boolean | `false` | false | 0xC2 + number_integer | -9223372036854775808..-2147483649 | int64 | 0xD3 + number_integer | -2147483648..-32769 | int32 | 0xD2 + number_integer | -32768..-129 | int16 | 0xD1 + number_integer | -128..-33 | int8 | 0xD0 + number_integer | -32..-1 | negative fixint | 0xE0..0xFF + number_integer | 0..127 | positive fixint | 0x00..0x7F + number_integer | 128..255 | uint 8 | 0xCC + number_integer | 256..65535 | uint 16 | 0xCD + number_integer | 65536..4294967295 | uint 32 | 0xCE + number_integer | 4294967296..18446744073709551615 | uint 64 | 0xCF + number_unsigned | 0..127 | positive fixint | 0x00..0x7F + number_unsigned | 128..255 | uint 8 | 0xCC + number_unsigned | 256..65535 | uint 16 | 0xCD + number_unsigned | 65536..4294967295 | uint 32 | 0xCE + number_unsigned | 4294967296..18446744073709551615 | uint 64 | 0xCF + number_float | *any value representable by a float* | float 32 | 0xCA + number_float | *any value NOT representable by a float* | float 64 | 0xCB + string | *length*: 0..31 | fixstr | 0xA0..0xBF + string | *length*: 32..255 | str 8 | 0xD9 + string | *length*: 256..65535 | str 16 | 0xDA + string | *length*: 65536..4294967295 | str 32 | 0xDB + array | *size*: 0..15 | fixarray | 0x90..0x9F + array | *size*: 16..65535 | array 16 | 0xDC + array | *size*: 65536..4294967295 | array 32 | 0xDD + object | *size*: 0..15 | fix map | 0x80..0x8F + object | *size*: 16..65535 | map 16 | 0xDE + object | *size*: 65536..4294967295 | map 32 | 0xDF + binary | *size*: 0..255 | bin 8 | 0xC4 + binary | *size*: 256..65535 | bin 16 | 0xC5 + binary | *size*: 65536..4294967295 | bin 32 | 0xC6 + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a MessagePack value. + + @note The following values can **not** be converted to a MessagePack value: + - strings with more than 4294967295 bytes + - byte strings with more than 4294967295 bytes + - arrays with more than 4294967295 elements + - objects with more than 4294967295 elements + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @param[in] j JSON value to serialize + @return MessagePack serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in MessagePack format.,to_msgpack} + + @sa http://msgpack.org + @sa @ref from_msgpack for the analogous deserialization + @sa @ref to_cbor(const basic_json& for the related CBOR format + @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9 + */ + static std::vector to_msgpack(const basic_json& j) + { + std::vector result; + to_msgpack(j, result); + return result; + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + static void to_msgpack(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_msgpack(j); + } + + /*! + @brief create a UBJSON serialization of a given JSON value + + Serializes a given JSON value @a j to a byte vector using the UBJSON + (Universal Binary JSON) serialization format. UBJSON aims to be more compact + than JSON itself, yet more efficient to parse. + + The library uses the following mapping from JSON values types to + UBJSON types according to the UBJSON specification: + + JSON value type | value/range | UBJSON type | marker + --------------- | --------------------------------- | ----------- | ------ + null | `null` | null | `Z` + boolean | `true` | true | `T` + boolean | `false` | false | `F` + number_integer | -9223372036854775808..-2147483649 | int64 | `L` + number_integer | -2147483648..-32769 | int32 | `l` + number_integer | -32768..-129 | int16 | `I` + number_integer | -128..127 | int8 | `i` + number_integer | 128..255 | uint8 | `U` + number_integer | 256..32767 | int16 | `I` + number_integer | 32768..2147483647 | int32 | `l` + number_integer | 2147483648..9223372036854775807 | int64 | `L` + number_unsigned | 0..127 | int8 | `i` + number_unsigned | 128..255 | uint8 | `U` + number_unsigned | 256..32767 | int16 | `I` + number_unsigned | 32768..2147483647 | int32 | `l` + number_unsigned | 2147483648..9223372036854775807 | int64 | `L` + number_unsigned | 2147483649..18446744073709551615 | high-precision | `H` + number_float | *any value* | float64 | `D` + string | *with shortest length indicator* | string | `S` + array | *see notes on optimized format* | array | `[` + object | *see notes on optimized format* | map | `{` + + @note The mapping is **complete** in the sense that any JSON value type + can be converted to a UBJSON value. + + @note The following values can **not** be converted to a UBJSON value: + - strings with more than 9223372036854775807 bytes (theoretical) + + @note The following markers are not used in the conversion: + - `Z`: no-op values are not created. + - `C`: single-byte strings are serialized with `S` markers. + + @note Any UBJSON output created @ref to_ubjson can be successfully parsed + by @ref from_ubjson. + + @note If NaN or Infinity are stored inside a JSON number, they are + serialized properly. This behavior differs from the @ref dump() + function which serializes NaN or Infinity to `null`. + + @note The optimized formats for containers are supported: Parameter + @a use_size adds size information to the beginning of a container and + removes the closing marker. Parameter @a use_type further checks + whether all elements of a container have the same type and adds the + type marker to the beginning of the container. The @a use_type + parameter must only be used together with @a use_size = true. Note + that @a use_size = true alone may result in larger representations - + the benefit of this parameter is that the receiving side is + immediately informed on the number of elements of the container. + + @note If the JSON data contains the binary type, the value stored is a list + of integers, as suggested by the UBJSON documentation. In particular, + this means that serialization and the deserialization of a JSON + containing binary values into UBJSON and back will result in a + different JSON object. + + @param[in] j JSON value to serialize + @param[in] use_size whether to add size annotations to container types + @param[in] use_type whether to add type annotations to container types + (must be combined with @a use_size = true) + @return UBJSON serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in UBJSON format.,to_ubjson} + + @sa http://ubjson.org + @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the + analogous deserialization + @sa @ref to_cbor(const basic_json& for the related CBOR format + @sa @ref to_msgpack(const basic_json&) for the related MessagePack format + + @since version 3.1.0 + */ + static std::vector to_ubjson(const basic_json& j, + const bool use_size = false, + const bool use_type = false) + { + std::vector result; + to_ubjson(j, result, use_size, use_type); + return result; + } + + static void to_ubjson(const basic_json& j, detail::output_adapter o, + const bool use_size = false, const bool use_type = false) + { + binary_writer(o).write_ubjson(j, use_size, use_type); + } + + static void to_ubjson(const basic_json& j, detail::output_adapter o, + const bool use_size = false, const bool use_type = false) + { + binary_writer(o).write_ubjson(j, use_size, use_type); + } + + + /*! + @brief Serializes the given JSON object `j` to BSON and returns a vector + containing the corresponding BSON-representation. + + BSON (Binary JSON) is a binary format in which zero or more ordered key/value pairs are + stored as a single entity (a so-called document). + + The library uses the following mapping from JSON values types to BSON types: + + JSON value type | value/range | BSON type | marker + --------------- | --------------------------------- | ----------- | ------ + null | `null` | null | 0x0A + boolean | `true`, `false` | boolean | 0x08 + number_integer | -9223372036854775808..-2147483649 | int64 | 0x12 + number_integer | -2147483648..2147483647 | int32 | 0x10 + number_integer | 2147483648..9223372036854775807 | int64 | 0x12 + number_unsigned | 0..2147483647 | int32 | 0x10 + number_unsigned | 2147483648..9223372036854775807 | int64 | 0x12 + number_unsigned | 9223372036854775808..18446744073709551615| -- | -- + number_float | *any value* | double | 0x01 + string | *any value* | string | 0x02 + array | *any value* | document | 0x04 + object | *any value* | document | 0x03 + binary | *any value* | binary | 0x05 + + @warning The mapping is **incomplete**, since only JSON-objects (and things + contained therein) can be serialized to BSON. + Also, integers larger than 9223372036854775807 cannot be serialized to BSON, + and the keys may not contain U+0000, since they are serialized a + zero-terminated c-strings. + + @throw out_of_range.407 if `j.is_number_unsigned() && j.get() > 9223372036854775807` + @throw out_of_range.409 if a key in `j` contains a NULL (U+0000) + @throw type_error.317 if `!j.is_object()` + + @pre The input `j` is required to be an object: `j.is_object() == true`. + + @note Any BSON output created via @ref to_bson can be successfully parsed + by @ref from_bson. + + @param[in] j JSON value to serialize + @return BSON serialization as byte vector + + @complexity Linear in the size of the JSON value @a j. + + @liveexample{The example shows the serialization of a JSON value to a byte + vector in BSON format.,to_bson} + + @sa http://bsonspec.org/spec.html + @sa @ref from_bson(detail::input_adapter&&, const bool strict) for the + analogous deserialization + @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the + related UBJSON format + @sa @ref to_cbor(const basic_json&) for the related CBOR format + @sa @ref to_msgpack(const basic_json&) for the related MessagePack format + */ + static std::vector to_bson(const basic_json& j) + { + std::vector result; + to_bson(j, result); + return result; + } + + /*! + @brief Serializes the given JSON object `j` to BSON and forwards the + corresponding BSON-representation to the given output_adapter `o`. + @param j The JSON object to convert to BSON. + @param o The output adapter that receives the binary BSON representation. + @pre The input `j` shall be an object: `j.is_object() == true` + @sa @ref to_bson(const basic_json&) + */ + static void to_bson(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_bson(j); + } + + /*! + @copydoc to_bson(const basic_json&, detail::output_adapter) + */ + static void to_bson(const basic_json& j, detail::output_adapter o) + { + binary_writer(o).write_bson(j); + } + + + /*! + @brief create a JSON value from an input in CBOR format + + Deserializes a given input @a i to a JSON value using the CBOR (Concise + Binary Object Representation) serialization format. + + The library maps CBOR types to JSON value types as follows: + + CBOR type | JSON value type | first byte + ---------------------- | --------------- | ---------- + Integer | number_unsigned | 0x00..0x17 + Unsigned integer | number_unsigned | 0x18 + Unsigned integer | number_unsigned | 0x19 + Unsigned integer | number_unsigned | 0x1A + Unsigned integer | number_unsigned | 0x1B + Negative integer | number_integer | 0x20..0x37 + Negative integer | number_integer | 0x38 + Negative integer | number_integer | 0x39 + Negative integer | number_integer | 0x3A + Negative integer | number_integer | 0x3B + Byte string | binary | 0x40..0x57 + Byte string | binary | 0x58 + Byte string | binary | 0x59 + Byte string | binary | 0x5A + Byte string | binary | 0x5B + UTF-8 string | string | 0x60..0x77 + UTF-8 string | string | 0x78 + UTF-8 string | string | 0x79 + UTF-8 string | string | 0x7A + UTF-8 string | string | 0x7B + UTF-8 string | string | 0x7F + array | array | 0x80..0x97 + array | array | 0x98 + array | array | 0x99 + array | array | 0x9A + array | array | 0x9B + array | array | 0x9F + map | object | 0xA0..0xB7 + map | object | 0xB8 + map | object | 0xB9 + map | object | 0xBA + map | object | 0xBB + map | object | 0xBF + False | `false` | 0xF4 + True | `true` | 0xF5 + Null | `null` | 0xF6 + Half-Precision Float | number_float | 0xF9 + Single-Precision Float | number_float | 0xFA + Double-Precision Float | number_float | 0xFB + + @warning The mapping is **incomplete** in the sense that not all CBOR + types can be converted to a JSON value. The following CBOR types + are not supported and will yield parse errors (parse_error.112): + - date/time (0xC0..0xC1) + - bignum (0xC2..0xC3) + - decimal fraction (0xC4) + - bigfloat (0xC5) + - expected conversions (0xD5..0xD7) + - simple values (0xE0..0xF3, 0xF8) + - undefined (0xF7) + + @warning CBOR allows map keys of any type, whereas JSON only allows + strings as keys in object values. Therefore, CBOR maps with keys + other than UTF-8 strings are rejected (parse_error.113). + + @note Any CBOR output created @ref to_cbor can be successfully parsed by + @ref from_cbor. + + @param[in] i an input in CBOR format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + @param[in] tag_handler how to treat CBOR tags (optional, error by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from CBOR were + used in the given input @a v or if the input is not valid CBOR + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in CBOR + format to a JSON value.,from_cbor} + + @sa http://cbor.io + @sa @ref to_cbor(const basic_json&) for the analogous serialization + @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for the + related MessagePack format + @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the + related UBJSON format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0; added @a allow_exceptions parameter + since 3.2.0; added @a tag_handler parameter since 3.9.0. + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_cbor(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_cbor(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) + static basic_json from_cbor(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + return from_cbor(ptr, ptr + len, strict, allow_exceptions, tag_handler); + } + + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) + static basic_json from_cbor(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true, + const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @brief create a JSON value from an input in MessagePack format + + Deserializes a given input @a i to a JSON value using the MessagePack + serialization format. + + The library maps MessagePack types to JSON value types as follows: + + MessagePack type | JSON value type | first byte + ---------------- | --------------- | ---------- + positive fixint | number_unsigned | 0x00..0x7F + fixmap | object | 0x80..0x8F + fixarray | array | 0x90..0x9F + fixstr | string | 0xA0..0xBF + nil | `null` | 0xC0 + false | `false` | 0xC2 + true | `true` | 0xC3 + float 32 | number_float | 0xCA + float 64 | number_float | 0xCB + uint 8 | number_unsigned | 0xCC + uint 16 | number_unsigned | 0xCD + uint 32 | number_unsigned | 0xCE + uint 64 | number_unsigned | 0xCF + int 8 | number_integer | 0xD0 + int 16 | number_integer | 0xD1 + int 32 | number_integer | 0xD2 + int 64 | number_integer | 0xD3 + str 8 | string | 0xD9 + str 16 | string | 0xDA + str 32 | string | 0xDB + array 16 | array | 0xDC + array 32 | array | 0xDD + map 16 | object | 0xDE + map 32 | object | 0xDF + bin 8 | binary | 0xC4 + bin 16 | binary | 0xC5 + bin 32 | binary | 0xC6 + ext 8 | binary | 0xC7 + ext 16 | binary | 0xC8 + ext 32 | binary | 0xC9 + fixext 1 | binary | 0xD4 + fixext 2 | binary | 0xD5 + fixext 4 | binary | 0xD6 + fixext 8 | binary | 0xD7 + fixext 16 | binary | 0xD8 + negative fixint | number_integer | 0xE0-0xFF + + @note Any MessagePack output created @ref to_msgpack can be successfully + parsed by @ref from_msgpack. + + @param[in] i an input in MessagePack format convertible to an input + adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if unsupported features from MessagePack were + used in the given input @a i or if the input is not valid MessagePack + @throw parse_error.113 if a string was expected as map key, but not found + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + MessagePack format to a JSON value.,from_msgpack} + + @sa http://msgpack.org + @sa @ref to_msgpack(const basic_json&) for the analogous serialization + @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for + the related UBJSON format + @sa @ref from_bson(detail::input_adapter&&, const bool, const bool) for + the related BSON format + + @since version 2.0.9; parameter @a start_index since 2.1.1; changed to + consume input adapters, removed start_index parameter, and added + @a strict parameter since 3.0.0; added @a allow_exceptions parameter + since 3.2.0 + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_msgpack(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_msgpack(detail::input_adapter&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_msgpack(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) + static basic_json from_msgpack(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_msgpack(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) + static basic_json from_msgpack(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + /*! + @brief create a JSON value from an input in UBJSON format + + Deserializes a given input @a i to a JSON value using the UBJSON (Universal + Binary JSON) serialization format. + + The library maps UBJSON types to JSON value types as follows: + + UBJSON type | JSON value type | marker + ----------- | --------------------------------------- | ------ + no-op | *no value, next value is read* | `N` + null | `null` | `Z` + false | `false` | `F` + true | `true` | `T` + float32 | number_float | `d` + float64 | number_float | `D` + uint8 | number_unsigned | `U` + int8 | number_integer | `i` + int16 | number_integer | `I` + int32 | number_integer | `l` + int64 | number_integer | `L` + high-precision number | number_integer, number_unsigned, or number_float - depends on number string | 'H' + string | string | `S` + char | string | `C` + array | array (optimized values are supported) | `[` + object | object (optimized values are supported) | `{` + + @note The mapping is **complete** in the sense that any UBJSON value can + be converted to a JSON value. + + @param[in] i an input in UBJSON format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.110 if the given input ends prematurely or the end of + file was not reached when @a strict was set to true + @throw parse_error.112 if a parse error occurs + @throw parse_error.113 if a string could not be parsed successfully + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + UBJSON format to a JSON value.,from_ubjson} + + @sa http://ubjson.org + @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the + analogous serialization + @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for + the related MessagePack format + @sa @ref from_bson(detail::input_adapter&&, const bool, const bool) for + the related BSON format + + @since version 3.1.0; added @a allow_exceptions parameter since 3.2.0 + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_ubjson(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_ubjson(detail::input_adapter&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_ubjson(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) + static basic_json from_ubjson(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_ubjson(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) + static basic_json from_ubjson(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + + /*! + @brief Create a JSON value from an input in BSON format + + Deserializes a given input @a i to a JSON value using the BSON (Binary JSON) + serialization format. + + The library maps BSON record types to JSON value types as follows: + + BSON type | BSON marker byte | JSON value type + --------------- | ---------------- | --------------------------- + double | 0x01 | number_float + string | 0x02 | string + document | 0x03 | object + array | 0x04 | array + binary | 0x05 | still unsupported + undefined | 0x06 | still unsupported + ObjectId | 0x07 | still unsupported + boolean | 0x08 | boolean + UTC Date-Time | 0x09 | still unsupported + null | 0x0A | null + Regular Expr. | 0x0B | still unsupported + DB Pointer | 0x0C | still unsupported + JavaScript Code | 0x0D | still unsupported + Symbol | 0x0E | still unsupported + JavaScript Code | 0x0F | still unsupported + int32 | 0x10 | number_integer + Timestamp | 0x11 | still unsupported + 128-bit decimal float | 0x13 | still unsupported + Max Key | 0x7F | still unsupported + Min Key | 0xFF | still unsupported + + @warning The mapping is **incomplete**. The unsupported mappings + are indicated in the table above. + + @param[in] i an input in BSON format convertible to an input adapter + @param[in] strict whether to expect the input to be consumed until EOF + (true by default) + @param[in] allow_exceptions whether to throw exceptions in case of a + parse error (optional, true by default) + + @return deserialized JSON value; in case of a parse error and + @a allow_exceptions set to `false`, the return value will be + value_t::discarded. + + @throw parse_error.114 if an unsupported BSON record type is encountered + + @complexity Linear in the size of the input @a i. + + @liveexample{The example shows the deserialization of a byte vector in + BSON format to a JSON value.,from_bson} + + @sa http://bsonspec.org/spec.html + @sa @ref to_bson(const basic_json&) for the analogous serialization + @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the + related CBOR format + @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for + the related MessagePack format + @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the + related UBJSON format + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_bson(InputType&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::forward(i)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + /*! + @copydoc from_bson(detail::input_adapter&&, const bool, const bool) + */ + template + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json from_bson(IteratorType first, IteratorType last, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = detail::input_adapter(std::move(first), std::move(last)); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + + template + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) + static basic_json from_bson(const T* ptr, std::size_t len, + const bool strict = true, + const bool allow_exceptions = true) + { + return from_bson(ptr, ptr + len, strict, allow_exceptions); + } + + JSON_HEDLEY_WARN_UNUSED_RESULT + JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) + static basic_json from_bson(detail::span_input_adapter&& i, + const bool strict = true, + const bool allow_exceptions = true) + { + basic_json result; + detail::json_sax_dom_parser sdp(result, allow_exceptions); + auto ia = i.get(); + const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); + return res ? result : basic_json(value_t::discarded); + } + /// @} + + ////////////////////////// + // JSON Pointer support // + ////////////////////////// + + /// @name JSON Pointer functions + /// @{ + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. Similar to @ref operator[](const typename + object_t::key_type&), `null` values are created in arrays and objects if + necessary. + + In particular: + - If the JSON pointer points to an object key that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. + - If the JSON pointer points to an array index that does not exist, it + is created an filled with a `null` value before a reference to it + is returned. All indices between the current maximum and the given + index are also filled with `null`. + - The special value `-` is treated as a synonym for the index past the + end. + + @param[in] ptr a JSON pointer + + @return reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer} + + @since version 2.0.0 + */ + reference operator[](const json_pointer& ptr) + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Uses a JSON pointer to retrieve a reference to the respective JSON value. + No bound checking is performed. The function does not change the JSON + value; no `null` values are created. In particular, the special value + `-` yields an exception. + + @param[in] ptr JSON pointer to the desired element + + @return const reference to the element pointed to by @a ptr + + @complexity Constant. + + @throw parse_error.106 if an array index begins with '0' + @throw parse_error.109 if an array index was not a number + @throw out_of_range.402 if the array index '-' is used + @throw out_of_range.404 if the JSON pointer can not be resolved + + @liveexample{The behavior is shown in the example.,operatorjson_pointer_const} + + @since version 2.0.0 + */ + const_reference operator[](const json_pointer& ptr) const + { + return ptr.get_unchecked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a reference to the element at with specified JSON pointer @a ptr, + with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.403 if the JSON pointer describes a key of an object + which cannot be found. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer} + */ + reference at(const json_pointer& ptr) + { + return ptr.get_checked(this); + } + + /*! + @brief access specified element via JSON Pointer + + Returns a const reference to the element at with specified JSON pointer @a + ptr, with bounds checking. + + @param[in] ptr JSON pointer to the desired element + + @return reference to the element pointed to by @a ptr + + @throw parse_error.106 if an array index in the passed JSON pointer @a ptr + begins with '0'. See example below. + + @throw parse_error.109 if an array index in the passed JSON pointer @a ptr + is not a number. See example below. + + @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr + is out of range. See example below. + + @throw out_of_range.402 if the array index '-' is used in the passed JSON + pointer @a ptr. As `at` provides checked access (and no elements are + implicitly inserted), the index '-' is always invalid. See example below. + + @throw out_of_range.403 if the JSON pointer describes a key of an object + which cannot be found. See example below. + + @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. + See example below. + + @exceptionsafety Strong guarantee: if an exception is thrown, there are no + changes in the JSON value. + + @complexity Constant. + + @since version 2.0.0 + + @liveexample{The behavior is shown in the example.,at_json_pointer_const} + */ + const_reference at(const json_pointer& ptr) const + { + return ptr.get_checked(this); + } + + /*! + @brief return flattened JSON value + + The function creates a JSON object whose keys are JSON pointers (see [RFC + 6901](https://tools.ietf.org/html/rfc6901)) and whose values are all + primitive. The original JSON value can be restored using the @ref + unflatten() function. + + @return an object that maps JSON pointers to primitive values + + @note Empty objects and arrays are flattened to `null` and will not be + reconstructed correctly by the @ref unflatten() function. + + @complexity Linear in the size the JSON value. + + @liveexample{The following code shows how a JSON object is flattened to an + object whose keys consist of JSON pointers.,flatten} + + @sa @ref unflatten() for the reverse function + + @since version 2.0.0 + */ + basic_json flatten() const + { + basic_json result(value_t::object); + json_pointer::flatten("", *this, result); + return result; + } + + /*! + @brief unflatten a previously flattened JSON value + + The function restores the arbitrary nesting of a JSON value that has been + flattened before using the @ref flatten() function. The JSON value must + meet certain constraints: + 1. The value must be an object. + 2. The keys must be JSON pointers (see + [RFC 6901](https://tools.ietf.org/html/rfc6901)) + 3. The mapped values must be primitive JSON types. + + @return the original JSON from a flattened version + + @note Empty objects and arrays are flattened by @ref flatten() to `null` + values and can not unflattened to their original type. Apart from + this example, for a JSON value `j`, the following is always true: + `j == j.flatten().unflatten()`. + + @complexity Linear in the size the JSON value. + + @throw type_error.314 if value is not an object + @throw type_error.315 if object values are not primitive + + @liveexample{The following code shows how a flattened JSON object is + unflattened into the original nested JSON object.,unflatten} + + @sa @ref flatten() for the reverse function + + @since version 2.0.0 + */ + basic_json unflatten() const + { + return json_pointer::unflatten(*this); + } + + /// @} + + ////////////////////////// + // JSON Patch functions // + ////////////////////////// + + /// @name JSON Patch functions + /// @{ + + /*! + @brief applies a JSON patch + + [JSON Patch](http://jsonpatch.com) defines a JSON document structure for + expressing a sequence of operations to apply to a JSON) document. With + this function, a JSON Patch is applied to the current JSON value by + executing all operations from the patch. + + @param[in] json_patch JSON patch document + @return patched document + + @note The application of a patch is atomic: Either all operations succeed + and the patched document is returned or an exception is thrown. In + any case, the original value is not changed: the patch is applied + to a copy of the value. + + @throw parse_error.104 if the JSON patch does not consist of an array of + objects + + @throw parse_error.105 if the JSON patch is malformed (e.g., mandatory + attributes are missing); example: `"operation add must have member path"` + + @throw out_of_range.401 if an array index is out of range. + + @throw out_of_range.403 if a JSON pointer inside the patch could not be + resolved successfully in the current JSON value; example: `"key baz not + found"` + + @throw out_of_range.405 if JSON pointer has no parent ("add", "remove", + "move") + + @throw other_error.501 if "test" operation was unsuccessful + + @complexity Linear in the size of the JSON value and the length of the + JSON patch. As usually only a fraction of the JSON value is affected by + the patch, the complexity can usually be neglected. + + @liveexample{The following code shows how a JSON patch is applied to a + value.,patch} + + @sa @ref diff -- create a JSON patch by comparing two JSON values + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + @sa [RFC 6901 (JSON Pointer)](https://tools.ietf.org/html/rfc6901) + + @since version 2.0.0 + */ + basic_json patch(const basic_json& json_patch) const + { + // make a working copy to apply the patch to + basic_json result = *this; + + // the valid JSON Patch operations + enum class patch_operations {add, remove, replace, move, copy, test, invalid}; + + const auto get_op = [](const std::string & op) + { + if (op == "add") + { + return patch_operations::add; + } + if (op == "remove") + { + return patch_operations::remove; + } + if (op == "replace") + { + return patch_operations::replace; + } + if (op == "move") + { + return patch_operations::move; + } + if (op == "copy") + { + return patch_operations::copy; + } + if (op == "test") + { + return patch_operations::test; + } + + return patch_operations::invalid; + }; + + // wrapper for "add" operation; add value at ptr + const auto operation_add = [&result](json_pointer & ptr, basic_json val) + { + // adding to the root of the target document means replacing it + if (ptr.empty()) + { + result = val; + return; + } + + // make sure the top element of the pointer exists + json_pointer top_pointer = ptr.top(); + if (top_pointer != ptr) + { + result.at(top_pointer); + } + + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.back(); + ptr.pop_back(); + basic_json& parent = result[ptr]; + + switch (parent.m_type) + { + case value_t::null: + case value_t::object: + { + // use operator[] to add value + parent[last_path] = val; + break; + } + + case value_t::array: + { + if (last_path == "-") + { + // special case: append to back + parent.push_back(val); + } + else + { + const auto idx = json_pointer::array_index(last_path); + if (JSON_HEDLEY_UNLIKELY(idx > parent.size())) + { + // avoid undefined behavior + JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); + } + + // default case: insert add offset + parent.insert(parent.begin() + static_cast(idx), val); + } + break; + } + + // if there exists a parent it cannot be primitive + default: // LCOV_EXCL_LINE + JSON_ASSERT(false); // LCOV_EXCL_LINE + } + }; + + // wrapper for "remove" operation; remove value at ptr + const auto operation_remove = [&result](json_pointer & ptr) + { + // get reference to parent of JSON pointer ptr + const auto last_path = ptr.back(); + ptr.pop_back(); + basic_json& parent = result.at(ptr); + + // remove child + if (parent.is_object()) + { + // perform range check + auto it = parent.find(last_path); + if (JSON_HEDLEY_LIKELY(it != parent.end())) + { + parent.erase(it); + } + else + { + JSON_THROW(out_of_range::create(403, "key '" + last_path + "' not found")); + } + } + else if (parent.is_array()) + { + // note erase performs range check + parent.erase(json_pointer::array_index(last_path)); + } + }; + + // type check: top level value must be an array + if (JSON_HEDLEY_UNLIKELY(!json_patch.is_array())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); + } + + // iterate and apply the operations + for (const auto& val : json_patch) + { + // wrapper to get a value for an operation + const auto get_value = [&val](const std::string & op, + const std::string & member, + bool string_type) -> basic_json & + { + // find value + auto it = val.m_value.object->find(member); + + // context-sensitive error message + const auto error_msg = (op == "op") ? "operation" : "operation '" + op + "'"; + + // check if desired value is present + if (JSON_HEDLEY_UNLIKELY(it == val.m_value.object->end())) + { + JSON_THROW(parse_error::create(105, 0, error_msg + " must have member '" + member + "'")); + } + + // check if result is of type string + if (JSON_HEDLEY_UNLIKELY(string_type && !it->second.is_string())) + { + JSON_THROW(parse_error::create(105, 0, error_msg + " must have string member '" + member + "'")); + } + + // no error: return value + return it->second; + }; + + // type check: every element of the array must be an object + if (JSON_HEDLEY_UNLIKELY(!val.is_object())) + { + JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); + } + + // collect mandatory members + const auto op = get_value("op", "op", true).template get(); + const auto path = get_value(op, "path", true).template get(); + json_pointer ptr(path); + + switch (get_op(op)) + { + case patch_operations::add: + { + operation_add(ptr, get_value("add", "value", false)); + break; + } + + case patch_operations::remove: + { + operation_remove(ptr); + break; + } + + case patch_operations::replace: + { + // the "path" location must exist - use at() + result.at(ptr) = get_value("replace", "value", false); + break; + } + + case patch_operations::move: + { + const auto from_path = get_value("move", "from", true).template get(); + json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + basic_json v = result.at(from_ptr); + + // The move operation is functionally identical to a + // "remove" operation on the "from" location, followed + // immediately by an "add" operation at the target + // location with the value that was just removed. + operation_remove(from_ptr); + operation_add(ptr, v); + break; + } + + case patch_operations::copy: + { + const auto from_path = get_value("copy", "from", true).template get(); + const json_pointer from_ptr(from_path); + + // the "from" location must exist - use at() + basic_json v = result.at(from_ptr); + + // The copy is functionally identical to an "add" + // operation at the target location using the value + // specified in the "from" member. + operation_add(ptr, v); + break; + } + + case patch_operations::test: + { + bool success = false; + JSON_TRY + { + // check if "value" matches the one at "path" + // the "path" location must exist - use at() + success = (result.at(ptr) == get_value("test", "value", false)); + } + JSON_INTERNAL_CATCH (out_of_range&) + { + // ignore out of range errors: success remains false + } + + // throw an exception if test fails + if (JSON_HEDLEY_UNLIKELY(!success)) + { + JSON_THROW(other_error::create(501, "unsuccessful: " + val.dump())); + } + + break; + } + + default: + { + // op must be "add", "remove", "replace", "move", "copy", or + // "test" + JSON_THROW(parse_error::create(105, 0, "operation value '" + op + "' is invalid")); + } + } + } + + return result; + } + + /*! + @brief creates a diff as a JSON patch + + Creates a [JSON Patch](http://jsonpatch.com) so that value @a source can + be changed into the value @a target by calling @ref patch function. + + @invariant For two JSON values @a source and @a target, the following code + yields always `true`: + @code {.cpp} + source.patch(diff(source, target)) == target; + @endcode + + @note Currently, only `remove`, `add`, and `replace` operations are + generated. + + @param[in] source JSON value to compare from + @param[in] target JSON value to compare against + @param[in] path helper value to create JSON pointers + + @return a JSON patch to convert the @a source to @a target + + @complexity Linear in the lengths of @a source and @a target. + + @liveexample{The following code shows how a JSON patch is created as a + diff for two JSON values.,diff} + + @sa @ref patch -- apply a JSON patch + @sa @ref merge_patch -- apply a JSON Merge Patch + + @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) + + @since version 2.0.0 + */ + JSON_HEDLEY_WARN_UNUSED_RESULT + static basic_json diff(const basic_json& source, const basic_json& target, + const std::string& path = "") + { + // the patch + basic_json result(value_t::array); + + // if the values are the same, return empty patch + if (source == target) + { + return result; + } + + if (source.type() != target.type()) + { + // different types: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + return result; + } + + switch (source.type()) + { + case value_t::array: + { + // first pass: traverse common elements + std::size_t i = 0; + while (i < source.size() && i < target.size()) + { + // recursive call to compare array values at index i + auto temp_diff = diff(source[i], target[i], path + "/" + std::to_string(i)); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + ++i; + } + + // i now reached the end of at least one array + // in a second pass, traverse the remaining elements + + // remove my remaining elements + const auto end_index = static_cast(result.size()); + while (i < source.size()) + { + // add operations in reverse order to avoid invalid + // indices + result.insert(result.begin() + end_index, object( + { + {"op", "remove"}, + {"path", path + "/" + std::to_string(i)} + })); + ++i; + } + + // add other remaining elements + while (i < target.size()) + { + result.push_back( + { + {"op", "add"}, + {"path", path + "/-"}, + {"value", target[i]} + }); + ++i; + } + + break; + } + + case value_t::object: + { + // first pass: traverse this object's elements + for (auto it = source.cbegin(); it != source.cend(); ++it) + { + // escape the key name to be used in a JSON patch + const auto key = json_pointer::escape(it.key()); + + if (target.find(it.key()) != target.end()) + { + // recursive call to compare object values at key it + auto temp_diff = diff(it.value(), target[it.key()], path + "/" + key); + result.insert(result.end(), temp_diff.begin(), temp_diff.end()); + } + else + { + // found a key that is not in o -> remove it + result.push_back(object( + { + {"op", "remove"}, {"path", path + "/" + key} + })); + } + } + + // second pass: traverse other object's elements + for (auto it = target.cbegin(); it != target.cend(); ++it) + { + if (source.find(it.key()) == source.end()) + { + // found a key that is not in this -> add it + const auto key = json_pointer::escape(it.key()); + result.push_back( + { + {"op", "add"}, {"path", path + "/" + key}, + {"value", it.value()} + }); + } + } + + break; + } + + default: + { + // both primitive type: replace value + result.push_back( + { + {"op", "replace"}, {"path", path}, {"value", target} + }); + break; + } + } + + return result; + } + + /// @} + + //////////////////////////////// + // JSON Merge Patch functions // + //////////////////////////////// + + /// @name JSON Merge Patch functions + /// @{ + + /*! + @brief applies a JSON Merge Patch + + The merge patch format is primarily intended for use with the HTTP PATCH + method as a means of describing a set of modifications to a target + resource's content. This function applies a merge patch to the current + JSON value. + + The function implements the following algorithm from Section 2 of + [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396): + + ``` + define MergePatch(Target, Patch): + if Patch is an Object: + if Target is not an Object: + Target = {} // Ignore the contents and set it to an empty Object + for each Name/Value pair in Patch: + if Value is null: + if Name exists in Target: + remove the Name/Value pair from Target + else: + Target[Name] = MergePatch(Target[Name], Value) + return Target + else: + return Patch + ``` + + Thereby, `Target` is the current object; that is, the patch is applied to + the current value. + + @param[in] apply_patch the patch to apply + + @complexity Linear in the lengths of @a patch. + + @liveexample{The following code shows how a JSON Merge Patch is applied to + a JSON document.,merge_patch} + + @sa @ref patch -- apply a JSON patch + @sa [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396) + + @since version 3.0.0 + */ + void merge_patch(const basic_json& apply_patch) + { + if (apply_patch.is_object()) + { + if (!is_object()) + { + *this = object(); + } + for (auto it = apply_patch.begin(); it != apply_patch.end(); ++it) + { + if (it.value().is_null()) + { + erase(it.key()); + } + else + { + operator[](it.key()).merge_patch(it.value()); + } + } + } + else + { + *this = apply_patch; + } + } + + /// @} +}; + +/*! +@brief user-defined to_string function for JSON values + +This function implements a user-defined to_string for JSON objects. + +@param[in] j a JSON object +@return a std::string object +*/ + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +std::string to_string(const NLOHMANN_BASIC_JSON_TPL& j) +{ + return j.dump(); +} +} // namespace nlohmann + +/////////////////////// +// nonmember support // +/////////////////////// + +// specialization of std::swap, and std::hash +namespace std +{ + +/// hash value for JSON objects +template<> +struct hash +{ + /*! + @brief return a hash value for a JSON object + + @since version 1.0.0 + */ + std::size_t operator()(const nlohmann::json& j) const + { + return nlohmann::detail::hash(j); + } +}; + +/// specialization for std::less +/// @note: do not remove the space after '<', +/// see https://github.com/nlohmann/json/pull/679 +template<> +struct less<::nlohmann::detail::value_t> +{ + /*! + @brief compare two value_t enum values + @since version 3.0.0 + */ + bool operator()(nlohmann::detail::value_t lhs, + nlohmann::detail::value_t rhs) const noexcept + { + return nlohmann::detail::operator<(lhs, rhs); + } +}; + +// C++20 prohibit function specialization in the std namespace. +#ifndef JSON_HAS_CPP_20 + +/*! +@brief exchanges the values of two JSON objects + +@since version 1.0.0 +*/ +template<> +inline void swap(nlohmann::json& j1, nlohmann::json& j2) noexcept( + is_nothrow_move_constructible::value&& + is_nothrow_move_assignable::value + ) +{ + j1.swap(j2); +} + +#endif + +} // namespace std + +/*! +@brief user-defined string literal for JSON values + +This operator implements a user-defined string literal for JSON objects. It +can be used by adding `"_json"` to a string literal and returns a JSON object +if no parse error occurred. + +@param[in] s a string representation of a JSON object +@param[in] n the length of string @a s +@return a JSON object + +@since version 1.0.0 +*/ +JSON_HEDLEY_NON_NULL(1) +inline nlohmann::json operator "" _json(const char* s, std::size_t n) +{ + return nlohmann::json::parse(s, s + n); +} + +/*! +@brief user-defined string literal for JSON pointer + +This operator implements a user-defined string literal for JSON Pointers. It +can be used by adding `"_json_pointer"` to a string literal and returns a JSON pointer +object if no parse error occurred. + +@param[in] s a string representation of a JSON Pointer +@param[in] n the length of string @a s +@return a JSON pointer object + +@since version 2.0.0 +*/ +JSON_HEDLEY_NON_NULL(1) +inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +{ + return nlohmann::json::json_pointer(std::string(s, n)); +} + +// #include + + +// restore GCC/clang diagnostic settings +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + #pragma GCC diagnostic pop +#endif +#if defined(__clang__) + #pragma GCC diagnostic pop +#endif + +// clean up +#undef JSON_ASSERT +#undef JSON_INTERNAL_CATCH +#undef JSON_CATCH +#undef JSON_THROW +#undef JSON_TRY +#undef JSON_HAS_CPP_14 +#undef JSON_HAS_CPP_17 +#undef NLOHMANN_BASIC_JSON_TPL_DECLARATION +#undef NLOHMANN_BASIC_JSON_TPL +#undef JSON_EXPLICIT + +// #include +#undef JSON_HEDLEY_ALWAYS_INLINE +#undef JSON_HEDLEY_ARM_VERSION +#undef JSON_HEDLEY_ARM_VERSION_CHECK +#undef JSON_HEDLEY_ARRAY_PARAM +#undef JSON_HEDLEY_ASSUME +#undef JSON_HEDLEY_BEGIN_C_DECLS +#undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_BUILTIN +#undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_CLANG_HAS_EXTENSION +#undef JSON_HEDLEY_CLANG_HAS_FEATURE +#undef JSON_HEDLEY_CLANG_HAS_WARNING +#undef JSON_HEDLEY_COMPCERT_VERSION +#undef JSON_HEDLEY_COMPCERT_VERSION_CHECK +#undef JSON_HEDLEY_CONCAT +#undef JSON_HEDLEY_CONCAT3 +#undef JSON_HEDLEY_CONCAT3_EX +#undef JSON_HEDLEY_CONCAT_EX +#undef JSON_HEDLEY_CONST +#undef JSON_HEDLEY_CONSTEXPR +#undef JSON_HEDLEY_CONST_CAST +#undef JSON_HEDLEY_CPP_CAST +#undef JSON_HEDLEY_CRAY_VERSION +#undef JSON_HEDLEY_CRAY_VERSION_CHECK +#undef JSON_HEDLEY_C_DECL +#undef JSON_HEDLEY_DEPRECATED +#undef JSON_HEDLEY_DEPRECATED_FOR +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#undef JSON_HEDLEY_DIAGNOSTIC_POP +#undef JSON_HEDLEY_DIAGNOSTIC_PUSH +#undef JSON_HEDLEY_DMC_VERSION +#undef JSON_HEDLEY_DMC_VERSION_CHECK +#undef JSON_HEDLEY_EMPTY_BASES +#undef JSON_HEDLEY_EMSCRIPTEN_VERSION +#undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK +#undef JSON_HEDLEY_END_C_DECLS +#undef JSON_HEDLEY_FLAGS +#undef JSON_HEDLEY_FLAGS_CAST +#undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_BUILTIN +#undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_GCC_HAS_EXTENSION +#undef JSON_HEDLEY_GCC_HAS_FEATURE +#undef JSON_HEDLEY_GCC_HAS_WARNING +#undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK +#undef JSON_HEDLEY_GCC_VERSION +#undef JSON_HEDLEY_GCC_VERSION_CHECK +#undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_BUILTIN +#undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_GNUC_HAS_EXTENSION +#undef JSON_HEDLEY_GNUC_HAS_FEATURE +#undef JSON_HEDLEY_GNUC_HAS_WARNING +#undef JSON_HEDLEY_GNUC_VERSION +#undef JSON_HEDLEY_GNUC_VERSION_CHECK +#undef JSON_HEDLEY_HAS_ATTRIBUTE +#undef JSON_HEDLEY_HAS_BUILTIN +#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE +#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS +#undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE +#undef JSON_HEDLEY_HAS_EXTENSION +#undef JSON_HEDLEY_HAS_FEATURE +#undef JSON_HEDLEY_HAS_WARNING +#undef JSON_HEDLEY_IAR_VERSION +#undef JSON_HEDLEY_IAR_VERSION_CHECK +#undef JSON_HEDLEY_IBM_VERSION +#undef JSON_HEDLEY_IBM_VERSION_CHECK +#undef JSON_HEDLEY_IMPORT +#undef JSON_HEDLEY_INLINE +#undef JSON_HEDLEY_INTEL_VERSION +#undef JSON_HEDLEY_INTEL_VERSION_CHECK +#undef JSON_HEDLEY_IS_CONSTANT +#undef JSON_HEDLEY_IS_CONSTEXPR_ +#undef JSON_HEDLEY_LIKELY +#undef JSON_HEDLEY_MALLOC +#undef JSON_HEDLEY_MESSAGE +#undef JSON_HEDLEY_MSVC_VERSION +#undef JSON_HEDLEY_MSVC_VERSION_CHECK +#undef JSON_HEDLEY_NEVER_INLINE +#undef JSON_HEDLEY_NON_NULL +#undef JSON_HEDLEY_NO_ESCAPE +#undef JSON_HEDLEY_NO_RETURN +#undef JSON_HEDLEY_NO_THROW +#undef JSON_HEDLEY_NULL +#undef JSON_HEDLEY_PELLES_VERSION +#undef JSON_HEDLEY_PELLES_VERSION_CHECK +#undef JSON_HEDLEY_PGI_VERSION +#undef JSON_HEDLEY_PGI_VERSION_CHECK +#undef JSON_HEDLEY_PREDICT +#undef JSON_HEDLEY_PRINTF_FORMAT +#undef JSON_HEDLEY_PRIVATE +#undef JSON_HEDLEY_PUBLIC +#undef JSON_HEDLEY_PURE +#undef JSON_HEDLEY_REINTERPRET_CAST +#undef JSON_HEDLEY_REQUIRE +#undef JSON_HEDLEY_REQUIRE_CONSTEXPR +#undef JSON_HEDLEY_REQUIRE_MSG +#undef JSON_HEDLEY_RESTRICT +#undef JSON_HEDLEY_RETURNS_NON_NULL +#undef JSON_HEDLEY_SENTINEL +#undef JSON_HEDLEY_STATIC_ASSERT +#undef JSON_HEDLEY_STATIC_CAST +#undef JSON_HEDLEY_STRINGIFY +#undef JSON_HEDLEY_STRINGIFY_EX +#undef JSON_HEDLEY_SUNPRO_VERSION +#undef JSON_HEDLEY_SUNPRO_VERSION_CHECK +#undef JSON_HEDLEY_TINYC_VERSION +#undef JSON_HEDLEY_TINYC_VERSION_CHECK +#undef JSON_HEDLEY_TI_ARMCL_VERSION +#undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL2000_VERSION +#undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL430_VERSION +#undef JSON_HEDLEY_TI_CL430_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL6X_VERSION +#undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK +#undef JSON_HEDLEY_TI_CL7X_VERSION +#undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK +#undef JSON_HEDLEY_TI_CLPRU_VERSION +#undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK +#undef JSON_HEDLEY_TI_VERSION +#undef JSON_HEDLEY_TI_VERSION_CHECK +#undef JSON_HEDLEY_UNAVAILABLE +#undef JSON_HEDLEY_UNLIKELY +#undef JSON_HEDLEY_UNPREDICTABLE +#undef JSON_HEDLEY_UNREACHABLE +#undef JSON_HEDLEY_UNREACHABLE_RETURN +#undef JSON_HEDLEY_VERSION +#undef JSON_HEDLEY_VERSION_DECODE_MAJOR +#undef JSON_HEDLEY_VERSION_DECODE_MINOR +#undef JSON_HEDLEY_VERSION_DECODE_REVISION +#undef JSON_HEDLEY_VERSION_ENCODE +#undef JSON_HEDLEY_WARNING +#undef JSON_HEDLEY_WARN_UNUSED_RESULT +#undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG +#undef JSON_HEDLEY_FALL_THROUGH + + + +#endif // INCLUDE_NLOHMANN_JSON_HPP_ diff --git a/core/thirdparty/opentracing/CMakeLists.txt b/core/thirdparty/opentracing/CMakeLists.txt new file mode 100644 index 0000000000..0cc05ee624 --- /dev/null +++ b/core/thirdparty/opentracing/CMakeLists.txt @@ -0,0 +1,68 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +if ( DEFINED ENV{MILVUS_OPENTRACING_URL} ) + set(OPENTRACING_SOURCE_URL "$ENV{MILVUS_OPENTRACING_URL}") +else () + set(OPENTRACING_SOURCE_URL + "https://github.com/opentracing/opentracing-cpp/archive/${OPENTRACING_VERSION}.tar.gz" ) +endif () + +message(STATUS "Building OPENTRACING-${OPENTRACING_VERSION} from source") + +FetchContent_Declare( + opentracing + URL ${OPENTRACING_SOURCE_URL} + URL_MD5 "e598ba4b81ae8e1ceed8cd8bbf86f2fd" + DOWNLOAD_DIR ${MILVUS_BINARY_DIR}/3rdparty_download/download + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/opentracing-src + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/opentracing-build + ) + +set( BUILD_SHARED_LIBS CACHE BOOL OFF FORCE) +set( ENABLE_LINTING CACHE BOOL OFF FORCE) +set( BUILD_TESTING CACHE BOOL OFF FORCE ) +FetchContent_GetProperties( opentracing ) +if ( NOT opentracing_POPULATED ) + + FetchContent_Populate( opentracing ) + + # Adding the following targets: + # opentracing-static + # opentracing_mocktracer-static + add_subdirectory( ${opentracing_SOURCE_DIR} + ${opentracing_BINARY_DIR} + EXCLUDE_FROM_ALL ) + + # Opentracing-cpp CMakeLists.txt file didn't give a + # correct interface directories + target_include_directories( opentracing-static + INTERFACE $ + $ + $ ) + target_include_directories( opentracing_mocktracer-static + INTERFACE $ ) + # Adding the following ALIAS Targets: + # opentracing::opentracing + # opentracing::mocktracer + if ( NOT TARGET opentracing::opentracing ) + add_library( opentracing::opentracing ALIAS opentracing-static ) + endif() + if ( NOT TARGET opentracing::mocktracer ) + add_library( opentracing::mocktracer ALIAS opentracing_mocktracer-static ) + endif() + +endif() + +get_property( var DIRECTORY "${opentracing_SOURCE_DIR}" PROPERTY COMPILE_OPTIONS ) +message( STATUS "opentracing compile options: ${var}" ) diff --git a/core/thirdparty/versions.txt b/core/thirdparty/versions.txt new file mode 100644 index 0000000000..47a2c01301 --- /dev/null +++ b/core/thirdparty/versions.txt @@ -0,0 +1,6 @@ +EASYLOGGINGPP_VERSION=v9.96.7 +GTEST_VERSION=1.8.1 +YAMLCPP_VERSION=0.6.3 +ZLIB_VERSION=v1.2.11 +OPENTRACING_VERSION=v1.5.1 +# vim: set filetype=sh: diff --git a/core/thirdparty/yaml-cpp/CMakeLists.txt b/core/thirdparty/yaml-cpp/CMakeLists.txt new file mode 100644 index 0000000000..a21b7b751b --- /dev/null +++ b/core/thirdparty/yaml-cpp/CMakeLists.txt @@ -0,0 +1,50 @@ +#------------------------------------------------------------------------------- +# Copyright (C) 2019-2020 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. +#------------------------------------------------------------------------------- + +if ( DEFINED ENV{MILVUS_YAMLCPP_URL} ) + set( YAMLCPP_SOURCE_URL "$ENV{MILVUS_YAMLCPP_URL}" ) +else() + set( YAMLCPP_SOURCE_URL + "https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-${YAMLCPP_VERSION}.tar.gz" ) +endif() + +message( STATUS "Building yaml-cpp-${YAMLCPP_VERSION} from source" ) +FetchContent_Declare( + yaml-cpp + URL ${YAMLCPP_SOURCE_URL} + URL_MD5 "b45bf1089a382e81f6b661062c10d0c2" + DOWNLOAD_DIR ${MILVUS_BINARY_DIR}/3rdparty_download/download + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/yaml-src + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/yaml-build + ) + +set( YAML_CPP_BUILD_TESTS CACHE BOOL OFF FORCE ) +set( YAML_CPP_BUILD_TOOLS CACHE BOOL OFF FORCE ) +FetchContent_GetProperties( yaml-cpp ) +if ( NOT yaml-cpp_POPULATED ) + + FetchContent_Populate( yaml-cpp ) + + # Adding the following targets: + # yaml-cpp::yaml-cpp, yaml-cpp + add_subdirectory( ${yaml-cpp_SOURCE_DIR} + ${yaml-cpp_BINARY_DIR} + EXCLUDE_FROM_ALL ) + +endif() + +get_target_property( YAML_CPP_INCLUDE_DIR yaml-cpp INTERFACE_INCLUDE_DIRECTORIES ) +message( STATUS ${YAML_CPP_INCLUDE_DIR} ) + +get_property( var DIRECTORY "${yaml-cpp_SOURCE_DIR}" PROPERTY COMPILE_OPTIONS ) +message( STATUS "yaml compile options: ${var}" ) diff --git a/core/unittest/CMakeLists.txt b/core/unittest/CMakeLists.txt index 0a1fb55d1b..a2b05944e8 100644 --- a/core/unittest/CMakeLists.txt +++ b/core/unittest/CMakeLists.txt @@ -1,17 +1,24 @@ enable_testing() find_package(GTest REQUIRED) -set(MILVUS_TEST_FILES - test_naive.cpp - # test_dog_segment.cpp - test_c_api.cpp -) -add_executable(all_tests - ${MILVUS_TEST_FILES} -) -target_link_libraries(all_tests - gtest - gtest_main - milvus_dog_segment - pthread -) \ No newline at end of file +include_directories(${CMAKE_HOME_DIRECTORY}/src) +include_directories(>>>> ${CMAKE_HOME_DIRECTORY}/src/index/knowhere) +set(MILVUS_TEST_FILES + test_naive.cpp + test_dog_segment.cpp + test_concurrent_vector.cpp + test_c_api.cpp + test_indexing.cpp + ) +add_executable(all_tests + ${MILVUS_TEST_FILES} + ) + +target_link_libraries(all_tests + gtest + gtest_main + milvus_dog_segment + knowhere + log + pthread + ) \ No newline at end of file diff --git a/core/unittest/test_c_api.cpp b/core/unittest/test_c_api.cpp index 0e312c5ec0..17ed70936c 100644 --- a/core/unittest/test_c_api.cpp +++ b/core/unittest/test_c_api.cpp @@ -7,6 +7,8 @@ #include "dog_segment/segment_c.h" + + TEST(CApiTest, CollectionTest) { auto collection_name = "collection0"; auto schema_tmp_conf = "null_schema"; @@ -137,8 +139,9 @@ TEST(CApiTest, SearchTest) { long result_ids[10]; float result_distances[10]; - auto sea_res = Search(segment, nullptr, 0, result_ids, result_distances); + auto sea_res = Search(segment, nullptr, 1, result_ids, result_distances); assert(sea_res == 0); + assert(result_ids[0] == 100911); DeleteCollection(collection); DeletePartition(partition); @@ -180,7 +183,7 @@ TEST(CApiTest, CloseTest) { } - +namespace { auto generate_data(int N) { std::vector raw_data; std::vector timestamps; @@ -202,6 +205,7 @@ auto generate_data(int N) { } return std::make_tuple(raw_data, timestamps, uids); } +} TEST(CApiTest, TestQuery) { @@ -235,6 +239,9 @@ TEST(CApiTest, TestQuery) { auto pre_off = PreDelete(segment, N / 2); Delete(segment, pre_off, N / 2, uids.data(), del_ts.data()); + Close(segment); + BuildIndex(segment); + std::vector result_ids2(10); std::vector result_distances2(10); @@ -258,6 +265,7 @@ TEST(CApiTest, TestQuery) { } } + DeleteCollection(collection); DeletePartition(partition); DeleteSegment(segment); diff --git a/core/unittest/test_concurrent_vector.cpp b/core/unittest/test_concurrent_vector.cpp new file mode 100644 index 0000000000..023e310072 --- /dev/null +++ b/core/unittest/test_concurrent_vector.cpp @@ -0,0 +1,129 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include +#include + +#include "dog_segment/ConcurrentVector.h" +#include "dog_segment/SegmentBase.h" +// #include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#include "dog_segment/SegmentBase.h" +#include "dog_segment/AckResponder.h" + +using std::cin; +using std::cout; +using std::endl; +using namespace milvus::engine; +using namespace milvus::dog_segment; +using std::vector; + +TEST(ConcurrentVector, TestABI) { + ASSERT_EQ(TestABI(), 42); + assert(true); +} + +TEST(ConcurrentVector, TestSingle) { + auto dim = 8; + ConcurrentVector c_vec(dim); + std::default_random_engine e(42); + int data = 0; + auto total_count = 0; + for (int i = 0; i < 10000; ++i) { + int insert_size = e() % 150; + vector vec(insert_size * dim); + for (auto& x : vec) { + x = data++; + } + c_vec.grow_to_at_least(total_count + insert_size); + c_vec.set_data(total_count, vec.data(), insert_size); + total_count += insert_size; + } + ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32); + for (int i = 0; i < total_count; ++i) { + for (int d = 0; d < dim; ++d) { + auto std_data = d + i * dim; + ASSERT_EQ(c_vec.get_element(i)[d], std_data); + } + } +} + +TEST(ConcurrentVector, TestMultithreads) { + auto dim = 8; + constexpr int threads = 16; + std::vector total_counts(threads); + + ConcurrentVector c_vec(dim); + std::atomic ack_counter = 0; + // std::mutex mutex; + + auto executor = [&](int thread_id) { + std::default_random_engine e(42 + thread_id); + int64_t data = 0; + int64_t total_count = 0; + for (int i = 0; i < 10000; ++i) { + // std::lock_guard lck(mutex); + int insert_size = e() % 150; + vector vec(insert_size * dim); + for (auto& x : vec) { + x = data++ * threads + thread_id; + } + auto offset = ack_counter.fetch_add(insert_size); + c_vec.grow_to_at_least(offset + insert_size); + c_vec.set_data(offset, vec.data(), insert_size); + total_count += insert_size; + } + assert(data == total_count * dim); + total_counts[thread_id] = total_count; + }; + std::vector pool; + for (int i = 0; i < threads; ++i) { + pool.emplace_back(executor, i); + } + for (auto& thread : pool) { + thread.join(); + } + + std::vector counts(threads); + auto N = ack_counter.load(); + for (int64_t i = 0; i < N; ++i) { + for (int d = 0; d < dim; ++d) { + auto data = c_vec.get_element(i)[d]; + auto thread_id = data % threads; + auto raw_data = data / threads; + auto std_data = counts[thread_id]++; + ASSERT_EQ(raw_data, std_data) << data; + } + } +} +TEST(ConcurrentVector, TestAckSingle) { + std::vector> raw_data; + std::default_random_engine e(42); + AckResponder ack; + int N = 10000; + for(int i = 0; i < 10000; ++i) { + auto weight = i + e() % 100; + raw_data.emplace_back(weight, i, (i + 1)); + } + std::sort(raw_data.begin(), raw_data.end()); + for(auto [_, b, e]: raw_data) { + EXPECT_LE(ack.GetAck(), b); + ack.AddSegment(b, e); + auto seg = ack.GetAck(); + EXPECT_GE(seg + 100, b); + } + EXPECT_EQ(ack.GetAck(), N); +} diff --git a/core/unittest/test_dog_segment.cpp b/core/unittest/test_dog_segment.cpp index e8e9a957b5..b50addd8b3 100644 --- a/core/unittest/test_dog_segment.cpp +++ b/core/unittest/test_dog_segment.cpp @@ -9,71 +9,21 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -// #include -// #include -// #include +#include #include #include -// #include "db/SnapshotVisitor.h" -// #include "db/Types.h" -// #include "db/snapshot/IterateHandler.h" -// #include "db/snapshot/Resources.h" -// #include "db/utils.h" // #include "knowhere/index/vector_index/helpers/IndexParameter.h" // #include "segment/SegmentReader.h" // #include "segment/SegmentWriter.h" -// #include "src/dog_segment/SegmentBase.h" +#include "dog_segment/SegmentBase.h" // #include "utils/Json.h" #include -#include -#include "dog_segment/SegmentBase.h" using std::cin; using std::cout; using std::endl; -// using SegmentVisitor = milvus::engine::SegmentVisitor; - -// namespace { -// milvus::Status -// CreateCollection(std::shared_ptr db, const std::string& collection_name, const LSN_TYPE& lsn) { -// CreateCollectionContext context; -// context.lsn = lsn; -// auto collection_schema = std::make_shared(collection_name); -// context.collection = collection_schema; - -// int64_t collection_id = 0; -// int64_t field_id = 0; -// /* field uid */ -// auto uid_field = std::make_shared(milvus::engine::FIELD_UID, 0, milvus::engine::DataType::INT64, -// milvus::engine::snapshot::JEmpty, field_id); -// auto uid_field_element_blt = -// std::make_shared(collection_id, field_id, milvus::engine::ELEMENT_BLOOM_FILTER, -// milvus::engine::FieldElementType::FET_BLOOM_FILTER); -// auto uid_field_element_del = -// std::make_shared(collection_id, field_id, milvus::engine::ELEMENT_DELETED_DOCS, -// milvus::engine::FieldElementType::FET_DELETED_DOCS); - -// field_id++; -// /* field vector */ -// milvus::json vector_param = {{milvus::knowhere::meta::DIM, 4}}; -// auto vector_field = -// std::make_shared("vector", 0, milvus::engine::DataType::VECTOR_FLOAT, vector_param, field_id); -// auto vector_field_element_index = -// std::make_shared(collection_id, field_id, milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, -// milvus::engine::FieldElementType::FET_INDEX); -// /* another field*/ -// auto int_field = std::make_shared("int", 0, milvus::engine::DataType::INT32, -// milvus::engine::snapshot::JEmpty, field_id++); - -// context.fields_schema[uid_field] = {uid_field_element_blt, uid_field_element_del}; -// context.fields_schema[vector_field] = {vector_field_element_index}; -// context.fields_schema[int_field] = {}; - -// return db->CreateCollection(context); -// } -// } // namespace TEST(DogSegmentTest, TestABI) { using namespace milvus::engine; @@ -82,60 +32,6 @@ TEST(DogSegmentTest, TestABI) { assert(true); } -// TEST_F(DogSegmentTest, TestCreateAndSchema) { -// using namespace milvus::engine; -// using namespace milvus::dog_segment; -// // step1: create segment from current snapshot. - -// LSN_TYPE lsn = 0; -// auto next_lsn = [&]() -> decltype(lsn) { return ++lsn; }; - -// // step 1.1: create collection -// std::string db_root = "/tmp/milvus_test/db/table"; -// std::string collection_name = "c1"; -// auto status = CreateCollection(db_, collection_name, next_lsn()); -// ASSERT_TRUE(status.ok()); - -// // step 1.2: get snapshot -// ScopedSnapshotT snapshot; -// status = Snapshots::GetInstance().GetSnapshot(snapshot, collection_name); -// ASSERT_TRUE(status.ok()); -// ASSERT_TRUE(snapshot); -// ASSERT_EQ(snapshot->GetName(), collection_name); - -// // step 1.3: get partition_id -// cout << endl; -// cout << endl; -// ID_TYPE partition_id = snapshot->GetResources().begin()->first; -// cout << partition_id; - -// // step 1.5 create schema from ids -// auto collection = snapshot->GetCollection(); - -// auto field_names = snapshot->GetFieldNames(); -// auto schema = std::make_shared(); -// for (const auto& field_name : field_names) { -// auto the_field = snapshot->GetField(field_name); -// auto param = the_field->GetParams(); -// auto type = the_field->GetFtype(); -// cout << field_name // -// << " " << (int)type // -// << " " << param // -// << endl; -// FieldMeta field(field_name, type); -// int dim = 1; -// if(field.is_vector()) { -// field.set_dim(dim); -// } -// schema->AddField(field); - -// } -// // step 1.6 create a segment from ids -// auto segment = CreateSegment(schema); -// std::vector primary_ids; -// } - - TEST(DogSegmentTest, MockTest) { using namespace milvus::dog_segment; @@ -145,7 +41,7 @@ TEST(DogSegmentTest, MockTest) { schema->AddField("age", DataType::INT32); std::vector raw_data; std::vector timestamps; - std::vector uids; + std::vector uids; int N = 10000; std::default_random_engine e(67); for(int i = 0; i < N; ++i) { @@ -163,108 +59,18 @@ TEST(DogSegmentTest, MockTest) { auto line_sizeof = (sizeof(int) + sizeof(float) * 16); assert(raw_data.size() == line_sizeof * N); - auto segment = CreateSegment(schema).release(); + + // auto index_meta = std::make_shared(schema); + auto segment = CreateSegment(schema, nullptr); + DogDataChunk data_chunk{raw_data.data(), (int)line_sizeof, N}; - segment->Insert(N, uids.data(), timestamps.data(), data_chunk); + auto offset = segment->PreInsert(N); + segment->Insert(offset, N, uids.data(), timestamps.data(), data_chunk); QueryResult query_result; - segment->Query(nullptr, 0, query_result); - delete segment; +// segment->Query(nullptr, 0, query_result); + segment->Close(); +// segment->BuildIndex(); int i = 0; i++; } -//TEST_F(DogSegmentTest, DogSegmentTest) { -// LSN_TYPE lsn = 0; -// auto next_lsn = [&]() -> decltype(lsn) { return ++lsn; }; -// -// std::string db_root = "/tmp/milvus_test/db/table"; -// std::string c1 = "c1"; -// auto status = CreateCollection(db_, c1, next_lsn()); -// ASSERT_TRUE(status.ok()); -// -// ScopedSnapshotT snapshot; -// status = Snapshots::GetInstance().GetSnapshot(snapshot, c1); -// ASSERT_TRUE(status.ok()); -// ASSERT_TRUE(snapshot); -// ASSERT_EQ(snapshot->GetName(), c1); -// { -// SegmentFileContext sf_context; -// SFContextBuilder(sf_context, snapshot); -// } -// std::vector segfile_ctxs; -// SFContextsBuilder(segfile_ctxs, snapshot); -// -// std::cout << snapshot->ToString() << std::endl; -// -// ID_TYPE partition_id; -// { -// auto& partitions = snapshot->GetResources(); -// partition_id = partitions.begin()->first; -// } -// -// [&next_lsn, // -// &segfile_ctxs, // -// &partition_id, // -// &snapshot, // -// &db_root] { -// /* commit new segment */ -// OperationContext op_ctx; -// op_ctx.lsn = next_lsn(); -// op_ctx.prev_partition = snapshot->GetResource(partition_id); -// -// auto new_seg_op = std::make_shared(op_ctx, snapshot); -// SegmentPtr new_seg; -// auto status = new_seg_op->CommitNewSegment(new_seg); -// ASSERT_TRUE(status.ok()); -// -// /* commit new segment file */ -// for (auto& cctx : segfile_ctxs) { -// SegmentFilePtr seg_file; -// auto nsf_context = cctx; -// nsf_context.segment_id = new_seg->GetID(); -// nsf_context.partition_id = new_seg->GetPartitionId(); -// status = new_seg_op->CommitNewSegmentFile(nsf_context, seg_file); -// } -// -// /* build segment visitor */ -// auto ctx = new_seg_op->GetContext(); -// ASSERT_TRUE(ctx.new_segment); -// auto visitor = SegmentVisitor::Build(snapshot, ctx.new_segment, ctx.new_segment_files); -// ASSERT_TRUE(visitor); -// ASSERT_EQ(visitor->GetSegment(), new_seg); -// ASSERT_FALSE(visitor->GetSegment()->IsActive()); -// // std::cout << visitor->ToString() << std::endl; -// // std::cout << snapshot->ToString() << std::endl; -// -// /* write data */ -// milvus::segment::SegmentWriter segment_writer(db_root, visitor); -// -// // std::vector raw_uids = {123}; -// // std::vector raw_vectors = {1, 2, 3, 4}; -// // status = segment_writer.AddChunk("test", raw_vectors, raw_uids); -// // ASSERT_TRUE(status.ok()) -// // -// // status = segment_writer.Serialize(); -// // ASSERT_TRUE(status.ok()); -// -// /* read data */ -// // milvus::segment::SSSegmentReader segment_reader(db_root, visitor); -// // -// // status = segment_reader.Load(); -// // ASSERT_TRUE(status.ok()); -// // -// // milvus::segment::SegmentPtr segment_ptr; -// // status = segment_reader.GetSegment(segment_ptr); -// // ASSERT_TRUE(status.ok()); -// // -// // auto& out_uids = segment_ptr->vectors_ptr_->GetUids(); -// // ASSERT_EQ(raw_uids.size(), out_uids.size()); -// // ASSERT_EQ(raw_uids[0], out_uids[0]); -// // auto& out_vectors = segment_ptr->vectors_ptr_->GetData(); -// // ASSERT_EQ(raw_vectors.size(), out_vectors.size()); -// // ASSERT_EQ(raw_vectors[0], out_vectors[0]); -// }(); -// -// status = db_->DropCollection(c1); -// ASSERT_TRUE(status.ok()); -//} diff --git a/core/unittest/test_indexing.cpp b/core/unittest/test_indexing.cpp new file mode 100644 index 0000000000..f83d31f9ff --- /dev/null +++ b/core/unittest/test_indexing.cpp @@ -0,0 +1,93 @@ +#include + +#include +#include +#include +#include +#include + +#include "dog_segment/ConcurrentVector.h" +#include "dog_segment/SegmentBase.h" +// #include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#include "dog_segment/SegmentBase.h" +#include "dog_segment/AckResponder.h" + +#include +#include +#include + +using std::cin; +using std::cout; +using std::endl; +using namespace milvus::engine; +using namespace milvus::dog_segment; +using std::vector; +using namespace milvus; + +namespace { + template + auto generate_data(int N) { + std::vector raw_data; + std::vector timestamps; + std::vector uids; + std::default_random_engine er(42); + std::uniform_real_distribution<> distribution(0.0, 1.0); + std::default_random_engine ei(42); + for (int i = 0; i < N; ++i) { + uids.push_back(10 * N + i); + timestamps.push_back(0); + // append vec + float vec[DIM]; + for (auto &x: vec) { + x = distribution(er); + } + raw_data.insert(raw_data.end(), (const char *) std::begin(vec), (const char *) std::end(vec)); +// int age = ei() % 100; +// raw_data.insert(raw_data.end(), (const char *) &age, ((const char *) &age) + sizeof(age)); + } + return std::make_tuple(raw_data, timestamps, uids); + } +} + +TEST(TestIndex, Naive) { + constexpr int N = 100000; + constexpr int DIM = 16; + constexpr int TOPK = 10; + + auto[raw_data, timestamps, uids] = generate_data(N); + auto index = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, + knowhere::IndexMode::MODE_CPU); + auto conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, DIM}, + {milvus::knowhere::meta::TOPK, TOPK}, + {milvus::knowhere::IndexParams::nlist, 100}, + {milvus::knowhere::IndexParams::nprobe, 4}, + {milvus::knowhere::IndexParams::m, 4}, + {milvus::knowhere::IndexParams::nbits, 8}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {milvus::knowhere::meta::DEVICEID, 0}, + }; + + auto ds = knowhere::GenDatasetWithIds(N / 2, DIM, raw_data.data(), uids.data()); + auto ds2 = knowhere::GenDatasetWithIds(N / 2, DIM, raw_data.data() + sizeof(float[DIM]) * N / 2, uids.data() + N / 2); + // NOTE: you must train first and then add + index->Train(ds, conf); + index->Train(ds2, conf); + index->Add(ds, conf); + index->Add(ds2, conf); + + auto query_ds = knowhere::GenDataset(1, DIM, raw_data.data()); + auto final = index->Query(query_ds, conf); + auto mmm = final->data(); + cout << endl; + for(auto [k, v]: mmm) { + cout << k << endl; + } + auto ids = final->Get(knowhere::meta::IDS); + auto distances = final->Get(knowhere::meta::DISTANCE); + for(int i = 0; i < TOPK; ++i) { + cout << ids[i] << "->" << distances[i] << endl; + } + int i = 1+1; +} diff --git a/pkg/master/mock/grpc_client_test.go b/pkg/master/mock/grpc_client_test.go index 4b743bddbe..a4d015a514 100644 --- a/pkg/master/mock/grpc_client_test.go +++ b/pkg/master/mock/grpc_client_test.go @@ -10,5 +10,6 @@ func TestFakeCreateCollectionByGRPC(t *testing.T) { if reason != "" { t.Error(reason) } + fmt.Println(collectionName) fmt.Println(segmentID) } diff --git a/pkg/master/server.go b/pkg/master/server.go index b3579da497..7e0c257648 100644 --- a/pkg/master/server.go +++ b/pkg/master/server.go @@ -14,18 +14,15 @@ import ( "github.com/czs007/suvlim/pkg/master/informer" "github.com/czs007/suvlim/pkg/master/kv" "github.com/czs007/suvlim/pkg/master/mock" - "github.com/google/uuid" "go.etcd.io/etcd/clientv3" "google.golang.org/grpc" ) func Run() { go mock.FakePulsarProducer() + go GRPCServer() go SegmentStatsController() - collectionChan := make(chan *messagepb.Mapping) - defer close(collectionChan) - go GRPCServer(collectionChan) - go CollectionController(collectionChan) + go CollectionController() for { } } @@ -78,13 +75,13 @@ func ComputeCloseTime(ss mock.SegmentStats, kvbase kv.Base) error { return nil } -func GRPCServer(ch chan *messagepb.Mapping) error { +func GRPCServer() error { lis, err := net.Listen("tcp", common.DEFAULT_GRPC_PORT) if err != nil { return err } s := grpc.NewServer() - pb.RegisterMasterServer(s, GRPCMasterServer{CreateRequest: ch}) + pb.RegisterMasterServer(s, GRPCMasterServer{}) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) return err @@ -92,13 +89,9 @@ func GRPCServer(ch chan *messagepb.Mapping) error { return nil } -type GRPCMasterServer struct { - CreateRequest chan *messagepb.Mapping -} +type GRPCMasterServer struct{} func (ms GRPCMasterServer) CreateCollection(ctx context.Context, in *messagepb.Mapping) (*messagepb.Status, error) { - ms.CreateRequest <- in - fmt.Println("Handle a new create collection request") return &messagepb.Status{ ErrorCode: 0, Reason: "", @@ -111,35 +104,26 @@ func (ms GRPCMasterServer) CreateCollection(ctx context.Context, in *messagepb.M // }, nil // } -func CollectionController(ch chan *messagepb.Mapping) { +func CollectionController() { cli, _ := clientv3.New(clientv3.Config{ Endpoints: []string{"127.0.0.1:12379"}, DialTimeout: 5 * time.Second, }) defer cli.Close() kvbase := kv.NewEtcdKVBase(cli, common.ETCD_ROOT_PATH) - for collection := range ch { - pTag := uuid.New() - cID := uuid.New() - c := mock.Collection{ - Name: collection.CollectionName, - CreateTime: time.Now(), - ID: uint64(cID.ID()), - PartitionTags: []string{pTag.String()}, - } - s := mock.FakeCreateSegment(uint64(pTag.ID()), c, time.Now(), time.Unix(1<<36-1, 0)) - collectionData, _ := mock.Collection2JSON(c) - segmentData, err := mock.Segment2JSON(s) - if err != nil { - log.Fatal(err) - } - err = kvbase.Save(cID.String(), collectionData) - if err != nil { - log.Fatal(err) - } - err = kvbase.Save(pTag.String(), segmentData) - if err != nil { - log.Fatal(err) - } + c := mock.FakeCreateCollection(uint64(3333)) + s := mock.FakeCreateSegment(uint64(11111), c, time.Now(), time.Unix(1<<36-1, 0)) + collectionData, _ := mock.Collection2JSON(c) + segmentData, err := mock.Segment2JSON(s) + if err != nil { + log.Fatal(err) + } + err = kvbase.Save("test-collection", collectionData) + if err != nil { + log.Fatal(err) + } + err = kvbase.Save("test-segment", segmentData) + if err != nil { + log.Fatal(err) } } diff --git a/proxy/src/message_client/ClientV2.cpp b/proxy/src/message_client/ClientV2.cpp index dab00da4a9..c1924022c6 100644 --- a/proxy/src/message_client/ClientV2.cpp +++ b/proxy/src/message_client/ClientV2.cpp @@ -54,17 +54,17 @@ milvus::grpc::QueryResult Aggregation(std::vector all_scores; std::vector all_distance; std::vector all_kv_pairs; - std::vector index(length * results[0]->distances_size()); + std::vector index(length * results[0]->scores_size()); - for (int n = 0; n < length * results[0]->distances_size(); ++n) { + for (int n = 0; n < length * results[0]->scores_size(); ++n) { index[n] = n; } for (int i = 0; i < length; i++) { - for (int j = 0; j < results[i]->distances_size(); j++) { -// all_scores.push_back(results[i]->scores()[j]); + for (int j = 0; j < results[i]->scores_size(); j++) { + all_scores.push_back(results[i]->scores()[j]); all_distance.push_back(results[i]->distances()[j]); -// all_kv_pairs.push_back(results[i]->extra_params()[j]); + all_kv_pairs.push_back(results[i]->extra_params()[j]); } } @@ -89,20 +89,22 @@ milvus::grpc::QueryResult Aggregation(std::vectorCopyFrom(results[0]->entities()); result.set_row_num(results[0]->row_num()); - for (int m = 0; m < results[0]->distances_size(); ++m) { -// result.add_scores(all_scores[index[m]]); + for (int m = 0; m < results[0]->scores_size(); ++m) { + result.add_scores(all_scores[index[m]]); result.add_distances(all_distance[m]); -// result.add_extra_params(); -// result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]); + result.add_extra_params(); + result.mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]); } -// result.set_query_id(results[0]->query_id()); -// result.set_client_id(results[0]->client_id()); + result.set_query_id(results[0]->query_id()); + result.set_client_id(results[0]->client_id()); return result; } -Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult* result) { +Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result) { + + std::vector> results; int64_t query_node_num = GetQueryNodeNum(); @@ -124,7 +126,7 @@ Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult* return Status(DB_ERROR, "can't parse message which from pulsar!"); } } - *result = Aggregation(total_results[query_id]); + result = Aggregation(total_results[query_id]); return Status::OK(); } diff --git a/proxy/src/message_client/ClientV2.h b/proxy/src/message_client/ClientV2.h index 67d8184241..3af1e798f4 100644 --- a/proxy/src/message_client/ClientV2.h +++ b/proxy/src/message_client/ClientV2.h @@ -31,7 +31,7 @@ class MsgClientV2 { // Status SendQueryMessage(const milvus::grpc::SearchParam &request, uint64_t timestamp, int64_t &query_id); - Status GetQueryResult(int64_t query_id, milvus::grpc::QueryResult* result); + Status GetQueryResult(int64_t query_id, milvus::grpc::QueryResult &result); private: int64_t GetUniqueQId() { diff --git a/proxy/src/server/delivery/request/SearchReq.cpp b/proxy/src/server/delivery/request/SearchReq.cpp index d5afd298f8..289776569d 100644 --- a/proxy/src/server/delivery/request/SearchReq.cpp +++ b/proxy/src/server/delivery/request/SearchReq.cpp @@ -58,7 +58,7 @@ SearchReq::OnExecute() { return send_status; } - Status status = client->GetQueryResult(query_id, result_); + Status status = client->GetQueryResult(query_id, *result_); return status; } diff --git a/proxy/thirdparty/grpc/CMakeLists.txt b/proxy/thirdparty/grpc/CMakeLists.txt index bfe7d99481..e41ec9d524 100644 --- a/proxy/thirdparty/grpc/CMakeLists.txt +++ b/proxy/thirdparty/grpc/CMakeLists.txt @@ -65,7 +65,7 @@ add_custom_command(TARGET generate_suvlim_pb_grpc POST_BUILD COMMAND echo "${PROTOC_EXCUTABLE}" COMMAND bash "${PROTO_GEN_SCRIPTS_DIR}/generate_go.sh" -p "${PROTOC_EXCUTABLE}" - COMMAND bash "${PROTO_GEN_SCRIPTS_DIR}/generate_cpp.sh" -p "${PROTOC_EXCUTABLE}" -g "${GRPC_CPP_PLUGIN_EXCUTABLE}" + COMMAND echo "${PROTO_GEN_SCRIPTS_DIR}/generate_cpp.sh" -p "${PROTOC_EXCUTABLE}" -g "${GRPC_CPP_PLUGIN_EXCUTABLE}" ) set_property( GLOBAL PROPERTY PROTOC_EXCUTABLE ${PROTOC_EXCUTABLE}) diff --git a/reader/index.go b/reader/index.go index 28d376c040..4831fa6bba 100644 --- a/reader/index.go +++ b/reader/index.go @@ -1,7 +1,7 @@ package reader import ( - msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/message" ) type IndexConfig struct {} diff --git a/reader/message_client/message_client.go b/reader/message_client/message_client.go index eeeb3b825a..3ab04497fe 100644 --- a/reader/message_client/message_client.go +++ b/reader/message_client/message_client.go @@ -2,12 +2,10 @@ package message_client import ( "context" - "fmt" "github.com/apache/pulsar-client-go/pulsar" - msgpb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgpb "github.com/czs007/suvlim/pkg/message" "github.com/golang/protobuf/proto" "log" - "time" ) type MessageClient struct { @@ -34,21 +32,14 @@ type MessageClient struct { } func (mc *MessageClient) Send(ctx context.Context, msg msgpb.QueryResult) { - var msgBuffer, _ = proto.Marshal(&msg) if _, err := mc.searchResultProducer.Send(ctx, &pulsar.ProducerMessage{ - Payload: msgBuffer, + Payload: []byte(msg.String()), }); err != nil { log.Fatal(err) } } -func (mc *MessageClient) GetSearchChan() chan *msgpb.SearchMsg { - return mc.searchChan -} - func (mc *MessageClient) ReceiveInsertOrDeleteMsg() { - var count = 0 - var start time.Time for { insetOrDeleteMsg := msgpb.InsertOrDeleteMsg{} msg, err := mc.insertOrDeleteConsumer.Receive(context.Background()) @@ -56,16 +47,8 @@ func (mc *MessageClient) ReceiveInsertOrDeleteMsg() { if err != nil { log.Fatal(err) } - if count == 0 { - start = time.Now() - } - count++ mc.insertOrDeleteChan <- &insetOrDeleteMsg mc.insertOrDeleteConsumer.Ack(msg) - if count == 100000 - 1 { - elapsed := time.Since(start) - fmt.Println("Query node ReceiveInsertOrDeleteMsg time:", elapsed) - } } } @@ -112,7 +95,6 @@ func (mc *MessageClient) ReceiveMessage() { go mc.ReceiveInsertOrDeleteMsg() go mc.ReceiveSearchMsg() go mc.ReceiveTimeSyncMsg() - go mc.ReceiveKey2SegMsg() } func (mc *MessageClient) CreatProducer(topicName string) pulsar.Producer { @@ -215,30 +197,21 @@ func (mc *MessageClient) PrepareMsg(messageType MessageType, msgLen int) { } } -func (mc *MessageClient) PrepareKey2SegmentMsg() { - mc.Key2SegMsg = mc.Key2SegMsg[:0] - msgLen := len(mc.key2SegChan) - for i := 0; i < msgLen; i++ { - msg := <-mc.key2SegChan - mc.Key2SegMsg = append(mc.Key2SegMsg, msg) - } -} - func (mc *MessageClient) PrepareBatchMsg() []int { // assume the channel not full mc.InsertOrDeleteMsg = mc.InsertOrDeleteMsg[:0] - //mc.SearchMsg = mc.SearchMsg[:0] + mc.SearchMsg = mc.SearchMsg[:0] mc.TimeSyncMsg = mc.TimeSyncMsg[:0] // get the length of every channel insertOrDeleteLen := len(mc.insertOrDeleteChan) - //searchLen := len(mc.searchChan) + searchLen := len(mc.searchChan) timeLen := len(mc.timeSyncChan) // get message from channel to slice mc.PrepareMsg(InsertOrDelete, insertOrDeleteLen) - //mc.PrepareMsg(Search, searchLen) + mc.PrepareMsg(Search, searchLen) mc.PrepareMsg(TimeSync, timeLen) - return []int{insertOrDeleteLen} + return []int{insertOrDeleteLen, searchLen, timeLen} } diff --git a/reader/query_node.go b/reader/query_node.go index 890780db5b..f5ace7473c 100644 --- a/reader/query_node.go +++ b/reader/query_node.go @@ -15,36 +15,34 @@ import "C" import ( "fmt" - msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/message" "github.com/czs007/suvlim/reader/message_client" "sort" "sync" - "sync/atomic" - "time" ) type InsertData struct { - insertIDs map[int64][]int64 - insertTimestamps map[int64][]uint64 - insertRecords map[int64][][]byte - insertOffset map[int64]int64 + insertIDs map[int64][]int64 + insertTimestamps map[int64][]uint64 + insertRecords map[int64][][]byte + insertOffset map[int64]int64 } type DeleteData struct { - deleteIDs map[int64][]int64 - deleteTimestamps map[int64][]uint64 - deleteOffset map[int64]int64 + deleteIDs map[int64][]int64 + deleteTimestamps map[int64][]uint64 + deleteOffset map[int64]int64 } type DeleteRecord struct { - entityID int64 - timestamp uint64 - segmentID int64 + entityID int64 + timestamp uint64 + segmentID int64 } type DeletePreprocessData struct { - deleteRecords []*DeleteRecord - count int32 + deleteRecords []*DeleteRecord + count chan int } type QueryNodeDataBuffer struct { @@ -62,7 +60,7 @@ type QueryNode struct { queryNodeTimeSync *QueryNodeTime buffer QueryNodeDataBuffer deletePreprocessData DeletePreprocessData - deleteData DeleteData + deleteData DeleteData insertData InsertData } @@ -79,47 +77,15 @@ func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode { segmentsMap := make(map[int64]*Segment) - buffer := QueryNodeDataBuffer{ - InsertDeleteBuffer: make([]*msgPb.InsertOrDeleteMsg, 0), - SearchBuffer: make([]*msgPb.SearchMsg, 0), - validInsertDeleteBuffer: make([]bool, 0), - validSearchBuffer: make([]bool, 0), - } - return &QueryNode{ - QueryNodeId: queryNodeId, - Collections: nil, - SegmentsMap: segmentsMap, - messageClient: mc, - queryNodeTimeSync: queryNodeTimeSync, - buffer: buffer, + QueryNodeId: queryNodeId, + Collections: nil, + SegmentsMap: segmentsMap, + messageClient: mc, + queryNodeTimeSync: queryNodeTimeSync, } } -func (node *QueryNode) QueryNodeDataInit() { - deletePreprocessData := DeletePreprocessData{ - deleteRecords: make([]*DeleteRecord, 0), - count: 0, - } - - deleteData := DeleteData{ - deleteIDs: make(map[int64][]int64), - deleteTimestamps: make(map[int64][]uint64), - deleteOffset: make(map[int64]int64), - } - - insertData := InsertData{ - insertIDs: make(map[int64][]int64), - insertTimestamps: make(map[int64][]uint64), - insertRecords: make(map[int64][][]byte), - insertOffset: make(map[int64]int64), - } - - node.deletePreprocessData = deletePreprocessData - node.deleteData = deleteData - node.insertData = insertData -} - func (node *QueryNode) NewCollection(collectionName string, schemaConfig string) *Collection { cName := C.CString(collectionName) cSchema := C.CString(schemaConfig) @@ -140,14 +106,13 @@ func (node *QueryNode) DeleteCollection(collection *Collection) { //////////////////////////////////////////////////////////////////////////////////////////////////// -func (node *QueryNode) PrepareBatchMsg() []int { - var msgLen = node.messageClient.PrepareBatchMsg() - return msgLen +func (node *QueryNode) PrepareBatchMsg() { + node.messageClient.PrepareBatchMsg() } func (node *QueryNode) StartMessageClient() { // TODO: add consumerMsgSchema - node.messageClient.InitClient("pulsar://192.168.2.28:6650") + node.messageClient.InitClient("pulsar://localhost:6650") go node.messageClient.ReceiveMessage() } @@ -158,93 +123,53 @@ func (node *QueryNode) InitQueryNodeCollection() { var newCollection = node.NewCollection("collection1", "fakeSchema") var newPartition = newCollection.NewPartition("partition1") // TODO: add segment id - var segment = newPartition.NewSegment(0) - node.SegmentsMap[0] = segment + var _ = newPartition.NewSegment(0) } //////////////////////////////////////////////////////////////////////////////////////////////////// func (node *QueryNode) RunInsertDelete() { - var count = 0 - var start time.Time for { - //time.Sleep(2 * 1000 * time.Millisecond) - node.QueryNodeDataInit() // TODO: get timeRange from message client var timeRange = TimeRange{0, 0} - var msgLen = node.PrepareBatchMsg() - //fmt.Println("PrepareBatchMsg Done, Insert len = ", msgLen[0]) - if msgLen[0] == 0 { - //fmt.Println("0 msg found") - continue - } - if count == 0 { - start = time.Now() - } - count+=msgLen[0] + node.PrepareBatchMsg() node.MessagesPreprocess(node.messageClient.InsertOrDeleteMsg, timeRange) - //fmt.Println("MessagesPreprocess Done") node.WriterDelete() node.PreInsertAndDelete() - //fmt.Println("PreInsertAndDelete Done") node.DoInsertAndDelete() - //fmt.Println("DoInsertAndDelete Done") node.queryNodeTimeSync.UpdateSearchTimeSync(timeRange) - //fmt.Print("UpdateSearchTimeSync Done\n\n\n") - if count == 100000 - 1 { - elapsed := time.Since(start) - fmt.Println("Query node insert 10 × 10000 time:", elapsed) - } } } func (node *QueryNode) RunSearch() { for { - //time.Sleep(2 * 1000 * time.Millisecond) - - start := time.Now() - - if len(node.messageClient.GetSearchChan()) <= 0 { - //fmt.Println("null Search") - continue - } - node.messageClient.SearchMsg = node.messageClient.SearchMsg[:0] - msg := <-node.messageClient.GetSearchChan() - node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg) - fmt.Println("Do Search...") node.Search(node.messageClient.SearchMsg) - - elapsed := time.Since(start) - fmt.Println("Query node search time:", elapsed) } } //////////////////////////////////////////////////////////////////////////////////////////////////// func (node *QueryNode) MessagesPreprocess(insertDeleteMessages []*msgPb.InsertOrDeleteMsg, timeRange TimeRange) msgPb.Status { - //var tMax = timeRange.timestampMax + var tMax = timeRange.timestampMax // 1. Extract messages before readTimeSync from QueryNodeDataBuffer. // Set valid bitmap to false. for i, msg := range node.buffer.InsertDeleteBuffer { - //if msg.Timestamp < tMax { - if msg.Op == msgPb.OpType_INSERT { - if msg.RowsData == nil { - continue + if msg.Timestamp < tMax { + if msg.Op == msgPb.OpType_INSERT { + node.insertData.insertIDs[msg.SegmentId] = append(node.insertData.insertIDs[msg.SegmentId], msg.Uid) + node.insertData.insertTimestamps[msg.SegmentId] = append(node.insertData.insertTimestamps[msg.SegmentId], msg.Timestamp) + node.insertData.insertRecords[msg.SegmentId] = append(node.insertData.insertRecords[msg.SegmentId], msg.RowsData.Blob) + } else if msg.Op == msgPb.OpType_DELETE { + var r = DeleteRecord { + entityID: msg.Uid, + timestamp: msg.Timestamp, + } + node.deletePreprocessData.deleteRecords = append(node.deletePreprocessData.deleteRecords, &r) + node.deletePreprocessData.count <- <- node.deletePreprocessData.count + 1 } - node.insertData.insertIDs[msg.SegmentId] = append(node.insertData.insertIDs[msg.SegmentId], msg.Uid) - node.insertData.insertTimestamps[msg.SegmentId] = append(node.insertData.insertTimestamps[msg.SegmentId], msg.Timestamp) - node.insertData.insertRecords[msg.SegmentId] = append(node.insertData.insertRecords[msg.SegmentId], msg.RowsData.Blob) - } else if msg.Op == msgPb.OpType_DELETE { - var r = DeleteRecord{ - entityID: msg.Uid, - timestamp: msg.Timestamp, - } - node.deletePreprocessData.deleteRecords = append(node.deletePreprocessData.deleteRecords, &r) - atomic.AddInt32(&node.deletePreprocessData.count, 1) + node.buffer.validInsertDeleteBuffer[i] = false } - node.buffer.validInsertDeleteBuffer[i] = false - //} } // 2. Remove invalid messages from buffer. @@ -260,26 +185,23 @@ func (node *QueryNode) MessagesPreprocess(insertDeleteMessages []*msgPb.InsertOr // Move massages after readTimeSync to QueryNodeDataBuffer. // Set valid bitmap to true. for _, msg := range insertDeleteMessages { - //if msg.Timestamp < tMax { - if msg.Op == msgPb.OpType_INSERT { - if msg.RowsData == nil { - continue + if msg.Timestamp < tMax { + if msg.Op == msgPb.OpType_INSERT { + node.insertData.insertIDs[msg.SegmentId] = append(node.insertData.insertIDs[msg.SegmentId], msg.Uid) + node.insertData.insertTimestamps[msg.SegmentId] = append(node.insertData.insertTimestamps[msg.SegmentId], msg.Timestamp) + node.insertData.insertRecords[msg.SegmentId] = append(node.insertData.insertRecords[msg.SegmentId], msg.RowsData.Blob) + } else if msg.Op == msgPb.OpType_DELETE { + var r = DeleteRecord { + entityID: msg.Uid, + timestamp: msg.Timestamp, + } + node.deletePreprocessData.deleteRecords = append(node.deletePreprocessData.deleteRecords, &r) + node.deletePreprocessData.count <- <- node.deletePreprocessData.count + 1 } - node.insertData.insertIDs[msg.SegmentId] = append(node.insertData.insertIDs[msg.SegmentId], msg.Uid) - node.insertData.insertTimestamps[msg.SegmentId] = append(node.insertData.insertTimestamps[msg.SegmentId], msg.Timestamp) - node.insertData.insertRecords[msg.SegmentId] = append(node.insertData.insertRecords[msg.SegmentId], msg.RowsData.Blob) - } else if msg.Op == msgPb.OpType_DELETE { - var r = DeleteRecord{ - entityID: msg.Uid, - timestamp: msg.Timestamp, - } - node.deletePreprocessData.deleteRecords = append(node.deletePreprocessData.deleteRecords, &r) - atomic.AddInt32(&node.deletePreprocessData.count, 1) + } else { + node.buffer.InsertDeleteBuffer = append(node.buffer.InsertDeleteBuffer, msg) + node.buffer.validInsertDeleteBuffer = append(node.buffer.validInsertDeleteBuffer, true) } - //} else { - // node.buffer.InsertDeleteBuffer = append(node.buffer.InsertDeleteBuffer, msg) - // node.buffer.validInsertDeleteBuffer = append(node.buffer.validInsertDeleteBuffer, true) - //} } return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} @@ -288,22 +210,21 @@ func (node *QueryNode) MessagesPreprocess(insertDeleteMessages []*msgPb.InsertOr func (node *QueryNode) WriterDelete() msgPb.Status { // TODO: set timeout for { - if node.deletePreprocessData.count == 0 { - return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} - } - node.messageClient.PrepareKey2SegmentMsg() var ids, timestamps, segmentIDs = node.GetKey2Segments() - for i := 0; i < len(*ids); i++ { + for i := 0; i <= len(*ids); i++ { id := (*ids)[i] timestamp := (*timestamps)[i] segmentID := (*segmentIDs)[i] for _, r := range node.deletePreprocessData.deleteRecords { if r.timestamp == timestamp && r.entityID == id { r.segmentID = segmentID - atomic.AddInt32(&node.deletePreprocessData.count, -1) + node.deletePreprocessData.count <- <- node.deletePreprocessData.count - 1 } } } + if <- node.deletePreprocessData.count == 0 { + return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} + } } } @@ -355,7 +276,6 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { for segmentID, deleteIDs := range node.deleteData.deleteIDs { wg.Add(1) var deleteTimestamps = node.deleteData.deleteTimestamps[segmentID] - fmt.Println("Doing delete......") go node.DoDelete(segmentID, &deleteIDs, &deleteTimestamps, &wg) } @@ -404,11 +324,11 @@ func (node *QueryNode) DoDelete(segmentID int64, deleteIDs *[]int64, deleteTimes } func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { - var clientId = (*(searchMessages[0])).ClientId + var clientId = searchMessages[0].ClientId type SearchResultTmp struct { - ResultId int64 - ResultDistance float32 + ResultId int64 + ResultDistance float32 } // Traverse all messages in the current messageClient. @@ -421,7 +341,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { return msgPb.Status{ErrorCode: 1} } - var resultsTmp = make([]SearchResultTmp, 0) + var resultsTmp []SearchResultTmp // TODO: get top-k's k from queryString const TopK = 1 @@ -430,9 +350,9 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // 1. Timestamp check // TODO: return or wait? Or adding graceful time - //if timestamp > node.queryNodeTimeSync.SearchTimeSync { - // return msgPb.Status{ErrorCode: 1} - //} + if timestamp > node.queryNodeTimeSync.SearchTimeSync { + return msgPb.Status{ErrorCode: 1} + } // 2. Do search in all segments for _, partition := range targetCollection.Partitions { @@ -442,8 +362,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { fmt.Println(err.Error()) return msgPb.Status{ErrorCode: 1} } - fmt.Println(res.ResultIds) - for i := 0; i < len(res.ResultIds); i++ { + for i := 0; i <= len(res.ResultIds); i++ { resultsTmp = append(resultsTmp, SearchResultTmp{ResultId: res.ResultIds[i], ResultDistance: res.ResultDistances[i]}) } } @@ -464,22 +383,12 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { return resultsTmp[i].ResultDistance < resultsTmp[j].ResultDistance }) resultsTmp = resultsTmp[:TopK] - var entities = msgPb.Entities{ - Ids: make([]int64, 0), - } - var results = msgPb.QueryResult{ - Entities: &entities, - Distances: make([]float32, 0), - QueryId: msg.Uid, - } + var results msgPb.QueryResult for _, res := range resultsTmp { results.Entities.Ids = append(results.Entities.Ids, res.ResultId) results.Distances = append(results.Distances, res.ResultDistance) - results.Scores = append(results.Distances, float32(0)) } - results.RowNum = int64(len(results.Distances)) - // 3. publish result to pulsar node.PublishSearchResult(&results, clientId) } diff --git a/reader/reader.go b/reader/reader.go index 6a1b612c41..dc9c58fe8a 100644 --- a/reader/reader.go +++ b/reader/reader.go @@ -3,9 +3,9 @@ package reader func startQueryNode() { qn := NewQueryNode(0, 0) qn.InitQueryNodeCollection() - //go qn.SegmentService() + go qn.SegmentService() qn.StartMessageClient() + go qn.RunInsertDelete() go qn.RunSearch() - qn.RunInsertDelete() } diff --git a/reader/result.go b/reader/result.go index 785cc3116e..8c4129e540 100644 --- a/reader/result.go +++ b/reader/result.go @@ -3,7 +3,7 @@ package reader import ( "context" "fmt" - msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/message" "strconv" ) diff --git a/reader/result_test.go b/reader/result_test.go index 5ce12ba873..af854d29ad 100644 --- a/reader/result_test.go +++ b/reader/result_test.go @@ -1,7 +1,7 @@ package reader import ( - msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgPb "github.com/czs007/suvlim/pkg/message" "testing" ) diff --git a/reader/segment.go b/reader/segment.go index 420f95d087..43c5269993 100644 --- a/reader/segment.go +++ b/reader/segment.go @@ -13,9 +13,8 @@ package reader */ import "C" import ( - "fmt" "github.com/czs007/suvlim/errors" - schema "github.com/czs007/suvlim/pkg/master/grpc/message" + schema "github.com/czs007/suvlim/pkg/message" "strconv" "unsafe" ) @@ -110,19 +109,16 @@ func (s *Segment) SegmentInsert(offset int64, entityIDs *[]int64, timestamps *[] signed long int count); */ // Blobs to one big blob - var numOfRow = len(*entityIDs) - var sizeofPerRow = len((*records)[0]) - - var rawData = make([]byte, numOfRow * sizeofPerRow) + var rawData []byte for i := 0; i < len(*records); i++ { copy(rawData, (*records)[i]) } var cOffset = C.long(offset) - var cNumOfRows = C.long(numOfRow) + var cNumOfRows = C.long(len(*entityIDs)) var cEntityIdsPtr = (*C.long)(&(*entityIDs)[0]) var cTimestampsPtr = (*C.ulong)(&(*timestamps)[0]) - var cSizeofPerRow = C.int(sizeofPerRow) + var cSizeofPerRow = C.int(len((*records)[0])) var cRawDataVoidPtr = unsafe.Pointer(&rawData[0]) var status = C.Insert(s.SegmentPtr, @@ -174,7 +170,7 @@ func (s *Segment) SegmentSearch(queryString string, timestamp uint64, vectorReco float* result_distances); */ // TODO: get top-k's k from queryString - const TopK = 10 + const TopK = 1 resultIds := make([]int64, TopK) resultDistances := make([]float32, TopK) @@ -190,7 +186,5 @@ func (s *Segment) SegmentSearch(queryString string, timestamp uint64, vectorReco return nil, errors.New("Search failed, error code = " + strconv.Itoa(int(status))) } - fmt.Println("Search Result---- Ids =", resultIds, ", Distances =", resultDistances) - return &SearchResult{ResultIds: resultIds, ResultDistances: resultDistances}, nil } diff --git a/reader/segment_test.go b/reader/segment_test.go index 1b24166b09..75342d5435 100644 --- a/reader/segment_test.go +++ b/reader/segment_test.go @@ -1,10 +1,8 @@ package reader import ( - "encoding/binary" "fmt" "github.com/stretchr/testify/assert" - "math" "testing" ) @@ -29,32 +27,28 @@ func TestSegment_SegmentInsert(t *testing.T) { var segment = partition.NewSegment(0) // 2. Create ids and timestamps - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} // 3. Create records, use schema below: // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); // schema_tmp->AddField("age", DataType::INT32); - const DIM = 16 + const DIM = 4 const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4} var rawData []byte for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) + rawData=append(rawData, byte(ele)) } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) + rawData=append(rawData, byte(1)) var records [][]byte - for i := 0; i < N; i++ { + for i:= 0; i < N; i++ { records = append(records, rawData) } // 4. Do PreInsert var offset = segment.SegmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 5. Do Insert var err = segment.SegmentInsert(offset, &ids, ×tamps, &records) @@ -74,12 +68,12 @@ func TestSegment_SegmentDelete(t *testing.T) { var segment = partition.NewSegment(0) // 2. Create ids and timestamps - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} // 3. Do PreDelete var offset = segment.SegmentPreDelete(10) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 4. Do Delete var err = segment.SegmentDelete(offset, &ids, ×tamps) @@ -99,32 +93,28 @@ func TestSegment_SegmentSearch(t *testing.T) { var segment = partition.NewSegment(0) // 2. Create ids and timestamps - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} // 3. Create records, use schema below: // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); // schema_tmp->AddField("age", DataType::INT32); - const DIM = 16 + const DIM = 4 const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4} var rawData []byte for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) + rawData=append(rawData, byte(ele)) } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) + rawData=append(rawData, byte(1)) var records [][]byte - for i := 0; i < N; i++ { + for i:= 0; i < N; i++ { records = append(records, rawData) } // 4. Do PreInsert var offset = segment.SegmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 5. Do Insert var err = segment.SegmentInsert(offset, &ids, ×tamps, &records) @@ -150,7 +140,7 @@ func TestSegment_SegmentPreInsert(t *testing.T) { // 2. Do PreInsert var offset = segment.SegmentPreInsert(10) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 3. Destruct node, collection, and segment partition.DeleteSegment(segment) @@ -167,7 +157,7 @@ func TestSegment_SegmentPreDelete(t *testing.T) { // 2. Do PreDelete var offset = segment.SegmentPreDelete(10) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 3. Destruct node, collection, and segment partition.DeleteSegment(segment) @@ -219,32 +209,28 @@ func TestSegment_GetRowCount(t *testing.T) { var segment = partition.NewSegment(0) // 2. Create ids and timestamps - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} // 3. Create records, use schema below: // schema_tmp->AddField("fakeVec", DataType::VECTOR_FLOAT, 16); // schema_tmp->AddField("age", DataType::INT32); - const DIM = 16 + const DIM = 4 const N = 3 - var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var vec = [DIM]float32{1.1, 2.2, 3.3, 4.4} var rawData []byte for _, ele := range vec { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, math.Float32bits(ele)) - rawData = append(rawData, buf...) + rawData=append(rawData, byte(ele)) } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, 1) - rawData = append(rawData, bs...) + rawData=append(rawData, byte(1)) var records [][]byte - for i := 0; i < N; i++ { + for i:= 0; i < N; i++ { records = append(records, rawData) } // 4. Do PreInsert var offset = segment.SegmentPreInsert(N) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 5. Do Insert var err = segment.SegmentInsert(offset, &ids, ×tamps, &records) @@ -268,12 +254,12 @@ func TestSegment_GetDeletedCount(t *testing.T) { var segment = partition.NewSegment(0) // 2. Create ids and timestamps - ids := []int64{1, 2, 3} - timestamps := []uint64{0, 0, 0} + ids :=[] int64{1, 2, 3} + timestamps :=[] uint64 {0, 0, 0} // 3. Do PreDelete var offset = segment.SegmentPreDelete(10) - assert.GreaterOrEqual(t, offset, int64(0)) + assert.Greater(t, offset, 0) // 4. Do Delete var err = segment.SegmentDelete(offset, &ids, ×tamps) diff --git a/reader/util_functions.go b/reader/util_functions.go index 8496b1ffc2..3e169bb619 100644 --- a/reader/util_functions.go +++ b/reader/util_functions.go @@ -7,13 +7,13 @@ import ( // Function `GetSegmentByEntityId` should return entityIDs, timestamps and segmentIDs func (node *QueryNode) GetKey2Segments() (*[]int64, *[]uint64, *[]int64) { - var entityIDs = make([]int64, 0) - var timestamps = make([]uint64, 0) - var segmentIDs = make([]int64, 0) + var entityIDs []int64 + var timestamps []uint64 + var segmentIDs []int64 - var key2SegMsg = node.messageClient.Key2SegMsg - for _, msg := range key2SegMsg { - for _, segmentID := range msg.SegmentId { + var key2SegMsg = &node.messageClient.Key2SegMsg + for _, msg := range *key2SegMsg { + for _, segmentID := range (*msg).SegmentId { entityIDs = append(entityIDs, msg.Uid) timestamps = append(timestamps, msg.Timestamp) segmentIDs = append(segmentIDs, segmentID) diff --git a/writer/message_client/message_client.go b/writer/message_client/message_client.go index 27d4713674..f39e12eebc 100644 --- a/writer/message_client/message_client.go +++ b/writer/message_client/message_client.go @@ -2,8 +2,8 @@ package message_client import ( "context" - "github.com/apache/pulsar-client-go/pulsar" - msgpb "github.com/czs007/suvlim/pkg/master/grpc/message" + "github.com/apache/pulsar/pulsar-client-go/pulsar" + msgpb "github.com/czs007/suvlim/pkg/message" "github.com/golang/protobuf/proto" "log" ) @@ -30,9 +30,8 @@ type MessageClient struct { } func (mc *MessageClient) Send(ctx context.Context, msg msgpb.Key2SegMsg) { - var msgBuffer, _ = proto.Marshal(&msg) - if _, err := mc.key2segProducer.Send(ctx, &pulsar.ProducerMessage{ - Payload: msgBuffer, + if err := mc.key2segProducer.Send(ctx, pulsar.ProducerMessage{ + Payload: []byte(msg.String()), }); err != nil { log.Fatal(err) } diff --git a/writer/write_node/writer_node.go b/writer/write_node/writer_node.go index 9189141296..213ee6fbf0 100644 --- a/writer/write_node/writer_node.go +++ b/writer/write_node/writer_node.go @@ -3,7 +3,7 @@ package write_node import ( "context" "fmt" - msgpb "github.com/czs007/suvlim/pkg/master/grpc/message" + msgpb "github.com/czs007/suvlim/pkg/message" storage "github.com/czs007/suvlim/storage/pkg" "github.com/czs007/suvlim/storage/pkg/types" "github.com/czs007/suvlim/writer/message_client" @@ -85,7 +85,6 @@ func (wn *WriteNode) DeleteBatchData(ctx context.Context, data []*msgpb.InsertOr segmentInfo := msgpb.Key2SegMsg{ Uid: data[i].Uid, SegmentId: segmentIds, - Timestamp: data[i].Timestamp, } wn.MessageClient.Send(ctx, segmentInfo) }