Support compile and run on Mac (#15491)

Co-authored-by: jaime <yun.zhang@zilliz.com>
Co-authored-by: Cai Yudong <yudong.cai@zilliz.com>
Co-authored-by: Jenny Li <jing.li@zilliz.com>
Co-authored-by: Nemo <yuchen.gao@zilliz.com>
Signed-off-by: yun.zhang <yun.zhang@zilliz.com>

Co-authored-by: Cai Yudong <yudong.cai@zilliz.com>
Co-authored-by: Jenny Li <jing.li@zilliz.com>
Co-authored-by: Nemo <yuchen.gao@zilliz.com>
This commit is contained in:
jaime 2022-02-09 14:27:46 +08:00 committed by GitHub
parent d413f653b7
commit 307a8ce535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
132 changed files with 849 additions and 25982 deletions

View File

@ -22,32 +22,48 @@ on:
jobs:
ubuntu:
name: Code Checker AMD64 Ubuntu ${{ matrix.ubuntu }}
runs-on: ubuntu-${{ matrix.ubuntu }}
name: ${{ matrix.name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
ubuntu: [18.04]
include:
- name: Code Checker AMD64 Ubuntu 18.04
os: ubuntu-18.04
- name: Code Checker MacOS 11
os: macos-11
env:
UBUNTU: ${{ matrix.ubuntu }}
UBUNTU: 18.04
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Cache CCache Volumes
if: ${{ matrix.os == 'ubuntu-18.04' }}
uses: actions/cache@v1
with:
path: .docker/amd64-ubuntu${{ matrix.ubuntu }}-ccache
key: ubuntu${{ matrix.ubuntu }}-ccache-${{ hashFiles('internal/core/**') }}
restore-keys: ubuntu${{ matrix.ubuntu }}-ccache-
path: .docker/amd64-ubuntu18.04-ccache
key: ubuntu18.04-ccache-${{ hashFiles('internal/core/**') }}
restore-keys: ubuntu18.04-ccache-
- name: Cache Go Mod Volumes
if: ${{ matrix.os == 'ubuntu-18.04' }}
uses: actions/cache@v1
with:
path: .docker/amd64-ubuntu${{ matrix.ubuntu }}-go-mod
key: ubuntu${{ matrix.ubuntu }}-go-mod-${{ hashFiles('**/go.sum') }}
restore-keys: ubuntu${{ matrix.ubuntu }}-go-mod-
path: .docker/amd64-ubuntu18.04-go-mod
key: ubuntu18.04-go-mod-${{ hashFiles('**/go.sum') }}
restore-keys: ubuntu18.04-go-mod-
- name: Code Check
if: ${{ matrix.os == 'ubuntu-18.04' }}
env:
CHECK_BUILDER: "1"
run: |
./build/builder.sh /bin/bash -c "make check-proto-product && make verifiers"
./build/builder.sh /bin/bash -c "make check-proto-product && make verifiers"
- name: Code Check
if: ${{ matrix.os == 'macos-11' }}
run: |
bash < <(curl -s -S -L https://raw.githubusercontent.com/moovweb/gvm/master/binscripts/gvm-installer)
source ~/.gvm/scripts/gvm
gvm install go1.17.2
gvm use go1.17.2
brew install boost libomp ninja tbb
make check-proto-product && make verifiers

View File

@ -1,6 +1,6 @@
# Development
This document will help to set up your development environment and run tests. If you encounter a problem, please file an issue.
This document will help to set up your Milvus development environment and to run tests. Please [file an issue](https://github.com/milvus-io/milvus/issues/new/choose) if there's a problem.
Table of contents
=================
@ -14,27 +14,27 @@ Table of contents
- [Docker & Docker Compose](#docker--docker-compose)
- [Building Milvus](#building-milvus)
- [A Quick Start for Testing Milvus](#a-quick-start-for-testing-milvus)
- [Presubmission Verification](#presubmission-verification)
- [Pre-submission Verification](#pre-submission-verification)
- [Unit Tests](#unit-tests)
- [Code coverage](#code-coverage)
- [E2E Tests](#e2e-tests)
- [Test on local branch](#test-on-local-branch)
- [On Linux](#on-linux)
- [With Linux and MacOS](#with-linux-and-macos)
- [With docker](#with-docker)
- [GitHub Flow](#github-flow)
## Building Milvus with Docker
Official releases are built with Docker containers. To build Milvus with Docker please follow [these instructions](https://github.com/milvus-io/milvus/blob/master/build/README.md).
Our official Milvus versions are releases as Docker images. To build Milvus Docker on your own, please follow [these instructions](https://github.com/milvus-io/milvus/blob/master/build/README.md).
## Building Milvus on a local OS/shell environment
The details below outline the hardware and software requirements for building on Linux.
The details below outline the hardware and software requirements for building on Linux and MacOS.
### Hardware Requirements
Milvus is written in Go and C++, it requires a lot of resources to compile it. We recommend the following for any physical or virtual machine being used for building Milvus.
The following specification (either physical or virtual machine resources) is recommended for Milvus to build and run from source code.
```
- 8GB of RAM
@ -43,50 +43,48 @@ Milvus is written in Go and C++, it requires a lot of resources to compile it. W
### Software Requirements
In fact, all Linux distributions are available to develop Milvus. The following only contains commands on Ubuntu and CentOS, because we mainly use them. If you develop Milvus on other distributions, you are welcome to improve this document.
All Linux distributions are available for Milvus development. However a majority of our contributor worked with Ubuntu or CentOS systems, with a small portion of Mac (both x86_64 and Apple Silicon) contributors. If you would like Milvus to build and run on other distributions, you are more than welcome to file an issue and contribute!
#### Dependencies
- Debian/Ubuntu
Here's a list of verified OS types where Milvus can successfully build and run:
```shell
$ sudo apt update
$ sudo apt install -y build-essential ccache gfortran \
libssl-dev zlib1g-dev python3-dev libcurl4-openssl-dev libtbb-dev\
libboost-regex-dev libboost-program-options-dev libboost-system-dev \
libboost-filesystem-dev libboost-serialization-dev libboost-python-dev
* Debian/Ubuntu
* CentOS
* MacOS (x86_64)
* MacOS (Apple Silicon)
#### Prerequisites
Linux systems (Recommend Ubuntu 18.04 or later):
```bash
go: >= 1.15
cmake: >= 3.18
gcc: 7.5
```
- CentOS
```shell
$ sudo yum install -y epel-release centos-release-scl-rh && \
$ sudo yum install -y git make automake openssl-devel zlib-devel \
libcurl-devel python3-devel \
devtoolset-7-gcc devtoolset-7-gcc-c++ devtoolset-7-gcc-gfortran \
llvm-toolset-7.0-clang llvm-toolset-7.0-clang-tools-extra \
ccache lcov
$ echo "source scl_source enable devtoolset-7" | sudo tee -a /etc/profile.d/devtoolset-7.sh
$ echo "source scl_source enable llvm-toolset-7.0" | sudo tee -a /etc/profile.d/llvm-toolset-7.sh
$ echo "export CLANG_TOOLS_PATH=/opt/rh/llvm-toolset-7.0/root/usr/bin" | sudo tee -a /etc/profile.d/llvm-toolset-7.sh
$ source "/etc/profile.d/llvm-toolset-7.sh"
# Install tbb
$ git clone https://github.com/wjakob/tbb.git && \
cd tbb/build && \
cmake .. && make -j && \
sudo make install && \
cd ../../ && rm -rf tbb/
# Install boost
$ wget -q https://boostorg.jfrog.io/artifactory/main/release/1.65.1/source/boost_1_65_1.tar.gz && \
tar zxf boost_1_65_1.tar.gz && cd boost_1_65_1 && \
./bootstrap.sh --prefix=/usr/local --with-toolset=gcc --without-libraries=python && \
sudo ./b2 -j2 --prefix=/usr/local --without-python toolset=gcc install && \
cd ../ && rm -rf ./boost_1_65_1*
MacOS systems with x86_64 (Big Sur 11.5 or later recommended):
```bash
go: >= 1.15
cmake: >= 3.18
llvm: >= 12
```
MacOS systems with Apple Silicon (Monterey 12.0.1 or later recommended):
```bash
go: >= 1.17 (Arch=ARM64)
cmake: >= 3.18
llvm: >= 13
```
#### Installing Dependencies
In the Milvus repository root, simply run:
```bash
$ ./scripts/install_deps.sh
```
#### Caveats
* [Google Test](https://github.com/google/googletest.git) is automatically cloned from GitHub, which in some case could conflict with your local google test library.
Once you have finished, confirm that `gcc` and `make` are installed:
```shell
@ -103,7 +101,7 @@ Confirm that cmake is available:
```shell
$ cmake --version
```
Note: 3.18 or higher cmake version is required to build Milvus.
Note: 3.18 or higher cmake version is required to build Milvus.
#### Go
@ -114,7 +112,7 @@ Confirm that your `GOPATH` and `GOBIN` environment variables are correctly set a
```shell
$ go version
```
Note: go1.15 is required to build Milvus.
Note: go >= 1.15 is required to build Milvus.
#### Docker & Docker Compose
@ -142,11 +140,11 @@ If you want to know more, you can read Makefile.
## A Quick Start for Testing Milvus
### Presubmission Verification
### Pre-submission Verification
Presubmission verification provides a battery of checks and tests to give your pull request the best chance of being accepted. Developers need to run as many verification tests as possible locally.
Pre-submission verification provides a battery of checks and tests to give your pull request the best chance of being accepted. Developers need to run as many verification tests as possible locally.
To run all presubmission verification tests, use this command:
To run all pre-submission verification tests, use this command:
```shell
$ make verifiers
@ -154,49 +152,42 @@ $ make verifiers
### Unit Tests
Pull requests need to pass all unit tests. To run every unit test, use this command:
It is required that all pull request candidates should pass all Milvus unit tests.
Beforce running unit tests, you need to first bring up the Milvus deployment environment.
You may set up a local docker environment with our docker compose yaml file to start unit testing.
For Apple Silicon users (Apple M1):
```shell
$ cd deployments/docker/dev
$ docker-compose -f docker-compose-apple-silicon.yml up -d
$ cd ../../../
$ make unittest
```
Before using `make unittest` command, we should run a milvus's deployment environment which helps us to do go test. Here we use local docker environment, use the following commands:
For others:
```shell
# Using cluster environment
$ cd deployments/docker/dev
$ docker-compose up -d
$ cd ../../../
$ make unittest
# Or using standalone environment
$ cd deployments/docker/standalone
$ docker-compose up -d
$ cd ../../../
$ make unittest
```
To run only cpp test, we can use this command:
To run only cpp test:
```shell
make test-cpp
$ make test-cpp
```
To run only go test, we can use this command:
To run only go test:
```shell
make test-go
$ make test-go
```
To run single test case, for instance, run TestSearchTask in /internal/proxy directory, use
To run a single test case (TestSearchTask in /internal/proxy directory, for example):
```shell
$ go test -v ./internal/proxy/ -test.run TestSearchTask
```
### Code coverage
Before submitting your Pull Request, make sure your code change is covered by unit test. Use the following commands to check code coverage rate:
Install lcov(cpp code coverage tool):
```shell
$ sudo apt-get install lcov
```
Before submitting your pull request, make sure your code change is covered by unit test. Use the following commands to check code coverage rate:
Run unit test and generate code coverage report:
```shell
@ -242,7 +233,7 @@ $ pytest --tags=L0 -n auto
```
### Test on local branch
#### On Linux
#### With Linux and MacOS
After preparing deployment environment, we can start the cluster on your host machine
```shell
@ -264,3 +255,29 @@ $ install with docker compose
## GitHub Flow
To check out code to work on, please refer to the [GitHub Flow](https://guides.github.com/introduction/flow/).
## FAQs
Q: The go building phase fails on Apple Silicon (Mac M1) machines.
A: Please double-check that you have [right Go version](https://go.dev/dl/) installed, i.e. with OS=macOS and Arch=ARM64.
---
Q: "make" fails with "*ld: library not found for -lSystem*" on MacOS.
A: There are a couple of things you could try:
1. Use **Software Update** (from **About this Mac** -> **Overview**) to install updates.
2. Try the following commands:
```bash
sudo rm -rf /Library/Developer/CommandLineTools
sudo xcode-select --install
```
---
Q: Rocksdb fails to compile with "*ld: warning: object file was built for newer macOS version (11.6) than being linked (11.0).*" on MacOS.
A: Use **Software Update** (from **About this Mac** -> **Overview**) to install updates.
---
Q: Some Go unit tests failed.
A: We are aware that some tests can be flaky occasionally. If there's something you believe is abnormal (i.e. tests that fail every single time). You are more than welcome to [file an issue](https://github.com/milvus-io/milvus/issues/new/choose)!

View File

@ -15,10 +15,18 @@ GOPATH := $(shell $(GO) env GOPATH)
INSTALL_PATH := $(PWD)/bin
LIBRARY_PATH := $(PWD)/lib
OS := $(shell uname -s)
ARCH := $(shell arch)
mode = Release
all: build-cpp build-go
mode = Release
pre-proc:
@echo "Running pre-processing"
ifeq ($(OS),Darwin) # MacOS X
@echo "MacOS system identified. Switching to customized gorocksdb fork..."
@go mod edit -replace=github.com/tecbot/gorocksdb=github.com/soothing-rain/gorocksdb@latest
endif
get-build-deps:
@(env bash $(PWD)/scripts/install_deps.sh)
@ -74,9 +82,19 @@ ifdef GO_DIFF_FILES
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go $(GO_DIFF_FILES)
else
@echo "Running $@ check"
ifeq ($(OS),Darwin) # MacOS X
ifeq ($(ARCH),arm64)
@${GOPATH}/bin/darwin_arm64/ruleguard -rules ruleguard.rules.go ./internal/...
@${GOPATH}/bin/darwin_arm64/ruleguard -rules ruleguard.rules.go ./cmd/...
else
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./internal/...
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./cmd/...
# @${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./tests/go/...
endif
else
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./internal/...
@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./cmd/...
endif
#@${GOPATH}/bin/ruleguard -rules ruleguard.rules.go ./tests/go/...
endif
verifiers: build-cpp getdeps cppcheck fmt static-check ruleguard
@ -87,7 +105,7 @@ binlog:
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/binlog $(PWD)/cmd/tools/binlog/main.go 1>/dev/null
BUILD_TAGS = $(shell git describe --tags --always --dirty="-dev")
BUILD_TIME = $(shell date --utc)
BUILD_TIME = $(shell date -u)
GIT_COMMIT = $(shell git rev-parse --short HEAD)
GO_VERSION = $(shell go version)
@ -105,17 +123,17 @@ milvus: build-cpp print-build-info
build-go: milvus
build-cpp:
build-cpp: pre-proc
@echo "Building Milvus cpp library ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_rocksdb_build.sh -t Release -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_rocksdb_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)")
build-cpp-with-unittest:
build-cpp-with-unittest: pre-proc
@echo "Building Milvus cpp library with unittest ..."
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -c -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_rocksdb_build.sh -t Release -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -c -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)")
@(env bash $(PWD)/scripts/cwrapper_rocksdb_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)")
# Run the tests.
unittest: test-cpp test-go

View File

@ -67,11 +67,25 @@ Milvus was released under the [open-source Apache License 2.0](https://github.co
Check the requirements first.
Linux systems (Ubuntu 18.04 or later recommended):
```bash
go: 1.15
cmake: >=3.18
go: >= 1.15
cmake: >= 3.18
gcc: 7.5
protobuf: >=3.7
```
MacOS systems with x86_64 (Big Sur 11.5 or later recommended):
```bash
go: >= 1.15
cmake: >= 3.18
llvm: >= 12
```
MacOS systems with Apple Silicon (Monterey 12.0.1 or later recommended):
```bash
go: >= 1.17 (Arch=ARM64)
cmake: >= 3.18
llvm: >= 13
```
Clone Milvus repo and build.

View File

@ -0,0 +1,55 @@
version: '3.5'
services:
etcd:
image: quay.io/coreos/etcd:v3.5.0
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
command: etcd -listen-peer-urls=http://127.0.0.1:2380 -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379,http://0.0.0.0:4001 -initial-advertise-peer-urls=http://127.0.0.1:2380 --initial-cluster default=http://127.0.0.1:2380
ports:
- "2379:2379"
- "2380:2380"
- "4001:4001"
pulsar:
image: milvusdb/pulsar:v2.7.3-m1
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/pulsar:/pulsar/data
environment:
# bin/apply-config-from-env.py script will modify the configuration file based on the environment variables
# nettyMaxFrameSizeBytes must be calculated from maxMessageSize + 10240 (padding)
- nettyMaxFrameSizeBytes=104867840 # this is 104857600 + 10240 (padding)
- defaultRetentionTimeInMinutes=10080
- defaultRetentionSizeInMB=8192
# maxMessageSize is missing from standalone.conf, must use PULSAR_PREFIX_ to get it configured
- PULSAR_PREFIX_maxMessageSize=104857600
- PULSAR_GC=-XX:+UseG1GC
ports:
- "6650:6650"
- "18080:8080"
minio:
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
ports:
- "9000:9000"
environment:
MINIO_ACCESS_KEY: minioadmin
MINIO_SECRET_KEY: minioadmin
volumes:
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
command: minio server /minio_data
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
jaeger:
image: jaegertracing/all-in-one:latest
ports:
- "6831:6831/udp"
- "16686:16686"
networks:
default:
name: milvus_dev

1
go.mod
View File

@ -30,6 +30,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.11.0
github.com/shirou/gopsutil v3.21.8+incompatible
github.com/soothing-rain/gorocksdb v0.0.0-20220113075731-e68e68ed4c62 // indirect
github.com/spaolacci/murmur3 v1.1.0
github.com/spf13/cast v1.3.1
github.com/spf13/viper v1.8.1

4
go.sum
View File

@ -483,6 +483,10 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js=
github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0=
github.com/soothing-rain/gorocksdb v0.0.0-20220113074926-f4b9b182e17e h1:E1iqJnSWf+BVfBbuBKLaW4rSCZAAMH1ui1uT+IXTgz4=
github.com/soothing-rain/gorocksdb v0.0.0-20220113074926-f4b9b182e17e/go.mod h1:OEi1TEIyGy/kDfY3ZwoHfn+/dIK0ZOOKJ2ReaYf4sao=
github.com/soothing-rain/gorocksdb v0.0.0-20220113075731-e68e68ed4c62 h1:EAmHsU58jrwMjgnVVxgnlqlBG+1uFuv+Om1qXr5GfII=
github.com/soothing-rain/gorocksdb v0.0.0-20220113075731-e68e68ed4c62/go.mod h1:OEi1TEIyGy/kDfY3ZwoHfn+/dIK0ZOOKJ2ReaYf4sao=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=

View File

@ -16,10 +16,53 @@
cmake_minimum_required( VERSION 3.18 )
if ( APPLE )
set( CMAKE_CROSSCOMPILING TRUE )
set( RUN_HAVE_GNU_POSIX_REGEX 0 )
set( CMAKE_C_COMPILER "/usr/local/opt/llvm/bin/clang" )
set( CMAKE_CXX_COMPILER "/usr/local/opt/llvm/bin/clang++" )
endif ()
add_definitions(-DELPP_THREAD_SAFE)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
message( STATUS "Building using CMake version: ${CMAKE_VERSION}" )
project(core)
include(CheckCXXCompilerFlag)
if ( APPLE )
message(STATUS "==============Darwin Environment==============")
check_cxx_compiler_flag(-std=c++11 HAS_STD_CPP11_FLAG)
if(HAS_STD_CPP11_FLAG)
add_compile_options(-std=c++11)
endif()
if(CMAKE_C_COMPILER_ID MATCHES "Clang")
set(OpenMP_C "${CMAKE_C_COMPILER}" CACHE STRING "" FORCE)
set(OpenMP_C_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument" CACHE STRING "" FORCE)
set(OpenMP_C_LIB_NAMES "libomp" "libgomp" "libiomp5" CACHE STRING "" FORCE)
set(OpenMP_libomp_LIBRARY ${OpenMP_C_LIB_NAMES} CACHE STRING "" FORCE)
set(OpenMP_libgomp_LIBRARY ${OpenMP_C_LIB_NAMES} CACHE STRING "" FORCE)
set(OpenMP_libiomp5_LIBRARY ${OpenMP_C_LIB_NAMES} CACHE STRING "" FORCE)
endif()
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}" CACHE STRING "" FORCE)
set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument" CACHE STRING "" FORCE)
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5" CACHE STRING "" FORCE)
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES} CACHE STRING "" FORCE)
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES} CACHE STRING "" FORCE)
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES} CACHE STRING "" FORCE)
endif()
elseif (${CMAKE_SYSTEM_NAME} MATCHES "Linux")
message(STATUS "==============Linux Environment===============")
set(LINUX TRUE)
else ()
message(FATAL_ERROR "Unsupported platform!" )
endif ()
find_package(OpenMP REQUIRED)
if (OPENMP_FOUND)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake" )
include( Utils )
@ -61,26 +104,43 @@ include( DefineOptions )
using_ccache_if_defined(MILVUS_USE_CCACHE)
include( ExternalProject )
include( GNUInstallDirs )
include( FetchContent )
include_directories(thirdparty)
set( FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download )
set(FETCHCONTENT_QUIET OFF)
include( ThirdPartyPackages )
find_package(OpenMP REQUIRED)
if (OPENMP_FOUND)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()
# **************************** Compiler arguments ****************************
message( STATUS "Building Milvus CPU version" )
if (LINUX)
append_flags( CMAKE_CXX_FLAGS
FLAGS
"-fPIC"
"-DELPP_THREAD_SAFE"
"-fopenmp"
"-Werror"
)
"-fPIC"
"-DELPP_THREAD_SAFE"
"-fopenmp"
"-Werror"
)
endif ()
if ( APPLE )
append_flags( CMAKE_CXX_FLAGS
FLAGS
"-fPIC"
"-DELPP_THREAD_SAFE"
"-fopenmp"
"-Wno-error"
"-Wsign-compare"
"-Wall"
"-pedantic"
"-Wno-unused-command-line-argument"
"-Wextra"
"-Wno-unused-parameter"
"-Wno-deprecated"
"-DBOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED=1"
)
endif ()
# **************************** Coding style check tools ****************************
find_package( ClangTools )
@ -107,20 +167,18 @@ if ( NOT LINT_EXCLUSIONS_FILE )
set( LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt )
endif ()
find_program( CPPLINT_BIN NAMES cpplint cpplint.py HINTS ${BUILD_SUPPORT_DIR} )
message( STATUS "Found cpplint executable at ${CPPLINT_BIN}" )
#
# "make lint" targets
#
add_custom_target( lint
add_custom_target(lint
${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/run_cpplint.py
--cpplint_binary ${CPPLINT_BIN}
--exclude_globs ${LINT_EXCLUSIONS_FILE}
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}
--source_dir ${CMAKE_CURRENT_SOURCE_DIR}/src
${MILVUS_LINT_QUIET}
)
find_program( CPPLINT_BIN NAMES cpplint cpplint.py HINTS ${BUILD_SUPPORT_DIR} )
message( STATUS "Found cpplint executable at ${CPPLINT_BIN}" )
#
# "make clang-format" and "make check-clang-format" targets
#
@ -190,7 +248,9 @@ add_subdirectory( src )
if ( BUILD_UNIT_TEST STREQUAL "ON" )
append_flags( CMAKE_CXX_FLAGS FLAGS "-DELPP_DISABLE_LOGS")
add_subdirectory(unittest)
add_subdirectory(bench)
if (LINUX)
add_subdirectory(bench)
endif ()
endif ()
add_custom_target( Clean-All COMMAND ${CMAKE_BUILD_TOOL} clean )
@ -205,7 +265,7 @@ set( GPU_ENABLE "false" )
# Install segcore
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/segcore/
DESTINATION include/segcore/
DESTINATION include/segcore
FILES_MATCHING PATTERN "*_c.h"
)
@ -217,7 +277,7 @@ install(
# Install indexbuilder
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/indexbuilder/
DESTINATION include/indexbuilder/
DESTINATION include/indexbuilder
FILES_MATCHING PATTERN "*_c.h"
)
@ -228,6 +288,6 @@ install(FILES ${CMAKE_BINARY_DIR}/src/indexbuilder/libmilvus_indexbuilder${CMAKE
# Install common
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/common/
DESTINATION include/common/
DESTINATION include/common
FILES_MATCHING PATTERN "*_c.h"
)
)

View File

@ -11,7 +11,6 @@
include_directories(${CMAKE_HOME_DIRECTORY}/src)
include_directories(${CMAKE_HOME_DIRECTORY}/unittest)
include_directories(${CMAKE_HOME_DIRECTORY}/src/index/knowhere)
set(bench_srcs
bench_naive.cpp
@ -37,6 +36,8 @@ target_link_libraries(indexbuilder_bench
milvus_indexbuilder
log
pthread
knowhere
milvus_utils
)
target_link_libraries(indexbuilder_bench benchmark::benchmark_main)

View File

@ -13,10 +13,10 @@
#include <tuple>
#include <map>
#include <google/protobuf/text_format.h>
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include "pb/index_cgo_msg.pb.h"
#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h"
#include "index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/index_c.h"
#include "indexbuilder/utils.h"

View File

@ -23,8 +23,7 @@ while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symli
done
SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
BUILD_OUTPUT_DIR="${SCRIPTS_DIR}/../../cmake_build"
BUILD_OUTPUT_DIR="./cmake_build"
BUILD_TYPE="Release"
BUILD_UNITTEST="OFF"
INSTALL_PREFIX="${SCRIPTS_DIR}/output"
@ -130,6 +129,19 @@ if [[ ${MAKE_CLEAN} == "ON" ]]; then
exit 0
fi
unameOut="$(uname -s)"
case "${unameOut}" in
Darwin*)
llvm_prefix="$(brew --prefix llvm)"
export CLANG_TOOLS_PATH="${llvm_prefix}/bin"
export CC="${llvm_prefix}/bin/clang"
export CXX="${llvm_prefix}/bin/clang++"
export LDFLAGS="-L${llvm_prefix}/lib -L/usr/local/opt/libomp/lib"
export CXXFLAGS="-I${llvm_prefix}/include -I/usr/local/include -I/usr/local/opt/libomp/include"
;;
*) echo "==System:${unameOut}";
esac
CMAKE_CMD="cmake \
-DBUILD_UNIT_TEST=${BUILD_UNITTEST} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX}
@ -143,6 +155,7 @@ CMAKE_CMD="cmake \
-DMILVUS_WITH_PROMETHEUS=${WITH_PROMETHEUS} \
-DMILVUS_CUDA_ARCH=${CUDA_ARCH} \
-DCUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_PATH} \
-DKNOWHERE_GPU_VERSION=${SUPPORT_GPU} \
${SCRIPTS_DIR}"
echo ${CMAKE_CMD}
${CMAKE_CMD}

View File

@ -85,7 +85,7 @@ set(THIRDPARTY_DIR "${MILVUS_SOURCE_DIR}/thirdparty")
# ----------------------------------------------------------------------
# ExternalProject options
string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE)
#string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE)
set(EP_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}")
set(EP_C_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}")

View File

@ -17,15 +17,7 @@
include_directories(${MILVUS_ENGINE_SRC})
include_directories(${MILVUS_THIRDPARTY_SRC})
set(FOUND_OPENBLAS "unknown")
add_subdirectory(index)
set(INDEX_INCLUDE_DIRS ${INDEX_INCLUDE_DIRS} PARENT_SCOPE)
foreach (DIR ${INDEX_INCLUDE_DIRS})
include_directories(${DIR})
endforeach ()
add_subdirectory( exceptions )
add_subdirectory( utils )
add_subdirectory( log )

View File

@ -18,4 +18,5 @@ set(COMMON_SRC
add_library(milvus_common
${COMMON_SRC}
)
target_link_libraries(milvus_common milvus_proto yaml-cpp)
target_link_libraries(milvus_common knowhere milvus_proto yaml-cpp )

View File

@ -24,8 +24,8 @@
#include <boost/align/aligned_allocator.hpp>
#include <NamedType/named_type.hpp>
#include "faiss/utils/BitsetView.h"
#include "faiss/MetricType.h"
#include "knowhere/utils/BitsetView.h"
#include "knowhere/common/MetricType.h"
#include "pb/schema.pb.h"
#include "utils/Types.h"

View File

@ -22,7 +22,7 @@ add_library(milvus_config
${CONFIG_SRC}
)
target_link_libraries(milvus_config
knowhere
milvus_proto
milvus_utils
knowhere
)

View File

@ -19,7 +19,6 @@
#include "ConfigKnowhere.h"
#include "exceptions/EasyAssert.h"
#include "easyloggingpp/easylogging++.h"
#include "faiss/FaissHook.h"
#include "log/Log.h"
#include "knowhere/archive/KnowhereConfig.h"

0
internal/core/src/index/build.sh Executable file → Normal file
View File

View File

@ -1,3 +0,0 @@
if( ${UNIX} )
add_subdirectory(${PROJECT_SOURCE_DIR}/lib/NGT)
endif()

View File

@ -1,89 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "ArrayFile.h"
#include <iostream>
#include <assert.h>
class ItemID {
public:
void serialize(std::ostream &os, NGT::ObjectSpace *ospace = 0) {
os.write((char*)&value, sizeof(value));
}
void deserialize(std::istream &is, NGT::ObjectSpace *ospace = 0) {
is.read((char*)&value, sizeof(value));
}
static size_t getSerializedDataSize() {
return sizeof(uint64_t);
}
uint64_t value;
};
void
sampleForUsage() {
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 1;
itemID.value = 4910002490100;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 2;
itemID.value = 4910002490101;
itemIDFile.put(id, itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
itemID.value = 4910002490102;
id = itemIDFile.insert(itemID);
itemID.value = 0;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490102);
itemIDFile.close();
}
{
ArrayFile<ItemID> itemIDFile;
itemIDFile.create("test.data", ItemID::getSerializedDataSize());
itemIDFile.open("test.data");
ItemID itemID;
size_t id;
id = 10;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490100);
id = 20;
itemIDFile.get(id, itemID);
std::cerr << "value=" << itemID.value << std::endl;
assert(itemID.value == 4910002490101);
}
}

View File

@ -1,220 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <fstream>
#include <string>
#include <cstddef>
#include <stdint.h>
#include <iostream>
#include <stdexcept>
#include <cerrno>
#include <cstring>
namespace NGT {
class ObjectSpace;
};
template <class TYPE>
class ArrayFile {
private:
struct FileHeadStruct {
size_t recordSize;
uint64_t extraData; // reserve
};
struct RecordStruct {
bool deleteFlag;
uint64_t extraData; // reserve
};
bool _isOpen;
std::fstream _stream;
FileHeadStruct _fileHead;
bool _readFileHead();
pthread_mutex_t _mutex;
public:
ArrayFile();
~ArrayFile();
bool create(const std::string &file, size_t recordSize);
bool open(const std::string &file);
void close();
size_t insert(TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
bool get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace = 0);
void remove(const size_t id);
bool isOpen() const;
size_t size();
size_t getRecordSize() { return _fileHead.recordSize; }
};
// constructor
template <class TYPE>
ArrayFile<TYPE>::ArrayFile()
: _isOpen(false), _mutex((pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER){
if(pthread_mutex_init(&_mutex, NULL) < 0) throw std::runtime_error("pthread init error.");
}
// destructor
template <class TYPE>
ArrayFile<TYPE>::~ArrayFile() {
pthread_mutex_destroy(&_mutex);
close();
}
template <class TYPE>
bool ArrayFile<TYPE>::create(const std::string &file, size_t recordSize) {
std::fstream tmpstream;
tmpstream.open(file.c_str());
if(tmpstream){
return false;
}
tmpstream.open(file.c_str(), std::ios::out);
tmpstream.seekp(0, std::ios::beg);
FileHeadStruct fileHead = {recordSize, 0};
tmpstream.write((char *)(&fileHead), sizeof(FileHeadStruct));
tmpstream.close();
return true;
}
template <class TYPE>
bool ArrayFile<TYPE>::open(const std::string &file) {
_stream.open(file.c_str(), std::ios::in | std::ios::out);
if(!_stream){
_isOpen = false;
return false;
}
_isOpen = true;
bool ret = _readFileHead();
return ret;
}
template <class TYPE>
void ArrayFile<TYPE>::close(){
_stream.close();
_isOpen = false;
}
template <class TYPE>
size_t ArrayFile<TYPE>::insert(TYPE &data, NGT::ObjectSpace *objectSpace) {
_stream.seekp(sizeof(RecordStruct), std::ios::end);
int64_t write_pos = _stream.tellg();
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(write_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t id = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
if(offset_pos % (sizeof(RecordStruct) + _fileHead.recordSize) == 0){
id -= 1;
}
return id;
}
template <class TYPE>
void ArrayFile<TYPE>::put(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekp(offset_pos, std::ios::beg);
for(size_t i = 0; i < _fileHead.recordSize; i++) { _stream.write("", 1); }
_stream.seekp(offset_pos, std::ios::beg);
data.serialize(_stream, objectSpace);
}
template <class TYPE>
bool ArrayFile<TYPE>::get(const size_t id, TYPE &data, NGT::ObjectSpace *objectSpace) {
pthread_mutex_lock(&_mutex);
if( size() <= id ){
pthread_mutex_unlock(&_mutex);
return false;
}
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
offset_pos += sizeof(RecordStruct);
_stream.seekg(offset_pos, std::ios::beg);
if (!_stream.fail()) {
data.deserialize(_stream, objectSpace);
}
if (_stream.fail()) {
const int trialCount = 10;
for (int tc = 0; tc < trialCount; tc++) {
_stream.clear();
_stream.seekg(offset_pos, std::ios::beg);
if (_stream.fail()) {
continue;
}
data.deserialize(_stream, objectSpace);
if (_stream.fail()) {
continue;
} else {
break;
}
}
if (_stream.fail()) {
throw std::runtime_error("ArrayFile::get: Error!");
}
}
pthread_mutex_unlock(&_mutex);
return true;
}
template <class TYPE>
void ArrayFile<TYPE>::remove(const size_t id) {
uint64_t offset_pos = (id * (sizeof(RecordStruct) + _fileHead.recordSize)) + sizeof(FileHeadStruct);
_stream.seekp(offset_pos, std::ios::beg);
RecordStruct recordHead = {1, 0};
_stream.write((char *)(&recordHead), sizeof(RecordStruct));
}
template <class TYPE>
bool ArrayFile<TYPE>::isOpen() const
{
return _isOpen;
}
template <class TYPE>
size_t ArrayFile<TYPE>::size()
{
_stream.seekp(0, std::ios::end);
int64_t offset_pos = _stream.tellg();
offset_pos -= sizeof(FileHeadStruct);
size_t num = offset_pos / (sizeof(RecordStruct) + _fileHead.recordSize);
return num;
}
template <class TYPE>
bool ArrayFile<TYPE>::_readFileHead() {
_stream.seekp(0, std::ios::beg);
_stream.read((char *)(&_fileHead), sizeof(FileHeadStruct));
if(_stream.bad()){
return false;
}
return true;
}

View File

@ -1,40 +0,0 @@
if( ${UNIX} )
option(NGT_SHARED_MEMORY_ALLOCATOR "enable shared memory" OFF)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/defines.h.in ${CMAKE_CURRENT_BINARY_DIR}/defines.h)
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/lib" "${PROJECT_BINARY_DIR}/lib/")
include_directories("${PROJECT_SOURCE_DIR}/../")
file(GLOB NGT_SOURCES *.cpp)
file(GLOB HEADER_FILES *.h *.hpp)
file(GLOB NGTQ_HEADER_FILES NGTQ/*.h NGTQ/*.hpp)
add_library(ngtstatic STATIC ${NGT_SOURCES})
set_target_properties(ngtstatic PROPERTIES OUTPUT_NAME ngt)
set_target_properties(ngtstatic PROPERTIES COMPILE_FLAGS "-fPIC")
target_link_libraries(ngtstatic)
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngtstatic OpenMP::OpenMP_CXX)
endif()
add_library(ngt SHARED ${NGT_SOURCES})
set_target_properties(ngt PROPERTIES VERSION ${ngt_VERSION})
set_target_properties(ngt PROPERTIES SOVERSION ${ngt_SOVERSION})
add_dependencies(ngt ngtstatic)
if(${APPLE})
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
target_link_libraries(ngt OpenMP::OpenMP_CXX)
else()
target_link_libraries(ngt gomp)
endif()
else(${APPLE})
target_link_libraries(ngt gomp rt)
endif(${APPLE})
install(TARGETS
ngt
ngtstatic
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
endif()

View File

@ -1,988 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <string>
#include <iostream>
#include <sstream>
#include "NGT/Index.h"
#include "NGT/GraphOptimizer.h"
#include "Capi.h"
static bool operate_error_string_(const std::stringstream &ss, NGTError error){
if(error != NULL){
try{
std::string *error_str = static_cast<std::string*>(error);
*error_str = ss.str();
}catch(std::exception &err){
std::cerr << ss.str() << " > " << err.what() << std::endl;
return false;
}
}else{
std::cerr << ss.str() << std::endl;
}
return true;
}
NGTIndex ngt_open_index(const char *index_path, NGTError error) {
try{
std::string index_path_str(index_path);
NGT::Index *index = new NGT::Index(index_path_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree(const char *database, NGTProperty prop, NGTError error) {
NGT::Index *index = NULL;
try{
std::string database_str(database);
NGT::Property prop_i = *(static_cast<NGT::Property*>(prop));
NGT::Index::createGraphAndTree(database_str, prop_i, true);
index = new NGT::Index(database_str);
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
delete index;
return NULL;
}
}
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty prop, NGTError error) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << __FUNCTION__ << " is unavailable for shared-memory-type NGT.";
operate_error_string_(ss, error);
return NULL;
#else
try{
NGT::Index *index = new NGT::GraphAndTreeIndex(*(static_cast<NGT::Property*>(prop)));
index->disableLog();
return static_cast<NGTIndex>(index);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
#endif
}
NGTProperty ngt_create_property(NGTError error) {
try{
return static_cast<NGTProperty>(new NGT::Property());
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_save_index(const NGTIndex index, const char *database, NGTError error) {
try{
std::string database_str(database);
(static_cast<NGT::Index*>(index))->saveIndex(database_str);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_property(NGTIndex index, NGTProperty prop, NGTError error) {
if(index == NULL || prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " prop = " << prop;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->getProperty(*(static_cast<NGT::Property*>(prop)));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
int32_t ngt_get_property_dimension(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).dimension;
}
bool ngt_set_property_dimension(NGTProperty prop, int32_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).dimension = value;
return true;
}
bool ngt_set_property_edge_size_for_creation(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForCreation = value;
return true;
}
bool ngt_set_property_edge_size_for_search(NGTProperty prop, int16_t value, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).edgeSizeForSearch = value;
return true;
}
int32_t ngt_get_property_object_type(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).objectType;
}
bool ngt_is_property_object_type_float(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Float);
}
bool ngt_is_property_object_type_integer(int32_t object_type) {
return (object_type == NGT::ObjectSpace::ObjectType::Uint8);
}
bool ngt_set_property_object_type_float(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Float;
return true;
}
bool ngt_set_property_object_type_integer(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).objectType = NGT::ObjectSpace::ObjectType::Uint8;
return true;
}
bool ngt_set_property_distance_type_l1(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1;
return true;
}
bool ngt_set_property_distance_type_l2(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
return true;
}
bool ngt_set_property_distance_type_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeAngle;
return true;
}
bool ngt_set_property_distance_type_hamming(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeHamming;
return true;
}
bool ngt_set_property_distance_type_jaccard(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeJaccard;
return true;
}
bool ngt_set_property_distance_type_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeCosine;
return true;
}
bool ngt_set_property_distance_type_normalized_angle(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedAngle;
return true;
}
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return false;
}
(*static_cast<NGT::Property*>(prop)).distanceType = NGT::Index::Property::DistanceType::DistanceTypeNormalizedCosine;
return true;
}
NGTObjectDistances ngt_create_empty_results(NGTError error) {
try{
return static_cast<NGTObjectDistances>(new NGT::ObjectDistances());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
static bool ngt_search_index_(NGT::Index* pindex, NGT::Object *ngtquery, size_t size, float epsilon, float radius, NGTObjectDistances results, int edge_size = INT_MIN) {
// set search prameters.
NGT::SearchContainer sc(*ngtquery); // search parametera container.
sc.setResults(static_cast<NGT::ObjectDistances*>(results)); // set the result set.
sc.setSize(size); // the number of resultant objects.
sc.setRadius(radius); // search radius.
sc.setEpsilon(epsilon); // set exploration coefficient.
if (edge_size != INT_MIN) {
sc.setEdgeSize(edge_size);// set # of edges for each node
}
pindex->search(sc);
// delete the query object.
pindex->deleteObject(ngtquery);
return true;
}
bool ngt_search_index(NGTIndex index, double *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<double> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_as_float(NGTIndex index, float *query, int32_t query_dim, size_t size, float epsilon, float radius, NGTObjectDistances results, NGTError error) {
if(index == NULL || query == NULL || results == NULL || query_dim <= 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query << " results = " << results << " query_dim = " << query_dim;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::Object *ngtquery = NULL;
if(radius < 0.0){
radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query[0], &query[query_dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, size, epsilon, radius, results);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
bool ngt_search_index_with_query(NGTIndex index, NGTQuery query, NGTObjectDistances results, NGTError error) {
if(index == NULL || query.query == NULL || results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " query = " << query.query << " results = " << results;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
NGT::Object *ngtquery = NULL;
if(query.radius < 0.0){
query.radius = FLT_MAX;
}
try{
std::vector<float> vquery(&query.query[0], &query.query[dim]);
ngtquery = pindex->allocateObject(vquery);
ngt_search_index_(pindex, ngtquery, query.size, query.epsilon, query.radius, results, query.edge_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
if(ngtquery != NULL){
pindex->deleteObject(ngtquery);
}
return false;
}
return true;
}
// * deprecated *
int32_t ngt_get_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return -1;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
uint32_t ngt_get_result_size(NGTObjectDistances results, NGTError error) {
if(results == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: results = " << results;
operate_error_string_(ss, error);
return 0;
}
return (static_cast<NGT::ObjectDistances*>(results))->size();
}
NGTObjectDistance ngt_get_result(const NGTObjectDistances results, const uint32_t i, NGTError error) {
try{
NGT::ObjectDistances objects = *(static_cast<NGT::ObjectDistances*>(results));
NGTObjectDistance ret_val = {objects[i].id, objects[i].distance};
return ret_val;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
NGTObjectDistance err_val = {0};
return err_val;
}
}
ObjectID ngt_insert_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index(NGTIndex index, double *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<double> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_insert_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->insert(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
ObjectID ngt_append_index_as_float(NGTIndex index, float *obj, uint32_t obj_dim, NGTError error) {
if(index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
std::vector<float> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}
bool ngt_batch_append_index(NGTIndex index, float *obj, uint32_t data_count, NGTError error) {
try{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
pindex->append(obj, data_count);
return true;
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
}
bool ngt_batch_insert_index(NGTIndex index, float *obj, uint32_t data_count, uint32_t *ids, NGTError error) {
NGT::Index* pindex = static_cast<NGT::Index*>(index);
int32_t dim = pindex->getObjectSpace().getDimension();
bool status = true;
float *objptr = obj;
for (size_t idx = 0; idx < data_count; idx++, objptr += dim) {
try{
std::vector<double> vobj(objptr, objptr + dim);
ids[idx] = pindex->insert(vobj);
}catch(std::exception &err) {
status = false;
ids[idx] = 0;
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
}
}
return status;
}
bool ngt_create_index(NGTIndex index, uint32_t pool_size, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->createIndex(pool_size);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_remove_index(NGTIndex index, ObjectID id, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::Index*>(index))->remove(id);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
NGTObjectSpace ngt_get_object_space(NGTIndex index, NGTError error) {
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: idnex = " << index;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<NGTObjectSpace>(&(static_cast<NGT::Index*>(index))->getObjectSpace());
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
float* ngt_get_object_as_float(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<float*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
uint8_t* ngt_get_object_as_integer(NGTObjectSpace object_space, ObjectID id, NGTError error) {
if(object_space == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: object_space = " << object_space;
operate_error_string_(ss, error);
return NULL;
}
try{
return static_cast<uint8_t*>((static_cast<NGT::ObjectSpace*>(object_space))->getObject(id));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
void ngt_destroy_results(NGTObjectDistances results) {
if(results == NULL) return;
delete(static_cast<NGT::ObjectDistances*>(results));
}
void ngt_destroy_property(NGTProperty prop) {
if(prop == NULL) return;
delete(static_cast<NGT::Property*>(prop));
}
void ngt_close_index(NGTIndex index) {
if(index == NULL) return;
(static_cast<NGT::Index*>(index))->close();
delete(static_cast<NGT::Index*>(index));
}
int16_t ngt_get_property_edge_size_for_creation(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForCreation;
}
int16_t ngt_get_property_edge_size_for_search(NGTProperty prop, NGTError error) {
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).edgeSizeForSearch;
}
int32_t ngt_get_property_distance_type(NGTProperty prop, NGTError error){
if(prop == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: prop = " << prop;
operate_error_string_(ss, error);
return -1;
}
return (*static_cast<NGT::Property*>(prop)).distanceType;
}
NGTError ngt_create_error_object()
{
try{
std::string *error_str = new std::string();
return static_cast<NGTError>(error_str);
}catch(std::exception &err){
std::cerr << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
return NULL;
}
}
const char *ngt_get_error_string(const NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
return error_str->c_str();
}
void ngt_clear_error_string(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
*error_str = "";
}
void ngt_destroy_error_object(NGTError error)
{
std::string *error_str = static_cast<std::string*>(error);
delete error_str;
}
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError error)
{
try{
return static_cast<NGTOptimizer>(new NGT::GraphOptimizer(logDisabled));
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return NULL;
}
}
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer optimizer, const char *index, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->adjustSearchCoefficients(std::string(index));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_execute(NGTOptimizer optimizer, const char *inIndex, const char *outIndex, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->execute(std::string(inIndex), std::string(outIndex));
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
// obsolute because of a lack of a parameter
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
int nofqs, int nofrs, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->set(outgoing, incoming, nofqs, nofrs);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error) {
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
try{
(static_cast<NGT::GraphOptimizer*>(optimizer))->setExtension(baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
bool prefetchParameter, bool accuracyTable, NGTError error)
{
if(optimizer == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: optimizer = " << optimizer;
operate_error_string_(ss, error);
return false;
}
(static_cast<NGT::GraphOptimizer*>(optimizer))->setProcessingModes(searchParameter, prefetchParameter,
accuracyTable);
return true;
}
void ngt_destroy_optimizer(NGTOptimizer optimizer)
{
if(optimizer == NULL) return;
delete(static_cast<NGT::GraphOptimizer*>(optimizer));
}
bool ngt_refine_anng(NGTIndex index, float epsilon, float accuracy, int noOfEdges, int exploreEdgeSize, size_t batchSize, NGTError error)
{
NGT::Index* pindex = static_cast<NGT::Index*>(index);
try {
NGT::GraphReconstructor::refineANNG(*pindex, true, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
} catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index* pindex = static_cast<NGT::Index*>(index);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(pindex->getIndex());
try {
NGT::ObjectDistances &objects = *static_cast<NGT::ObjectDistances*>(edges);
objects = *graph.getNode(id);
}catch(std::exception &err){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error)
{
if(index == NULL){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return false;
}
NGT::Index& pindex = *static_cast<NGT::Index*>(index);
return pindex.getObjectRepositorySize();
}
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter()
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter gp;
NGTAnngEdgeOptimizationParameter parameter;
parameter.no_of_queries = gp.noOfQueries;
parameter.no_of_results = gp.noOfResults;
parameter.no_of_threads = gp.noOfThreads;
parameter.target_accuracy = gp.targetAccuracy;
parameter.target_no_of_objects = gp.targetNoOfObjects;
parameter.no_of_sample_objects = gp.noOfSampleObjects;
parameter.max_of_no_of_edges = gp.maxNoOfEdges;
parameter.log = false;
return parameter;
}
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error)
{
NGT::GraphOptimizer::ANNGEdgeOptimizationParameter p;
p.noOfQueries = parameter.no_of_queries;
p.noOfResults = parameter.no_of_results;
p.noOfThreads = parameter.no_of_threads;
p.targetAccuracy = parameter.target_accuracy;
p.targetNoOfObjects = parameter.target_no_of_objects;
p.noOfSampleObjects = parameter.no_of_sample_objects;
p.maxNoOfEdges = parameter.max_of_no_of_edges;
try {
NGT::GraphOptimizer graphOptimizer(!parameter.log); // false=log
std::string path(indexPath);
auto edge = graphOptimizer.optimizeNumberOfEdgesForANNG(path, p);
if (parameter.log) {
std::cerr << "the optimized number of edges is" << edge.first << "(" << edge.second << ")" << std::endl;
}
}catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return false;
}
return true;
}

View File

@ -1,210 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
typedef unsigned int ObjectID;
typedef void* NGTIndex;
typedef void* NGTProperty;
typedef void* NGTObjectSpace;
typedef void* NGTObjectDistances;
typedef void* NGTError;
typedef void* NGTOptimizer;
typedef struct {
ObjectID id;
float distance;
} NGTObjectDistance;
typedef struct {
float *query;
size_t size; // # of returned objects
float epsilon;
float accuracy; // expected accuracy
float radius;
size_t edge_size; // # of edges to explore for each node
} NGTQuery;
typedef struct {
size_t no_of_queries;
size_t no_of_results;
size_t no_of_threads;
float target_accuracy;
size_t target_no_of_objects;
size_t no_of_sample_objects;
size_t max_of_no_of_edges;
bool log;
} NGTAnngEdgeOptimizationParameter;
NGTIndex ngt_open_index(const char *, NGTError);
NGTIndex ngt_create_graph_and_tree(const char *, NGTProperty, NGTError);
NGTIndex ngt_create_graph_and_tree_in_memory(NGTProperty, NGTError);
NGTProperty ngt_create_property(NGTError);
bool ngt_save_index(const NGTIndex, const char *, NGTError);
bool ngt_get_property(const NGTIndex, NGTProperty, NGTError);
int32_t ngt_get_property_dimension(NGTProperty, NGTError);
bool ngt_set_property_dimension(NGTProperty, int32_t, NGTError);
bool ngt_set_property_edge_size_for_creation(NGTProperty, int16_t, NGTError);
bool ngt_set_property_edge_size_for_search(NGTProperty, int16_t, NGTError);
int32_t ngt_get_property_object_type(NGTProperty, NGTError);
bool ngt_is_property_object_type_float(int32_t);
bool ngt_is_property_object_type_integer(int32_t);
bool ngt_set_property_object_type_float(NGTProperty, NGTError);
bool ngt_set_property_object_type_integer(NGTProperty, NGTError);
bool ngt_set_property_distance_type_l1(NGTProperty, NGTError);
bool ngt_set_property_distance_type_l2(NGTProperty, NGTError);
bool ngt_set_property_distance_type_angle(NGTProperty, NGTError);
bool ngt_set_property_distance_type_hamming(NGTProperty, NGTError);
bool ngt_set_property_distance_type_jaccard(NGTProperty, NGTError);
bool ngt_set_property_distance_type_cosine(NGTProperty, NGTError);
bool ngt_set_property_distance_type_normalized_angle(NGTProperty, NGTError);
bool ngt_set_property_distance_type_normalized_cosine(NGTProperty, NGTError);
NGTObjectDistances ngt_create_empty_results(NGTError);
bool ngt_search_index(NGTIndex, double*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
bool ngt_search_index_as_float(NGTIndex, float*, int32_t, size_t, float, float, NGTObjectDistances, NGTError);
bool ngt_search_index_with_query(NGTIndex, NGTQuery, NGTObjectDistances, NGTError);
int32_t ngt_get_size(NGTObjectDistances, NGTError); // deprecated
uint32_t ngt_get_result_size(NGTObjectDistances, NGTError);
NGTObjectDistance ngt_get_result(const NGTObjectDistances, const uint32_t, NGTError);
ObjectID ngt_insert_index(NGTIndex, double*, uint32_t, NGTError);
ObjectID ngt_append_index(NGTIndex, double*, uint32_t, NGTError);
ObjectID ngt_insert_index_as_float(NGTIndex, float*, uint32_t, NGTError);
ObjectID ngt_append_index_as_float(NGTIndex, float*, uint32_t, NGTError);
bool ngt_batch_append_index(NGTIndex, float*, uint32_t, NGTError);
bool ngt_batch_insert_index(NGTIndex, float*, uint32_t, uint32_t *, NGTError);
bool ngt_create_index(NGTIndex, uint32_t, NGTError);
bool ngt_remove_index(NGTIndex, ObjectID, NGTError);
NGTObjectSpace ngt_get_object_space(NGTIndex, NGTError);
float* ngt_get_object_as_float(NGTObjectSpace, ObjectID, NGTError);
uint8_t* ngt_get_object_as_integer(NGTObjectSpace, ObjectID, NGTError);
void ngt_destroy_results(NGTObjectDistances);
void ngt_destroy_property(NGTProperty);
void ngt_close_index(NGTIndex);
int16_t ngt_get_property_edge_size_for_creation(NGTProperty, NGTError);
int16_t ngt_get_property_edge_size_for_search(NGTProperty, NGTError);
int32_t ngt_get_property_distance_type(NGTProperty, NGTError);
NGTError ngt_create_error_object();
const char *ngt_get_error_string(const NGTError);
void ngt_clear_error_string(NGTError);
void ngt_destroy_error_object(NGTError);
NGTOptimizer ngt_create_optimizer(bool logDisabled, NGTError);
bool ngt_optimizer_adjust_search_coefficients(NGTOptimizer, const char *, NGTError);
bool ngt_optimizer_execute(NGTOptimizer, const char *, const char *, NGTError);
bool ngt_optimizer_set(NGTOptimizer optimizer, int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error);
bool ngt_optimizer_set_minimum(NGTOptimizer optimizer, int outgoing, int incoming,
int nofqs, int nofrs, NGTError error);
bool ngt_optimizer_set_extension(NGTOptimizer optimizer,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m, NGTError error);
bool ngt_optimizer_set_processing_modes(NGTOptimizer optimizer, bool searchParameter,
bool prefetchParameter, bool accuracyTable, NGTError error);
void ngt_destroy_optimizer(NGTOptimizer);
// refine: the specified index by searching each node.
// epsilon, exepectedAccuracy and edgeSize: the same as the prameters for search. but if edgeSize is INT_MIN, default is used.
// noOfEdges: if this is not 0, kNNG with k = noOfEdges is build
// batchSize: batch size for parallelism.
bool ngt_refine_anng(NGTIndex index, float epsilon, float expectedAccuracy,
int noOfEdges, int edgeSize, size_t batchSize, NGTError error);
// get edges of the node that is specified with id.
bool ngt_get_edges(NGTIndex index, ObjectID id, NGTObjectDistances edges, NGTError error);
// get the size of the specified object repository.
// Since the size includes empty objects, the size is not the number of objects.
// The size is mostly the largest ID of the objects - 1;
uint32_t ngt_get_object_repository_size(NGTIndex index, NGTError error);
// return parameters for ngt_optimize_number_of_edges. You can customize them before calling ngt_optimize_number_of_edges.
NGTAnngEdgeOptimizationParameter ngt_get_anng_edge_optimization_parameter();
// optimize the number of initial edges for ANNG that is specified with indexPath.
// The parameter should be a struct which is returned by nt_get_optimization_parameter.
bool ngt_optimize_number_of_edges(const char *indexPath, NGTAnngEdgeOptimizationParameter parameter, NGTError error);
#ifdef __cplusplus
}
#endif

View File

@ -1,880 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Index.h"
#include "defines.h"
using namespace std;
#if defined(NGT_AVX_DISABLED)
#define NGT_CLUSTER_NO_AVX
#else
#if defined(__AVX2__)
#define NGT_CLUSTER_AVX2
#else
#define NGT_CLUSTER_NO_AVX
#endif
#endif
#if defined(NGT_CLUSTER_NO_AVX)
// #warning "*** SIMD is *NOT* available! ***"
#else
#include <immintrin.h>
#endif
#include <omp.h>
#include <random>
namespace NGT {
class Clustering {
public:
enum InitializationMode {
InitializationModeHead = 0,
InitializationModeRandom = 1,
InitializationModeKmeansPlusPlus = 2
};
enum ClusteringType {
ClusteringTypeKmeansWithNGT = 0,
ClusteringTypeKmeansWithoutNGT = 1,
ClusteringTypeKmeansWithIteration = 2,
ClusteringTypeKmeansWithNGTForCentroids = 3
};
class Entry {
public:
Entry() : vectorID(0), centroidID(0), distance(0.0) {
}
Entry(size_t vid, size_t cid, double d) : vectorID(vid), centroidID(cid), distance(d) {
}
bool
operator<(const Entry& e) const {
return distance > e.distance;
}
uint32_t vectorID;
uint32_t centroidID;
double distance;
};
class DescendingEntry {
public:
DescendingEntry(size_t vid, double d) : vectorID(vid), distance(d) {
}
bool
operator<(const DescendingEntry& e) const {
return distance < e.distance;
}
size_t vectorID;
double distance;
};
class Cluster {
public:
Cluster(std::vector<float>& c) : centroid(c), radius(0.0) {
}
Cluster(const Cluster& c) {
*this = c;
}
Cluster&
operator=(const Cluster& c) {
members = c.members;
centroid = c.centroid;
radius = c.radius;
return *this;
}
std::vector<Entry> members;
std::vector<float> centroid;
double radius;
};
Clustering(InitializationMode im = InitializationModeHead, ClusteringType ct = ClusteringTypeKmeansWithNGT,
size_t mi = 100)
: clusteringType(ct), initializationMode(im), maximumIteration(mi) {
initialize();
}
void
initialize() {
epsilonFrom = 0.12;
epsilonTo = epsilonFrom;
epsilonStep = 0.04;
resultSizeCoefficient = 5;
}
static void
convert(std::vector<std::string>& strings, std::vector<float>& vector) {
vector.clear();
for (auto it = strings.begin(); it != strings.end(); ++it) {
vector.push_back(stod(*it));
}
}
static void
extractVector(const std::string& str, std::vector<float>& vec) {
std::vector<std::string> tokens;
NGT::Common::tokenize(str, tokens, " \t");
convert(tokens, vec);
}
static void
loadVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
std::ifstream is(file);
if (!is) {
throw std::runtime_error("loadVectors::Cannot open " + file);
}
std::string line;
while (getline(is, line)) {
std::vector<float> v;
extractVector(line, v);
vectors.push_back(v);
}
}
static void
saveVectors(const std::string& file, std::vector<std::vector<float> >& vectors) {
std::ofstream os(file);
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
std::vector<float>& v = *vit;
for (auto it = v.begin(); it != v.end(); ++it) {
os << std::setprecision(9) << (*it);
if (it + 1 != v.end()) {
os << "\t";
}
}
os << std::endl;
}
}
static void
saveVector(const std::string& file, std::vector<size_t>& vectors) {
std::ofstream os(file);
for (auto vit = vectors.begin(); vit != vectors.end(); ++vit) {
os << *vit << std::endl;
}
}
static void
loadClusters(const std::string& file, std::vector<Cluster>& clusters, size_t numberOfClusters = 0) {
std::ifstream is(file);
if (!is) {
throw std::runtime_error("loadClusters::Cannot open " + file);
}
std::string line;
while (getline(is, line)) {
std::vector<float> v;
extractVector(line, v);
clusters.push_back(v);
if ((numberOfClusters != 0) && (clusters.size() >= numberOfClusters)) {
break;
}
}
if ((numberOfClusters != 0) && (clusters.size() < numberOfClusters)) {
// std::cerr << "initial cluster data are not enough. " << clusters.size() << ":" << numberOfClusters
// << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("initial cluster data are not enough. " + std::to_string(clusters.size()) + ":" + std::to_string(numberOfClusters));
exit(1);
}
}
#if !defined(NGT_CLUSTER_NO_AVX)
static double
sumOfSquares(float* a, float* b, size_t size) {
__m256 sum = _mm256_setzero_ps();
float* last = a + size;
float* lastgroup = last - 7;
while (a < lastgroup) {
__m256 v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum = _mm256_add_ps(sum, _mm256_mul_ps(v, v));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[8];
_mm256_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
while (a < last) {
double d = *a++ - *b++;
s += d * d;
}
return s;
}
#else // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
static double
sumOfSquares(float* a, float* b, size_t size) {
double csum = 0.0;
float* x = a;
float* y = b;
for (size_t i = 0; i < size; i++) {
double d = (double)*x++ - (double)*y++;
csum += d * d;
}
return csum;
}
#endif // !defined(NGT_AVX_DISABLED) && defined(__AVX__)
static double
distanceL2(std::vector<float>& vector1, std::vector<float>& vector2) {
return sqrt(sumOfSquares(&vector1[0], &vector2[0], vector1.size()));
}
static double
distanceL2(std::vector<std::vector<float> >& vector1, std::vector<std::vector<float> >& vector2) {
assert(vector1.size() == vector2.size());
double distance = 0.0;
for (size_t i = 0; i < vector1.size(); i++) {
distance += distanceL2(vector1[i], vector2[i]);
}
distance /= (double)vector1.size();
return distance;
}
static double
meanSumOfSquares(std::vector<float>& vector1, std::vector<float>& vector2) {
return sumOfSquares(&vector1[0], &vector2[0], vector1.size()) / (double)vector1.size();
}
static void
subtract(std::vector<float>& a, std::vector<float>& b) {
assert(a.size() == b.size());
auto bit = b.begin();
for (auto ait = a.begin(); ait != a.end(); ++ait, ++bit) {
*ait = *ait - *bit;
}
}
static void
getInitialCentroidsFromHead(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t size) {
size = size > vectors.size() ? vectors.size() : size;
clusters.clear();
for (size_t i = 0; i < size; i++) {
clusters.push_back(Cluster(vectors[i]));
}
}
static void
getInitialCentroidsRandomly(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, size_t size,
size_t seed) {
clusters.clear();
std::random_device rnd;
if (seed == 0) {
seed = rnd();
}
std::mt19937 mt(seed);
for (size_t i = 0; i < size; i++) {
size_t idx = mt() * vectors.size() / mt.max();
if (idx >= size) {
i--;
continue;
}
clusters.push_back(Cluster(vectors[idx]));
}
assert(clusters.size() == size);
}
static void
getInitialCentroidsKmeansPlusPlus(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t size) {
size = size > vectors.size() ? vectors.size() : size;
clusters.clear();
std::random_device rnd;
std::mt19937 mt(rnd());
size_t idx = (long long)mt() * (long long)vectors.size() / (long long)mt.max();
clusters.push_back(Cluster(vectors[idx]));
NGT::Timer timer;
for (size_t k = 1; k < size; k++) {
double sum = 0;
std::priority_queue<DescendingEntry> sortedObjects;
// get d^2 and sort
#pragma omp parallel for
for (size_t vi = 0; vi < vectors.size(); vi++) {
auto vit = vectors.begin() + vi;
double mind = DBL_MAX;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(*vit, (*cit).centroid);
d *= d;
if (d < mind) {
mind = d;
}
}
#pragma omp critical
{
sortedObjects.push(DescendingEntry(distance(vectors.begin(), vit), mind));
sum += mind;
}
}
double l = (double)mt() / (double)mt.max() * sum;
while (!sortedObjects.empty()) {
sum -= sortedObjects.top().distance;
if (l >= sum) {
clusters.push_back(Cluster(vectors[sortedObjects.top().vectorID]));
break;
}
sortedObjects.pop();
}
}
}
static void
assign(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
size_t clusterSize = std::numeric_limits<size_t>::max()) {
// compute distances to the nearest clusters, and construct heap by the distances.
NGT::Timer timer;
timer.start();
std::vector<Entry> sortedObjects(vectors.size());
#pragma omp parallel for
for (size_t vi = 0; vi < vectors.size(); vi++) {
auto vit = vectors.begin() + vi;
{
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(*vit, (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
sortedObjects[vi] = Entry(vi, mincidx, mind);
}
}
std::sort(sortedObjects.begin(), sortedObjects.end());
// clear
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
(*cit).members.clear();
}
// distribute objects to the nearest clusters in the same size constraint.
for (auto soi = sortedObjects.rbegin(); soi != sortedObjects.rend();) {
Entry& entry = *soi;
if (entry.centroidID >= clusters.size()) {
// std::cerr << "Something wrong. " << entry.centroidID << ":" << clusters.size() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Something wrong. " + std::to_string(entry.centroidID) + ":" + std::to_string(clusters.size()));
soi++;
continue;
}
if (clusters[entry.centroidID].members.size() < clusterSize) {
clusters[entry.centroidID].members.push_back(entry);
soi++;
} else {
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() >= clusterSize) {
continue;
}
double d = distanceL2(vectors[entry.vectorID], (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
entry = Entry(entry.vectorID, mincidx, mind);
int pt = distance(sortedObjects.rbegin(), soi);
std::sort(sortedObjects.begin(), soi.base());
soi = sortedObjects.rbegin() + pt;
assert(pt == distance(sortedObjects.rbegin(), soi));
}
}
moveFartherObjectsToEmptyClusters(clusters);
}
static void
moveFartherObjectsToEmptyClusters(std::vector<Cluster>& clusters) {
size_t emptyClusterCount = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() == 0) {
emptyClusterCount++;
double max = 0.0;
auto maxit = clusters.begin();
for (auto scit = clusters.begin(); scit != clusters.end(); ++scit) {
if ((*scit).members.size() >= 2 && (*scit).members.back().distance > max) {
maxit = scit;
max = (*scit).members.back().distance;
}
}
(*cit).members.push_back((*maxit).members.back());
(*cit).members.back().centroidID = distance(clusters.begin(), cit);
(*maxit).members.pop_back();
}
}
emptyClusterCount = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
if ((*cit).members.size() == 0) {
emptyClusterCount++;
}
}
}
static void
assignWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
float& radius, size_t& resultSize, float epsilon = 0.12, size_t notRetrievedObjectCount = 0) {
size_t dataSize = vectors.size();
assert(index.getObjectRepositorySize() - 1 == vectors.size());
vector<vector<Entry> > results(clusters.size());
#pragma omp parallel for
for (size_t ci = 0; ci < clusters.size(); ci++) {
auto cit = clusters.begin() + ci;
NGT::ObjectDistances objects; // result set
NGT::Object* query = 0;
query = index.allocateObject((*cit).centroid);
// set search prameters.
NGT::SearchContainer sc(*query); // search parametera container.
sc.setResults(&objects); // set the result set.
sc.setEpsilon(epsilon); // set exploration coefficient.
if (radius > 0.0) {
sc.setRadius(radius);
sc.setSize(dataSize / 2);
} else {
sc.setSize(resultSize); // the number of resultant objects.
}
index.search(sc);
results[ci].reserve(objects.size());
for (size_t idx = 0; idx < objects.size(); idx++) {
size_t oidx = objects[idx].id - 1;
results[ci].push_back(Entry(oidx, ci, objects[idx].distance));
}
index.deleteObject(query);
}
size_t resultCount = 0;
for (auto ri = results.begin(); ri != results.end(); ++ri) {
resultCount += (*ri).size();
}
vector<Entry> sortedResults;
sortedResults.reserve(resultCount);
for (auto ri = results.begin(); ri != results.end(); ++ri) {
auto end = (*ri).begin();
for (; end != (*ri).end(); ++end) {
}
std::copy((*ri).begin(), end, std::back_inserter(sortedResults));
}
vector<bool> processedObjects(dataSize, false);
for (auto i = sortedResults.begin(); i != sortedResults.end(); ++i) {
processedObjects[(*i).vectorID] = true;
}
notRetrievedObjectCount = 0;
vector<uint32_t> notRetrievedObjectIDs;
for (size_t idx = 0; idx < dataSize; idx++) {
if (!processedObjects[idx]) {
notRetrievedObjectCount++;
notRetrievedObjectIDs.push_back(idx);
}
}
sort(sortedResults.begin(), sortedResults.end());
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
(*cit).members.clear();
}
for (auto i = sortedResults.rbegin(); i != sortedResults.rend(); ++i) {
size_t objectID = (*i).vectorID;
size_t clusterID = (*i).centroidID;
if (processedObjects[objectID]) {
processedObjects[objectID] = false;
clusters[clusterID].members.push_back(*i);
clusters[clusterID].members.back().centroidID = clusterID;
radius = (*i).distance;
}
}
vector<Entry> notRetrievedObjects(notRetrievedObjectIDs.size());
#pragma omp parallel for
for (size_t vi = 0; vi < notRetrievedObjectIDs.size(); vi++) {
auto vit = notRetrievedObjectIDs.begin() + vi;
{
double mind = DBL_MAX;
size_t mincidx = -1;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
double d = distanceL2(vectors[*vit], (*cit).centroid);
if (d < mind) {
mind = d;
mincidx = distance(clusters.begin(), cit);
}
}
notRetrievedObjects[vi] = Entry(*vit, mincidx, mind); // Entry(vectorID, centroidID, distance)
}
}
sort(notRetrievedObjects.begin(), notRetrievedObjects.end());
for (auto nroit = notRetrievedObjects.begin(); nroit != notRetrievedObjects.end(); ++nroit) {
clusters[(*nroit).centroidID].members.push_back(*nroit);
}
moveFartherObjectsToEmptyClusters(clusters);
}
static double
calculateCentroid(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double distance = 0;
size_t memberCount = 0;
for (auto it = clusters.begin(); it != clusters.end(); ++it) {
memberCount += (*it).members.size();
if ((*it).members.size() != 0) {
std::vector<float> mean(vectors[0].size(), 0.0);
for (auto memit = (*it).members.begin(); memit != (*it).members.end(); ++memit) {
auto mit = mean.begin();
auto& v = vectors[(*memit).vectorID];
for (auto vit = v.begin(); vit != v.end(); ++vit, ++mit) {
*mit += *vit;
}
}
for (auto mit = mean.begin(); mit != mean.end(); ++mit) {
*mit /= (*it).members.size();
}
distance += distanceL2((*it).centroid, mean);
(*it).centroid = mean;
} else {
// cerr << "Clustering: Fatal Error. No member!" << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Clustering: Fatal Error. No member!");
abort();
}
}
return distance;
}
static void
saveClusters(const std::string& file, std::vector<Cluster>& clusters) {
std::ofstream os(file);
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
std::vector<float>& v = (*cit).centroid;
for (auto it = v.begin(); it != v.end(); ++it) {
os << std::setprecision(9) << (*it);
if (it + 1 != v.end()) {
os << "\t";
}
}
os << std::endl;
}
}
double
kmeansWithoutNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters) {
size_t clusterSize = std::numeric_limits<size_t>::max();
if (clusterSizeConstraint) {
clusterSize = ceil((double)vectors.size() / (double)numberOfClusters);
}
double diff = 0;
for (size_t i = 0; i < maximumIteration; i++) {
// std::cerr << "iteration=" << i << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i));
assign(vectors, clusters, clusterSize);
// centroid is recomputed.
// diff is distance between the current centroids and the previous centroids.
diff = calculateCentroid(vectors, clusters);
if (diff == 0) {
break;
}
}
return diff == 0;
}
double
kmeansWithNGT(NGT::Index& index, std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters, float epsilon) {
diffHistory.clear();
NGT::Timer timer;
timer.start();
float radius;
double diff = 0.0;
size_t resultSize;
resultSize = resultSizeCoefficient * vectors.size() / clusters.size();
for (size_t i = 0; i < maximumIteration; i++) {
size_t notRetrievedObjectCount = 0;
radius = -1.0;
assignWithNGT(index, vectors, clusters, radius, resultSize, epsilon, notRetrievedObjectCount);
// centroid is recomputed.
// diff is distance between the current centroids and the previous centroids.
std::vector<Cluster> prevClusters = clusters;
diff = calculateCentroid(vectors, clusters);
timer.stop();
// std::cerr << "iteration=" << i << " time=" << timer << " diff=" << diff << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("iteration=" + std::to_string(i) + " time=" + std::to_string(timer.time)+ " diff=" + std::to_string(diff));
timer.start();
diffHistory.push_back(diff);
if (diff == 0) {
break;
}
}
return diff;
}
double
kmeansWithNGT(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
pid_t pid = getpid();
std::stringstream str;
str << "cluster-ngt." << pid;
string database = str.str();
string dataFile;
size_t dataSize = 0;
size_t dim = clusters.front().centroid.size();
NGT::Property property;
property.dimension = dim;
property.graphType = NGT::Property::GraphType::GraphTypeANNG;
property.objectType = NGT::Index::Property::ObjectType::Float;
property.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL2;
NGT::Index::createGraphAndTree(database, property, dataFile, dataSize);
float* data = new float[vectors.size() * dim];
float* ptr = data;
dataSize = vectors.size();
for (auto vi = vectors.begin(); vi != vectors.end(); ++vi) {
memcpy(ptr, &((*vi)[0]), dim * sizeof(float));
ptr += dim;
}
size_t threadSize = 20;
NGT::Index::append(database, data, dataSize, threadSize);
delete[] data;
NGT::Index index(database);
return kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilonFrom);
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, std::vector<Cluster>& clusters) {
NGT::GraphIndex& graph = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::ObjectSpace& os = graph.getObjectSpace();
size_t size = os.getRepository().size();
std::vector<std::vector<float> > vectors(size - 1);
for (size_t idx = 1; idx < size; idx++) {
try {
os.getObject(idx, vectors[idx - 1]);
} catch (...) {
// cerr << "Cannot get object " << idx << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Cannot get object " + std::to_string(idx));
}
}
// cerr << "# of data for clustering=" << vectors.size() << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of data for clustering=" + std::to_string(vectors.size()));
double diff = DBL_MAX;
clusters.clear();
setupInitialClusters(vectors, numberOfClusters, clusters);
for (float epsilon = epsilonFrom; epsilon <= epsilonTo; epsilon += epsilonStep) {
// cerr << "epsilon=" << epsilon << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("epsilon=" + std::to_string(epsilon));
diff = kmeansWithNGT(index, vectors, numberOfClusters, clusters, epsilon);
if (diff == 0.0) {
return diff;
}
}
return diff;
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters, NGT::Index& outIndex) {
std::vector<Cluster> clusters;
double diff = kmeansWithNGT(index, numberOfClusters, clusters);
for (auto i = clusters.begin(); i != clusters.end(); ++i) {
outIndex.insert((*i).centroid);
}
outIndex.createIndex(16);
return diff;
}
double
kmeansWithNGT(NGT::Index& index, size_t numberOfClusters) {
NGT::Property prop;
index.getProperty(prop);
string path = index.getPath();
index.save();
index.close();
string outIndexName = path;
string inIndexName = path + ".tmp";
std::rename(outIndexName.c_str(), inIndexName.c_str());
NGT::Index::createGraphAndTree(outIndexName, prop);
index.open(outIndexName);
NGT::Index inIndex(inIndexName);
double diff = kmeansWithNGT(inIndex, numberOfClusters, index);
inIndex.close();
NGT::Index::destroy(inIndexName);
return diff;
}
double
kmeansWithNGT(string& indexName, size_t numberOfClusters) {
NGT::Index inIndex(indexName);
double diff = kmeansWithNGT(inIndex, numberOfClusters);
inIndex.save();
inIndex.close();
return diff;
}
static double
calculateMSE(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double mse = 0.0;
size_t count = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
count += (*cit).members.size();
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
mse += meanSumOfSquares((*cit).centroid, vectors[(*mit).vectorID]);
}
}
assert(vectors.size() == count);
return mse / (double)vectors.size();
}
static double
calculateML2(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters) {
double d = 0.0;
size_t count = 0;
for (auto cit = clusters.begin(); cit != clusters.end(); ++cit) {
count += (*cit).members.size();
double localD = 0.0;
for (auto mit = (*cit).members.begin(); mit != (*cit).members.end(); ++mit) {
double distance = distanceL2((*cit).centroid, vectors[(*mit).vectorID]);
d += distance;
localD += distance;
}
}
if (vectors.size() != count) {
// std::cerr << "Warning! vectors.size() != count" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Warning! vectors.size() != count");
}
return d / (double)vectors.size();
}
static double
calculateML2FromSpecifiedCentroids(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters,
std::vector<size_t>& centroidIds) {
double d = 0.0;
size_t count = 0;
for (auto it = centroidIds.begin(); it != centroidIds.end(); ++it) {
Cluster& cluster = clusters[(*it)];
count += cluster.members.size();
for (auto mit = cluster.members.begin(); mit != cluster.members.end(); ++mit) {
d += distanceL2(cluster.centroid, vectors[(*mit).vectorID]);
}
}
return d / (double)vectors.size();
}
void
setupInitialClusters(std::vector<std::vector<float> >& vectors, size_t numberOfClusters,
std::vector<Cluster>& clusters) {
if (clusters.empty()) {
switch (initializationMode) {
case InitializationModeHead: {
getInitialCentroidsFromHead(vectors, clusters, numberOfClusters);
break;
}
case InitializationModeRandom: {
getInitialCentroidsRandomly(vectors, clusters, numberOfClusters, 0);
break;
}
case InitializationModeKmeansPlusPlus: {
getInitialCentroidsKmeansPlusPlus(vectors, clusters, numberOfClusters);
break;
}
default:
// std::cerr << "proper initMode is not specified." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("proper initMode is not specified.");
exit(1);
}
}
}
bool
kmeans(std::vector<std::vector<float> >& vectors, size_t numberOfClusters, std::vector<Cluster>& clusters) {
setupInitialClusters(vectors, numberOfClusters, clusters);
switch (clusteringType) {
case ClusteringTypeKmeansWithoutNGT:
return kmeansWithoutNGT(vectors, numberOfClusters, clusters);
break;
case ClusteringTypeKmeansWithNGT:
return kmeansWithNGT(vectors, numberOfClusters, clusters);
break;
default:
// cerr << "kmeans::fatal error!. invalid clustering type. " << clusteringType << endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("kmeans::fatal error!. invalid clustering type. " + std::to_string(clusteringType));
abort();
break;
}
}
static void
evaluate(std::vector<std::vector<float> >& vectors, std::vector<Cluster>& clusters, char mode,
std::vector<size_t> centroidIds = std::vector<size_t>()) {
size_t clusterSize = std::numeric_limits<size_t>::max();
assign(vectors, clusters, clusterSize);
// std::cout << "The number of vectors=" << vectors.size() << std::endl;
// std::cout << "The number of centroids=" << clusters.size() << std::endl;
if (centroidIds.size() == 0) {
switch (mode) {
case 'e':
// std::cout << "MSE=" << calculateMSE(vectors, clusters) << std::endl;
break;
case '2':
default:
// std::cout << "ML2=" << calculateML2(vectors, clusters) << std::endl;
break;
}
} else {
switch (mode) {
case 'e':
break;
case '2':
default:
// std::cout << "ML2=" << calculateML2FromSpecifiedCentroids(vectors, clusters, centroidIds)
// << std::endl;
break;
}
}
}
ClusteringType clusteringType;
InitializationMode initializationMode;
size_t numberOfClusters;
bool clusterSizeConstraint;
size_t maximumIteration;
float epsilonFrom;
float epsilonTo;
float epsilonStep;
size_t resultSizeCoefficient;
vector<double> diffHistory;
};
} // namespace NGT

File diff suppressed because it is too large Load Diff

View File

@ -1,127 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Index.h"
namespace NGT {
class Command {
public:
class SearchParameter {
public:
SearchParameter() {
openMode = 'r';
query = "";
querySize = 0;
indexType = 't';
size = 20;
edgeSize = -1;
outputMode = "-";
radius = FLT_MAX;
step = 0;
trial = 1;
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
accuracy = 0.0;
}
SearchParameter(Args &args) { parse(args); }
void parse(Args &args) {
openMode = args.getChar("m", 'r');
try {
query = args.get("#2");
} catch (...) {
NGTThrowException("ngt: Error: Query is not specified");
}
querySize = args.getl("Q", 0);
indexType = args.getChar("i", 't');
size = args.getl("n", 20);
// edgeSize
// -1(default) : using the size which was specified at the index creation.
// 0 : no limitation for the edge size.
// -2('e') : automatically set it according to epsilon.
if (args.getChar("E", '-') == 'e') {
edgeSize = -2;
} else {
edgeSize = args.getl("E", -1);
}
outputMode = args.getString("o", "-");
radius = args.getf("r", FLT_MAX);
trial = args.getl("t", 1);
{
beginOfEpsilon = endOfEpsilon = stepOfEpsilon = 0.1;
std::string epsilon = args.getString("e", "0.1");
std::vector<std::string> tokens;
NGT::Common::tokenize(epsilon, tokens, ":");
if (tokens.size() >= 1) { beginOfEpsilon = endOfEpsilon = NGT::Common::strtod(tokens[0]); }
if (tokens.size() >= 2) { endOfEpsilon = NGT::Common::strtod(tokens[1]); }
if (tokens.size() >= 3) { stepOfEpsilon = NGT::Common::strtod(tokens[2]); }
step = 0;
if (tokens.size() >= 4) { step = NGT::Common::strtol(tokens[3]); }
}
accuracy = args.getf("a", 0.0);
}
char openMode;
std::string query;
size_t querySize;
char indexType;
int size;
long edgeSize;
std::string outputMode;
float radius;
float beginOfEpsilon;
float endOfEpsilon;
float stepOfEpsilon;
float accuracy;
size_t step;
size_t trial;
};
Command():debugLevel(0) {}
void create(Args &args);
void append(Args &args);
static void search(NGT::Index &index, SearchParameter &searchParameter, std::ostream &stream)
{
std::ifstream is(searchParameter.query);
if (!is) {
std::cerr << "Cannot open the specified file. " << searchParameter.query << std::endl;
return;
}
search(index, searchParameter, is, stream);
}
static void search(NGT::Index &index, SearchParameter &searchParameter, std::istream &is, std::ostream &stream);
void search(Args &args);
void remove(Args &args);
void exportIndex(Args &args);
void importIndex(Args &args);
void prune(Args &args);
void reconstructGraph(Args &args);
void optimizeSearchParameters(Args &args);
void optimizeNumberOfEdgesForANNG(Args &args);
void refineANNG(Args &args);
void repair(Args &args);
void info(Args &args);
void setDebugLevel(int level) { debugLevel = level; }
int getDebugLevel() { return debugLevel; }
protected:
int debugLevel;
};
}; // NGT

View File

@ -1,24 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/lib/NGT/Common.h"
#include "NGT/lib/NGT/ObjectSpace.h"
int64_t
NGT::SearchContainer::memSize() {
auto workres_size = workingResult.size() == 0 ? 0 : workingResult.size() * workingResult.top().memSize();
return sizeof(size_t) * 3 + sizeof(float) * 3 + result->memSize() + 1 + workres_size + Container::memSize();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +0,0 @@
#include "NGT/GetCoreNumber.h"
namespace NGT
{
int getCoreNumber()
{
#ifndef __linux__
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
return sysInfo.dwNumberOfProcessors;
#else
return get_nprocs();
#endif
}
}

View File

@ -1,12 +0,0 @@
#ifndef __linux__
# include "windows.h"
#else
# include "sys/sysinfo.h"
# include "unistd.h"
#endif
namespace NGT
{
int getCoreNumber();
}

File diff suppressed because it is too large Load Diff

View File

@ -1,967 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <bitset>
#include <sstream>
#include "NGT/defines.h"
#include "NGT/Common.h"
#include "NGT/ObjectSpaceRepository.h"
#include "faiss/utils/BitsetView.h"
#include "NGT/HashBasedBooleanSet.h"
#ifndef NGT_GRAPH_CHECK_VECTOR
#include <unordered_set>
#endif
#ifdef NGT_GRAPH_UNCHECK_STACK
#include <stack>
#endif
#ifndef NGT_EXPLORATION_COEFFICIENT
#define NGT_EXPLORATION_COEFFICIENT 1.1
#endif
#ifndef NGT_INSERTION_EXPLORATION_COEFFICIENT
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#endif
#ifndef NGT_TRUNCATION_THRESHOLD
#define NGT_TRUNCATION_THRESHOLD 50
#endif
#ifndef NGT_SEED_SIZE
#define NGT_SEED_SIZE 10
#endif
#ifndef NGT_CREATION_EDGE_SIZE
#define NGT_CREATION_EDGE_SIZE 10
#endif
namespace NGT {
class Property;
typedef GraphNode GRAPH_NODE;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class GraphRepository: public PersistentRepository<GRAPH_NODE> {
#else
class GraphRepository: public Repository<GRAPH_NODE> {
#endif
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
typedef PersistentRepository<GRAPH_NODE> VECTOR;
#else
typedef Repository<GRAPH_NODE> VECTOR;
GraphRepository() {
prevsize = new vector<unsigned short>;
}
virtual ~GraphRepository() {
deleteAll();
if (prevsize != 0) {
delete prevsize;
}
}
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &file, size_t sharedMemorySize) {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = (off_t*)allocator.construct(file, sharedMemorySize);
if (entryTable == 0) {
entryTable = (off_t*)construct();
allocator.setEntry(entryTable);
}
assert(entryTable != 0);
this->initialize(entryTable);
}
void *construct() {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = new(allocator) off_t[2];
entryTable[0] = allocator.getOffset(PersistentRepository<GRAPH_NODE>::construct());
entryTable[1] = allocator.getOffset(new(allocator) Vector<unsigned short>);
return entryTable;
}
void initialize(void *e) {
SharedMemoryAllocator &allocator = VECTOR::getAllocator();
off_t *entryTable = (off_t*)e;
array = (ARRAY*)allocator.getAddr(entryTable[0]);
PersistentRepository<GRAPH_NODE>::initialize(allocator.getAddr(entryTable[0]));
prevsize = (Vector<unsigned short>*)allocator.getAddr(entryTable[1]);
}
#endif
void insert(ObjectID id, ObjectDistances &objects) {
GRAPH_NODE *r = allocate();
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*r).copy(objects, VECTOR::getAllocator());
#else
*r = objects;
#endif
try {
put(id, r);
} catch (Exception &exp) {
delete r;
throw exp;
}
if (id >= prevsize->size()) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
prevsize->resize(id + 1, VECTOR::getAllocator(), 0);
#else
prevsize->resize(id + 1, 0);
#endif
} else {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*prevsize).at(id, VECTOR::getAllocator()) = 0;
#else
(*prevsize)[id] = 0;
#endif
}
return;
}
inline GRAPH_NODE *get(ObjectID fid, size_t &minsize) {
GRAPH_NODE *rs = VECTOR::get(fid);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
minsize = (*prevsize).at(fid, VECTOR::getAllocator());
#else
minsize = (*prevsize)[fid];
#endif
return rs;
}
void serialize(std::ofstream &os) {
VECTOR::serialize(os);
Serializer::write(os, *prevsize);
}
// for milvus
void serialize(std::stringstream & grp)
{
VECTOR::serialize(grp);
Serializer::write(grp, *prevsize);
}
void deserialize(std::ifstream &is) {
VECTOR::deserialize(is);
Serializer::read(is, *prevsize);
}
// for milvus
void deserialize(std::stringstream & is)
{
VECTOR::deserialize(is);
Serializer::read(is, *prevsize);
}
void show() {
for (size_t i = 0; i < this->size(); i++) {
std::cout << "Show graph " << i << " ";
if ((*this)[i] == 0) {
std::cout << std::endl;
continue;
}
for (size_t j = 0; j < (*this)[i]->size(); j++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cout << (*this)[i]->at(j, VECTOR::getAllocator()).id << ":" << (*this)[i]->at(j, VECTOR::getAllocator()).distance << " ";
#else
std::cout << (*this)[i]->at(j).id << ":" << (*this)[i]->at(j).distance << " ";
#endif
}
std::cout << std::endl;
}
}
virtual int64_t memSize() {
int64_t ret = prevsize->size() * sizeof(unsigned short);
for (size_t i = 1; i < this->size(); ++ i) {
ret += (*this)[i]->memSize();
}
return ret;
}
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Vector<unsigned short> *prevsize;
#else
std::vector<unsigned short> *prevsize;
#endif
};
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
class ReadOnlyGraphNode : public std::vector<std::pair<uint64_t, PersistentObject*>> {
public:
ReadOnlyGraphNode():reservedSize(0), usedSize(0) {}
void reserve(size_t s) {
reservedSize = ((s & 7) == 0) ? s : (s & 0xFFFFFFFFFFFFFFF8) + 8;
resize(reservedSize);
for (size_t i = (reservedSize & 0xFFFFFFFFFFFFFFF8); i < reservedSize; i++) {
(*this)[i].first = 0;
}
}
void push_back(std::pair<uint32_t, PersistentObject*> node) {
(*this)[usedSize] = node;
usedSize++;
}
size_t size() { return usedSize; }
virtual int64_t memSize() { return reservedSize * (sizeof(uint64_t) + (*this)[0].second->memSize()); }
size_t reservedSize;
size_t usedSize;
};
class SearchGraphRepository : public std::vector<ReadOnlyGraphNode> {
public:
SearchGraphRepository() {}
bool isEmpty(size_t idx) { return (*this)[idx].empty(); }
virtual int64_t memSize() {
int64_t ret = 0;
for (size_t i = 1; i < this->size(); ++ i) {
ret += (*this)[i].memSize();
}
return ret;
}
void deserialize(std::ifstream &is, ObjectRepository &objectRepository) {
if (!is.is_open()) {
NGTThrowException("NGT::SearchGraph: Not open the specified stream yet.");
}
clear();
size_t s;
NGT::Serializer::read(is, s);
resize(s);
for (size_t id = 0; id < s; id++) {
char type;
NGT::Serializer::read(is, type);
switch(type) {
case '-':
break;
case '+':
{
ObjectDistances node;
node.deserialize(is);
ReadOnlyGraphNode &searchNode = at(id);
searchNode.reserve(node.size());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
for (auto ni = node.begin(); ni != node.end(); ni++) {
std::cerr << "not implement" << std::endl;
abort();
}
#else
for (auto ni = node.begin(); ni != node.end(); ni++) {
searchNode.push_back(std::pair<uint32_t, Object*>((*ni).id, objectRepository.get((*ni).id)));
}
#endif
}
break;
default:
{
assert(type == '-' || type == '+');
break;
}
}
}
}
};
#endif // NGT_GRAPH_READ_ONLY_GRAPH
class NeighborhoodGraph {
public:
enum GraphType {
GraphTypeNone = 0,
GraphTypeANNG = 1,
GraphTypeKNNG = 2,
GraphTypeBKNNG = 3,
GraphTypeONNG = 4,
GraphTypeIANNG = 5, // Improved ANNG
GraphTypeDNNG = 6
};
enum SeedType {
SeedTypeNone = 0,
SeedTypeRandomNodes = 1,
SeedTypeFixedNodes = 2,
SeedTypeFirstNode = 3,
SeedTypeAllLeafNodes = 4
};
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
class Search {
public:
static void (*getMethod(NGT::ObjectSpace::DistanceType dtype, NGT::ObjectSpace::ObjectType otype, size_t size))(NGT::NeighborhoodGraph&, NGT::SearchContainer&, NGT::ObjectDistances&) {
if (size < 5000000) {
switch (otype) {
default:
case NGT::ObjectSpace::Float:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloat;
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloat;
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloat;
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloat;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Float;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Float;
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloat;
default: return l2Float;
}
break;
case NGT::ObjectSpace::Uint8:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8;
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8;
default : return l2Uint8;
}
break;
}
return l1Uint8;
} else {
switch (otype) {
default:
case NGT::ObjectSpace::Float:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeNormalizedCosine : return normalizedCosineSimilarityFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeCosine : return cosineSimilarityFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeNormalizedAngle : return normalizedAngleFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeAngle : return angleFloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL2 : return l2FloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL1 : return l1FloatForLargeDataset;
case NGT::ObjectSpace::DistanceTypeSparseJaccard : return sparseJaccardFloatForLargeDataset;
default: return l2FloatForLargeDataset;
}
break;
case NGT::ObjectSpace::Uint8:
switch (dtype) {
case NGT::ObjectSpace::DistanceTypeHamming : return hammingUint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeJaccard : return jaccardUint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL2 : return l2Uint8ForLargeDataset;
case NGT::ObjectSpace::DistanceTypeL1 : return l1Uint8ForLargeDataset;
default : return l2Uint8ForLargeDataset;
}
break;
}
return l1Uint8ForLargeDataset;
}
}
static void l1Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Uint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Float(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void hammingUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void jaccardUint8(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void sparseJaccardFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void cosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void angleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedCosineSimilarityFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedAngleFloat(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2Uint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l1FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void l2FloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void hammingUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void jaccardUint8ForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void sparseJaccardFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void cosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void angleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedCosineSimilarityFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
static void normalizedAngleFloatForLargeDataset(NeighborhoodGraph &graph, NGT::SearchContainer &sc, ObjectDistances &seeds);
};
#endif
class Property {
public:
Property() { setDefault(); }
void setDefault() {
truncationThreshold = 0;
edgeSizeForCreation = NGT_CREATION_EDGE_SIZE;
edgeSizeForSearch = 0;
edgeSizeLimitForCreation = 5;
insertionRadiusCoefficient = NGT_INSERTION_EXPLORATION_COEFFICIENT;
seedSize = NGT_SEED_SIZE;
seedType = SeedTypeNone;
truncationThreadPoolSize = 8;
batchSizeForCreation = 200;
graphType = GraphTypeANNG;
dynamicEdgeSizeBase = 30;
dynamicEdgeSizeRate = 20;
buildTimeLimit = 0.0;
outgoingEdge = 10;
incomingEdge = 80;
}
void clear() {
truncationThreshold = -1;
edgeSizeForCreation = -1;
edgeSizeForSearch = -1;
edgeSizeLimitForCreation = -1;
insertionRadiusCoefficient = -1;
seedSize = -1;
seedType = SeedTypeNone;
truncationThreadPoolSize = -1;
batchSizeForCreation = -1;
graphType = GraphTypeNone;
dynamicEdgeSizeBase = -1;
dynamicEdgeSizeRate = -1;
buildTimeLimit = -1;
outgoingEdge = -1;
incomingEdge = -1;
}
void set(NGT::Property &prop);
void get(NGT::Property &prop);
void exportProperty(NGT::PropertySet &p) {
p.set("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
p.set("EdgeSizeForCreation", edgeSizeForCreation);
p.set("EdgeSizeForSearch", edgeSizeForSearch);
p.set("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
assert(insertionRadiusCoefficient >= 1.0);
p.set("EpsilonForCreation", insertionRadiusCoefficient - 1.0);
p.set("BatchSizeForCreation", batchSizeForCreation);
p.set("SeedSize", seedSize);
p.set("TruncationThreadPoolSize", truncationThreadPoolSize);
p.set("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
p.set("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
p.set("BuildTimeLimit", buildTimeLimit);
p.set("OutgoingEdge", outgoingEdge);
p.set("IncomingEdge", incomingEdge);
switch (graphType) {
case NeighborhoodGraph::GraphTypeKNNG: p.set("GraphType", "KNNG"); break;
case NeighborhoodGraph::GraphTypeANNG: p.set("GraphType", "ANNG"); break;
case NeighborhoodGraph::GraphTypeBKNNG: p.set("GraphType", "BKNNG"); break;
case NeighborhoodGraph::GraphTypeONNG: p.set("GraphType", "ONNG"); break;
case NeighborhoodGraph::GraphTypeIANNG: p.set("GraphType", "IANNG"); break;
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Graph Type." << std::endl; abort();
}
switch (seedType) {
case NeighborhoodGraph::SeedTypeRandomNodes: p.set("SeedType", "RandomNodes"); break;
case NeighborhoodGraph::SeedTypeFixedNodes: p.set("SeedType", "FixedNodes"); break;
case NeighborhoodGraph::SeedTypeFirstNode: p.set("SeedType", "FirstNode"); break;
case NeighborhoodGraph::SeedTypeNone: p.set("SeedType", "None"); break;
case NeighborhoodGraph::SeedTypeAllLeafNodes: p.set("SeedType", "AllLeafNodes"); break;
default: std::cerr << "Graph::exportProperty: Fatal error! Invalid Seed Type." << std::endl; abort();
}
}
void importProperty(NGT::PropertySet &p) {
setDefault();
truncationThreshold = p.getl("IncrimentalEdgeSizeLimitForTruncation", truncationThreshold);
edgeSizeForCreation = p.getl("EdgeSizeForCreation", edgeSizeForCreation);
edgeSizeForSearch = p.getl("EdgeSizeForSearch", edgeSizeForSearch);
edgeSizeLimitForCreation = p.getl("EdgeSizeLimitForCreation", edgeSizeLimitForCreation);
insertionRadiusCoefficient = p.getf("EpsilonForCreation", insertionRadiusCoefficient);
insertionRadiusCoefficient += 1.0;
batchSizeForCreation = p.getl("BatchSizeForCreation", batchSizeForCreation);
seedSize = p.getl("SeedSize", seedSize);
truncationThreadPoolSize = p.getl("TruncationThreadPoolSize", truncationThreadPoolSize);
dynamicEdgeSizeBase = p.getl("DynamicEdgeSizeBase", dynamicEdgeSizeBase);
dynamicEdgeSizeRate = p.getl("DynamicEdgeSizeRate", dynamicEdgeSizeRate);
buildTimeLimit = p.getf("BuildTimeLimit", buildTimeLimit);
outgoingEdge = p.getl("OutgoingEdge", outgoingEdge);
incomingEdge = p.getl("IncomingEdge", incomingEdge);
PropertySet::iterator it = p.find("GraphType");
if (it != p.end()) {
if (it->second == "KNNG") graphType = NeighborhoodGraph::GraphTypeKNNG;
else if (it->second == "ANNG") graphType = NeighborhoodGraph::GraphTypeANNG;
else if (it->second == "BKNNG") graphType = NeighborhoodGraph::GraphTypeBKNNG;
else if (it->second == "ONNG") graphType = NeighborhoodGraph::GraphTypeONNG;
else if (it->second == "IANNG") graphType = NeighborhoodGraph::GraphTypeIANNG;
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Graph Type. " << it->second << std::endl; abort(); }
}
it = p.find("SeedType");
if (it != p.end()) {
if (it->second == "RandomNodes") seedType = NeighborhoodGraph::SeedTypeRandomNodes;
else if (it->second == "FixedNodes") seedType = NeighborhoodGraph::SeedTypeFixedNodes;
else if (it->second == "FirstNode") seedType = NeighborhoodGraph::SeedTypeFirstNode;
else if (it->second == "None") seedType = NeighborhoodGraph::SeedTypeNone;
else if (it->second == "AllLeafNodes") seedType = NeighborhoodGraph::SeedTypeAllLeafNodes;
else { std::cerr << "Graph::importProperty: Fatal error! Invalid Seed Type. " << it->second << std::endl; abort(); }
}
}
friend std::ostream & operator<<(std::ostream& os, const Property& p) {
os << "truncationThreshold=" << p.truncationThreshold << std::endl;
os << "edgeSizeForCreation=" << p.edgeSizeForCreation << std::endl;
os << "edgeSizeForSearch=" << p.edgeSizeForSearch << std::endl;
os << "edgeSizeLimitForCreation=" << p.edgeSizeLimitForCreation << std::endl;
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
os << "insertionRadiusCoefficient=" << p.insertionRadiusCoefficient << std::endl;
os << "seedSize=" << p.seedSize << std::endl;
os << "seedType=" << p.seedType << std::endl;
os << "truncationThreadPoolSize=" << p.truncationThreadPoolSize << std::endl;
os << "batchSizeForCreation=" << p.batchSizeForCreation << std::endl;
os << "graphType=" << p.graphType << std::endl;
os << "dynamicEdgeSizeBase=" << p.dynamicEdgeSizeBase << std::endl;
os << "dynamicEdgeSizeRate=" << p.dynamicEdgeSizeRate << std::endl;
os << "outgoingEdge=" << p.outgoingEdge << std::endl;
os << "incomingEdge=" << p.incomingEdge << std::endl;
return os;
}
int64_t memSize() { return sizeof(*this); }
int16_t truncationThreshold;
int16_t edgeSizeForCreation;
int16_t edgeSizeForSearch;
int16_t edgeSizeLimitForCreation;
double insertionRadiusCoefficient;
int16_t seedSize;
SeedType seedType;
int16_t truncationThreadPoolSize;
int16_t batchSizeForCreation;
GraphType graphType;
int16_t dynamicEdgeSizeBase;
int16_t dynamicEdgeSizeRate;
float buildTimeLimit;
int16_t outgoingEdge;
int16_t incomingEdge;
};
NeighborhoodGraph(): objectSpace(0) {
property.truncationThreshold = NGT_TRUNCATION_THRESHOLD;
// initialize random to generate random seeds
#ifdef NGT_DISABLE_SRAND_FOR_RANDOM
struct timeval randTime;
gettimeofday(&randTime, 0);
srand(randTime.tv_usec);
#endif
}
inline GraphNode *getNode(ObjectID fid, size_t &minsize) { return repository.get(fid, minsize); }
inline GraphNode *getNode(ObjectID fid) { return repository.VECTOR::get(fid); }
void insertNode(ObjectID id, ObjectDistances &objects) {
switch (property.graphType) {
case GraphTypeANNG:
insertANNGNode(id, objects);
break;
case GraphTypeIANNG:
insertIANNGNode(id, objects);
break;
case GraphTypeONNG:
insertONNGNode(id, objects);
break;
case GraphTypeKNNG:
insertKNNGNode(id, objects);
break;
case GraphTypeBKNNG:
insertBKNNGNode(id, objects);
break;
case GraphTypeNone:
NGTThrowException("NGT::insertNode: GraphType is not specified.");
break;
default:
NGTThrowException("NGT::insertNode: GraphType is invalid.");
break;
}
}
void insertBKNNGNode(ObjectID id, ObjectDistances &results) {
if (repository.isEmpty(id)) {
repository.insert(id, results);
} else {
GraphNode &rs = *getNode(id);
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
rs.push_back((*ri), repository.allocator);
#else
rs.push_back((*ri));
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(rs.begin(repository.allocator), rs.end(repository.allocator));
ObjectID prev = 0;
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator);) {
if (prev == (*ri).id) {
ri = rs.erase(ri, repository.allocator);
continue;
}
prev = (*ri).id;
ri++;
}
#else
std::sort(rs.begin(), rs.end());
ObjectID prev = 0;
for (GraphNode::iterator ri = rs.begin(); ri != rs.end();) {
if (prev == (*ri).id) {
ri = rs.erase(ri);
continue;
}
prev = (*ri).id;
ri++;
}
#endif
}
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
addBKNNGEdge((*ri).id, id, (*ri).distance);
}
return;
}
void insertKNNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
}
void insertANNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
std::queue<ObjectID> truncateQueue;
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
if (addEdge((*ri).id, id, (*ri).distance)) {
truncateQueue.push((*ri).id);
}
}
while (!truncateQueue.empty()) {
ObjectID tid = truncateQueue.front();
truncateEdges(tid);
truncateQueue.pop();
}
return;
}
void insertIANNGNode(ObjectID id, ObjectDistances &results) {
repository.insert(id, results);
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++) {
assert(id != (*ri).id);
addEdgeDeletingExcessEdges((*ri).id, id, (*ri).distance);
}
return;
}
void insertONNGNode(ObjectID id, ObjectDistances &results) {
if (property.truncationThreshold != 0) {
std::stringstream msg;
msg << "NGT::insertONNGNode: truncation should be disabled!" << std::endl;
NGTThrowException(msg);
}
int count = 0;
for (ObjectDistances::iterator ri = results.begin(); ri != results.end(); ri++, count++) {
assert(id != (*ri).id);
if (count >= property.incomingEdge) {
break;
}
addEdge((*ri).id, id, (*ri).distance);
}
if (static_cast<int>(results.size()) > property.outgoingEdge) {
results.resize(property.outgoingEdge);
}
repository.insert(id, results);
}
void removeEdgesReliably(ObjectID id);
int truncateEdgesOptimally(ObjectID id, GraphNode &results, size_t truncationSize);
int truncateEdges(ObjectID id) {
GraphNode &results = *getNode(id);
if (results.size() == 0) {
return -1;
}
size_t truncationSize = NGT_TRUNCATION_THRESHOLD;
if (truncationSize < (size_t)property.edgeSizeForCreation) {
truncationSize = property.edgeSizeForCreation;
}
return truncateEdgesOptimally(id, results, truncationSize);
}
// setup edgeSize
inline size_t getEdgeSize(NGT::SearchContainer &sc) {
size_t edgeSize = INT_MAX;
if (sc.edgeSize < 0) {
if (sc.edgeSize == -2) {
double add = pow(10, (sc.explorationCoefficient - 1.0) * static_cast<float>(property.dynamicEdgeSizeRate));
edgeSize = add >= static_cast<double>(INT_MAX) ? INT_MAX : property.dynamicEdgeSizeBase + add;
} else {
edgeSize = property.edgeSizeForSearch == 0 ? INT_MAX : property.edgeSizeForSearch;
}
} else {
edgeSize = sc.edgeSize == 0 ? INT_MAX : sc.edgeSize;
}
return edgeSize;
}
void search(NGT::SearchContainer &sc, ObjectDistances &seeds);
// for milvus
void search(NGT::SearchContainer & sc, ObjectDistances & seeds, const faiss::BitsetView bitset);
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
template <typename COMPARATOR, typename CHECK_LIST> void searchReadOnlyGraph(NGT::SearchContainer &sc, ObjectDistances &seeds);
#endif
void removeEdge(ObjectID fid, ObjectID rmid) {
GraphNode &rs = *getNode(fid);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
for (GraphNode::iterator ri = rs.begin(repository.allocator); ri != rs.end(repository.allocator); ri++) {
if ((*ri).id == rmid) {
rs.erase(ri, repository.allocator);
break;
}
}
#else
for (GraphNode::iterator ri = rs.begin(); ri != rs.end(); ri++) {
if ((*ri).id == rmid) {
rs.erase(ri);
break;
}
}
#endif
}
void removeEdge(GraphNode &node, ObjectDistance &edge) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), edge);
if (ni != node.end(repository.allocator) && (*ni).id == edge.id) {
node.erase(ni, repository.allocator);
#else
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge);
if (ni != node.end() && (*ni).id == edge.id) {
node.erase(ni);
#endif
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (ni == node.end(repository.allocator)) {
#else
if (ni == node.end()) {
#endif
std::stringstream msg;
msg << "NGT::removeEdge: Cannot found " << edge.id;
NGTThrowException(msg);
} else {
std::stringstream msg;
msg << "NGT::removeEdge: Cannot found " << (*ni).id << ":" << edge.id;
NGTThrowException(msg);
}
}
void
removeNode(ObjectID id) {
repository.erase(id);
}
class BooleanVector : public std::vector<bool> {
public:
inline BooleanVector(size_t s):std::vector<bool>(s, false) {}
inline void insert(size_t i) { std::vector<bool>::operator[](i) = true; }
};
#ifdef NGT_GRAPH_VECTOR_RESULT
typedef ObjectDistances ResultSet;
#else
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
#endif
#if defined(NGT_GRAPH_CHECK_BOOLEANSET)
typedef BooleanSet DistanceCheckedSet;
#elif defined(NGT_GRAPH_CHECK_VECTOR)
typedef BooleanVector DistanceCheckedSet;
#elif defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
typedef HashBasedBooleanSet DistanceCheckedSet;
#else
class DistanceCheckedSet : public unordered_set<ObjectID> {
public:
bool operator[](ObjectID id) { return find(id) != end(); }
};
#endif
typedef HashBasedBooleanSet DistanceCheckedSetForLargeDataset;
class NodeWithPosition : public ObjectDistance {
public:
NodeWithPosition(uint32_t p = 0):position(p){}
NodeWithPosition(ObjectDistance &o):ObjectDistance(o), position(0){}
NodeWithPosition &operator=(const NodeWithPosition &n) {
ObjectDistance::operator=(static_cast<const ObjectDistance&>(n));
position = n.position;
assert(id != 0);
return *this;
}
uint32_t position;
};
#ifdef NGT_GRAPH_UNCHECK_STACK
typedef std::stack<ObjectDistance> UncheckedSet;
#else
#ifdef NGT_GRAPH_BETTER_FIRST_RESTORE
typedef std::priority_queue<NodeWithPosition, std::vector<NodeWithPosition>, std::greater<NodeWithPosition> > UncheckedSet;
#else
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::greater<ObjectDistance> > UncheckedSet;
#endif
#endif
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds);
void setupDistances(NGT::SearchContainer &sc, ObjectDistances &seeds, double (&comparator)(const void*, const void*, size_t));
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
UncheckedSet &unchecked, DistanceCheckedSet &distanceChecked);
#if !defined(NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET)
void setupSeeds(SearchContainer &sc, ObjectDistances &seeds, ResultSet &results,
UncheckedSet &unchecked, DistanceCheckedSetForLargeDataset &distanceChecked);
#endif
int getEdgeSize() {return property.edgeSizeForCreation;}
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
ObjectSpace &getObjectSpace() { return *objectSpace; }
void deleteInMemory() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
assert(0);
#else
for (std::vector<NGT::GraphNode*>::iterator i = repository.begin(); i != repository.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
repository.clear();
#endif
}
static double (*getComparator())(const void*, const void*, size_t);
protected:
void
addBKNNGEdge(ObjectID target, ObjectID addID, Distance addDistance) {
if (repository.isEmpty(target)) {
ObjectDistances objs;
objs.push_back(ObjectDistance(addID, addDistance));
repository.insert(target, objs);
return;
}
addEdge(target, addID, addDistance, false);
}
public:
void addEdge(GraphNode &node, ObjectID addID, Distance addDistance, bool identityCheck = true) {
ObjectDistance obj(addID, addDistance);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
GraphNode::iterator ni = std::lower_bound(node.begin(repository.allocator), node.end(repository.allocator), obj);
if ((ni != node.end(repository.allocator)) && ((*ni).id == addID)) {
if (identityCheck) {
std::stringstream msg;
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
NGTThrowException(msg);
}
return;
}
#else
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), obj);
if ((ni != node.end()) && ((*ni).id == addID)) {
if (identityCheck) {
std::stringstream msg;
msg << "NGT::addEdge: already existed! " << (*ni).id << ":" << addID;
NGTThrowException(msg);
}
return;
}
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.insert(ni, obj, repository.allocator);
#else
node.insert(ni, obj);
#endif
}
// identityCheck is checking whether the same edge has already added to the node.
// return whether truncation is needed that means the node has too many edges.
bool addEdge(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
size_t minsize = 0;
GraphNode &node = property.truncationThreshold == 0 ? *getNode(target) : *getNode(target, minsize);
addEdge(node, addID, addDistance, identityCheck);
if ((size_t)property.truncationThreshold != 0 && node.size() - minsize >
(size_t)property.truncationThreshold) {
return true;
}
return false;
}
void addEdgeDeletingExcessEdges(ObjectID target, ObjectID addID, Distance addDistance, bool identityCheck = true) {
GraphNode &node = *getNode(target);
size_t kEdge = property.edgeSizeForCreation - 1;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
if (node.size() > kEdge && node.at(kEdge, repository.allocator).distance >= addDistance) {
GraphNode &linkedNode = *getNode(node.at(kEdge, repository.allocator).id);
ObjectDistance linkedNodeEdge(target, node.at(kEdge, repository.allocator).distance);
if ((linkedNode.size() > kEdge) && node.at(kEdge, repository.allocator).distance >=
linkedNode.at(kEdge, repository.allocator).distance) {
#else
if (node.size() > kEdge && node[kEdge].distance >= addDistance) {
GraphNode &linkedNode = *getNode(node[kEdge].id);
ObjectDistance linkedNodeEdge(target, node[kEdge].distance);
if ((linkedNode.size() > kEdge) && node[kEdge].distance >= linkedNode[kEdge].distance) {
#endif
try {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
removeEdge(node, node.at(kEdge, repository.allocator));
#else
removeEdge(node, node[kEdge]);
#endif
} catch (Exception &exp) {
std::stringstream msg;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
#else
msg << "addEdge: Cannot remove. (a) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
#endif
msg << ":" << exp.what();
NGTThrowException(msg.str());
}
try {
removeEdge(linkedNode, linkedNodeEdge);
} catch (Exception &exp) {
std::stringstream msg;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node.at(kEdge, repository.allocator).id << "," << node.at(kEdge, repository.allocator).distance;
#else
msg << "addEdge: Cannot remove. (b) " << target << "," << addID << "," << node[kEdge].id << "," << node[kEdge].distance;
#endif
msg << ":" << exp.what();
NGTThrowException(msg.str());
}
}
}
addEdge(node, addID, addDistance, identityCheck);
}
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
void loadSearchGraph(const std::string &database) {
std::ifstream isg(database + "/grp");
NeighborhoodGraph::searchRepository.deserialize(isg, NeighborhoodGraph::getObjectRepository());
}
#endif
public:
virtual int64_t memSize() { return repository.memSize() + searchRepository.memSize() + property.memSize() + objectSpace->memSize(); }
GraphRepository repository;
ObjectSpace *objectSpace;
#ifdef NGT_GRAPH_READ_ONLY_GRAPH
SearchGraphRepository searchRepository;
#endif
NeighborhoodGraph::Property property;
}; // NeighborhoodGraph
} // NGT

View File

@ -1,823 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "GraphReconstructor.h"
#include "Optimizer.h"
namespace NGT {
class GraphOptimizer {
public:
class ANNGEdgeOptimizationParameter {
public:
ANNGEdgeOptimizationParameter() {
initialize();
}
void initialize() {
noOfQueries = 200;
noOfResults = 50;
noOfThreads = 16;
targetAccuracy = 0.9; // when epsilon is 0.0 and all of the edges are used
targetNoOfObjects = 0;
noOfSampleObjects = 100000;
maxNoOfEdges = 100;
}
size_t noOfQueries;
size_t noOfResults;
size_t noOfThreads;
float targetAccuracy;
size_t targetNoOfObjects;
size_t noOfSampleObjects;
size_t maxNoOfEdges;
};
GraphOptimizer(bool unlog = false) {
init();
logDisabled = unlog;
}
GraphOptimizer(int outgoing, int incoming, int nofqs, int nofrs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m,
bool unlog // stderr log is disabled.
) {
init();
set(outgoing, incoming, nofqs, nofrs, baseAccuracyFrom, baseAccuracyTo,
rateAccuracyFrom, rateAccuracyTo, gte, m);
logDisabled = unlog;
}
void init() {
numOfOutgoingEdges = 10;
numOfIncomingEdges= 120;
numOfQueries = 100;
numOfResults = 20;
baseAccuracyRange = std::pair<float, float>(0.30, 0.50);
rateAccuracyRange = std::pair<float, float>(0.80, 0.90);
gtEpsilon = 0.1;
margin = 0.2;
logDisabled = false;
shortcutReduction = true;
searchParameterOptimization = true;
prefetchParameterOptimization = true;
accuracyTableGeneration = true;
}
void adjustSearchCoefficients(const std::string indexPath){
NGT::Index index(indexPath);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::Optimizer optimizer(index);
if (logDisabled) {
optimizer.disableLog();
} else {
optimizer.enableLog();
}
try {
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
} catch(NGT::Exception &err) {
std::stringstream msg;
msg << "Optimizer::adjustSearchCoefficients: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
graph.saveIndex(indexPath);
}
static double measureQueryTime(NGT::Index &index, size_t start) {
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
NGT::ObjectRepository &objectRepository = objectSpace.getRepository();
size_t nQueries = 200;
nQueries = objectRepository.size() - 1 < nQueries ? objectRepository.size() - 1 : nQueries;
size_t step = objectRepository.size() / nQueries;
assert(step != 0);
std::vector<size_t> ids;
for (size_t startID = start; startID < step; startID++) {
for (size_t id = startID; id < objectRepository.size(); id += step) {
if (!objectRepository.isEmpty(id)) {
ids.push_back(id);
}
}
if (ids.size() >= nQueries) {
ids.resize(nQueries);
break;
}
}
if (nQueries > ids.size()) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of Queries is not enough.");
// std::cerr << "# of Queries is not enough." << std::endl;
return DBL_MAX;
}
NGT::Timer timer;
timer.reset();
for (auto id = ids.begin(); id != ids.end(); id++) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
NGT::Object *obj = objectSpace.allocateObject(*objectRepository.get(*id));
NGT::SearchContainer searchContainer(*obj);
#else
NGT::SearchContainer searchContainer(*objectRepository.get(*id));
#endif
NGT::ObjectDistances objects;
searchContainer.setResults(&objects);
searchContainer.setSize(10);
searchContainer.setEpsilon(0.1);
timer.restart();
index.search(searchContainer);
timer.stop();
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectSpace.deleteObject(obj);
#endif
}
return timer.time * 1000.0;
}
static std::pair<size_t, double> searchMinimumQueryTime(NGT::Index &index, size_t prefetchOffset,
int maxPrefetchSize, size_t seedID) {
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
int step = 256;
int prevPrefetchSize = 64;
size_t minPrefetchSize = 0;
double minTime = DBL_MAX;
for (step = 256; step != 32; step /= 2) {
double prevTime = DBL_MAX;
for (int prefetchSize = prevPrefetchSize - step < 64 ? 64 : prevPrefetchSize - step; prefetchSize <= maxPrefetchSize; prefetchSize += step) {
objectSpace.setPrefetchOffset(prefetchOffset);
objectSpace.setPrefetchSize(prefetchSize);
double time = measureQueryTime(index, seedID);
if (prevTime < time) {
break;
}
prevTime = time;
prevPrefetchSize = prefetchSize;
}
if (minTime > prevTime) {
minTime = prevTime;
minPrefetchSize = prevPrefetchSize;
}
}
return std::make_pair(minPrefetchSize, minTime);
}
static std::pair<size_t, size_t> adjustPrefetchParameters(NGT::Index &index) {
bool gridSearch = false;
{
double time = measureQueryTime(index, 1);
if (time < 500.0) {
gridSearch = true;
}
}
size_t prefetchOffset = 0;
size_t prefetchSize = 0;
std::vector<std::pair<size_t, size_t>> mins;
NGT::ObjectSpace &objectSpace = index.getObjectSpace();
int maxSize = objectSpace.getByteSizeOfObject() * 4;
maxSize = maxSize < 64 * 28 ? maxSize : 64 * 28;
for (int trial = 0; trial < 10; trial++) {
size_t minps = 0;
size_t minpo = 0;
if (gridSearch) {
double minTime = DBL_MAX;
for (size_t po = 1; po <= 10; po++) {
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
if (minTime > min.second) {
minTime = min.second;
minps = min.first;
minpo = po;
}
}
} else {
double prevTime = DBL_MAX;
for (size_t po = 1; po <= 10; po++) {
auto min = searchMinimumQueryTime(index, po, maxSize, trial + 1);
if (prevTime < min.second) {
break;
}
prevTime = min.second;
minps = min.first;
minpo = po;
}
}
if (std::find(mins.begin(), mins.end(), std::make_pair(minpo, minps)) != mins.end()) {
prefetchOffset = minpo;
prefetchSize = minps;
mins.push_back(std::make_pair(minpo, minps));
break;
}
mins.push_back(std::make_pair(minpo, minps));
}
return std::make_pair(prefetchOffset, prefetchSize);
}
void execute(NGT::Index & index_)
{
NGT::GraphIndex & graphIndex = static_cast<NGT::GraphIndex &>(index_.getIndex());
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0)
{
if (!logDisabled)
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property & prop = graphIndex.getGraphProperty();
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG)
{
NGT::GraphReconstructor::convertToANNG(graph);
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
}
catch (NGT::Exception & err)
{
throw(err);
}
}
if (shortcutReduction)
{
if (!logDisabled)
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try
{
NGT::Timer timer;
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Path adjustment time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Path adjustment time=" << timer.time << " (sec) " << std::endl;
}
catch (NGT::Exception & err)
{
throw(err);
}
}
}
void optimizeSearchParameters(NGT::Index & outIndex)
{
if (searchParameterOptimization)
{
if (!logDisabled)
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(outIndex.getIndex());
NGT::Optimizer optimizer(outIndex);
if (logDisabled)
{
optimizer.disableLog();
}
else
{
optimizer.enableLog();
}
try
{
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property & prop = outGraph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
}
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration)
{
// NGT::GraphIndex & outGraph = static_cast<NGT::GraphIndex &>(*outIndex.getIndex());
if (prefetchParameterOptimization)
{
if (!logDisabled)
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try
{
auto prefetch = adjustPrefetchParameters(outIndex);
NGT::Property prop;
outIndex.getProperty(prop);
prop.prefetchOffset = prefetch.first;
prop.prefetchSize = prefetch.second;
outIndex.setProperty(prop);
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
NGTThrowException(msg);
}
}
if (accuracyTableGeneration)
{
if (!logDisabled)
{
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try
{
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
NGT::Index::AccuracyTable accuracyTable(table);
NGT::Property prop;
outIndex.getProperty(prop);
prop.accuracyTable = accuracyTable.getString();
outIndex.setProperty(prop);
}
catch (NGT::Exception & err)
{
std::stringstream msg;
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
NGTThrowException(msg);
}
}
}
}
void execute(
const std::string inIndexPath,
const std::string outIndexPath
){
if (access(outIndexPath.c_str(), 0) == 0) {
std::stringstream msg;
msg << "Optimizer::execute: The specified index exists. " << outIndexPath;
NGTThrowException(msg);
}
const std::string com = "cp -r " + inIndexPath + " " + outIndexPath;
int stat = system(com.c_str());
if (stat != 0) {
std::stringstream msg;
msg << "Optimizer::execute: Cannot create the specified index. " << outIndexPath;
NGTThrowException(msg);
}
{
NGT::StdOstreamRedirector redirector(logDisabled);
NGT::GraphIndex graphIndex(outIndexPath, false);
if (numOfOutgoingEdges > 0 || numOfIncomingEdges > 0) {
if (!logDisabled) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: adjusting outgoing and incoming edges...");
// std::cerr << "GraphOptimizer: adjusting outgoing and incoming edges..." << std::endl;
}
redirector.begin();
NGT::Timer timer;
timer.start();
std::vector<NGT::ObjectDistances> graph;
try {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Extract the graph data.");
// std::cerr << "Optimizer::execute: Extract the graph data." << std::endl;
// extract only edges from the index to reduce the memory usage.
NGT::GraphReconstructor::extractGraph(graph, graphIndex);
NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
if (prop.graphType != NGT::NeighborhoodGraph::GraphTypeANNG) {
NGT::GraphReconstructor::convertToANNG(graph);
}
NGT::GraphReconstructor::reconstructGraph(graph, graphIndex, numOfOutgoingEdges, numOfIncomingEdges);
timer.stop();
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Optimizer::execute: Graph reconstruction time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "Optimizer::execute: Graph reconstruction time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
graphIndex.saveProperty(outIndexPath);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
if (shortcutReduction) {
if (!logDisabled) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: redusing shortcut edges...");
// std::cerr << "GraphOptimizer: redusing shortcut edges..." << std::endl;
}
try {
NGT::Timer timer;
timer.start();
NGT::GraphReconstructor::adjustPathsEffectively(graphIndex);
timer.stop();
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("optimizer::execute: path adjustment time=" + std::to_string(timer.time) + " (sec) ");
// std::cerr << "optimizer::execute: path adjustment time=" << timer.time << " (sec) " << std::endl;
graphIndex.saveGraph(outIndexPath);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
redirector.end();
}
optimizeSearchParameters(outIndexPath);
}
void optimizeSearchParameters(const std::string outIndexPath)
{
if (searchParameterOptimization) {
if (!logDisabled) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing search parameters...");
// std::cerr << "GraphOptimizer: optimizing search parameters..." << std::endl;
}
NGT::Index outIndex(outIndexPath);
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
NGT::Optimizer optimizer(outIndex);
if (logDisabled) {
optimizer.disableLog();
} else {
optimizer.enableLog();
}
try {
auto coefficients = optimizer.adjustSearchEdgeSize(baseAccuracyRange, rateAccuracyRange, numOfQueries, gtEpsilon, margin);
NGT::NeighborhoodGraph::Property &prop = outGraph.getGraphProperty();
prop.dynamicEdgeSizeBase = coefficients.first;
prop.dynamicEdgeSizeRate = coefficients.second;
prop.edgeSizeForSearch = -2;
outGraph.saveProperty(outIndexPath);
} catch(NGT::Exception &err) {
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust the search coefficients. " << err.what();
NGTThrowException(msg);
}
}
if (searchParameterOptimization || prefetchParameterOptimization || accuracyTableGeneration) {
NGT::StdOstreamRedirector redirector(logDisabled);
redirector.begin();
NGT::Index outIndex(outIndexPath, true);
NGT::GraphIndex &outGraph = static_cast<NGT::GraphIndex&>(outIndex.getIndex());
if (prefetchParameterOptimization) {
if (!logDisabled) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: optimizing prefetch parameters...");
// std::cerr << "GraphOptimizer: optimizing prefetch parameters..." << std::endl;
}
try {
auto prefetch = adjustPrefetchParameters(outIndex);
NGT::Property prop;
outIndex.getProperty(prop);
prop.prefetchOffset = prefetch.first;
prop.prefetchSize = prefetch.second;
outIndex.setProperty(prop);
outGraph.saveProperty(outIndexPath);
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot adjust prefetch parameters. " << err.what();
NGTThrowException(msg);
}
}
if (accuracyTableGeneration) {
if (!logDisabled) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphOptimizer: generating the accuracy table...");
// std::cerr << "GraphOptimizer: generating the accuracy table..." << std::endl;
}
try {
auto table = NGT::Optimizer::generateAccuracyTable(outIndex, numOfResults, numOfQueries);
NGT::Index::AccuracyTable accuracyTable(table);
NGT::Property prop;
outIndex.getProperty(prop);
prop.accuracyTable = accuracyTable.getString();
outIndex.setProperty(prop);
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot generate the accuracy table. " << err.what();
NGTThrowException(msg);
}
}
try {
outGraph.saveProperty(outIndexPath);
redirector.end();
} catch(NGT::Exception &err) {
redirector.end();
std::stringstream msg;
msg << "Optimizer::execute: Cannot save the index. " << outIndexPath << err.what();
NGTThrowException(msg);
}
}
}
static std::tuple<size_t, double, double> // optimized # of edges, accuracy, accuracy gain per edge
optimizeNumberOfEdgesForANNG(NGT::Optimizer &optimizer, std::vector<std::vector<float>> &queries,
size_t nOfResults, float targetAccuracy, size_t maxNoOfEdges) {
NGT::Index &index = optimizer.index;
std::stringstream queryStream;
std::stringstream gtStream;
float maxEpsilon = 0.0;
optimizer.generatePseudoGroundTruth(queries, maxEpsilon, queryStream, gtStream);
size_t nOfEdges = 0;
double accuracy = 0.0;
size_t prevEdge = 0;
double prevAccuracy = 0.0;
double gain = 0.0;
{
std::vector<NGT::ObjectDistances> graph;
NGT::GraphReconstructor::extractGraph(graph, static_cast<NGT::GraphIndex&>(index.getIndex()));
float epsilon = 0.0;
for (size_t edgeSize = 5; edgeSize <= maxNoOfEdges; edgeSize += (edgeSize >= 10 ? 10 : 5) ) {
NGT::GraphReconstructor::reconstructANNGFromANNG(graph, index, edgeSize);
NGT::Command::SearchParameter searchParameter;
searchParameter.size = nOfResults;
searchParameter.outputMode = 'e';
searchParameter.edgeSize = 0;
searchParameter.beginOfEpsilon = searchParameter.endOfEpsilon = epsilon;
queryStream.clear();
queryStream.seekg(0, std::ios_base::beg);
std::vector<NGT::Optimizer::MeasuredValue> acc;
NGT::Optimizer::search(index, queryStream, gtStream, searchParameter, acc);
if (acc.size() == 0) {
NGTThrowException("Fatal error! Cannot get any accuracy value.");
}
accuracy = acc[0].meanAccuracy;
nOfEdges = edgeSize;
if (prevEdge != 0) {
gain = (accuracy - prevAccuracy) / (edgeSize - prevEdge);
}
if (accuracy >= targetAccuracy) {
break;
}
prevEdge = edgeSize;
prevAccuracy = accuracy;
}
}
return std::make_tuple(nOfEdges, accuracy, gain);
}
static std::pair<size_t, float>
optimizeNumberOfEdgesForANNG(NGT::Index &index, ANNGEdgeOptimizationParameter &parameter)
{
if (parameter.targetNoOfObjects == 0) {
parameter.targetNoOfObjects = index.getObjectRepositorySize();
}
NGT::Optimizer optimizer(index, parameter.noOfResults);
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
NGT::GraphIndex &graphIndex = static_cast<NGT::GraphIndex&>(index.getIndex());
NGT::GraphAndTreeIndex &treeIndex = static_cast<NGT::GraphAndTreeIndex&>(index.getIndex());
NGT::GraphRepository &graphRepository = graphIndex.NeighborhoodGraph::repository;
//float targetAccuracy = parameter.targetAccuracy + FLT_EPSILON;
std::vector<std::vector<float>> queries;
optimizer.extractAndRemoveRandomQueries(parameter.noOfQueries, queries);
{
graphRepository.deleteAll();
treeIndex.DVPTree::deleteAll();
treeIndex.DVPTree::insertNode(treeIndex.DVPTree::leafNodes.allocate());
}
NGT::NeighborhoodGraph::Property &prop = graphIndex.getGraphProperty();
prop.edgeSizeForCreation = parameter.maxNoOfEdges;
std::vector<std::pair<size_t, std::tuple<size_t, double, double>>> transition;
size_t targetNo = 12500;
for (;targetNo <= objectRepository.size() && targetNo <= parameter.noOfSampleObjects;
targetNo *= 2) {
ObjectID id = 0;
size_t noOfObjects = 0;
for (id = 1; id < objectRepository.size(); ++id) {
if (!objectRepository.isEmpty(id)) {
noOfObjects++;
}
if (noOfObjects >= targetNo) {
break;
}
}
id++;
index.createIndex(parameter.noOfThreads, id);
auto edge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(optimizer, queries, parameter.noOfResults, parameter.targetAccuracy, parameter.maxNoOfEdges);
transition.push_back(make_pair(noOfObjects, edge));
}
if (transition.size() < 2) {
std::stringstream msg;
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. Too small object set. # of objects=" << objectRepository.size() << " target No.=" << targetNo;
NGTThrowException(msg);
}
double edgeRate = 0.0;
double accuracyRate = 0.0;
for (auto i = transition.begin(); i != transition.end() - 1; ++i) {
edgeRate += std::get<0>((*(i + 1)).second) - std::get<0>((*i).second);
accuracyRate += std::get<1>((*(i + 1)).second) - std::get<1>((*i).second);
}
edgeRate /= (transition.size() - 1);
accuracyRate /= (transition.size() - 1);
size_t estimatedEdge = std::get<0>(transition[0].second) +
edgeRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
float estimatedAccuracy = std::get<1>(transition[0].second) +
accuracyRate * (log2(parameter.targetNoOfObjects) - log2(transition[0].first));
if (estimatedAccuracy < parameter.targetAccuracy) {
estimatedEdge += (parameter.targetAccuracy - estimatedAccuracy) / std::get<2>(transition.back().second);
estimatedAccuracy = parameter.targetAccuracy;
}
if (estimatedEdge == 0) {
std::stringstream msg;
msg << "Optimizer::optimizeNumberOfEdgesForANNG: Cannot optimize the number of edges. "
<< estimatedEdge << ":" << estimatedAccuracy << " # of objects=" << objectRepository.size();
NGTThrowException(msg);
}
return std::make_pair(estimatedEdge, estimatedAccuracy);
}
std::pair<size_t, float>
optimizeNumberOfEdgesForANNG(const std::string indexPath, GraphOptimizer::ANNGEdgeOptimizationParameter &parameter) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
NGTThrowException("Not implemented for NGT with the shared memory option.");
#endif
NGT::StdOstreamRedirector redirector(logDisabled);
redirector.begin();
try {
NGT::Index index(indexPath, false);
auto optimizedEdge = NGT::GraphOptimizer::optimizeNumberOfEdgesForANNG(index, parameter);
NGT::GraphIndex &graph = static_cast<NGT::GraphIndex&>(index.getIndex());
size_t noOfEdges = (optimizedEdge.first + 10) / 5 * 5;
if (noOfEdges > parameter.maxNoOfEdges) {
noOfEdges = parameter.maxNoOfEdges;
}
NGT::NeighborhoodGraph::Property &prop = graph.getGraphProperty();
prop.edgeSizeForCreation = noOfEdges;
static_cast<NGT::GraphIndex&>(index.getIndex()).saveProperty(indexPath);
optimizedEdge.first = noOfEdges;
redirector.end();
return optimizedEdge;
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
void set(int outgoing, int incoming, int nofqs, int nofrs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
set(outgoing, incoming, nofqs, nofrs);
setExtension(baseAccuracyFrom, baseAccuracyTo, rateAccuracyFrom, rateAccuracyTo, gte, m);
}
void set(int outgoing, int incoming, int nofqs, int nofrs) {
if (outgoing >= 0) {
numOfOutgoingEdges = outgoing;
}
if (incoming >= 0) {
numOfIncomingEdges = incoming;
}
if (nofqs > 0) {
numOfQueries = nofqs;
}
if (nofrs > 0) {
numOfResults = nofrs;
}
}
void setExtension(float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
if (baseAccuracyFrom > 0.0) {
baseAccuracyRange.first = baseAccuracyFrom;
}
if (baseAccuracyTo > 0.0) {
baseAccuracyRange.second = baseAccuracyTo;
}
if (rateAccuracyFrom > 0.0) {
rateAccuracyRange.first = rateAccuracyFrom;
}
if (rateAccuracyTo > 0.0) {
rateAccuracyRange.second = rateAccuracyTo;
}
if (gte >= -1.0) {
gtEpsilon = gte;
}
if (m > 0.0) {
margin = m;
}
}
// obsolete because of a lack of a parameter
void set(int outgoing, int incoming, int nofqs,
float baseAccuracyFrom, float baseAccuracyTo,
float rateAccuracyFrom, float rateAccuracyTo,
double gte, double m
) {
if (outgoing >= 0) {
numOfOutgoingEdges = outgoing;
}
if (incoming >= 0) {
numOfIncomingEdges = incoming;
}
if (nofqs > 0) {
numOfQueries = nofqs;
}
if (baseAccuracyFrom > 0.0) {
baseAccuracyRange.first = baseAccuracyFrom;
}
if (baseAccuracyTo > 0.0) {
baseAccuracyRange.second = baseAccuracyTo;
}
if (rateAccuracyFrom > 0.0) {
rateAccuracyRange.first = rateAccuracyFrom;
}
if (rateAccuracyTo > 0.0) {
rateAccuracyRange.second = rateAccuracyTo;
}
if (gte >= -1.0) {
gtEpsilon = gte;
}
if (m > 0.0) {
margin = m;
}
}
void setProcessingModes(bool shortcut = true, bool searchParameter = true, bool prefetchParameter = true,
bool accuracyTable = true) {
shortcutReduction = shortcut;
searchParameterOptimization = searchParameter;
prefetchParameterOptimization = prefetchParameter;
accuracyTableGeneration = accuracyTable;
}
size_t numOfOutgoingEdges;
size_t numOfIncomingEdges;
std::pair<float, float> baseAccuracyRange;
std::pair<float, float> rateAccuracyRange;
size_t numOfQueries;
size_t numOfResults;
double gtEpsilon;
double margin;
bool logDisabled;
bool shortcutReduction;
bool searchParameterOptimization;
bool prefetchParameterOptimization;
bool accuracyTableGeneration;
};
}; // NGT

View File

@ -1,988 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <unordered_map>
#include <unordered_set>
#include <list>
#include "defines.h"
#ifdef _OPENMP
#include <omp.h>
#else
#warning "*** OMP is *NOT* available! ***"
#endif
namespace NGT {
class GraphReconstructor {
public:
static void extractGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &graphIndex) {
graph.reserve(graphIndex.repository.size());
for (size_t id = 1; id < graphIndex.repository.size(); id++) {
if (id % 1000000 == 0) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Processed " + std::to_string(id) + " objects.");
// std::cerr << "GraphReconstructor::extractGraph: Processed " << id << " objects." << std::endl;
}
try {
NGT::GraphNode &node = *graphIndex.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistances nd;
nd.reserve(node.size());
for (auto n = node.begin(graphIndex.repository.allocator); n != node.end(graphIndex.repository.allocator); ++n) {
nd.push_back(ObjectDistance((*n).id, (*n).distance));
}
graph.push_back(nd);
#else
graph.push_back(node);
#endif
if (graph.back().size() != graph.back().capacity()) {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " + std::to_string(id));
// std::cerr << "GraphReconstructor::extractGraph: Warning! The graph size must be the same as the capacity. " << id << std::endl;
}
} catch(NGT::Exception &err) {
graph.push_back(NGT::ObjectDistances());
continue;
}
}
}
static void
adjustPaths(NGT::Index &outIndex)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("construct index is not implemented.");
// std::cerr << "construct index is not implemented." << std::endl;
exit(1);
#else
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
size_t rStartRank = 0;
std::list<std::pair<size_t, NGT::GraphNode> > tmpGraph;
for (size_t id = 1; id < outGraph.repository.size(); id++) {
NGT::GraphNode &node = *outGraph.getNode(id);
tmpGraph.push_back(std::pair<size_t, NGT::GraphNode>(id, node));
if (node.size() > rStartRank) {
node.resize(rStartRank);
}
}
size_t removeCount = 0;
for (size_t rank = rStartRank; ; rank++) {
bool edge = false;
Timer timer;
for (auto it = tmpGraph.begin(); it != tmpGraph.end();) {
size_t id = (*it).first;
try {
NGT::GraphNode &node = (*it).second;
if (rank >= node.size()) {
it = tmpGraph.erase(it);
continue;
}
edge = true;
if (rank >= 1 && node[rank - 1].distance > node[rank].distance) {
// std::cerr << "distance order is wrong!" << std::endl;
// std::cerr << id << ":" << rank << ":" << node[rank - 1].id << ":" << node[rank].id << std::endl;
if (NGT_LOG_DEBUG_) {
(*NGT_LOG_DEBUG_)("distance order is wrong!");
(*NGT_LOG_DEBUG_)(std::to_string(id) + ":" + std::to_string(rank) + ":" + std::to_string(node[rank - 1].id) + ":" + std::to_string(node[rank].id));
}
}
NGT::GraphNode &tn = *outGraph.getNode(id);
volatile bool found = false;
if (rank < 1000) {
for (size_t tni = 0; tni < tn.size() && !found; tni++) {
if (tn[tni].id == node[rank].id) {
continue;
}
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
for (size_t dni = 0; dni < dstNode.size(); dni++) {
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
found = true;
break;
}
}
}
} else {
#ifdef _OPENMP
#pragma omp parallel for num_threads(10)
#endif
for (size_t tni = 0; tni < tn.size(); tni++) {
if (found) {
continue;
}
if (tn[tni].id == node[rank].id) {
continue;
}
NGT::GraphNode &dstNode = *outGraph.getNode(tn[tni].id);
for (size_t dni = 0; dni < dstNode.size(); dni++) {
if ((dstNode[dni].id == node[rank].id) && (dstNode[dni].distance < node[rank].distance)) {
found = true;
}
}
}
}
if (!found) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
outGraph.addEdge(id, node.at(i, outGraph.repository.allocator).id,
node.at(i, outGraph.repository.allocator).distance, true);
#else
tn.push_back(NGT::ObjectDistance(node[rank].id, node[rank].distance));
#endif
} else {
removeCount++;
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
it++;
continue;
}
it++;
}
if (edge == false) {
break;
}
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
}
static void
adjustPathsEffectively(NGT::Index &outIndex)
{
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(outIndex.getIndex());
adjustPathsEffectively(outGraph);
}
static bool edgeComp(NGT::ObjectDistance a, NGT::ObjectDistance b) {
return a.id < b.id;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance, NGT::GraphIndex &graph) {
NGT::ObjectDistance edge(edgeID, edgeDistance);
GraphNode::iterator ni = std::lower_bound(node.begin(graph.repository.allocator), node.end(graph.repository.allocator), edge, edgeComp);
node.insert(ni, edge, graph.repository.allocator);
}
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
{
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
GraphNode::iterator ni = std::lower_bound(srcNode.begin(graph.repository.allocator), srcNode.end(graph.repository.allocator), ObjectDistance(dstNodeID, 0.0), edgeComp);
return (ni != srcNode.end(graph.repository.allocator)) && ((*ni).id == dstNodeID);
}
#else
static void insert(NGT::GraphNode &node, size_t edgeID, NGT::Distance edgeDistance) {
NGT::ObjectDistance edge(edgeID, edgeDistance);
GraphNode::iterator ni = std::lower_bound(node.begin(), node.end(), edge, edgeComp);
node.insert(ni, edge);
}
static bool hasEdge(NGT::GraphIndex &graph, size_t srcNodeID, size_t dstNodeID)
{
NGT::GraphNode &srcNode = *graph.getNode(srcNodeID);
GraphNode::iterator ni = std::lower_bound(srcNode.begin(), srcNode.end(), ObjectDistance(dstNodeID, 0.0), edgeComp);
return (ni != srcNode.end()) && ((*ni).id == dstNodeID);
}
#endif
static void
adjustPathsEffectively(NGT::GraphIndex &outGraph)
{
Timer timer;
timer.start();
std::vector<NGT::GraphNode> tmpGraph;
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
tmpGraph.push_back(node);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
node.clear();
#endif
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
tmpGraph.push_back(NGT::GraphNode(outGraph.repository.allocator));
#else
tmpGraph.push_back(NGT::GraphNode());
#endif
}
}
if (outGraph.repository.size() != tmpGraph.size() + 1) {
std::stringstream msg;
msg << "GraphReconstructor: Fatal inner error. " << outGraph.repository.size() << ":" << tmpGraph.size();
NGTThrowException(msg);
}
timer.stop();
// std::cerr << "GraphReconstructor::adjustPaths: graph preparing time=" << timer << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: graph preparing time=" + std::to_string(timer.time));
timer.reset();
timer.start();
std::vector<std::vector<std::pair<uint32_t, uint32_t> > > removeCandidates(tmpGraph.size());
int removeCandidateCount = 0;
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
auto it = tmpGraph.begin() + idx;
size_t id = idx + 1;
try {
NGT::GraphNode &srcNode = *it;
std::unordered_map<uint32_t, std::pair<size_t, double> > neighbors;
for (size_t sni = 0; sni < srcNode.size(); ++sni) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
neighbors[srcNode.at(sni, outGraph.repository.allocator).id] = std::pair<size_t, double>(sni, srcNode.at(sni, outGraph.repository.allocator).distance);
#else
neighbors[srcNode[sni].id] = std::pair<size_t, double>(sni, srcNode[sni].distance);
#endif
}
std::vector<std::pair<int, std::pair<uint32_t, uint32_t> > > candidates;
for (size_t sni = 0; sni < srcNode.size(); sni++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode &pathNode = tmpGraph[srcNode.at(sni, outGraph.repository.allocator).id - 1];
#else
NGT::GraphNode &pathNode = tmpGraph[srcNode[sni].id - 1];
#endif
for (size_t pni = 0; pni < pathNode.size(); pni++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
auto dstNodeID = pathNode.at(pni, outGraph.repository.allocator).id;
#else
auto dstNodeID = pathNode[pni].id;
#endif
auto dstNode = neighbors.find(dstNodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (dstNode != neighbors.end()
&& srcNode.at(sni, outGraph.repository.allocator).distance < (*dstNode).second.second
&& pathNode.at(pni, outGraph.repository.allocator).distance < (*dstNode).second.second
) {
#else
if (dstNode != neighbors.end()
&& srcNode[sni].distance < (*dstNode).second.second
&& pathNode[pni].distance < (*dstNode).second.second
) {
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode.at(sni, outGraph.repository.allocator).id, dstNodeID)));
#else
candidates.push_back(std::pair<int, std::pair<uint32_t, uint32_t> >((*dstNode).second.first, std::pair<uint32_t, uint32_t>(srcNode[sni].id, dstNodeID)));
#endif
removeCandidateCount++;
}
}
}
sort(candidates.begin(), candidates.end(), std::greater<std::pair<int, std::pair<uint32_t, uint32_t>>>());
removeCandidates[id - 1].reserve(candidates.size());
for (size_t i = 0; i < candidates.size(); i++) {
removeCandidates[id - 1].push_back(candidates[i].second);
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
timer.stop();
// std::cerr << "GraphReconstructor::adjustPaths: extracting removed edge candidates time=" << timer << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor::adjustPaths: extracting removed edge candidates time=" + std::to_string(timer.time));
timer.reset();
timer.start();
std::list<size_t> ids;
for (size_t idx = 0; idx < tmpGraph.size(); ++idx) {
ids.push_back(idx + 1);
}
int removeCount = 0;
removeCandidateCount = 0;
for (size_t rank = 0; ids.size() != 0; rank++) {
for (auto it = ids.begin(); it != ids.end(); ) {
size_t id = *it;
size_t idx = id - 1;
try {
NGT::GraphNode &srcNode = tmpGraph[idx];
if (rank >= srcNode.size()) {
if (!removeCandidates[idx].empty()) {
// std::cerr << "Something wrong! ID=" << id << " # of remaining candidates=" << removeCandidates[idx].size() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Something wrong! ID=" + std::to_string(id) + " # of remaining candidates=" + std::to_string(removeCandidates[idx].size()));
abort();
}
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode empty;
tmpGraph[idx] = empty;
#endif
it = ids.erase(it);
continue;
}
if (removeCandidates[idx].size() > 0) {
removeCandidateCount++;
bool pathExist = false;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
#else
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
#endif
size_t path = removeCandidates[idx].back().first;
size_t dst = removeCandidates[idx].back().second;
removeCandidates[idx].pop_back();
if (removeCandidates[idx].empty()) {
std::vector<std::pair<uint32_t, uint32_t>> empty;
removeCandidates[idx] = empty;
}
if ((hasEdge(outGraph, id, path)) && (hasEdge(outGraph, path, dst))) {
pathExist = true;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode.at(rank, outGraph.repository.allocator).id)) {
#else
while (!removeCandidates[idx].empty() && (removeCandidates[idx].back().second == srcNode[rank].id)) {
#endif
removeCandidates[idx].pop_back();
if (removeCandidates[idx].empty()) {
std::vector<std::pair<uint32_t, uint32_t>> empty;
removeCandidates[idx] = empty;
}
}
break;
}
}
if (pathExist) {
removeCount++;
it++;
continue;
}
}
NGT::GraphNode &outSrcNode = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
insert(outSrcNode, srcNode.at(rank, outGraph.repository.allocator).id, srcNode.at(rank, outGraph.repository.allocator).distance, outGraph);
#else
insert(outSrcNode, srcNode[rank].id, srcNode[rank].distance);
#endif
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
it++;
continue;
}
it++;
}
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(node.begin(outGraph.repository.allocator), node.end(outGraph.repository.allocator));
#else
std::sort(node.begin(), node.end());
#endif
} catch(...) {}
}
}
static
void convertToANNG(std::vector<NGT::ObjectDistances> &graph)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG is not implemented for shared memory.");
// std::cerr << "convertToANNG is not implemented for shared memory." << std::endl;
return;
#else
// std::cerr << "convertToANNG begin" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG begin");
for (size_t idx = 0; idx < graph.size(); idx++) {
NGT::GraphNode &node = graph[idx];
for (auto ni = node.begin(); ni != node.end(); ++ni) {
graph[(*ni).id - 1].push_back(NGT::ObjectDistance(idx + 1, (*ni).distance));
}
}
for (size_t idx = 0; idx < graph.size(); idx++) {
NGT::GraphNode &node = graph[idx];
if (node.size() == 0) {
continue;
}
std::sort(node.begin(), node.end());
NGT::ObjectID prev = 0;
for (auto it = node.begin(); it != node.end();) {
if (prev == (*it).id) {
it = node.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = node;
node.swap(tmp);
}
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("convertToANNG end");
// std::cerr << "convertToANNG end" << std::endl;
#endif
}
static
void reconstructGraph(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph, size_t originalEdgeSize, size_t reverseEdgeSize)
{
if (reverseEdgeSize > 10000) {
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
exit(1);
}
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
originalEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &node = *outGraph.getNode(id);
if (originalEdgeSize == 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
NGT::GraphNode empty;
node.swap(empty);
#endif
} else {
NGT::ObjectDistances n = graph[id - 1];
if (n.size() < originalEdgeSize) {
// std::cerr << "GraphReconstructor: Warning. The edges are too few. " << n.size() << ":" << originalEdgeSize << " for " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. The edges are too few. " + std::to_string(n.size()) + ":" + std::to_string(originalEdgeSize) + " for " + std::to_string(id));
continue;
}
n.resize(originalEdgeSize);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.copy(n, outGraph.repository.allocator);
#else
node.swap(n);
#endif
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
originalEdgeTimer.stop();
reverseEdgeTimer.start();
int insufficientNodeCount = 0;
for (size_t id = 1; id <= graph.size(); ++id) {
try {
NGT::ObjectDistances &node = graph[id - 1];
size_t rsize = reverseEdgeSize;
if (rsize > node.size()) {
insufficientNodeCount++;
rsize = node.size();
}
for (size_t i = 0; i < rsize; ++i) {
NGT::Distance distance = node[i].distance;
size_t nodeID = node[i].id;
try {
NGT::GraphNode &n = *outGraph.getNode(nodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
n.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
#else
n.push_back(NGT::ObjectDistance(id, distance));
#endif
} catch(...) {}
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
reverseEdgeTimer.stop();
if (insufficientNodeCount != 0) {
// std::cerr << "# of the nodes edges of which are in short = " << insufficientNodeCount << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of the nodes edges of which are in short = " + std::to_string(insufficientNodeCount));
}
normalizeEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
// std::cerr << "Processed " << id << " nodes" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes");
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::sort(n.begin(outGraph.repository.allocator), n.end(outGraph.repository.allocator));
#else
std::sort(n.begin(), n.end());
#endif
NGT::ObjectID prev = 0;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
for (auto it = n.begin(outGraph.repository.allocator); it != n.end(outGraph.repository.allocator);) {
#else
for (auto it = n.begin(); it != n.end();) {
#endif
if (prev == (*it).id) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
it = n.erase(it, outGraph.repository.allocator);
#else
it = n.erase(it);
#endif
continue;
}
prev = (*it).id;
it++;
}
#if !defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::GraphNode tmp = n;
n.swap(tmp);
#endif
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
normalizeEdgeTimer.stop();
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
// << ":" << normalizeEdgeTimer.time << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
+ ":" + std::to_string(normalizeEdgeTimer.time));
NGT::Property prop;
outGraph.getProperty().get(prop);
prop.graphType = NGT::NeighborhoodGraph::GraphTypeONNG;
outGraph.getProperty().set(prop);
}
static
void reconstructGraphWithConstraint(std::vector<NGT::ObjectDistances> &graph, NGT::GraphIndex &outGraph,
size_t originalEdgeSize, size_t reverseEdgeSize,
char mode = 'a')
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("reconstructGraphWithConstraint is not implemented.");
// std::cerr << "reconstructGraphWithConstraint is not implemented." << std::endl;
abort();
#else
NGT::Timer originalEdgeTimer, reverseEdgeTimer, normalizeEdgeTimer;
if (reverseEdgeSize > 10000) {
// std::cerr << "something wrong. Edge size=" << reverseEdgeSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("something wrong. Edge size=" + std::to_string(reverseEdgeSize));
exit(1);
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
if (node.size() == 0) {
continue;
}
node.clear();
NGT::GraphNode empty;
node.swap(empty);
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
std::vector<ObjectDistances> reverse(graph.size() + 1);
for (size_t id = 1; id <= graph.size(); ++id) {
try {
NGT::GraphNode &node = graph[id - 1];
if (id % 100000 == 0) {
// std::cerr << "Processed (summing up) " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed (summing up) " + std::to_string(id));
}
for (size_t rank = 0; rank < node.size(); rank++) {
reverse[node[rank].id].push_back(ObjectDistance(id, node[rank].distance));
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
std::vector<std::pair<size_t, size_t> > reverseSize(graph.size() + 1);
reverseSize[0] = std::pair<size_t, size_t>(0, 0);
for (size_t rid = 1; rid <= graph.size(); ++rid) {
reverseSize[rid] = std::pair<size_t, size_t>(reverse[rid].size(), rid);
}
std::sort(reverseSize.begin(), reverseSize.end());
std::vector<uint32_t> indegreeCount(graph.size(), 0);
size_t zeroCount = 0;
for (size_t sizerank = 0; sizerank <= reverseSize.size(); sizerank++) {
if (reverseSize[sizerank].first == 0) {
zeroCount++;
continue;
}
size_t rid = reverseSize[sizerank].second;
ObjectDistances &rnode = reverse[rid];
for (auto rni = rnode.begin(); rni != rnode.end(); ++rni) {
if (indegreeCount[(*rni).id] >= reverseEdgeSize) {
continue;
}
NGT::GraphNode &node = *outGraph.getNode(rid);
if (indegreeCount[(*rni).id] > 0 && node.size() >= originalEdgeSize) {
continue;
}
node.push_back(NGT::ObjectDistance((*rni).id, (*rni).distance));
indegreeCount[(*rni).id]++;
}
}
reverseEdgeTimer.stop();
// std::cerr << "The number of nodes with zero outdegree by reverse edges=" << zeroCount << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The number of nodes with zero outdegree by reverse edges=" + std::to_string(zeroCount));
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
normalizeEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
if (id % 100000 == 0) {
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
std::sort(n.begin(), n.end());
NGT::ObjectID prev = 0;
for (auto it = n.begin(); it != n.end();) {
if (prev == (*it).id) {
it = n.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = n;
n.swap(tmp);
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
normalizeEdgeTimer.stop();
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
originalEdgeTimer.start();
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
// std::cerr << "Processed " << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id));
}
NGT::GraphNode &node = graph[id - 1];
try {
NGT::GraphNode &onode = *outGraph.getNode(id);
bool stop = false;
for (size_t rank = 0; (rank < node.size() && rank < originalEdgeSize) && stop == false; rank++) {
switch (mode) {
case 'a':
if (onode.size() >= originalEdgeSize) {
stop = true;
continue;
}
break;
case 'c':
break;
}
NGT::Distance distance = node[rank].distance;
size_t nodeID = node[rank].id;
outGraph.addEdge(id, nodeID, distance, false);
}
} catch(NGT::Exception &err) {
// std::cerr << "GraphReconstructor: Warning. Cannot get the node. ID=" << id << ":" << err.what() << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("GraphReconstructor: Warning. Cannot get the node. ID=" + std::to_string(id) + ":" + err.what());
continue;
}
}
originalEdgeTimer.stop();
NGT::GraphIndex::showStatisticsOfGraph(outGraph);
// std::cerr << "Reconstruction time=" << originalEdgeTimer.time << ":" << reverseEdgeTimer.time
// << ":" << normalizeEdgeTimer.time << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Reconstruction time=" + std::to_string(originalEdgeTimer.time) + ":" + std::to_string(reverseEdgeTimer.time)
+ ":" + std::to_string(normalizeEdgeTimer.time));
#endif
}
// reconstruct a pseudo ANNG with a fewer edges from an actual ANNG with more edges.
// graph is a source ANNG
// index is an index with a reconstructed ANNG
static
void reconstructANNGFromANNG(std::vector<NGT::ObjectDistances> &graph, NGT::Index &index, size_t edgeSize)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("reconstructANNGFromANNG is not implemented.");
// std::cerr << "reconstructANNGFromANNG is not implemented." << std::endl;
abort();
#else
NGT::GraphIndex &outGraph = dynamic_cast<NGT::GraphIndex&>(index.getIndex());
// remove all edges in the index.
for (size_t id = 1; id < outGraph.repository.size(); id++) {
if (id % 1000000 == 0) {
// std::cerr << "Processed " << id << " nodes." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Processed " + std::to_string(id) + " nodes.");
}
try {
NGT::GraphNode &node = *outGraph.getNode(id);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
node.clear(outGraph.repository.allocator);
#else
NGT::GraphNode empty;
node.swap(empty);
#endif
} catch(NGT::Exception &err) {
}
}
for (size_t id = 1; id <= graph.size(); ++id) {
size_t edgeCount = 0;
try {
NGT::ObjectDistances &node = graph[id - 1];
NGT::GraphNode &n = *outGraph.getNode(id);
NGT::Distance prevDistance = 0.0;
assert(n.size() == 0);
for (size_t i = 0; i < node.size(); ++i) {
NGT::Distance distance = node[i].distance;
if (prevDistance > distance) {
NGTThrowException("Edge distance order is invalid");
}
prevDistance = distance;
size_t nodeID = node[i].id;
if (node[i].id < id) {
try {
NGT::GraphNode &dn = *outGraph.getNode(nodeID);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
n.push_back(NGT::ObjectDistance(nodeID, distance), outGraph.repository.allocator);
dn.push_back(NGT::ObjectDistance(id, distance), outGraph.repository.allocator);
#else
n.push_back(NGT::ObjectDistance(nodeID, distance));
dn.push_back(NGT::ObjectDistance(id, distance));
#endif
} catch(...) {}
edgeCount++;
}
if (edgeCount >= edgeSize) {
break;
}
}
} catch(NGT::Exception &err) {
}
}
for (size_t id = 1; id < outGraph.repository.size(); id++) {
try {
NGT::GraphNode &n = *outGraph.getNode(id);
std::sort(n.begin(), n.end());
NGT::ObjectID prev = 0;
for (auto it = n.begin(); it != n.end();) {
if (prev == (*it).id) {
it = n.erase(it);
continue;
}
prev = (*it).id;
it++;
}
NGT::GraphNode tmp = n;
n.swap(tmp);
} catch (...) {
}
}
#endif
}
static void refineANNG(NGT::Index &index, bool unlog, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
NGT::StdOstreamRedirector redirector(unlog);
redirector.begin();
try {
refineANNG(index, epsilon, accuracy, noOfEdges, exploreEdgeSize, batchSize);
} catch (NGT::Exception &err) {
redirector.end();
throw(err);
}
}
static void refineANNG(NGT::Index &index, float epsilon = 0.1, float accuracy = 0.0, int noOfEdges = 0, int exploreEdgeSize = INT_MIN, size_t batchSize = 10000) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGTThrowException("GraphReconstructor::refineANNG: Not implemented for the shared memory option.");
#else
auto prop = static_cast<GraphIndex&>(index.getIndex()).getGraphProperty();
NGT::ObjectRepository &objectRepository = index.getObjectSpace().getRepository();
NGT::GraphIndex &graphIndex = static_cast<GraphIndex&>(index.getIndex());
size_t nOfObjects = objectRepository.size();
bool error = false;
std::string errorMessage;
for (size_t bid = 1; bid < nOfObjects; bid += batchSize) {
NGT::ObjectDistances results[batchSize];
// search
#pragma omp parallel for
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 100000 == 0) {
// std::cerr << "# of processed objects=" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
}
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::SearchContainer searchContainer(*objectRepository.get(id));
searchContainer.setResults(&results[idx]);
assert(prop.edgeSizeForCreation > 0);
searchContainer.setSize(noOfEdges > prop.edgeSizeForCreation ? noOfEdges : prop.edgeSizeForCreation);
if (accuracy > 0.0) {
searchContainer.setExpectedAccuracy(accuracy);
} else {
searchContainer.setEpsilon(epsilon);
}
if (exploreEdgeSize != INT_MIN) {
searchContainer.setEdgeSize(exploreEdgeSize);
}
if (!error) {
try {
index.search(searchContainer);
} catch (NGT::Exception &err) {
#pragma omp critical
{
error = true;
errorMessage = err.what();
}
}
}
}
if (error) {
std::stringstream msg;
msg << "GraphReconstructor::refineANNG: " << errorMessage;
NGTThrowException(msg);
}
// outgoing edges
#pragma omp parallel for
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::GraphNode &node = *graphIndex.getNode(id);
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
if ((*i).id != id) {
node.push_back(*i);
}
}
std::sort(node.begin(), node.end());
// dedupe
ObjectID prev = 0;
for (GraphNode::iterator ni = node.begin(); ni != node.end();) {
if (prev == (*ni).id) {
ni = node.erase(ni);
continue;
}
prev = (*ni).id;
ni++;
}
}
// incomming edges
if (noOfEdges != 0) {
continue;
}
for (size_t idx = 0; idx < batchSize; idx++) {
size_t id = bid + idx;
if (id % 10000 == 0) {
// std::cerr << "# of processed objects=" << id << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("# of processed objects=" + std::to_string(id));
}
for (auto i = results[idx].begin(); i != results[idx].end(); ++i) {
if ((*i).id != id) {
NGT::GraphNode &node = *graphIndex.getNode((*i).id);
graphIndex.addEdge(node, id, (*i).distance, false);
}
}
}
}
if (noOfEdges != 0) {
// prune to build knng
size_t nedges = noOfEdges < 0 ? -noOfEdges : noOfEdges;
#pragma omp parallel for
for (ObjectID id = 1; id < nOfObjects; ++id) {
if (objectRepository.isEmpty(id)) {
continue;
}
NGT::GraphNode &node = *graphIndex.getNode(id);
if (node.size() > nedges) {
node.resize(nedges);
}
}
}
#endif // defined(NGT_SHARED_MEMORY_ALLOCATOR)
}
};
}; // NGT

View File

@ -1,113 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "defines.h"
#include <iostream>
#include <cstring>
#include <stdint.h>
#include <climits>
#include <unordered_set>
class HashBasedBooleanSet{
private:
uint32_t *_table;
uint32_t _tableSize;
uint32_t _mask;
std::unordered_set<uint32_t> _stlHash;
inline uint32_t _hash1(const uint32_t value){
return value & _mask;
}
public:
HashBasedBooleanSet():_table(NULL), _tableSize(0), _mask(0) {}
HashBasedBooleanSet(const uint64_t size):_table(NULL), _tableSize(0), _mask(0) {
size_t bitSize = 0;
size_t bit = size;
while (bit != 0) {
bitSize++;
bit >>= 1;
}
size_t bucketSize = 0x1 << ((bitSize + 4) / 2 + 3);
initialize(bucketSize);
}
void initialize(const uint32_t tableSize) {
_tableSize = tableSize;
_mask = _tableSize - 1;
const uint32_t checkValue = _hash1(tableSize);
if(checkValue != 0){
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("[WARN] table size is not 2^N : " + std::to_string(tableSize));
// std::cerr << "[WARN] table size is not 2^N : " << tableSize << std::endl;
}
_table = new uint32_t[tableSize];
memset(_table, 0, tableSize * sizeof(uint32_t));
}
~HashBasedBooleanSet(){
delete[] _table;
_stlHash.clear();
}
inline bool operator[](const uint32_t num){
const uint32_t hashValue = _hash1(num);
auto v = _table[hashValue];
if (v == num){
return true;
}
if (v == 0){
return false;
}
if (_stlHash.count(num) <= 0) {
return false;
}
return true;
}
inline void set(const uint32_t num){
uint32_t &value = _table[_hash1(num)];
if(value == 0){
value = num;
}else{
if(value != num){
_stlHash.insert(num);
}
}
}
inline void insert(const uint32_t num){
set(num);
}
inline void reset(const uint32_t num){
const uint32_t hashValue = _hash1(num);
if(_table[hashValue] != 0){
if(_table[hashValue] != num){
_stlHash.erase(num);
}else{
_table[hashValue] = UINT_MAX;
}
}
}
};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,457 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "MmapManagerImpl.hpp"
namespace MemoryManager{
// static method ---
void MmapManager::setDefaultOptionValue(init_option_st &optionst)
{
optionst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
optionst.reuse_type = REUSE_DATA_CLASSIFY;
}
size_t MmapManager::getAlignSize(size_t size){
if((size % MMAP_MEMORY_ALIGN) == 0){
return size;
}else{
return ( (size >> MMAP_MEMORY_ALIGN_EXP ) + 1 ) * MMAP_MEMORY_ALIGN;
}
}
// static method ---
MmapManager::MmapManager():_impl(new MmapManager::Impl(*this))
{
for(uint64_t i = 0; i < MMAP_MAX_UNIT_NUM; ++i){
_impl->mmapDataAddr[i] = NULL;
}
}
MmapManager::~MmapManager() = default;
void MmapManager::dumpHeap() const
{
_impl->dumpHeap();
}
bool MmapManager::isOpen() const
{
return _impl->isOpen;
}
void *MmapManager::getEntryHook() const {
return getAbsAddr(_impl->mmapCntlHead->entry_p);
}
void MmapManager::setEntryHook(const void *entry_p){
_impl->mmapCntlHead->entry_p = getRelAddr(entry_p);
}
bool MmapManager::init(const std::string &filePath, size_t size, const init_option_st *optionst) const
{
try{
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
struct stat st;
if(stat(controlFile.c_str(), &st) == 0){
return false;
}
if(filePath.length() > MMAP_MAX_FILE_NAME_LENGTH){
std::cerr << "too long filepath" << std::endl;
return false;
}
if((size % sysconf(_SC_PAGESIZE) != 0) || ( size < MMAP_LOWER_SIZE )){
std::cerr << "input size error" << std::endl;
return false;
}
int32_t fd = _impl->formatFile(controlFile, MMAP_CNTL_FILE_SIZE);
assert(fd >= 0);
errno = 0;
char *cntl_p = (char *)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if(cntl_p == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(controlFile + " " + err_str);
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
try {
fd = _impl->formatFile(filePath, size);
} catch (MmapManagerException &err) {
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) {
throw MmapManagerException("[ERR] : munmap error : " + getErrorStr(errno) +
" : Through the exception : " + err.what());
}
throw err;
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
boot_st bootStruct = {0};
control_st controlStruct = {0};
_impl->initBootStruct(bootStruct, size);
_impl->initControlStruct(controlStruct, size);
char *cntl_head = cntl_p;
cntl_head += sizeof(boot_st);
if(optionst != NULL){
controlStruct.use_expand = optionst->use_expand;
controlStruct.reuse_type = optionst->reuse_type;
}
memcpy(cntl_p, (char *)&bootStruct, sizeof(boot_st));
memcpy(cntl_head, (char *)&controlStruct, sizeof(control_st));
errno = 0;
if(munmap(cntl_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
return true;
}catch(MmapManagerException &e){
std::cerr << "init error. " << e.what() << std::endl;
throw e;
}
}
bool MmapManager::openMemory(const std::string &filePath)
{
try{
if(_impl->isOpen == true){
std::string err_str = "[ERROR] : openMemory error (double open).";
throw MmapManagerException(err_str);
}
const std::string controlFile = filePath + MMAP_CNTL_FILE_SUFFIX;
_impl->filePath = filePath;
int32_t fd;
errno = 0;
if((fd = open(controlFile.c_str(), O_RDWR, 0666)) == -1){
const std::string err_str = getErrorStr(errno);
throw MmapManagerException("file open error" + err_str);
}
errno = 0;
boot_st *boot_p = (boot_st*)mmap(NULL, MMAP_CNTL_FILE_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if(boot_p == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(controlFile + " " + err_str);
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
if(boot_p->version != MMAP_MANAGER_VERSION){
std::cerr << "[WARN] : version error" << std::endl;
errno = 0;
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
throw MmapManagerException("MemoryManager version error");
}
errno = 0;
if((fd = open(filePath.c_str(), O_RDWR, 0666)) == -1){
const std::string err_str = getErrorStr(errno);
errno = 0;
if(munmap(boot_p, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
throw MmapManagerException("file open error = " + std::string(filePath.c_str()) + err_str);
}
_impl->mmapCntlHead = (control_st*)( (char *)boot_p + sizeof(boot_st));
_impl->mmapCntlAddr = (void *)boot_p;
for(uint64_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
off_t offset = _impl->mmapCntlHead->base_size * i;
errno = 0;
_impl->mmapDataAddr[i] = mmap(NULL, _impl->mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
if(_impl->mmapDataAddr[i] == MAP_FAILED){
if (errno == EINVAL) {
std::cerr << "MmapManager::openMemory: Fatal error. EINVAL" << std::endl
<< " If you use valgrind, this error might occur when the DB is created." << std::endl
<< " In the case of that, reduce bsize in SharedMemoryAllocator." << std::endl;
assert(errno != EINVAL);
}
const std::string err_str = getErrorStr(errno);
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
closeMemory(true);
throw MmapManagerException(err_str);
}
}
if(close(fd) == -1) std::cerr << controlFile << "[WARN] : filedescript cannot close" << std::endl;
_impl->isOpen = true;
return true;
}catch(MmapManagerException &e){
std::cerr << "open error" << std::endl;
throw e;
}
}
void MmapManager::closeMemory(const bool force)
{
try{
if(force || _impl->isOpen){
uint16_t count = 0;
void *error_ids[MMAP_MAX_UNIT_NUM] = {0};
for(uint16_t i = 0; i < _impl->mmapCntlHead->unit_num; i++){
if(_impl->mmapDataAddr[i] != NULL){
if(munmap(_impl->mmapDataAddr[i], _impl->mmapCntlHead->base_size) == -1){
error_ids[i] = _impl->mmapDataAddr[i];;
count++;
}
_impl->mmapDataAddr[i] = NULL;
}
}
if(count > 0){
std::string msg = "";
for(uint16_t i = 0; i < count; i++){
std::stringstream ss;
ss << error_ids[i];
msg += ss.str() + ", ";
}
throw MmapManagerException("unmap error : ids = " + msg);
}
if(_impl->mmapCntlAddr != NULL){
if(munmap(_impl->mmapCntlAddr, MMAP_CNTL_FILE_SIZE) == -1) throw MmapManagerException("munmap error : " + getErrorStr(errno));
_impl->mmapCntlAddr = NULL;
}
_impl->isOpen = false;
}
}catch(MmapManagerException &e){
std::cerr << "close error" << std::endl;
throw e;
}
}
off_t MmapManager::alloc(const size_t size, const bool not_reuse_flag)
{
try{
if(!_impl->isOpen){
std::cerr << "not open this file" << std::endl;
return -1;
}
size_t alloc_size = getAlignSize(size);
if( (alloc_size + sizeof(chunk_head_st)) >= _impl->mmapCntlHead->base_size ){
std::cerr << "alloc size over. size=" << size << "." << std::endl;
return -1;
}
if(!not_reuse_flag){
if( _impl->mmapCntlHead->reuse_type == REUSE_DATA_CLASSIFY
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE
|| _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE_PLUS){
off_t ret_offset;
reuse_state_t reuse_state = REUSE_STATE_OK;
ret_offset = reuse(alloc_size, reuse_state);
if(reuse_state != REUSE_STATE_ALLOC){
return ret_offset;
}
}
}
head_st *unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
if((unit_header->break_p + sizeof(chunk_head_st) + alloc_size) >= _impl->mmapCntlHead->base_size){
if(_impl->mmapCntlHead->use_expand == true){
if(_impl->expandMemory() == false){
std::cerr << __func__ << ": cannot expand" << std::endl;
return -1;
}
unit_header = &_impl->mmapCntlHead->data_headers[_impl->mmapCntlHead->active_unit];
}else{
std::cerr << __func__ << ": total size over" << std::endl;
return -1;
}
}
const off_t file_offset = _impl->mmapCntlHead->active_unit * _impl->mmapCntlHead->base_size;
const off_t ret_p = file_offset + ( unit_header->break_p + sizeof(chunk_head_st) );
chunk_head_st *chunk_head = (chunk_head_st*)(unit_header->break_p + (char *)_impl->mmapDataAddr[_impl->mmapCntlHead->active_unit]);
_impl->setupChunkHead(chunk_head, false, _impl->mmapCntlHead->active_unit, -1, alloc_size);
unit_header->break_p += alloc_size + sizeof(chunk_head_st);
unit_header->chunk_num++;
return ret_p;
}catch(MmapManagerException &e){
std::cerr << "allocation error" << std::endl;
throw e;
}
}
void MmapManager::free(const off_t p)
{
switch(_impl->mmapCntlHead->reuse_type){
case REUSE_DATA_CLASSIFY:
_impl->free_data_classify(p);
break;
case REUSE_DATA_QUEUE:
_impl->free_data_queue(p);
break;
case REUSE_DATA_QUEUE_PLUS:
_impl->free_data_queue_plus(p);
break;
default:
_impl->free_data_classify(p);
break;
}
}
off_t MmapManager::reuse(const size_t size, reuse_state_t &reuse_state)
{
off_t ret_off;
switch(_impl->mmapCntlHead->reuse_type){
case REUSE_DATA_CLASSIFY:
ret_off = _impl->reuse_data_classify(size, reuse_state);
break;
case REUSE_DATA_QUEUE:
ret_off = _impl->reuse_data_queue(size, reuse_state);
break;
case REUSE_DATA_QUEUE_PLUS:
ret_off = _impl->reuse_data_queue_plus(size, reuse_state);
break;
default:
ret_off = _impl->reuse_data_classify(size, reuse_state);
break;
}
return ret_off;
}
void *MmapManager::getAbsAddr(off_t p) const
{
if(p < 0){
return NULL;
}
const uint16_t unit_id = p / _impl->mmapCntlHead->base_size;
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
const off_t ret_p = p - file_offset;
return ABS_ADDR(ret_p, _impl->mmapDataAddr[unit_id]);
}
off_t MmapManager::getRelAddr(const void *p) const
{
const chunk_head_st *chunk_head = (chunk_head_st *)((char *)p - sizeof(chunk_head_st));
const uint16_t unit_id = chunk_head->unit_id;
const off_t file_offset = unit_id * _impl->mmapCntlHead->base_size;
off_t ret_p = (off_t)((char *)p - (char *)_impl->mmapDataAddr[unit_id]);
ret_p += file_offset;
return ret_p;
}
std::string getErrorStr(int32_t err_num){
char err_msg[256];
#ifdef _GNU_SOURCE
char *msg = strerror_r(err_num, err_msg, 256);
return std::string(msg);
#else
strerror_r(err_num, err_msg, 256);
return std::string(err_msg);
#endif
}
size_t MmapManager::getTotalSize() const
{
const uint16_t active_unit = _impl->mmapCntlHead->active_unit;
const size_t ret_size = ((_impl->mmapCntlHead->unit_num - 1) * _impl->mmapCntlHead->base_size) + _impl->mmapCntlHead->data_headers[active_unit].break_p;
return ret_size;
}
size_t MmapManager::getUseSize() const
{
size_t total_size = 0;
void *ref_addr = (void *)&total_size;
_impl->scanAllData(ref_addr, CHECK_STATS_USE_SIZE);
return total_size;
}
uint64_t MmapManager::getUseNum() const
{
uint64_t total_chunk_num = 0;
void *ref_addr = (void *)&total_chunk_num;
_impl->scanAllData(ref_addr, CHECK_STATS_USE_NUM);
return total_chunk_num;
}
size_t MmapManager::getFreeSize() const
{
size_t total_size = 0;
void *ref_addr = (void *)&total_size;
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_SIZE);
return total_size;
}
uint64_t MmapManager::getFreeNum() const
{
uint64_t total_chunk_num = 0;
void *ref_addr = (void *)&total_chunk_num;
_impl->scanAllData(ref_addr, CHECK_STATS_FREE_NUM);
return total_chunk_num;
}
uint16_t MmapManager::getUnitNum() const
{
return _impl->mmapCntlHead->unit_num;
}
size_t MmapManager::getQueueCapacity() const
{
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
return free_queue->capacity;
}
uint64_t MmapManager::getQueueNum() const
{
free_queue_st *free_queue = &_impl->mmapCntlHead->free_queue;
return free_queue->tail;
}
uint64_t MmapManager::getLargeListNum() const
{
uint64_t count = 0;
free_list_st *free_list = &_impl->mmapCntlHead->free_data.large_list;
if(free_list->free_p == -1){
return count;
}
off_t current_off = free_list->free_p;
chunk_head_st *current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
while(current_chunk_head != NULL){
count++;
current_off = current_chunk_head->free_next;
current_chunk_head = (chunk_head_st *)getAbsAddr(current_off);
}
return count;
}
}

View File

@ -1,95 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sys/types.h>
#include <stdint.h>
#include <string>
#include <memory>
#define ABS_ADDR(x, y) (void *)(x + (char *)y);
#define USE_MMAP_MANAGER
namespace MemoryManager{
typedef enum _option_reuse_t{
REUSE_DATA_CLASSIFY,
REUSE_DATA_QUEUE,
REUSE_DATA_QUEUE_PLUS,
}option_reuse_t;
typedef enum _reuse_state_t{
REUSE_STATE_OK,
REUSE_STATE_FALSE,
REUSE_STATE_ALLOC,
}reuse_state_t;
typedef enum _check_statistics_t{
CHECK_STATS_USE_SIZE,
CHECK_STATS_USE_NUM,
CHECK_STATS_FREE_SIZE,
CHECK_STATS_FREE_NUM,
}check_statistics_t;
typedef struct _init_option_st{
bool use_expand;
option_reuse_t reuse_type;
}init_option_st;
class MmapManager{
public:
MmapManager();
~MmapManager();
bool init(const std::string &filePath, size_t size, const init_option_st *optionst = NULL) const;
bool openMemory(const std::string &filePath);
void closeMemory(const bool force = false);
off_t alloc(const size_t size, const bool not_reuse_flag = false);
void free(const off_t p);
off_t reuse(const size_t size, reuse_state_t &reuse_state);
void *getAbsAddr(off_t p) const;
off_t getRelAddr(const void *p) const;
size_t getTotalSize() const;
size_t getUseSize() const;
uint64_t getUseNum() const;
size_t getFreeSize() const;
uint64_t getFreeNum() const;
uint16_t getUnitNum() const;
size_t getQueueCapacity() const;
uint64_t getQueueNum() const;
uint64_t getLargeListNum() const;
void dumpHeap() const;
bool isOpen() const;
void *getEntryHook() const;
void setEntryHook(const void *entry_p);
// static method ---
static void setDefaultOptionValue(init_option_st &optionst);
static size_t getAlignSize(size_t size);
private:
class Impl;
std::unique_ptr<Impl> _impl;
};
std::string getErrorStr(int32_t err_num);
}

View File

@ -1,98 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "MmapManager.h"
#include <unistd.h>
namespace MemoryManager{
const uint64_t MMAP_MANAGER_VERSION = 5;
const bool MMAP_DEFAULT_ALLOW_EXPAND = false;
const uint64_t MMAP_CNTL_FILE_RANGE = 16;
const size_t MMAP_CNTL_FILE_SIZE = MMAP_CNTL_FILE_RANGE * sysconf(_SC_PAGESIZE);
const uint64_t MMAP_MAX_FILE_NAME_LENGTH = 1024;
const std::string MMAP_CNTL_FILE_SUFFIX = "c";
const size_t MMAP_LOWER_SIZE = 1;
const size_t MMAP_MEMORY_ALIGN = 8;
const size_t MMAP_MEMORY_ALIGN_EXP = 3;
#ifndef MMANAGER_TEST_MODE
const uint64_t MMAP_MAX_UNIT_NUM = 1024;
#else
const uint64_t MMAP_MAX_UNIT_NUM = 8;
#endif
const uint64_t MMAP_FREE_QUEUE_SIZE = 1024;
const uint64_t MMAP_FREE_LIST_NUM = 64;
typedef struct _boot_st{
uint32_t version;
uint64_t reserve;
size_t size;
}boot_st;
typedef struct _head_st{
off_t break_p;
uint64_t chunk_num;
uint64_t reserve;
}head_st;
typedef struct _free_list_st{
off_t free_p;
off_t free_last_p;
}free_list_st;
typedef struct _free_st{
free_list_st large_list;
free_list_st free_lists[MMAP_FREE_LIST_NUM];
}free_st;
typedef struct _free_queue_st{
off_t data;
size_t capacity;
uint64_t tail;
}free_queue_st;
typedef struct _control_st{
bool use_expand;
uint16_t unit_num;
uint16_t active_unit;
uint64_t reserve;
size_t base_size;
off_t entry_p;
option_reuse_t reuse_type;
free_st free_data;
free_queue_st free_queue;
head_st data_headers[MMAP_MAX_UNIT_NUM];
}control_st;
typedef struct _chunk_head_st{
bool delete_flg;
uint16_t unit_id;
off_t free_next;
size_t size;
}chunk_head_st;
}

View File

@ -1,28 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <iostream>
#include <exception>
#include <stdexcept>
namespace MemoryManager{
class MmapManagerException : public std::domain_error{
public:
MmapManagerException(const std::string &msg) : std::domain_error(msg){}
};
}

View File

@ -1,655 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "defines.h"
#include "MmapManagerDefs.h"
#include "MmapManagerException.h"
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <iostream>
#include <sstream>
#include <cstring>
#include <cassert>
namespace MemoryManager{
class MmapManager::Impl{
public:
Impl() = delete;
Impl(MmapManager &ommanager);
virtual ~Impl(){}
MmapManager &mmanager;
bool isOpen;
void *mmapCntlAddr;
control_st *mmapCntlHead;
std::string filePath;
void *mmapDataAddr[MMAP_MAX_UNIT_NUM];
void initBootStruct(boot_st &bst, size_t size) const;
void initFreeStruct(free_st &fst) const;
void initFreeQueue(free_queue_st &fqst) const;
void initControlStruct(control_st &cntlst, size_t size) const;
void setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const;
bool expandMemory();
int32_t formatFile(const std::string &targetFile, size_t size) const;
void clearChunk(const off_t chunk_off) const;
void free_data_classify(const off_t p, const bool force_large_list = false) const;
off_t reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list = false) const;
void free_data_queue(const off_t p);
off_t reuse_data_queue(const size_t size, reuse_state_t &reuse_state);
void free_data_queue_plus(const off_t p);
off_t reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state);
bool scanAllData(void *target, const check_statistics_t stats_type) const;
void upHeap(free_queue_st *free_queue, uint64_t index) const;
void downHeap(free_queue_st *free_queue)const;
bool insertHeap(free_queue_st *free_queue, const off_t p) const;
bool getHeap(free_queue_st *free_queue, off_t *p) const;
size_t getMaxHeapValue(free_queue_st *free_queue) const;
void dumpHeap() const;
void divChunk(const off_t chunk_offset, const size_t size);
};
MmapManager::Impl::Impl(MmapManager &ommanager):mmanager(ommanager), isOpen(false), mmapCntlAddr(NULL), mmapCntlHead(NULL){}
void MmapManager::Impl::initBootStruct(boot_st &bst, size_t size) const
{
bst.version = MMAP_MANAGER_VERSION;
bst.reserve = 0;
bst.size = size;
}
void MmapManager::Impl::initFreeStruct(free_st &fst) const
{
fst.large_list.free_p = -1;
fst.large_list.free_last_p = -1;
for(uint32_t i = 0; i < MMAP_FREE_LIST_NUM; ++i){
fst.free_lists[i].free_p = -1;
fst.free_lists[i].free_last_p = -1;
}
}
void MmapManager::Impl::initFreeQueue(free_queue_st &fqst) const
{
fqst.data = -1;
fqst.capacity = MMAP_FREE_QUEUE_SIZE;
fqst.tail = 1;
}
void MmapManager::Impl::initControlStruct(control_st &cntlst, size_t size) const
{
cntlst.use_expand = MMAP_DEFAULT_ALLOW_EXPAND;
cntlst.unit_num = 1;
cntlst.active_unit = 0;
cntlst.reserve = 0;
cntlst.base_size = size;
cntlst.entry_p = 0;
cntlst.reuse_type = REUSE_DATA_CLASSIFY;
initFreeStruct(cntlst.free_data);
initFreeQueue(cntlst.free_queue);
memset(cntlst.data_headers, 0, sizeof(head_st) * MMAP_MAX_UNIT_NUM);
}
void MmapManager::Impl::setupChunkHead(chunk_head_st *chunk_head, const bool delete_flg, const uint16_t unit_id, const off_t free_next, const size_t size) const
{
chunk_head_st chunk_buffer;
chunk_buffer.delete_flg = delete_flg;
chunk_buffer.unit_id = unit_id;
chunk_buffer.free_next = free_next;
chunk_buffer.size = size;
memcpy(chunk_head, &chunk_buffer, sizeof(chunk_head_st));
}
bool MmapManager::Impl::expandMemory()
{
const uint16_t new_unit_num = mmapCntlHead->unit_num + 1;
const size_t new_file_size = mmapCntlHead->base_size * new_unit_num;
const off_t old_file_size = mmapCntlHead->base_size * mmapCntlHead->unit_num;
if(new_unit_num >= MMAP_MAX_UNIT_NUM){
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("over max unit num");
// std::cerr << "over max unit num" << std::endl;
return false;
}
int32_t fd = formatFile(filePath, new_file_size);
assert(fd >= 0);
const off_t offset = mmapCntlHead->base_size * mmapCntlHead->unit_num;
errno = 0;
void *new_area = mmap(NULL, mmapCntlHead->base_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, offset);
if(new_area == MAP_FAILED){
const std::string err_str = getErrorStr(errno);
errno = 0;
if(ftruncate(fd, old_file_size) == -1){
const std::string err_str = getErrorStr(errno);
throw MmapManagerException("truncate error" + err_str);
}
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException("mmap error" + err_str);
}
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(filePath + "[WARN] : filedescript cannot close");
// std::cerr << filePath << "[WARN] : filedescript cannot close" << std::endl;
mmapDataAddr[mmapCntlHead->unit_num] = new_area;
mmapCntlHead->unit_num = new_unit_num;
mmapCntlHead->active_unit++;
return true;
}
int32_t MmapManager::Impl::formatFile(const std::string &targetFile, size_t size) const
{
const char *c = "";
int32_t fd;
errno = 0;
if((fd = open(targetFile.c_str(), O_RDWR|O_CREAT, 0666)) == -1){
std::stringstream ss;
ss << "[ERR] Cannot open the file. " << targetFile << " " << getErrorStr(errno);
throw MmapManagerException(ss.str());
}
errno = 0;
if(lseek(fd, (off_t)size-1, SEEK_SET) < 0){
std::stringstream ss;
ss << "[ERR] Cannot seek the file. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}
errno = 0;
if(write(fd, &c, sizeof(char)) == -1){
std::stringstream ss;
ss << "[ERR] Cannot write the file. Check the disk space. " << targetFile << " " << getErrorStr(errno);
if(close(fd) == -1 && NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(targetFile + "[WARN] : filedescript cannot close");
// std::cerr << targetFile << "[WARN] : filedescript cannot close" << std::endl;
throw MmapManagerException(ss.str());
}
return fd;
}
void MmapManager::Impl::clearChunk(const off_t chunk_off) const
{
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_off);
const off_t payload_off = chunk_off + sizeof(chunk_head_st);
chunk_head->delete_flg = false;
chunk_head->free_next = -1;
char *payload_addr = (char *)mmanager.getAbsAddr(payload_off);
memset(payload_addr, 0, chunk_head->size);
}
void MmapManager::Impl::free_data_classify(const off_t p, const bool force_large_list) const
{
const off_t chunk_offset = p - sizeof(chunk_head_st);
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t p_size = chunk_head->size;
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
free_list_st *free_list;
if(p_size <= border_size && force_large_list == false){
uint32_t index = (p_size / MMAP_MEMORY_ALIGN) - 1;
free_list = &mmapCntlHead->free_data.free_lists[index];
}else{
free_list = &mmapCntlHead->free_data.large_list;
}
if(free_list->free_p == -1){
free_list->free_p = free_list->free_last_p = chunk_offset;
}else{
off_t last_off = free_list->free_last_p;
chunk_head_st *tmp_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(last_off);
free_list->free_last_p = tmp_chunk_head->free_next = chunk_offset;
}
chunk_head->delete_flg = true;
}
off_t MmapManager::Impl::reuse_data_classify(const size_t size, reuse_state_t &reuse_state, const bool force_large_list) const
{
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
free_list_st *free_list;
if(size <= border_size && force_large_list == false){
uint32_t index = (size / MMAP_MEMORY_ALIGN) - 1;
free_list = &mmapCntlHead->free_data.free_lists[index];
}else{
free_list = &mmapCntlHead->free_data.large_list;
}
if(free_list->free_p == -1){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
off_t current_off = free_list->free_p;
off_t ret_off = 0;
chunk_head_st *current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
chunk_head_st *ret_chunk_head = NULL;
if( (size <= border_size) && (free_list->free_last_p == free_list->free_p) ){
ret_off = current_off;
ret_chunk_head = current_chunk_head;
free_list->free_p = free_list->free_last_p = -1;
}else{
off_t ret_before_off = -1, before_off = -1;
bool found_candidate_flag = false;
while(current_chunk_head != NULL){
if( current_chunk_head->size >= size ) found_candidate_flag = true;
if(found_candidate_flag){
ret_off = current_off;
ret_chunk_head = current_chunk_head;
ret_before_off = before_off;
break;
}
before_off = current_off;
current_off = current_chunk_head->free_next;
current_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(current_off);
}
if(!found_candidate_flag){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
const off_t free_next = ret_chunk_head->free_next;
if(free_list->free_p == ret_off){
free_list->free_p = free_next;
}else{
chunk_head_st *before_chunk = (chunk_head_st *)mmanager.getAbsAddr(ret_before_off);
before_chunk->free_next = free_next;
}
if(free_list->free_last_p == ret_off){
free_list->free_last_p = ret_before_off;
}
}
clearChunk(ret_off);
ret_off = ret_off + sizeof(chunk_head_st);
return ret_off;
}
void MmapManager::Impl::free_data_queue(const off_t p)
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
const size_t queue_size = sizeof(off_t) * free_queue->capacity;
const off_t alloc_offset = mmanager.alloc(queue_size);
if(alloc_offset == -1){
return free_data_classify(p, true);
}
free_queue->data = alloc_offset;
}else if(free_queue->tail >= free_queue->capacity){
const off_t tmp_old_queue = free_queue->data;
const size_t old_size = sizeof(off_t) * free_queue->capacity;
const size_t new_capacity = free_queue->capacity * 2;
const size_t new_size = sizeof(off_t) * new_capacity;
if(new_size > mmapCntlHead->base_size){
return free_data_classify(p, true);
}else{
const off_t alloc_offset = mmanager.alloc(new_size);
if(alloc_offset == -1){
return free_data_classify(p, true);
}
free_queue->data = alloc_offset;
const off_t *old_data = (off_t *)mmanager.getAbsAddr(tmp_old_queue);
off_t *new_data = (off_t *)mmanager.getAbsAddr(free_queue->data);
memcpy(new_data, old_data, old_size);
free_queue->capacity = new_capacity;
mmanager.free(tmp_old_queue);
}
}
const off_t chunk_offset = p - sizeof(chunk_head_st);
if(!insertHeap(free_queue, chunk_offset)){
return;
}
chunk_head_st *chunk_head = (chunk_head_st*)mmanager.getAbsAddr(chunk_offset);
chunk_head->delete_flg = 1;
return;
}
off_t MmapManager::Impl::reuse_data_queue(const size_t size, reuse_state_t &reuse_state)
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
if(getMaxHeapValue(free_queue) < size){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
off_t ret_off;
if(!getHeap(free_queue, &ret_off)){
reuse_state = REUSE_STATE_ALLOC;
return -1;
}
reuse_state_t list_state = REUSE_STATE_OK;
off_t candidate_off = reuse_data_classify(MMAP_MEMORY_ALIGN, list_state, true);
if(list_state == REUSE_STATE_OK){
mmanager.free(candidate_off);
}
const off_t c_ret_off = ret_off;
divChunk(c_ret_off, size);
clearChunk(ret_off);
ret_off = ret_off + sizeof(chunk_head_st);
return ret_off;
}
void MmapManager::Impl::free_data_queue_plus(const off_t p)
{
const off_t chunk_offset = p - sizeof(chunk_head_st);
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t p_size = chunk_head->size;
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
if(p_size <= border_size){
free_data_classify(p);
}else{
free_data_queue(p);
}
}
off_t MmapManager::Impl::reuse_data_queue_plus(const size_t size, reuse_state_t &reuse_state)
{
const size_t border_size = MMAP_MEMORY_ALIGN * MMAP_FREE_LIST_NUM;
off_t ret_off;
if(size <= border_size){
ret_off = reuse_data_classify(size, reuse_state);
if(reuse_state == REUSE_STATE_ALLOC){
reuse_state = REUSE_STATE_OK;
ret_off = reuse_data_queue(size, reuse_state);
}
}else{
ret_off = reuse_data_queue(size, reuse_state);
}
return ret_off;
}
bool MmapManager::Impl::scanAllData(void *target, const check_statistics_t stats_type) const
{
const uint16_t unit_num = mmapCntlHead->unit_num;
size_t total_size = 0;
uint64_t total_chunk_num = 0;
for(int i = 0; i < unit_num; i++){
const head_st *target_unit_head = &mmapCntlHead->data_headers[i];
const uint64_t chunk_num = target_unit_head->chunk_num;
const off_t base_offset = i * mmapCntlHead->base_size;
off_t target_offset = base_offset;
chunk_head_st *target_chunk;
for(uint64_t j = 0; j < chunk_num; j++){
target_chunk = (chunk_head_st*)mmanager.getAbsAddr(target_offset);
if(stats_type == CHECK_STATS_USE_SIZE){
if(target_chunk->delete_flg == false){
total_size += target_chunk->size;
}
}else if(stats_type == CHECK_STATS_USE_NUM){
if(target_chunk->delete_flg == false){
total_chunk_num++;
}
}else if(stats_type == CHECK_STATS_FREE_SIZE){
if(target_chunk->delete_flg == true){
total_size += target_chunk->size;
}
}else if(stats_type == CHECK_STATS_FREE_NUM){
if(target_chunk->delete_flg == true){
total_chunk_num++;
}
}
const size_t chunk_size = sizeof(chunk_head_st) + target_chunk->size;
target_offset += chunk_size;
}
}
if(stats_type == CHECK_STATS_USE_SIZE || stats_type == CHECK_STATS_FREE_SIZE){
size_t *tmp_size = (size_t *)target;
*tmp_size = total_size;
}else if(stats_type == CHECK_STATS_USE_NUM || stats_type == CHECK_STATS_FREE_NUM){
uint64_t *tmp_chunk_num = (uint64_t *)target;
*tmp_chunk_num = total_chunk_num;
}
return true;
}
void MmapManager::Impl::upHeap(free_queue_st *free_queue, uint64_t index) const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
while(index > 1){
uint64_t parent = index / 2;
const off_t parent_chunk_offset = queue[parent];
const off_t index_chunk_offset = queue[index];
const chunk_head_st *parent_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(parent_chunk_offset);
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
if(parent_chunk_head->size < index_chunk_head->size){
const off_t tmp = queue[parent];
queue[parent] = queue[index];
queue[index] = tmp;
}
index = parent;
}
}
void MmapManager::Impl::downHeap(free_queue_st *free_queue)const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
uint64_t index = 1;
while(index * 2 <= free_queue->tail){
uint64_t child = index * 2;
const off_t index_chunk_offset = queue[index];
const chunk_head_st *index_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(index_chunk_offset);
if(child + 1 < free_queue->tail){
const off_t left_chunk_offset = queue[child];
const off_t right_chunk_offset = queue[child+1];
const chunk_head_st *left_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(left_chunk_offset);
const chunk_head_st *right_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(right_chunk_offset);
if(left_chunk_head->size < right_chunk_head->size){
child = child + 1;
}
}
const off_t child_chunk_offset = queue[child];
const chunk_head_st *child_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(child_chunk_offset);
if(child_chunk_head->size > index_chunk_head->size){
const off_t tmp = queue[child];
queue[child] = queue[index];
queue[index] = tmp;
index = child;
}else{
break;
}
}
}
bool MmapManager::Impl::insertHeap(free_queue_st *free_queue, const off_t p) const
{
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
uint64_t index;
if(free_queue->capacity < free_queue->tail){
return false;
}
index = free_queue->tail;
queue[index] = p;
free_queue->tail += 1;
upHeap(free_queue, index);
return true;
}
bool MmapManager::Impl::getHeap(free_queue_st *free_queue, off_t *p) const
{
if( (free_queue->tail - 1) <= 0){
return false;
}
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
*p = queue[1];
free_queue->tail -= 1;
queue[1] = queue[free_queue->tail];
downHeap(free_queue);
return true;
}
size_t MmapManager::Impl::getMaxHeapValue(free_queue_st *free_queue) const
{
if(free_queue->data == -1){
return 0;
}
const off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(queue[1]);
return chunk_head->size;
}
void MmapManager::Impl::dumpHeap() const
{
free_queue_st *free_queue = &mmapCntlHead->free_queue;
if(free_queue->data == -1){
std::cout << "heap unused" << std::endl;
return;
}
off_t *queue = (off_t *)mmanager.getAbsAddr(free_queue->data);
for(uint32_t i = 1; i < free_queue->tail; ++i){
const off_t chunk_offset = queue[i];
const off_t payload_offset = chunk_offset + sizeof(chunk_head_st);
const chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t size = chunk_head->size;
std::cout << "[" << chunk_offset << "(" << payload_offset << "), " << size << "] ";
}
std::cout << std::endl;
}
void MmapManager::Impl::divChunk(const off_t chunk_offset, const size_t size)
{
if((mmapCntlHead->reuse_type != REUSE_DATA_QUEUE)
&& (mmapCntlHead->reuse_type != REUSE_DATA_QUEUE_PLUS)){
return;
}
chunk_head_st *chunk_head = (chunk_head_st *)mmanager.getAbsAddr(chunk_offset);
const size_t border_size = sizeof(chunk_head_st) + MMAP_MEMORY_ALIGN;
const size_t align_size = getAlignSize(size);
const size_t rest_size = chunk_head->size - align_size;
if(rest_size < border_size){
return;
}
chunk_head->size = align_size;
const off_t new_chunk_offset = chunk_offset + sizeof(chunk_head_st) + align_size;
chunk_head_st *new_chunk_head = (chunk_head_st *)mmanager.getAbsAddr(new_chunk_offset);
const size_t new_size = rest_size - sizeof(chunk_head_st);
setupChunkHead(new_chunk_head, true, chunk_head->unit_id, -1, new_size);
head_st *unit_header = &mmapCntlHead->data_headers[mmapCntlHead->active_unit];
unit_header->chunk_num++;
const off_t payload_offset = new_chunk_offset + sizeof(chunk_head_st);
mmanager.free(payload_offset);
return;
}
}

View File

@ -1,602 +0,0 @@
//
// Copyright (C) 2016-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/NGTQ/Quantizer.h"
#define NGTQ_SEARCH_CODEBOOK_SIZE_FLUCTUATION
namespace NGTQ {
class Command {
public:
Command():debugLevel(0) {}
void
create(NGT::Args &args)
{
const string usage = "Usage: ngtq create "
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
"[-M global-centroid-creation-mode (d|s)] [-L global-centroid-creation-mode (d|k|s)] "
"[-S local-sample-coefficient] "
"index(output) data.tsv(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string data;
try {
data = args.get("#2");
} catch (...) {
cerr << "Data is not specified." << endl;
}
char objectType = args.getChar("o", 'f');
char distanceType = args.getChar("D", '2');
size_t dataSize = args.getl("n", 0);
NGTQ::Property property;
property.threadSize = args.getl("p", 24);
property.dimension = args.getl("d", 0);
property.globalRange = args.getf("R", 0);
property.localRange = args.getf("r", 0);
property.globalCentroidLimit = args.getl("C", 1000000);
property.localCentroidLimit = args.getl("c", 65000);
property.localDivisionNo = args.getl("N", 8);
property.batchSize = args.getl("b", 1000);
property.localClusteringSampleCoefficient = args.getl("S", 10);
{
char localCentroidType = args.getChar("T", 'f');
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
}
{
char centroidCreationMode = args.getChar("M", 'd');
switch(centroidCreationMode) {
case 'd': property.centroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
case 's': property.centroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
default:
cerr << "ngt: Invalid centroid creation mode. " << centroidCreationMode << endl;
cerr << usage << endl;
return;
}
}
{
char localCentroidCreationMode = args.getChar("L", 'd');
switch(localCentroidCreationMode) {
case 'd': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamic; break;
case 's': property.localCentroidCreationMode = NGTQ::CentroidCreationModeStatic; break;
case 'k': property.localCentroidCreationMode = NGTQ::CentroidCreationModeDynamicKmeans; break;
default:
cerr << "ngt: Invalid centroid creation mode. " << localCentroidCreationMode << endl;
cerr << usage << endl;
return;
}
}
NGT::Property globalProperty;
NGT::Property localProperty;
{
char indexType = args.getChar("i", 't');
globalProperty.indexType = indexType == 't' ? NGT::Property::GraphAndTree : NGT::Property::Graph;
localProperty.indexType = globalProperty.indexType;
}
globalProperty.insertionRadiusCoefficient = args.getf("e", 0.1) + 1.0;
localProperty.insertionRadiusCoefficient = globalProperty.insertionRadiusCoefficient;
if (debugLevel >= 1) {
cerr << "epsilon=" << globalProperty.insertionRadiusCoefficient << endl;
cerr << "data size=" << dataSize << endl;
cerr << "dimension=" << property.dimension << endl;
cerr << "thread size=" << property.threadSize << endl;
cerr << "batch size=" << localProperty.batchSizeForCreation << endl;;
cerr << "index type=" << globalProperty.indexType << endl;
}
switch (objectType) {
case 'f': property.dataType = NGTQ::DataTypeFloat; break;
case 'c': property.dataType = NGTQ::DataTypeUint8; break;
default:
cerr << "ngt: Invalid object type. " << objectType << endl;
cerr << usage << endl;
return;
}
switch (distanceType) {
case '2': property.distanceType = NGTQ::DistanceTypeL2; break;
case '1': property.distanceType = NGTQ::DistanceTypeL1; break;
case 'a': property.distanceType = NGTQ::DistanceTypeAngle; break;
default:
cerr << "ngt: Invalid distance type. " << distanceType << endl;
cerr << usage << endl;
return;
}
cerr << "ngtq: Create" << endl;
NGTQ::Index::create(database, property, globalProperty, localProperty);
cerr << "ngtq: Append" << endl;
NGTQ::Index::append(database, data, dataSize);
}
void
rebuild(NGT::Args &args)
{
const string usage = "Usage: ngtq rebuild "
"[-o object-type (f:float|c:unsigned char)] [-D distance-function] [-n data-size] "
"[-p #-of-thread] [-d dimension] [-R global-codebook-range] [-r local-codebook-range] "
"[-C global-codebook-size-limit] [-c local-codebook-size-limit] [-N local-division-no] "
"[-T single-local-centroid (t|f)] [-e epsilon] [-i index-type (t:Tree|g:Graph)] "
"[-M centroid-creation_mode (d|s)] "
"index(output) data.tsv(input)";
string srcIndex;
try {
srcIndex = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string rebuiltIndex = srcIndex + ".tmp";
NGTQ::Property property;
NGT::Property globalProperty;
NGT::Property localProperty;
{
NGTQ::Index index(srcIndex);
property = index.getQuantizer().property;
index.getQuantizer().globalCodebook.getProperty(globalProperty);
index.getQuantizer().getLocalCodebook(0).getProperty(localProperty);
}
property.globalRange = args.getf("R", property.globalRange);
property.localRange = args.getf("r", property.localRange);
property.globalCentroidLimit = args.getl("C", property.globalCentroidLimit);
property.localCentroidLimit = args.getl("c", property.localCentroidLimit);
property.localDivisionNo = args.getl("N", property.localDivisionNo);
{
char localCentroidType = args.getChar("T", '-');
if (localCentroidType != '-') {
property.singleLocalCodebook = localCentroidType == 't' ? true : false;
}
}
{
char centroidCreationMode = args.getChar("M", '-');
if (centroidCreationMode != '-') {
property.centroidCreationMode = centroidCreationMode == 'd' ?
NGTQ::CentroidCreationModeDynamic : NGTQ::CentroidCreationModeStatic;
}
}
cerr << "global range=" << property.globalRange << endl;
cerr << "local range=" << property.localRange << endl;
cerr << "global centroid limit=" << property.globalCentroidLimit << endl;
cerr << "local centroid limit=" << property.localCentroidLimit << endl;
cerr << "local division no=" << property.localDivisionNo << endl;
NGTQ::Index::create(rebuiltIndex, property, globalProperty, localProperty);
cerr << "created a new db" << endl;
cerr << "start rebuilding..." << endl;
NGTQ::Index::rebuild(srcIndex, rebuiltIndex);
{
string src = srcIndex;
string dst = srcIndex + ".org";
if (std::rename(src.c_str(), dst.c_str()) != 0) {
stringstream msg;
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
NGTThrowException(msg);
}
}
{
string src = rebuiltIndex;
string dst = srcIndex;
if (std::rename(src.c_str(), dst.c_str()) != 0) {
stringstream msg;
msg << "ngtq::rebuild: Cannot rename. " << src << "=>" << dst ;
NGTThrowException(msg);
}
}
}
void
append(NGT::Args &args)
{
const string usage = "Usage: ngtq append [-n data-size] "
"index(output) data.tsv(input)";
string index;
try {
index = args.get("#1");
} catch (...) {
cerr << "DB is not specified." << endl;
cerr << usage << endl;
return;
}
string data;
try {
data = args.get("#2");
} catch (...) {
cerr << "Data is not specified." << endl;
}
size_t dataSize = args.getl("n", 0);
if (debugLevel >= 1) {
cerr << "data size=" << dataSize << endl;
}
NGTQ::Index::append(index, data, dataSize);
}
void
search(NGT::Args &args)
{
const string usage = "Usage: ngtq search [-i g|t|s] [-n result-size] [-e epsilon] [-m mode(r|l|c|a)] "
"[-E edge-size] [-o output-mode] [-b result expansion(begin:end:[x]step)] "
"index(input) query.tsv(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
string query;
try {
query = args.get("#2");
} catch (...) {
cerr << "Query is not specified" << endl;
cerr << usage << endl;
return;
}
int size = args.getl("n", 20);
char outputMode = args.getChar("o", '-');
float epsilon = 0.1;
char mode = args.getChar("m", '-');
NGTQ::AggregationMode aggregationMode;
switch (mode) {
case 'r': aggregationMode = NGTQ::AggregationModeExactDistanceThroughApproximateDistance; break; // refine
case 'e': aggregationMode = NGTQ::AggregationModeExactDistance; break; // refine
case 'l': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithLookupTable; break; // lookup
case 'c': aggregationMode = NGTQ::AggregationModeApproximateDistanceWithCache; break; // cache
case '-':
case 'a': aggregationMode = NGTQ::AggregationModeApproximateDistance; break; // cache
default:
cerr << "Invalid aggregation mode. " << mode << endl;
cerr << usage << endl;
return;
}
if (args.getString("e", "none") == "-") {
// linear search
epsilon = FLT_MAX;
} else {
epsilon = args.getf("e", 0.1);
}
size_t beginOfResultExpansion, endOfResultExpansion, stepOfResultExpansion;
bool mulStep = false;
{
beginOfResultExpansion = stepOfResultExpansion = 1;
endOfResultExpansion = 0;
string str = args.getString("b", "16");
vector<string> tokens;
NGT::Common::tokenize(str, tokens, ":");
if (tokens.size() >= 1) { beginOfResultExpansion = NGT::Common::strtod(tokens[0]); }
if (tokens.size() >= 2) { endOfResultExpansion = NGT::Common::strtod(tokens[1]); }
if (tokens.size() >= 3) {
if (tokens[2][0] == 'x') {
mulStep = true;
stepOfResultExpansion = NGT::Common::strtod(tokens[2].substr(1));
} else {
stepOfResultExpansion = NGT::Common::strtod(tokens[2]);
}
}
}
if (debugLevel >= 1) {
cerr << "size=" << size << endl;
cerr << "result expansion=" << beginOfResultExpansion << "->" << endOfResultExpansion << "," << stepOfResultExpansion << endl;
}
NGTQ::Index index(database);
try {
ifstream is(query);
if (!is) {
cerr << "Cannot open the specified file. " << query << endl;
return;
}
if (outputMode == 's') { cout << "# Beginning of Evaluation" << endl; }
string line;
double totalTime = 0;
int queryCount = 0;
while(getline(is, line)) {
NGT::Object *query = index.allocateObject(line, " \t", 0);
queryCount++;
size_t resultExpansion = 0;
for (size_t base = beginOfResultExpansion;
resultExpansion <= endOfResultExpansion;
base = mulStep ? base * stepOfResultExpansion : base + stepOfResultExpansion) {
resultExpansion = base;
NGT::ObjectDistances objects;
if (outputMode == 'e') {
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
objects.clear();
}
NGT::Timer timer;
timer.start();
// size : # of final resultant objects
// resultExpansion : # of resultant objects by using codebook search
index.search(query, objects, size, resultExpansion, aggregationMode, epsilon);
timer.stop();
totalTime += timer.time;
if (outputMode == 'e') {
cout << "# Query No.=" << queryCount << endl;
cout << "# Query=" << line.substr(0, 20) + " ..." << endl;
cout << "# Index Type=" << "----" << endl;
cout << "# Size=" << size << endl;
cout << "# Epsilon=" << epsilon << endl;
cout << "# Result expansion=" << resultExpansion << endl;
cout << "# Distance Computation=" << index.getQuantizer().distanceComputationCount << endl;
cout << "# Query Time (msec)=" << timer.time * 1000.0 << endl;
} else {
cout << "Query No." << queryCount << endl;
cout << "Rank\tIN-ID\tID\tDistance" << endl;
}
for (size_t i = 0; i < objects.size(); i++) {
cout << i + 1 << "\t" << objects[i].id << "\t";
cout << objects[i].distance << endl;
}
if (outputMode == 'e') {
cout << "# End of Search" << endl;
} else {
cout << "Query Time= " << timer.time << " (sec), " << timer.time * 1000.0 << " (msec)" << endl;
}
}
if (outputMode == 'e') {
cout << "# End of Query" << endl;
}
index.deleteObject(query);
}
if (outputMode == 'e') {
cout << "# Average Query Time (msec)=" << totalTime * 1000.0 / (double)queryCount << endl;
cout << "# Number of queries=" << queryCount << endl;
cout << "# End of Evaluation" << endl;
} else {
cout << "Average Query Time= " << totalTime / (double)queryCount << " (sec), "
<< totalTime * 1000.0 / (double)queryCount << " (msec), ("
<< totalTime << "/" << queryCount << ")" << endl;
}
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
index.close();
}
void
remove(NGT::Args &args)
{
const string usage = "Usage: ngtq remove [-d object-ID-type(f|d)] index(input) object-ID(input)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
try {
args.get("#2");
} catch (...) {
cerr << "ID is not specified" << endl;
cerr << usage << endl;
return;
}
char dataType = args.getChar("d", 'f');
if (debugLevel >= 1) {
cerr << "dataType=" << dataType << endl;
}
try {
vector<NGT::ObjectID> objects;
if (dataType == 'f') {
string ids;
try {
ids = args.get("#2");
} catch (...) {
cerr << "Data file is not specified" << endl;
cerr << usage << endl;
return;
}
ifstream is(ids);
if (!is) {
cerr << "Cannot open the specified file. " << ids << endl;
return;
}
string line;
int count = 0;
while(getline(is, line)) {
count++;
vector<string> tokens;
NGT::Common::tokenize(line, tokens, "\t ");
if (tokens.size() == 0 || tokens[0].size() == 0) {
continue;
}
char *e;
size_t id;
try {
id = strtol(tokens[0].c_str(), &e, 10);
objects.push_back(id);
} catch (...) {
cerr << "Illegal data. " << tokens[0] << endl;
}
if (*e != 0) {
cerr << "Illegal data. " << e << endl;
}
cerr << "removed ID=" << id << endl;
}
} else {
size_t id = args.getl("#2", 0);
cerr << "removed ID=" << id << endl;
objects.push_back(id);
}
NGT::Index::remove(database, objects);
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
}
void
info(NGT::Args &args)
{
const string usage = "Usage: ngtq info index";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
NGTQ::Index index(database);
index.info(cout);
}
void
validate(NGT::Args &args)
{
const string usage = "parameter";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
NGTQ::Index index(database);
index.getQuantizer().validate();
}
#ifdef NGTQ_SHARED_INVERTED_INDEX
void
compress(NGT::Args &args)
{
const string usage = "Usage: ngtq compress index)";
string database;
try {
database = args.get("#1");
} catch (...) {
cerr << "DB is not specified" << endl;
cerr << usage << endl;
return;
}
try {
NGTQ::Index::compress(database);
} catch (NGT::Exception &err) {
cerr << "Error " << err.what() << endl;
cerr << usage << endl;
} catch (...) {
cerr << "Error" << endl;
cerr << usage << endl;
}
}
#endif
void help() {
cerr << "Usage : ngtq command database data" << endl;
cerr << " command : create search remove append export import" << endl;
}
void execute(NGT::Args args) {
string command;
try {
command = args.get("#0");
} catch(...) {
help();
return;
}
debugLevel = args.getl("X", 0);
try {
if (debugLevel >= 1) {
cerr << "ngt::command=" << command << endl;
}
if (command == "search") {
search(args);
} else if (command == "create") {
create(args);
} else if (command == "append") {
append(args);
} else if (command == "remove") {
remove(args);
} else if (command == "info") {
info(args);
} else if (command == "validate") {
validate(args);
} else if (command == "rebuild") {
rebuild(args);
#ifdef NGTQ_SHARED_INVERTED_INDEX
} else if (command == "compress") {
compress(args);
#endif
} else {
cerr << "Illegal command. " << command << endl;
}
} catch(NGT::Exception &err) {
cerr << "ngt: Fatal error: " << err.what() << endl;
}
}
int debugLevel;
};
};

File diff suppressed because it is too large Load Diff

View File

@ -1,338 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/defines.h"
#include "NGT/Node.h"
#include "NGT/Tree.h"
#include <algorithm>
using namespace std;
const double NGT::Node::Object::Pivot = -1.0;
using namespace NGT;
void
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst,
SharedMemoryAllocator &allocator) {
#else
InternalNode::updateChild(DVPTree &dvptree, Node::ID src, Node::ID dst) {
#endif
int cs = dvptree.internalChildrenSize;
for (int i = 0; i < cs; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (getChildren(allocator)[i] == src) {
getChildren(allocator)[i] = dst;
#else
if (getChildren()[i] == src) {
getChildren()[i] = dst;
#endif
return;
}
}
}
int
LeafNode::selectPivotByMaxDistance(Container &c, Node::Objects &fs)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
int fsize = fs.size();
Distance maxd = 0.0;
int maxid = 0;
for (int i = 1; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[0].object, *fs[i].object);
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
int aid = maxid;
maxd = 0.0;
maxid = 0;
for (int i = 0; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[aid].object, *fs[i].object);
if (i == aid) {
continue;
}
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
int bid = maxid;
maxd = 0.0;
maxid = 0;
for (int i = 0; i < fsize; i++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[bid].object, *fs[i].object);
if (i == bid) {
continue;
}
if (d >= maxd) {
maxd = d;
maxid = i;
}
}
return maxid;
}
int
LeafNode::selectPivotByMaxVariance(Container &c, Node::Objects &fs)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
int fsize = fs.size();
Distance *distance = new Distance[fsize * fsize];
for (int i = 0; i < fsize; i++) {
distance[i * fsize + i] = 0;
}
for (int i = 0; i < fsize; i++) {
for (int j = i + 1; j < fsize; j++) {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[i].object, *fs[j].object);
distance[i * fsize + j] = d;
distance[j * fsize + i] = d;
}
}
double *variance = new double[fsize];
for (int i = 0; i < fsize; i++) {
double avg = 0.0;
for (int j = 0; j < fsize; j++) {
avg += distance[i * fsize + j];
}
avg /= (double)fsize;
double v = 0.0;
for (int j = 0; j < fsize; j++) {
v += pow(distance[i * fsize + j] - avg, 2.0);
}
variance[i] = v / (double)fsize;
}
double maxv = variance[0];
int maxid = 0;
for (int i = 0; i < fsize; i++) {
if (variance[i] > maxv) {
maxv = variance[i];
maxid = i;
}
}
delete [] variance;
delete [] distance;
return maxid;
}
void
LeafNode::splitObjects(Container &c, Objects &fs, int pv)
{
DVPTree::InsertContainer &iobj = (DVPTree::InsertContainer&)c;
// sort the objects by distance
int fsize = fs.size();
for (int i = 0; i < fsize; i++) {
if (i == pv) {
fs[i].distance = 0;
} else {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pv].object, *fs[i].object);
fs[i].distance = d;
}
}
sort(fs.begin(), fs.end());
int childrenSize = iobj.vptree->internalChildrenSize;
int cid = childrenSize - 1;
int cms = (fsize * cid) / childrenSize;
// divide the objects into child clusters.
fs[fsize - 1].clusterID = cid;
for (int i = fsize - 2; i >= 0; i--) {
if (i < cms && cid > 0) {
if (fs[i].distance != fs[i + 1].distance) {
cid--;
cms = (fsize * cid) / childrenSize;
}
}
fs[i].clusterID = cid;
}
if (cid != 0) {
// the required number of child nodes could not be acquired
stringstream msg;
msg << "LeafNode::splitObjects: Too many same distances. Reduce internal children size for the tree index or not use the tree index." << endl;
msg << " internalChildrenSize=" << childrenSize << endl;
msg << " # of the children=" << (childrenSize - cid) << endl;
msg << " Size=" << fsize << endl;
msg << " pivot=" << pv << endl;
msg << " cluster id=" << cid << endl;
msg << " Show distances for debug." << endl;
for (int i = 0; i < fsize; i++) {
msg << " " << fs[i].id << ":" << fs[i].distance << endl;
msg << " ";
PersistentObject &po = *fs[i].object;
iobj.vptree->objectSpace->show(msg, po);
msg << endl;
}
if (fs[fsize - 1].clusterID == cid) {
msg << "LeafNode::splitObjects: All of the object distances are the same!" << endl;;
NGTThrowException(msg.str());
} else {
cerr << msg.str() << endl;
cerr << "LeafNode::splitObjects: Anyway, continue..." << endl;
// sift the cluster IDs to start from 0 to continue.
for (int i = 0; i < fsize; i++) {
fs[i].clusterID -= cid;
}
}
}
long long *pivots = new long long[childrenSize];
for (int i = 0; i < childrenSize; i++) {
pivots[i] = -1;
}
// find the boundaries for the subspaces
for (int i = 0; i < fsize; i++) {
if (pivots[fs[i].clusterID] == -1) {
pivots[fs[i].clusterID] = i;
fs[i].leafDistance = Object::Pivot;
} else {
Distance d = iobj.vptree->objectSpace->getComparator()(*fs[pivots[fs[i].clusterID]].object, *fs[i].object);
fs[i].leafDistance = d;
}
}
delete[] pivots;
return;
}
void
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
LeafNode::removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator) {
#else
LeafNode::removeObject(size_t id, size_t replaceId) {
#endif
size_t fsize = getObjectSize();
size_t idx;
for (idx = 0; idx < fsize; idx++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (getObjectIDs(allocator)[idx].id == id) {
if (replaceId != 0) {
getObjectIDs(allocator)[idx].id = replaceId;
#else
if (getObjectIDs()[idx].id == id) {
if (replaceId != 0) {
getObjectIDs()[idx].id = replaceId;
#endif
return;
} else {
break;
}
}
}
if (idx == fsize) {
if (pivot == 0) {
NGTThrowException("LeafNode::removeObject: Internal error!. the pivot is illegal.");
}
stringstream msg;
msg << "VpTree::Leaf::remove: Warning. Cannot find the specified object. ID=" << id << "," << replaceId << " idx=" << idx << " If the same objects were inserted into the index, ignore this message.";
NGTThrowException(msg.str());
}
#ifdef NGT_NODE_USE_VECTOR
for (; idx < objectIDs.size() - 1; idx++) {
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
}
objectIDs.pop_back();
#else
objectSize--;
for (; idx < objectSize; idx++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getObjectIDs(allocator)[idx] = getObjectIDs(allocator)[idx + 1];
#else
getObjectIDs()[idx] = getObjectIDs()[idx + 1];
#endif
}
#endif
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool InternalNode::verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
SharedMemoryAllocator &allocator) {
#else
bool InternalNode::verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes) {
#endif
size_t isize = internalNodes.size();
size_t lsize = leafNodes.size();
bool valid = true;
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
size_t nid = getChildren(allocator)[i].getID();
ID::Type type = getChildren(allocator)[i].getType();
#else
size_t nid = getChildren()[i].getID();
ID::Type type = getChildren()[i].getType();
#endif
size_t size = type == ID::Leaf ? lsize : isize;
if (nid >= size) {
cerr << "Error! Internal children node id is too big." << nid << ":" << size << endl;
valid = false;
}
try {
if (type == ID::Leaf) {
leafNodes.get(nid);
} else {
internalNodes.get(nid);
}
} catch (...) {
cerr << "Error! Cannot get the node. " << ((type == ID::Leaf) ? "Leaf" : "Internal") << endl;
valid = false;
}
}
return valid;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status, SharedMemoryAllocator &allocator) {
#else
bool LeafNode::verify(size_t nobjs, vector<uint8_t> &status) {
#endif
bool valid = true;
for (size_t i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
size_t nid = getObjectIDs(allocator)[i].id;
#else
size_t nid = getObjectIDs()[i].id;
#endif
if (nid > nobjs) {
cerr << "Error! Object id is too big. " << nid << ":" << nobjs << endl;
valid =false;
continue;
}
status[nid] |= 0x04;
}
return valid;
}

View File

@ -1,779 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <algorithm>
#include <sstream>
#include "NGT/Common.h"
#include "NGT/ObjectSpaceRepository.h"
#include "NGT/defines.h"
namespace NGT {
class DVPTree;
class InternalNode;
class LeafNode;
class Node {
public:
typedef unsigned int NodeID;
class ID {
public:
enum Type {
Leaf = 1,
Internal = 0
};
ID():id(0) {}
ID &operator=(const ID &n) {
id = n.id;
return *this;
}
ID &operator=(int i) {
setID(i);
return *this;
}
bool operator==(ID &n) { return id == n.id; }
bool operator<(ID &n) { return id < n.id; }
Type getType() { return (Type)((0x80000000 & id) >> 31); }
NodeID getID() { return 0x7fffffff & id; }
NodeID get() { return id; }
void setID(NodeID i) { id = (0x80000000 & id) | i; }
void setType(Type t) { id = (t << 31) | getID(); }
void setRaw(NodeID i) { id = i; }
void setNull() { id = 0; }
// for milvus
void serialize(std::stringstream & os) { NGT::Serializer::write(os, id); }
void serialize(std::ofstream &os) { NGT::Serializer::write(os, id); }
void deserialize(std::ifstream &is) { NGT::Serializer::read(is, id); }
// for milvus
void deserialize(std::stringstream & is) { NGT::Serializer::read(is, id); }
void serializeAsText(std::ofstream &os) { NGT::Serializer::writeAsText(os, id); }
void deserializeAsText(std::ifstream &is) { NGT::Serializer::readAsText(is, id); }
virtual int64_t memSize() { return sizeof(id); }
protected:
NodeID id;
};
class Object {
public:
Object():object(0) {}
bool operator<(const Object &o) const { return distance < o.distance; }
virtual int64_t memSize() { return sizeof(*this) + object->memSize(); } // size of object cannot be decided accurately
static const double Pivot;
ObjectID id;
PersistentObject *object;
Distance distance;
Distance leafDistance;
int clusterID;
};
typedef std::vector<Object> Objects;
Node() {
parent.setNull();
id.setNull();
}
virtual ~Node() {}
Node &operator=(const Node &n) {
id = n.id;
parent = n.parent;
return *this;
}
// for milvus
void serialize(std::stringstream & os)
{
id.serialize(os);
parent.serialize(os);
}
void serialize(std::ofstream &os) {
id.serialize(os);
parent.serialize(os);
}
void deserialize(std::ifstream &is) {
id.deserialize(is);
parent.deserialize(is);
}
void deserialize(std::stringstream & is)
{
id.deserialize(is);
parent.deserialize(is);
}
void serializeAsText(std::ofstream &os) {
id.serializeAsText(os);
os << " ";
parent.serializeAsText(os);
}
void deserializeAsText(std::ifstream &is) {
id.deserializeAsText(is);
parent.deserializeAsText(is);
}
virtual int64_t memSize() { return id.memSize() * 2 + pivot->memSize(); }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void setPivot(PersistentObject &f, ObjectSpace &os, SharedMemoryAllocator &allocator) {
if (pivot == 0) {
pivot = NGT::PersistentObject::allocate(os);
}
getPivot(os).set(f, os);
}
PersistentObject &getPivot(ObjectSpace &os) {
return *(PersistentObject*)os.getRepository().getAllocator().getAddr(pivot);
}
void deletePivot(ObjectSpace &os, SharedMemoryAllocator &allocator) {
os.deleteObject(&getPivot(os));
}
#else // NGT_SHARED_MEMORY_ALLOCATOR
void setPivot(NGT::Object &f, ObjectSpace &os) {
if (pivot == 0) {
pivot = NGT::Object::allocate(os);
}
os.copy(getPivot(), f);
}
NGT::Object &getPivot() { return *pivot; }
void deletePivot(ObjectSpace &os) {
os.deleteObject(pivot);
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
bool pivotIsEmpty() {
return pivot == 0;
}
ID id;
ID parent;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t pivot;
#else
NGT::Object *pivot;
#endif
};
class InternalNode : public Node {
public:
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
InternalNode(size_t csize, SharedMemoryAllocator &allocator) : childrenSize(csize) { initialize(allocator); }
InternalNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(allocator); }
#else
InternalNode(size_t csize) : childrenSize(csize) { initialize(); }
InternalNode(NGT::ObjectSpace *os = 0) : childrenSize(5) { initialize(); }
#endif
~InternalNode() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
if (children != 0) {
delete[] children;
}
if (borders != 0) {
delete[] borders;
}
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void initialize(SharedMemoryAllocator &allocator) {
#else
void initialize() {
#endif
id = 0;
id.setType(ID::Internal);
pivot = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
children = allocator.getOffset(new(allocator) ID[childrenSize]);
#else
children = new ID[childrenSize];
#endif
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i] = 0;
#else
getChildren()[i] = 0;
#endif
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
borders = allocator.getOffset(new(allocator) Distance[childrenSize - 1]);
#else
borders = new Distance[childrenSize - 1];
#endif
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getBorders(allocator)[i] = 0;
#else
getBorders()[i] = 0;
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void updateChild(DVPTree &dvptree, ID src, ID dst, SharedMemoryAllocator &allocator);
#else
void updateChild(DVPTree &dvptree, ID src, ID dst);
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ID *getChildren(SharedMemoryAllocator &allocator) { return (ID*)allocator.getAddr(children); }
Distance *getBorders(SharedMemoryAllocator &allocator) { return (Distance*)allocator.getAddr(borders); }
#else // NGT_SHARED_MEMORY_ALLOCATOR
ID *getChildren() { return children; }
Distance *getBorders() { return borders; }
#endif // NGT_SHARED_MEMORY_ALLOCATOR
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
{
Node::serialize(os);
if (pivot == 0)
{
NGTThrowException("Node::write: pivot is null!");
}
assert(objectspace != 0);
getPivot().serialize(os, objectspace);
NGT::Serializer::write(os, childrenSize);
for (size_t i = 0; i < childrenSize; i++)
{
getChildren()[i].serialize(os);
}
for (size_t i = 0; i < childrenSize - 1; i++)
{
NGT::Serializer::write(os, getBorders()[i]);
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serialize(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serialize(os);
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).serialize(os, allocator, objectspace);
#else
getPivot().serialize(os, objectspace);
#endif
NGT::Serializer::write(os, childrenSize);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].serialize(os);
#else
getChildren()[i].serialize(os);
#endif
}
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::write(os, getBorders(allocator)[i]);
#else
NGT::Serializer::write(os, getBorders()[i]);
#endif
}
}
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
Node::deserialize(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
NGT::Serializer::read(is, childrenSize);
assert(children != 0);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
getChildren()[i].deserialize(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
NGT::Serializer::read(is, getBorders()[i]);
#endif
}
}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
{
Node::deserialize(is);
if (pivot == 0)
{
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
NGT::Serializer::read(is, childrenSize);
assert(children != 0);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
getChildren()[i].deserialize(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
NGT::Serializer::read(is, getBorders()[i]);
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serializeAsText(os);
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
os << " ";
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).serializeAsText(os, objectspace);
#else
getPivot().serializeAsText(os, objectspace);
#endif
os << " ";
NGT::Serializer::writeAsText(os, childrenSize);
os << " ";
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].serializeAsText(os);
#else
getChildren()[i].serializeAsText(os);
#endif
os << " ";
}
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::writeAsText(os, getBorders(allocator)[i]);
#else
NGT::Serializer::writeAsText(os, getBorders()[i]);
#endif
os << " ";
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
#endif
Node::deserializeAsText(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).deserializeAsText(is, objectspace);
#else
getPivot().deserializeAsText(is, objectspace);
#endif
size_t csize;
NGT::Serializer::readAsText(is, csize);
assert(children != 0);
assert(childrenSize == csize);
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getChildren(allocator)[i].deserializeAsText(is);
#else
getChildren()[i].deserializeAsText(is);
#endif
}
assert(borders != 0);
for (size_t i = 0; i < childrenSize - 1; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::Serializer::readAsText(is, getBorders(allocator)[i]);
#else
NGT::Serializer::readAsText(is, getBorders()[i]);
#endif
}
}
virtual int64_t memSize() { return sizeof(childrenSize) + children->memSize() + childrenSize * sizeof(Distance) + Node::memSize(); }
void show() {
std::cout << "Show internal node " << childrenSize << ":";
for (size_t i = 0; i < childrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
assert(0);
#else
std::cout << getChildren()[i].getID() << " ";
#endif
}
std::cout << std::endl;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool verify(PersistentRepository<InternalNode> &internalNodes, PersistentRepository<LeafNode> &leafNodes,
SharedMemoryAllocator &allocator);
#else
bool verify(Repository<InternalNode> &internalNodes, Repository<LeafNode> &leafNodes);
#endif
static const int InternalChildrenSizeMax = 5;
const size_t childrenSize;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t children;
off_t borders;
#else
ID *children;
Distance *borders;
#endif
};
class LeafNode : public Node {
public:
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
LeafNode(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {
#else
LeafNode(NGT::ObjectSpace *os = 0) {
#endif
id = 0;
id.setType(ID::Leaf);
pivot = 0;
#ifdef NGT_NODE_USE_VECTOR
objectIDs.reserve(LeafObjectsSizeMax);
#else
objectSize = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectIDs = allocator.getOffset(new(allocator) Object[LeafObjectsSizeMax]);
#else
objectIDs = new NGT::ObjectDistance[LeafObjectsSizeMax];
#endif
#endif
}
~LeafNode() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
#ifndef NGT_NODE_USE_VECTOR
if (objectIDs != 0) {
delete[] objectIDs;
}
#endif
#endif
}
static int
selectPivotByMaxDistance(Container &iobj, Node::Objects &fs);
static int
selectPivotByMaxVariance(Container &iobj, Node::Objects &fs);
static void
splitObjects(Container &insertedObject, Objects &splitObjectSet, int pivot);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void removeObject(size_t id, size_t replaceId, SharedMemoryAllocator &allocator);
#else
void removeObject(size_t id, size_t replaceId);
#endif
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
#ifndef NGT_NODE_USE_VECTOR
NGT::ObjectDistance *getObjectIDs(SharedMemoryAllocator &allocator) {
return (NGT::ObjectDistance *)allocator.getAddr(objectIDs);
}
#endif
#else // NGT_SHARED_MEMORY_ALLOCATOR
NGT::ObjectDistance *getObjectIDs() { return objectIDs; }
#endif // NGT_SHARED_MEMORY_ALLOCATOR
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objectspace = 0)
{
Node::serialize(os);
NGT::Serializer::write(os, objectSize);
for (int i = 0; i < objectSize; i++)
{
objectIDs[i].serialize(os);
}
if (pivot == 0)
{
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
if (parent.getID() != 0 || objectSize != 0)
{
NGTThrowException("Node::write: pivot is null!");
}
}
else
{
assert(objectspace != 0);
pivot->serialize(os, objectspace);
}
}
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) {
Node::serialize(os);
#ifdef NGT_NODE_USE_VECTOR
NGT::Serializer::write(os, objectIDs);
#else
NGT::Serializer::write(os, objectSize);
for (int i = 0; i < objectSize; i++) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::cerr << "not implemented" << std::endl;
assert(0);
#else
objectIDs[i].serialize(os);
#endif
}
#endif // NGT_NODE_USE_VECTOR
if (pivot == 0) {
// Before insertion, parent ID == 0 and object size == 0, that indicates an empty index
if (parent.getID() != 0 || objectSize != 0) {
NGTThrowException("Node::write: pivot is null!");
}
} else {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
std::cerr << "not implemented" << std::endl;
assert(0);
#else
assert(objectspace != 0);
pivot->serialize(os, objectspace);
#endif
}
}
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) {
Node::deserialize(is);
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::read(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::read(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getObjectIDs()[i].deserialize(is);
#endif
}
#endif
if (parent.getID() == 0 && objectSize == 0) {
// The index is empty
return;
}
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
assert(pivot != 0);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objectspace = 0)
{
Node::deserialize(is);
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::read(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::read(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getObjectIDs()[i].deserialize(is);
#endif
}
#endif
if (parent.getID() == 0 && objectSize == 0) {
// The index is empty
return;
}
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
assert(pivot != 0);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
getPivot().deserialize(is, objectspace);
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void serializeAsText(std::ofstream &os, ObjectSpace *objectspace = 0) {
#endif
Node::serializeAsText(os);
os << " ";
if (pivot == 0) {
NGTThrowException("Node::write: pivot is null!");
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
getPivot(*objectspace).serializeAsText(os, objectspace);
#else
assert(pivot != 0);
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
pivot->serializeAsText(os, allocator, objectspace);
#else
pivot->serializeAsText(os, objectspace);
#endif
#endif
os << " ";
#ifdef NGT_NODE_USE_VECTOR
NGT::Serializer::writeAsText(os, objectIDs);
#else
NGT::Serializer::writeAsText(os, objectSize);
for (int i = 0; i < objectSize; i++) {
os << " ";
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
getObjectIDs(allocator)[i].serializeAsText(os);
#else
objectIDs[i].serializeAsText(os);
#endif
}
#endif
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objectspace = 0) {
#else
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
#endif
Node::deserializeAsText(is);
if (pivot == 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
pivot = PersistentObject::allocate(*objectspace);
#else
pivot = PersistentObject::allocate(*objectspace);
#endif
}
assert(objectspace != 0);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getPivot(*objectspace).deserializeAsText(is, objectspace);
#else
getPivot().deserializeAsText(is, objectspace);
#endif
#ifdef NGT_NODE_USE_VECTOR
objectIDs.clear();
NGT::Serializer::readAsText(is, objectIDs);
#else
assert(objectIDs != 0);
NGT::Serializer::readAsText(is, objectSize);
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
getObjectIDs(allocator)[i].deserializeAsText(is);
#else
getObjectIDs()[i].deserializeAsText(is);
#endif
}
#endif
}
void show() {
std::cout << "Show leaf node " << objectSize << ":";
for (int i = 0; i < objectSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
std::cerr << "not implemented" << std::endl;
assert(0);
#else
std::cout << getObjectIDs()[i].id << "," << getObjectIDs()[i].distance << " ";
#endif
}
std::cout << std::endl;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
bool verify(size_t nobjs, std::vector<uint8_t> &status, SharedMemoryAllocator &allocator);
#else
bool verify(size_t nobjs, std::vector<uint8_t> &status);
#endif
virtual int64_t memSize() { return sizeof(objectSize) + objectIDs->memSize() * objectSize + Node::memSize(); }
#ifdef NGT_NODE_USE_VECTOR
size_t getObjectSize() { return objectIDs.size(); }
#else
size_t getObjectSize() { return objectSize; }
#endif
static const size_t LeafObjectsSizeMax = 100;
#ifdef NGT_NODE_USE_VECTOR
std::vector<Object> objectIDs;
#else
unsigned short objectSize;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
off_t objectIDs;
#else
ObjectDistance *objectIDs;
#endif
#endif
};
}

View File

@ -1,428 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sstream>
#include "defines.h"
namespace NGT {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class ObjectRepository :
public PersistentRepository<PersistentObject> {
public:
typedef PersistentRepository<PersistentObject> Parent;
void open(const std::string &smfile, size_t sharedMemorySize) {
std::string file = smfile;
file.append("po");
Parent::open(file, sharedMemorySize);
}
#else
class ObjectRepository : public Repository<Object> {
public:
typedef Repository<Object> Parent;
#endif
ObjectRepository(size_t dim, const std::type_info &ot):dimension(dim), type(ot), sparse(false) { }
void initialize() {
deleteAll();
Parent::push_back((PersistentObject*)0);
}
// for milvus
void serialize(std::stringstream & obj, ObjectSpace * ospace) { Parent::serialize(obj, ospace); }
void serialize(const std::string &ofile, ObjectSpace *ospace) {
std::ofstream objs(ofile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
NGTThrowException(msg);
}
Parent::serialize(objs, ospace);
}
void deserialize(std::stringstream & obj, ObjectSpace * ospace)
{
assert(ospace != 0);
Parent::deserialize(obj, ospace);
}
void deserialize(const std::string &ifile, ObjectSpace *ospace) {
assert(ospace != 0);
std::ifstream objs(ifile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
NGTThrowException(msg);
}
Parent::deserialize(objs, ospace);
}
void serializeAsText(const std::string &ofile, ObjectSpace *ospace) {
std::ofstream objs(ofile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ofile << ".";
NGTThrowException(msg);
}
Parent::serializeAsText(objs, ospace);
}
void deserializeAsText(const std::string &ifile, ObjectSpace *ospace) {
std::ifstream objs(ifile);
if (!objs.is_open()) {
std::stringstream msg;
msg << "NGT::ObjectSpace: Cannot open the specified file " << ifile << ".";
NGTThrowException(msg);
}
Parent::deserializeAsText(objs, ospace);
}
void readText(std::istream &is, size_t dataSize = 0) {
initialize();
appendText(is, dataSize);
}
// For milvus
template <typename T>
void readRawData(const T * raw_data, size_t dataSize)
{
initialize();
append(raw_data, dataSize);
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(double): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(float): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject(uint8_t): Fatal error! Something wrong!");
abort();
}
virtual PersistentObject *allocateNormalizedPersistentObject(const float *obj, size_t size) {
// std::cerr << "ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectRepository::allocateNormalizedPersistentObject: Fatal error! Something wrong!");
abort();
}
void appendText(std::istream &is, size_t dataSize = 0) {
if (dimension == 0) {
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
}
if (size() == 0) {
// First entry should be always a dummy entry.
// If it is empty, the dummy entry should be inserted.
push_back((PersistentObject*)0);
}
size_t prevDataSize = size();
if (dataSize > 0) {
reserve(size() + dataSize);
}
std::string line;
size_t lineNo = 0;
while (getline(is, line)) {
lineNo++;
if (dataSize > 0 && (dataSize <= size() - prevDataSize)) {
// std::cerr << "The size of data reached the specified size. The remaining data in the file are not inserted. "
// << dataSize << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("The size of data reached the specified size. The remaining data in the file are not inserted. "
+ std::to_string(dataSize));
break;
}
std::vector<double> object;
try {
extractObjectFromText(line, "\t ", object);
PersistentObject *obj = 0;
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
// std::cerr << err.what() << " continue..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
// std::cerr << "ObjectSpace::readText: Warning! Invalid line. [" << line << "] Skip the line " << lineNo << " and continue." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid line. [" + line + "] Skip the line " + std::to_string(lineNo) + " and continue.");
}
}
}
template <typename T>
void append(T *data, size_t objectCount) {
if (dimension == 0) {
NGTThrowException("ObjectSpace::readText: Dimension is not specified.");
}
if (size() == 0) {
// First entry should be always a dummy entry.
// If it is empty, the dummy entry should be inserted.
push_back((PersistentObject*)0);
}
if (objectCount > 0) {
reserve(size() + objectCount);
}
for (size_t idx = 0; idx < objectCount; idx++, data += dimension) {
std::vector<double> object;
object.reserve(dimension);
for (size_t dataidx = 0; dataidx < dimension; dataidx++) {
object.push_back(data[dataidx]);
}
try {
PersistentObject *obj = 0;
try {
obj = allocateNormalizedPersistentObject(object);
} catch (Exception &err) {
// std::cerr << err.what() << " continue..." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)(std::string(err.what()) + " continue...");
obj = allocatePersistentObject(object);
}
push_back(obj);
} catch (Exception &err) {
// std::cerr << "ObjectSpace::readText: Warning! Invalid data. Skip the data no. " << idx << " and continue." << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Invalid data. Skip the data no. " + std::to_string(idx) + " and continue.");
}
}
}
Object *allocateObject() {
return (Object*) new Object(paddedByteSize);
}
// This method is called during search to generate query.
// Therefore the object is not persistent.
Object *allocateObject(const std::string &textLine, const std::string &sep) {
std::vector<double> object;
extractObjectFromText(textLine, sep, object);
Object *po = (Object*)allocateObject(object);
return (Object*)po;
}
void extractObjectFromText(const std::string &textLine, const std::string &sep, std::vector<double> &object) {
object.resize(dimension);
std::vector<std::string> tokens;
NGT::Common::tokenize(textLine, tokens, sep);
if (dimension > tokens.size()) {
std::stringstream msg;
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":" << dimension << ". "
<< textLine;
NGTThrowException(msg);
}
size_t idx;
for (idx = 0; idx < dimension; idx++) {
if (tokens[idx].size() == 0) {
std::stringstream msg;
msg << "ObjectSpace::allocate: too few dimension. " << tokens.size() << ":"
<< dimension << ". " << textLine;
NGTThrowException(msg);
}
char *e;
object[idx] = strtod(tokens[idx].c_str(), &e);
if (*e != 0) {
// std::cerr << "ObjectSpace::readText: Warning! Not numerical value. [" << e << "]" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::readText: Warning! Not numerical value. [" + std::string(e) + "]");
break;
}
}
}
template <typename T>
Object *allocateObject(T *o, size_t size) {
size_t osize = paddedByteSize;
if (sparse) {
size_t vsize = size * (type == typeid(float) ? 4 : 1);
osize = osize < vsize ? vsize : osize;
} else {
if (dimension != size) {
// std::cerr << "ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
// << dimension << " The specified object=" << size << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocateObject: Fatal error! dimension is invalid. The indexed objects="
+ std::to_string(dimension) + " The specified object=" + std::to_string(size));
assert(dimension == size);
}
}
Object *po = new Object(osize);
void *object = static_cast<void*>(&(*po)[0]);
if (type == typeid(uint8_t)) {
uint8_t *obj = static_cast<uint8_t*>(object);
for (size_t i = 0; i < size; i++) {
obj[i] = static_cast<uint8_t>(o[i]);
}
} else if (type == typeid(float)) {
float *obj = static_cast<float*>(object);
for (size_t i = 0; i < size; i++) {
obj[i] = static_cast<float>(o[i]);
}
} else {
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
abort();
}
return po;
}
template <typename T>
Object *allocateObject(const std::vector<T> &o) {
return allocateObject(o.data(), o.size());
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentObject *allocatePersistentObject(Object &o) {
SharedMemoryAllocator &objectAllocator = getAllocator();
size_t cpsize = dimension;
if (type == typeid(uint8_t)) {
cpsize *= sizeof(uint8_t);
} else if (type == typeid(float)) {
cpsize *= sizeof(float);
} else {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
void *dsto = &(*po).at(0, allocator);
void *srco = &o[0];
memcpy(dsto, srco, cpsize);
return po;
}
template <typename T>
PersistentObject *allocatePersistentObject(T *o, size_t size) {
SharedMemoryAllocator &objectAllocator = getAllocator();
PersistentObject *po = new (objectAllocator) PersistentObject(objectAllocator, paddedByteSize);
if (size != 0 && dimension != size) {
std::stringstream msg;
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
NGTThrowException(msg);
}
void *object = static_cast<void*>(&(*po).at(0, allocator));
if (type == typeid(uint8_t)) {
uint8_t *obj = static_cast<uint8_t*>(object);
for (size_t i = 0; i < dimension; i++) {
obj[i] = static_cast<uint8_t>(o[i]);
}
} else if (type == typeid(float)) {
float *obj = static_cast<float*>(object);
for (size_t i = 0; i < dimension; i++) {
obj[i] = static_cast<float>(o[i]);
}
} else {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
abort();
}
return po;
}
template <typename T>
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
return allocatePersistentObject(o.data(), o.size());
}
#else
template <typename T>
PersistentObject *allocatePersistentObject(T *o, size_t size) {
if (size != 0 && dimension != size) {
std::stringstream msg;
msg << "ObjectSpace::allocatePersistentObject: Fatal error! The dimensionality is invalid. The specified dimensionality="
<< (sparse ? dimension - 1 : dimension) << ". The specified object=" << (sparse ? size - 1 : size) << ".";
NGTThrowException(msg);
}
return allocateObject(o, size);
}
template <typename T>
PersistentObject *allocatePersistentObject(const std::vector<T> &o) {
return allocatePersistentObject(o.data(), o.size());
}
#endif
void deleteObject(Object *po) {
delete po;
}
private:
void extractObject(void *object, std::vector<double> &d) {
if (type == typeid(uint8_t)) {
uint8_t *obj = (uint8_t*)object;
for (size_t i = 0; i < dimension; i++) {
d.push_back(obj[i]);
}
} else if (type == typeid(float)) {
float *obj = (float*)object;
for (size_t i = 0; i < dimension; i++) {
d.push_back(obj[i]);
}
} else {
// std::cerr << "ObjectSpace::allocate: Fatal error: unsupported type!" << std::endl;
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectSpace::allocate: Fatal error: unsupported type!");
abort();
}
}
public:
void extractObject(Object *o, std::vector<double> &d) {
void *object = (void*)(&(*o)[0]);
extractObject(object, d);
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void extractObject(PersistentObject *o, std::vector<double> &d) {
SharedMemoryAllocator &objectAllocator = getAllocator();
void *object = (void*)(&(*o).at(0, objectAllocator));
extractObject(object, d);
}
#endif
void setLength(size_t l) { byteSize = l; }
void setPaddedLength(size_t l) { paddedByteSize = l; }
void setSparse() { sparse = true; }
size_t getByteSize() { return byteSize; }
size_t insert(PersistentObject *obj) { return Parent::insert(obj); }
const size_t dimension;
const std::type_info &type;
protected:
size_t byteSize; // the length of all of elements.
size_t paddedByteSize;
bool sparse; // sparse data format
};
} // namespace NGT

View File

@ -1,496 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <cstring>
#include "PrimitiveComparator.h"
class ObjectSpace;
namespace NGT {
class PersistentObjectDistances;
class ObjectDistances : public std::vector<ObjectDistance> {
public:
ObjectDistances(NGT::ObjectSpace *os = 0) {}
// for milvus
void serialize(std::stringstream & os, ObjectSpace * objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance> &)*this); }
void serialize(std::ofstream &os, ObjectSpace *objspace = 0) { NGT::Serializer::write(os, (std::vector<ObjectDistance>&)*this);}
// for milvus
void deserialize(std::stringstream & is, ObjectSpace * objspace = 0)
{
NGT::Serializer::read(is, (std::vector<ObjectDistance> &)*this);
}
void deserialize(std::ifstream &is, ObjectSpace *objspace = 0) { NGT::Serializer::read(is, (std::vector<ObjectDistance>&)*this);}
void serializeAsText(std::ofstream &os, ObjectSpace *objspace = 0) {
NGT::Serializer::writeAsText(os, size());
os << " ";
for (size_t i = 0; i < size(); i++) {
(*this)[i].serializeAsText(os);
os << " ";
}
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objspace = 0) {
size_t s;
NGT::Serializer::readAsText(is, s);
resize(s);
for (size_t i = 0; i < size(); i++) {
(*this)[i].deserializeAsText(is);
}
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq) {
this->clear();
this->resize(pq.size());
for (int i = pq.size() - 1; i >= 0; i--) {
(*this)[i] = pq.top();
pq.pop();
}
assert(pq.size() == 0);
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, double (&f)(double)) {
this->clear();
this->resize(pq.size());
for (int i = pq.size() - 1; i >= 0; i--) {
(*this)[i] = pq.top();
(*this)[i].distance = f((*this)[i].distance);
pq.pop();
}
assert(pq.size() == 0);
}
void moveFrom(std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > &pq, unsigned int id) {
this->clear();
if (pq.size() == 0) {
return;
}
this->resize(id == 0 ? pq.size() : pq.size() - 1);
int i = this->size() - 1;
while (pq.size() != 0 && i >= 0) {
if (pq.top().id != id) {
(*this)[i] = pq.top();
i--;
}
pq.pop();
}
if (pq.size() != 0 && pq.top().id != id) {
std::cerr << "moveFrom: Fatal error: somethig wrong! " << pq.size() << ":" << this->size() << ":" << id << ":" << pq.top().id << std::endl;
assert(pq.size() == 0 || pq.top().id == id);
}
}
int64_t memSize() const {
// auto obj = (std::vector<ObjectDistance>)(*this);
if (this->size() == 0)
return 0;
else {
return (*this)[0].memSize() * this->size();
}
// return this->size() == 0 ? 0 : (*this)[0].memSize() * (this->size());
}
ObjectDistances &operator=(PersistentObjectDistances &objs);
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObjectDistances : public Vector<ObjectDistance> {
public:
PersistentObjectDistances(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0) {}
void serialize(std::ofstream &os, ObjectSpace *objectspace = 0) { NGT::Serializer::write(os, (Vector<ObjectDistance>&)*this); }
void deserialize(std::ifstream &is, ObjectSpace *objectspace = 0) { NGT::Serializer::read(is, (Vector<ObjectDistance>&)*this); }
void serializeAsText(std::ofstream &os, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
NGT::Serializer::writeAsText(os, size());
os << " ";
for (size_t i = 0; i < size(); i++) {
(*this).at(i, allocator).serializeAsText(os);
os << " ";
}
}
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator, ObjectSpace *objspace = 0) {
size_t s;
is >> s;
resize(s, allocator);
for (size_t i = 0; i < size(); i++) {
(*this).at(i, allocator).deserializeAsText(is);
}
}
PersistentObjectDistances &copy(ObjectDistances &objs, SharedMemoryAllocator &allocator) {
clear(allocator);
reserve(objs.size(), allocator);
for (ObjectDistances::iterator i = objs.begin(); i != objs.end(); i++) {
push_back(*i, allocator);
}
return *this;
}
};
typedef PersistentObjectDistances GraphNode;
inline ObjectDistances &ObjectDistances::operator=(PersistentObjectDistances &objs)
{
clear();
reserve(objs.size());
std::cerr << "not implemented" << std::endl;
assert(0);
return *this;
}
#else // NGT_SHARED_MEMORY_ALLOCATOR
typedef ObjectDistances GraphNode;
#endif // NGT_SHARED_MEMORY_ALLOCATOR
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObject;
#else
typedef Object PersistentObject;
#endif
class ObjectRepository;
class ObjectSpace {
public:
class Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Comparator(size_t d, SharedMemoryAllocator &a) : dimension(d), allocator(a) {}
#else
Comparator(size_t d) : dimension(d) {}
#endif
virtual double operator()(Object &objecta, Object &objectb) = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
virtual double operator()(Object &objecta, PersistentObject &objectb) = 0;
virtual double operator()(PersistentObject &objecta, PersistentObject &objectb) = 0;
#endif
size_t dimension;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
SharedMemoryAllocator &allocator;
#endif
virtual ~Comparator(){}
int64_t memSize() { return sizeof(size_t); }
};
enum DistanceType {
DistanceTypeNone = -1,
DistanceTypeL1 = 0,
DistanceTypeL2 = 1,
DistanceTypeHamming = 2,
DistanceTypeAngle = 3,
DistanceTypeCosine = 4,
DistanceTypeNormalizedAngle = 5,
DistanceTypeNormalizedCosine = 6,
DistanceTypeJaccard = 7,
DistanceTypeSparseJaccard = 8,
DistanceTypeIP = 9
};
enum ObjectType {
ObjectTypeNone = 0,
Uint8 = 1,
Float = 2
};
typedef std::priority_queue<ObjectDistance, std::vector<ObjectDistance>, std::less<ObjectDistance> > ResultSet;
ObjectSpace(size_t d):dimension(d), distanceType(DistanceTypeNone), comparator(0), normalization(false) {}
virtual ~ObjectSpace() { if (comparator != 0) { delete comparator; } }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
virtual void open(const std::string &f, size_t shareMemorySize) = 0;
virtual Object *allocateObject(Object &o) = 0;
virtual Object *allocateObject(PersistentObject &o) = 0;
virtual PersistentObject *allocatePersistentObject(Object &obj) = 0;
virtual void deleteObject(PersistentObject *) = 0;
virtual void copy(PersistentObject &objecta, PersistentObject &objectb) = 0;
virtual void show(std::ostream &os, PersistentObject &object) = 0;
virtual size_t insert(PersistentObject *obj) = 0;
#else
virtual size_t insert(Object *obj) = 0;
#endif
Comparator &getComparator() { return *comparator; }
virtual void serialize(const std::string &of) = 0;
// for milvus
virtual void serialize(std::stringstream & obj) = 0;
// for milvus
virtual void deserialize(std::stringstream & obj) = 0;
virtual void deserialize(const std::string &ifile) = 0;
virtual void serializeAsText(const std::string &of) = 0;
virtual void deserializeAsText(const std::string &of) = 0;
//for milvus
virtual void readRawData(const float * raw_data, size_t dataSize) = 0;
virtual void readText(std::istream &is, size_t dataSize) = 0;
virtual void appendText(std::istream &is, size_t dataSize) = 0;
virtual void append(const float *data, size_t dataSize) = 0;
virtual void append(const double *data, size_t dataSize) = 0;
virtual void copy(Object &objecta, Object &objectb) = 0;
virtual void linearSearch(Object &query, double radius, size_t size,
ObjectSpace::ResultSet &results) = 0;
virtual const std::type_info &getObjectType() = 0;
virtual void show(std::ostream &os, Object &object) = 0;
virtual size_t getSize() = 0;
virtual size_t getSizeOfElement() = 0;
virtual size_t getByteSizeOfObject() = 0;
virtual Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) = 0;
virtual Object *allocateNormalizedObject(const std::vector<double> &obj) = 0;
virtual Object *allocateNormalizedObject(const std::vector<float> &obj) = 0;
virtual Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) = 0;
virtual Object *allocateNormalizedObject(const float *obj, size_t size) = 0;
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) = 0;
virtual PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) = 0;
virtual void deleteObject(Object *po) = 0;
virtual Object *allocateObject() = 0;
virtual void remove(size_t id) = 0;
virtual ObjectRepository &getRepository() = 0;
virtual void setDistanceType(DistanceType t) = 0;
virtual DistanceType getDistanceType() = 0;
virtual void *getObject(size_t idx) = 0;
virtual void getObject(size_t idx, std::vector<float> &v) = 0;
virtual void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) = 0;
size_t getDimension() { return dimension; }
size_t getPaddedDimension() { return ((dimension - 1) / 16 + 1) * 16; }
virtual int64_t memSize() { return sizeof(dimension) + sizeof(distanceType) + sizeof(prefetchOffset) * 2 + sizeof(normalization) + comparator->memSize(); };
template <typename T>
void normalize(T *data, size_t dim) {
double sum = 0.0;
for (size_t i = 0; i < dim; i++) {
sum += (double)data[i] * (double)data[i];
}
if (sum == 0.0) {
std::stringstream msg;
msg << "ObjectSpace::normalize: Error! the object is an invalid zero vector for the cosine similarity or angle distance.";
NGTThrowException(msg);
}
sum = sqrt(sum);
for (size_t i = 0; i < dim; i++) {
data[i] = (double)data[i] / sum;
}
}
uint32_t getPrefetchOffset() { return prefetchOffset; }
uint32_t setPrefetchOffset(size_t offset) {
if (offset == 0) {
prefetchOffset = floor(300.0 / (static_cast<float>(getPaddedDimension()) + 30.0) + 1.0);
} else {
prefetchOffset = offset;
}
return prefetchOffset;
}
uint32_t getPrefetchSize() { return prefetchSize; }
uint32_t setPrefetchSize(size_t size) {
if (size == 0) {
prefetchSize = getByteSizeOfObject();
} else {
prefetchSize = size;
}
return prefetchSize;
}
protected:
const size_t dimension;
DistanceType distanceType;
Comparator *comparator;
bool normalization;
uint32_t prefetchOffset;
uint32_t prefetchSize;
};
class BaseObject {
public:
virtual uint8_t &operator[](size_t idx) const = 0;
void serialize(std::ostream &os, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
if(objectspace == 0) return; // make compiler happy;
size_t byteSize = objectspace->getByteSizeOfObject();
NGT::Serializer::write(os, (uint8_t*)&(*this)[0], byteSize);
}
void deserialize(std::istream &is, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
if(objectspace == 0) return; // make compiler happy;
size_t byteSize = objectspace->getByteSizeOfObject();
assert(&(*this)[0] != 0);
NGT::Serializer::read(is, (uint8_t*)&(*this)[0], byteSize);
}
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
if(objectspace == 0) return; // make compiler happy;
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = (void*)&(*this)[0];
if (t == typeid(uint8_t)) {
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0) {
assert(objectspace != 0);
if(objectspace == 0) return;
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = (void*)&(*this)[0];
assert(ref != 0);
if (t == typeid(uint8_t)) {
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::readAsText(is, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::readAsText(is, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
} else {
std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
};
class Object : public BaseObject {
public:
Object(NGT::ObjectSpace *os = 0):vector(0) {
assert(os != 0);
if(os == 0) return;
size_t s = os->getByteSizeOfObject();
construct(s);
}
Object(size_t s):vector(0) {
assert(s != 0);
construct(s);
}
void copy(Object &o, size_t s) {
assert(vector != 0);
for (size_t i = 0; i < s; i++) {
vector[i] = o[i];
}
}
virtual ~Object() { clear(); }
uint8_t &operator[](size_t idx) const { return vector[idx]; }
void *getPointer(size_t idx = 0) const { return vector + idx; }
static Object *allocate(ObjectSpace &objectspace) { return new Object(&objectspace); }
virtual int64_t memSize() { return std::strlen((char*)vector); }
private:
void clear() {
if (vector != 0) {
MemoryCache::alignedFree(vector);
}
vector = 0;
}
void construct(size_t s) {
assert(vector == 0);
size_t allocsize = ((s - 1) / 64 + 1) * 64;
vector = static_cast<uint8_t*>(MemoryCache::alignedAlloc(allocsize));
memset(vector, 0, allocsize);
}
uint8_t* vector;
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
class PersistentObject : public BaseObject {
public:
PersistentObject(SharedMemoryAllocator &allocator, NGT::ObjectSpace *os = 0):array(0) {
assert(os != 0);
size_t s = os->getByteSizeOfObject();
construct(s, allocator);
}
PersistentObject(SharedMemoryAllocator &allocator, size_t s):array(0) {
assert(s != 0);
construct(s, allocator);
}
~PersistentObject() {}
uint8_t &at(size_t idx, SharedMemoryAllocator &allocator) const {
uint8_t *a = (uint8_t *)allocator.getAddr(array);
return a[idx];
}
uint8_t &operator[](size_t idx) const {
std::cerr << "not implemented" << std::endl;
assert(0);
uint8_t *a = 0;
return a[idx];
}
void *getPointer(size_t idx, SharedMemoryAllocator &allocator) {
uint8_t *a = (uint8_t *)allocator.getAddr(array);
return a + idx;
}
// set v in objectspace to this object using allocator.
void set(PersistentObject &po, ObjectSpace &objectspace);
static off_t allocate(ObjectSpace &objectspace);
void serializeAsText(std::ostream &os, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
serializeAsText(os, objectspace);
}
void serializeAsText(std::ostream &os, ObjectSpace *objectspace = 0);
void deserializeAsText(std::ifstream &is, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
deserializeAsText(is, objectspace);
}
void deserializeAsText(std::ifstream &is, ObjectSpace *objectspace = 0);
void serialize(std::ostream &os, SharedMemoryAllocator &allocator,
ObjectSpace *objectspace = 0) {
std::cerr << "serialize is not implemented" << std::endl;
assert(0);
}
private:
void construct(size_t s, SharedMemoryAllocator &allocator) {
assert(array == 0);
assert(s != 0);
size_t allocsize = ((s - 1) / 64 + 1) * 64;
array = allocator.getOffset(new(allocator) uint8_t[allocsize]);
memset(getPointer(0, allocator), 0, allocsize);
}
off_t array;
};
#endif // NGT_SHARED_MEMORY_ALLOCATOR
}

View File

@ -1,652 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <sstream>
#include "Common.h"
#include "ObjectSpace.h"
#include "ObjectRepository.h"
#include "PrimitiveComparator.h"
class ObjectSpace;
namespace NGT {
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
class ObjectSpaceRepository : public ObjectSpace, public ObjectRepository {
public:
class ComparatorL1 : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorL1(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorL1(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL1((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorL2 : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorL2(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorL2(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareL2((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorIP : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorIP(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareIP((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorIP(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareInnerProduct((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorHammingDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorHammingDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorHammingDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareHammingDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorJaccardDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorJaccardDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorSparseJaccardDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorSparseJaccardDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorSparseJaccardDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareSparseJaccardDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorAngleDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorAngleDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorNormalizedAngleDistance : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorNormalizedAngleDistance(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorNormalizedAngleDistance(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorCosineSimilarity : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorCosineSimilarity(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
class ComparatorNormalizedCosineSimilarity : public Comparator {
public:
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
ComparatorNormalizedCosineSimilarity(size_t d, SharedMemoryAllocator &a) : Comparator(d, a) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
double operator()(Object &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
double operator()(PersistentObject &objecta, PersistentObject &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta.at(0, allocator), (OBJECT_TYPE*)&objectb.at(0, allocator), dimension);
}
#else
ComparatorNormalizedCosineSimilarity(size_t d) : Comparator(d) {}
double operator()(Object &objecta, Object &objectb) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension);
}
#endif
};
ObjectSpaceRepository(size_t d, const std::type_info &ot, DistanceType t) : ObjectSpace(d), ObjectRepository(d, ot) {
size_t objectSize = 0;
if (ot == typeid(uint8_t)) {
objectSize = sizeof(uint8_t);
} else if (ot == typeid(float)) {
objectSize = sizeof(float);
} else {
std::stringstream msg;
msg << "ObjectSpace::constructor: Not supported type. " << ot.name();
NGTThrowException(msg);
}
setLength(objectSize * d);
setPaddedLength(objectSize * ObjectSpace::getPaddedDimension());
setDistanceType(t);
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &f, size_t sharedMemorySize) { ObjectRepository::open(f, sharedMemorySize); }
void copy(PersistentObject &objecta, PersistentObject &objectb) { objecta = objectb; }
void show(std::ostream &os, PersistentObject &object) {
const std::type_info &t = getObjectType();
if (t == typeid(uint8_t)) {
unsigned char *optr = static_cast<unsigned char*>(&object.at(0,allocator));
for (size_t i = 0; i < getDimension(); i++) {
os << (int)optr[i] << " ";
}
} else if (t == typeid(float)) {
float *optr = reinterpret_cast<float*>(&object.at(0,allocator));
for (size_t i = 0; i < getDimension(); i++) {
os << optr[i] << " ";
}
} else {
os << " not implement for the type.";
}
}
Object *allocateObject(Object &o) {
Object *po = new Object(getByteSizeOfObject());
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
(*po)[i] = o[i];
}
return po;
}
Object *allocateObject(PersistentObject &o) {
PersistentObject &spo = (PersistentObject &)o;
Object *po = new Object(getByteSizeOfObject());
for (size_t i = 0; i < getByteSizeOfObject(); i++) {
(*po)[i] = spo.at(i,ObjectRepository::allocator);
}
return (Object*)po;
}
void deleteObject(PersistentObject *po) {
delete po;
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
void copy(Object &objecta, Object &objectb) {
objecta.copy(objectb, getByteSizeOfObject());
}
DistanceType getDistanceType() { return distanceType; }
void setDistanceType(DistanceType t) {
if (comparator != 0) {
delete comparator;
}
assert(ObjectSpace::dimension != 0);
distanceType = t;
switch (distanceType) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
case DistanceTypeL1:
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeL2:
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeHamming:
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeJaccard:
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeSparseJaccard:
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
setSparse();
break;
case DistanceTypeAngle:
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeCosine:
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
break;
case DistanceTypeNormalizedAngle:
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
normalization = true;
break;
case DistanceTypeNormalizedCosine:
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension(), ObjectRepository::allocator);
normalization = true;
break;
#else
case DistanceTypeL1:
comparator = new ObjectSpaceRepository::ComparatorL1(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeL2:
comparator = new ObjectSpaceRepository::ComparatorL2(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeIP:
comparator = new ObjectSpaceRepository::ComparatorIP(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeHamming:
comparator = new ObjectSpaceRepository::ComparatorHammingDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeJaccard:
comparator = new ObjectSpaceRepository::ComparatorJaccardDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeSparseJaccard:
comparator = new ObjectSpaceRepository::ComparatorSparseJaccardDistance(ObjectSpace::getPaddedDimension());
setSparse();
break;
case DistanceTypeAngle:
comparator = new ObjectSpaceRepository::ComparatorAngleDistance(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeCosine:
comparator = new ObjectSpaceRepository::ComparatorCosineSimilarity(ObjectSpace::getPaddedDimension());
break;
case DistanceTypeNormalizedAngle:
comparator = new ObjectSpaceRepository::ComparatorNormalizedAngleDistance(ObjectSpace::getPaddedDimension());
normalization = true;
break;
case DistanceTypeNormalizedCosine:
comparator = new ObjectSpaceRepository::ComparatorNormalizedCosineSimilarity(ObjectSpace::getPaddedDimension());
normalization = true;
break;
#endif
default:
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Distance type is not specified");
// std::cerr << "Distance type is not specified" << std::endl;
assert(distanceType != DistanceTypeNone);
abort();
}
}
void serialize(const std::string & ofile) { ObjectRepository::serialize(ofile, this); }
// for milvus
void serialize(std::stringstream & obj) { ObjectRepository::serialize(obj, this); }
// for milvus
void deserialize(std::stringstream & obj) { ObjectRepository::deserialize(obj, this); }
void deserialize(const std::string &ifile) { ObjectRepository::deserialize(ifile, this); }
void serializeAsText(const std::string &ofile) { ObjectRepository::serializeAsText(ofile, this); }
void deserializeAsText(const std::string &ifile) { ObjectRepository::deserializeAsText(ifile, this); }
// For milvus
void readRawData(const float * raw_data, size_t dataSize) { ObjectRepository::readRawData<float>(raw_data, dataSize); }
void readText(std::istream &is, size_t dataSize) { ObjectRepository::readText(is, dataSize); }
void appendText(std::istream &is, size_t dataSize) { ObjectRepository::appendText(is, dataSize); }
void append(const float *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
void append(const double *data, size_t dataSize) { ObjectRepository::append(data, dataSize); }
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentObject *allocatePersistentObject(Object &obj) {
return ObjectRepository::allocatePersistentObject(obj);
}
size_t insert(PersistentObject *obj) { return ObjectRepository::insert(obj); }
#else
size_t insert(Object *obj) { return ObjectRepository::insert(obj); }
#endif
void remove(size_t id) { ObjectRepository::remove(id); }
void linearSearch(Object &query, double radius, size_t size, ObjectSpace::ResultSet &results) {
if (!results.empty()) {
NGTThrowException("lenearSearch: results is not empty");
}
#ifndef NGT_PREFETCH_DISABLED
size_t byteSizeOfObject = getByteSizeOfObject();
const size_t prefetchOffset = getPrefetchOffset();
#endif
ObjectRepository &rep = *this;
for (size_t idx = 0; idx < rep.size(); idx++) {
#ifndef NGT_PREFETCH_DISABLED
if (idx + prefetchOffset < rep.size() && rep[idx + prefetchOffset] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(ObjectRepository::get(idx + prefetchOffset))), byteSizeOfObject);
#else
MemoryCache::prefetch((unsigned char*)&(*static_cast<PersistentObject*>(rep[idx + prefetchOffset]))[0], byteSizeOfObject);
#endif
}
#endif
if (rep[idx] == 0) {
continue;
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Distance d = (*comparator)((Object&)query, (PersistentObject&)*rep[idx]);
#else
Distance d = (*comparator)((Object&)query, (Object&)*rep[idx]);
#endif
if (radius < 0.0 || d <= radius) {
NGT::ObjectDistance obj(idx, d);
results.push(obj);
if (results.size() > size) {
results.pop();
}
}
}
return;
}
void *getObject(size_t idx) {
if (isEmpty(idx)) {
std::stringstream msg;
msg << "NGT::ObjectSpaceRepository: The specified ID is out of the range. The object ID should be greater than zero. " << idx << ":" << ObjectRepository::size() << ".";
NGTThrowException(msg);
}
PersistentObject &obj = *(*this)[idx];
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
return reinterpret_cast<OBJECT_TYPE*>(&obj.at(0, allocator));
#else
return reinterpret_cast<OBJECT_TYPE*>(&obj[0]);
#endif
}
void getObject(size_t idx, std::vector<float> &v) {
OBJECT_TYPE *obj = static_cast<OBJECT_TYPE*>(getObject(idx));
size_t dim = getDimension();
v.resize(dim);
for (size_t i = 0; i < dim; i++) {
v[i] = static_cast<float>(obj[i]);
}
}
void getObjects(const std::vector<size_t> &idxs, std::vector<std::vector<float>> &vs) {
vs.resize(idxs.size());
auto v = vs.begin();
for (auto idx = idxs.begin(); idx != idxs.end(); idx++, v++) {
getObject(*idx, *v);
}
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void normalize(PersistentObject &object) {
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object.at(0, getRepository().getAllocator());
ObjectSpace::normalize(obj, ObjectSpace::dimension);
}
#endif
void normalize(Object &object) {
OBJECT_TYPE *obj = (OBJECT_TYPE*)&object[0];
ObjectSpace::normalize(obj, ObjectSpace::dimension);
}
Object *allocateObject() { return ObjectRepository::allocateObject(); }
void deleteObject(Object *po) { ObjectRepository::deleteObject(po); }
Object *allocateNormalizedObject(const std::string &textLine, const std::string &sep) {
Object *allocatedObject = ObjectRepository::allocateObject(textLine, sep);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<double> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<float> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const std::vector<uint8_t> &obj) {
Object *allocatedObject = ObjectRepository::allocateObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
Object *allocateNormalizedObject(const float *obj, size_t size) {
Object *allocatedObject = ObjectRepository::allocateObject(obj, size);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<double> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<float> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
PersistentObject *allocateNormalizedPersistentObject(const std::vector<uint8_t> &obj) {
PersistentObject *allocatedObject = ObjectRepository::allocatePersistentObject(obj);
if (normalization) {
normalize(*allocatedObject);
}
return allocatedObject;
}
size_t getSize() { return ObjectRepository::size(); }
size_t getSizeOfElement() { return sizeof(OBJECT_TYPE); }
const std::type_info &getObjectType() { return typeid(OBJECT_TYPE); };
size_t getByteSizeOfObject() { return getByteSize(); }
ObjectRepository &getRepository() { return *this; };
void show(std::ostream &os, Object &object) {
const std::type_info &t = getObjectType();
if (t == typeid(uint8_t)) {
unsigned char *optr = static_cast<unsigned char*>(&object[0]);
for (size_t i = 0; i < getDimension(); i++) {
os << (int)optr[i] << " ";
}
} else if (t == typeid(float)) {
float *optr = reinterpret_cast<float*>(&object[0]);
for (size_t i = 0; i < getDimension(); i++) {
os << optr[i] << " ";
}
} else {
os << " not implement for the type.";
}
}
};
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
// set v in objectspace to this object using allocator.
inline void PersistentObject::set(PersistentObject &po, ObjectSpace &objectspace) {
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
uint8_t *src = (uint8_t *)&po.at(0, allocator);
uint8_t *dst = (uint8_t *)&(*this).at(0, allocator);
memcpy(dst, src, objectspace.getByteSizeOfObject());
}
inline off_t PersistentObject::allocate(ObjectSpace &objectspace) {
SharedMemoryAllocator &allocator = objectspace.getRepository().getAllocator();
return allocator.getOffset(new(allocator) PersistentObject(allocator, &objectspace));
}
inline void PersistentObject::serializeAsText(std::ostream &os, ObjectSpace *objectspace) {
assert(objectspace != 0);
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
const std::type_info &t = objectspace->getObjectType();
void *ref = &(*this).at(0, allocator);
size_t dimension = objectspace->getDimension();
if (t == typeid(uint8_t)) {
NGT::Serializer::writeAsText(os, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::writeAsText(os, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::writeAsText(os, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::writeAsText(os, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::writeAsText(os, (uint32_t*)ref, dimension);
} else {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("ObjectT::serializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
// std::cerr << "ObjectT::serializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
inline void PersistentObject::deserializeAsText(std::ifstream &is, ObjectSpace *objectspace) {
assert(objectspace != 0);
SharedMemoryAllocator &allocator = objectspace->getRepository().getAllocator();
const std::type_info &t = objectspace->getObjectType();
size_t dimension = objectspace->getDimension();
void *ref = &(*this).at(0, allocator);
assert(ref != 0);
if (t == typeid(uint8_t)) {
NGT::Serializer::readAsText(is, (uint8_t*)ref, dimension);
} else if (t == typeid(float)) {
NGT::Serializer::readAsText(is, (float*)ref, dimension);
} else if (t == typeid(double)) {
NGT::Serializer::readAsText(is, (double*)ref, dimension);
} else if (t == typeid(uint16_t)) {
NGT::Serializer::readAsText(is, (uint16_t*)ref, dimension);
} else if (t == typeid(uint32_t)) {
NGT::Serializer::readAsText(is, (uint32_t*)ref, dimension);
} else {
if (NGT_LOG_DEBUG_)
(*NGT_LOG_DEBUG_)("Object::deserializeAsText: not supported data type. [" + std::to_string(t.name()) + "]");
// std::cerr << "Object::deserializeAsText: not supported data type. [" << t.name() << "]" << std::endl;
assert(0);
}
}
#endif
} // namespace NGT

File diff suppressed because it is too large Load Diff

View File

@ -1,840 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/defines.h"
#if defined(NGT_NO_AVX)
// #warning "*** SIMD is *NOT* available! ***"
#else
#include <immintrin.h>
#endif
namespace NGT {
class MemoryCache {
public:
inline static void
prefetch(unsigned char* ptr, const size_t byteSizeOfObject) {
#if !defined(NGT_NO_AVX)
switch ((byteSizeOfObject - 1) >> 6) {
default:
case 28:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 27:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 26:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 25:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 24:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 23:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 22:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 21:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 20:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 19:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 18:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 17:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 16:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 15:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 14:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 13:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 12:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 11:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 10:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 9:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 8:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 7:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 6:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 5:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 4:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 3:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 2:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 1:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
case 0:
_mm_prefetch(ptr, _MM_HINT_T0);
ptr += 64;
break;
}
#endif
}
inline static void*
alignedAlloc(const size_t allocSize) {
#ifdef NGT_NO_AVX
return new uint8_t[allocSize];
#else
#if defined(NGT_AVX512)
size_t alignment = 64;
uint64_t mask = 0xFFFFFFFFFFFFFFC0;
#elif defined(NGT_AVX2)
size_t alignment = 32;
uint64_t mask = 0xFFFFFFFFFFFFFFE0;
#else
size_t alignment = 16;
uint64_t mask = 0xFFFFFFFFFFFFFFF0;
#endif
uint8_t* p = new uint8_t[allocSize + alignment];
uint8_t* ptr = p + alignment;
ptr = reinterpret_cast<uint8_t*>((reinterpret_cast<uint64_t>(ptr) & mask));
*p++ = 0xAB;
while (p != ptr) *p++ = 0xCD;
return ptr;
#endif
}
inline static void
alignedFree(void* ptr) {
#ifdef NGT_NO_AVX
delete[] static_cast<uint8_t*>(ptr);
#else
uint8_t* p = static_cast<uint8_t*>(ptr);
p--;
while (*p == 0xCD) p--;
if (*p != 0xAB) {
NGTThrowException("MemoryCache::alignedFree: Fatal Error! Cannot find allocated address.");
}
delete[] p;
#endif
}
};
class PrimitiveComparator {
public:
static double
absolute(double v) {
return fabs(v);
}
static int
absolute(int v) {
return abs(v);
}
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
inline static double
compareL2(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const OBJECT_TYPE* last = a + size;
const OBJECT_TYPE* lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;
}
while (a < last) {
diff0 = (COMPARE_TYPE)(*a++ - *b++);
d += diff0 * diff0;
}
// return sqrt((double)d);
return d;
}
inline static double
compareL2(const uint8_t* a, const uint8_t* b, size_t size) {
return compareL2<uint8_t, int>(a, b, size);
}
inline static double
compareL2(const float* a, const float* b, size_t size) {
return compareL2<float, double>(a, b, size);
}
#else
inline static double
compareL2(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
__m512 v = _mm512_sub_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b));
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v, v));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
__m256 v;
while (a < last) {
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
a += 8;
b += 8;
v = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v, v));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
__m128 v;
while (a < last) {
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
v = _mm_sub_ps(_mm_loadu_ps(a), _mm_loadu_ps(b));
sum128 = _mm_add_ps(sum128, _mm_mul_ps(v, v));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
// return sqrt(s);
return s;
}
inline static double
compareL2(const unsigned char* a, const unsigned char* b, size_t size) {
__m128 sum = _mm_setzero_ps();
const unsigned char* last = a + size;
const unsigned char* lastgroup = last - 7;
const __m128i zero = _mm_setzero_si128();
while (a < lastgroup) {
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
x1 = _mm_subs_epi16(x1, x2);
__m128i v = _mm_mullo_epi16(x1, x1);
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(v, zero)));
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(v, zero)));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3];
while (a < last) {
int d = (int)*a++ - (int)*b++;
s += d * d;
}
// return sqrt(s);
return s;
}
#endif
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
static double
compareL1(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const OBJECT_TYPE* last = a + size;
const OBJECT_TYPE* lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3);
a += 4;
b += 4;
}
while (a < last) {
diff0 = (COMPARE_TYPE)*a++ - (COMPARE_TYPE)*b++;
d += absolute(diff0);
}
return d;
}
inline static double
compareL1(const uint8_t* a, const uint8_t* b, size_t size) {
return compareL1<uint8_t, int>(a, b, size);
}
inline static double
compareL1(const float* a, const float* b, size_t size) {
return compareL1<float, double>(a, b, size);
}
#else
inline static double
compareL1(const float* a, const float* b, size_t size) {
__m256 sum = _mm256_setzero_ps();
const float* last = a + size;
const float* lastgroup = last - 7;
while (a < lastgroup) {
__m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b));
const __m256 mask = _mm256_set1_ps(-0.0f);
__m256 v = _mm256_andnot_ps(mask, x1);
sum = _mm256_add_ps(sum, v);
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[8];
_mm256_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7];
while (a < last) {
double d = fabs(*a++ - *b++);
s += d;
}
return s;
}
inline static double
compareL1(const unsigned char* a, const unsigned char* b, size_t size) {
__m128 sum = _mm_setzero_ps();
const unsigned char* last = a + size;
const unsigned char* lastgroup = last - 7;
const __m128i zero = _mm_setzero_si128();
while (a < lastgroup) {
__m128i x1 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)a));
__m128i x2 = _mm_cvtepu8_epi16(_mm_loadu_si128((__m128i const*)b));
x1 = _mm_subs_epi16(x1, x2);
x1 = _mm_sign_epi16(x1, x1);
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpacklo_epi16(x1, zero)));
sum = _mm_add_ps(sum, _mm_cvtepi32_ps(_mm_unpackhi_epi16(x1, zero)));
a += 8;
b += 8;
}
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum);
double s = f[0] + f[1] + f[2] + f[3];
while (a < last) {
double d = fabs((double)*a++ - (double)*b++);
s += d;
}
return s;
}
#endif
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
inline static double
popCount(uint32_t x) {
x = (x & 0x55555555) + (x >> 1 & 0x55555555);
x = (x & 0x33333333) + (x >> 2 & 0x33333333);
x = (x & 0x0F0F0F0F) + (x >> 4 & 0x0F0F0F0F);
x = (x & 0x00FF00FF) + (x >> 8 & 0x00FF00FF);
x = (x & 0x0000FFFF) + (x >> 16 & 0x0000FFFF);
return x;
}
template <typename OBJECT_TYPE>
inline static double
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
while (uinta < last) {
count += popCount(*uinta++ ^ *uintb++);
}
return static_cast<double>(count);
}
#else
template <typename OBJECT_TYPE>
inline static double
compareHammingDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
size_t count = 0;
while (uinta < last) {
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
count += _mm_popcnt_u64(*uinta++ ^ *uintb++);
}
return static_cast<double>(count);
}
#endif
#if defined(NGT_NO_AVX) || !defined(__POPCNT__)
template <typename OBJECT_TYPE>
inline static double
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint32_t* last = reinterpret_cast<const uint32_t*>(a + size);
const uint32_t* uinta = reinterpret_cast<const uint32_t*>(a);
const uint32_t* uintb = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
size_t countDe = 0;
while (uinta < last) {
count += popCount(*uinta & *uintb);
countDe += popCount(*uinta++ | *uintb++);
count += popCount(*uinta & *uintb);
countDe += popCount(*uinta++ | *uintb++);
}
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
}
#else
template <typename OBJECT_TYPE>
inline static double
compareJaccardDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
const uint64_t* last = reinterpret_cast<const uint64_t*>(a + size);
const uint64_t* uinta = reinterpret_cast<const uint64_t*>(a);
const uint64_t* uintb = reinterpret_cast<const uint64_t*>(b);
size_t count = 0;
size_t countDe = 0;
while (uinta < last) {
count += _mm_popcnt_u64(*uinta & *uintb);
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
count += _mm_popcnt_u64(*uinta & *uintb);
countDe += _mm_popcnt_u64(*uinta++ | *uintb++);
}
return 1.0 - static_cast<double>(count) / static_cast<double>(countDe);
}
#endif
inline static double
compareSparseJaccardDistance(const unsigned char* a, unsigned char* b, size_t size) {
abort();
}
inline static double
compareSparseJaccardDistance(const float* a, const float* b, size_t size) {
size_t loca = 0;
size_t locb = 0;
const uint32_t* ai = reinterpret_cast<const uint32_t*>(a);
const uint32_t* bi = reinterpret_cast<const uint32_t*>(b);
size_t count = 0;
while (locb < size && ai[loca] != 0 && bi[loca] != 0) {
int64_t sub = static_cast<int64_t>(ai[loca]) - static_cast<int64_t>(bi[locb]);
count += sub == 0;
loca += sub <= 0;
locb += sub >= 0;
}
while (ai[loca] != 0) {
loca++;
}
while (locb < size && bi[locb] != 0) {
locb++;
}
return 1.0 - static_cast<double>(count) / static_cast<double>(loca + locb - count);
}
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE>
inline static double
compareDotProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return sum;
}
template <typename OBJECT_TYPE>
inline static double
compareInnerProduct(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
// sum += a[loc] * b[loc];
}
return -sum;
}
template <typename OBJECT_TYPE>
inline static double
compareCosine(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double normA = 0.0;
double normB = 0.0;
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
normA += (double)a[loc] * (double)a[loc];
normB += (double)b[loc] * (double)b[loc];
sum += (double)a[loc] * (double)b[loc];
}
double cosine = sum / sqrt(normA * normB);
return cosine;
}
#else
inline static double
compareDotProduct(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
while (a < last) {
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
while (a < last) {
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return s;
}
inline static double
compareInnerProduct(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 sum512 = _mm512_setzero_ps();
while (a < last) {
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(_mm512_loadu_ps(a), _mm512_loadu_ps(b)));
a += 16;
b += 16;
}
__m256 sum256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum512, 0), _mm512_extractf32x8_ps(sum512, 1));
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#elif defined(NGT_AVX2)
__m256 sum256 = _mm256_setzero_ps();
while (a < last) {
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(_mm256_loadu_ps(a), _mm256_loadu_ps(b)));
a += 8;
b += 8;
}
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
#else
__m128 sum128 = _mm_setzero_ps();
while (a < last) {
sum128 = _mm_add_ps(sum128, _mm_mul_ps(_mm_loadu_ps(a), _mm_loadu_ps(b)));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, sum128);
double s = f[0] + f[1] + f[2] + f[3];
return -s;
}
inline static double
compareDotProduct(const unsigned char* a, const unsigned char* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return sum;
}
inline static double
compareInnerProduct(const unsigned char* a, const unsigned char* b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += (double)a[loc] * (double)b[loc];
}
return -sum;
}
inline static double
compareCosine(const float* a, const float* b, size_t size) {
const float* last = a + size;
#if defined(NGT_AVX512)
__m512 normA = _mm512_setzero_ps();
__m512 normB = _mm512_setzero_ps();
__m512 sum = _mm512_setzero_ps();
while (a < last) {
__m512 am = _mm512_loadu_ps(a);
__m512 bm = _mm512_loadu_ps(b);
normA = _mm512_add_ps(normA, _mm512_mul_ps(am, am));
normB = _mm512_add_ps(normB, _mm512_mul_ps(bm, bm));
sum = _mm512_add_ps(sum, _mm512_mul_ps(am, bm));
a += 16;
b += 16;
}
__m256 am256 = _mm256_add_ps(_mm512_extractf32x8_ps(normA, 0), _mm512_extractf32x8_ps(normA, 1));
__m256 bm256 = _mm256_add_ps(_mm512_extractf32x8_ps(normB, 0), _mm512_extractf32x8_ps(normB, 1));
__m256 s256 = _mm256_add_ps(_mm512_extractf32x8_ps(sum, 0), _mm512_extractf32x8_ps(sum, 1));
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(am256, 0), _mm256_extractf128_ps(am256, 1));
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(bm256, 0), _mm256_extractf128_ps(bm256, 1));
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(s256, 0), _mm256_extractf128_ps(s256, 1));
#elif defined(NGT_AVX2)
__m256 normA = _mm256_setzero_ps();
__m256 normB = _mm256_setzero_ps();
__m256 sum = _mm256_setzero_ps();
__m256 am, bm;
while (a < last) {
am = _mm256_loadu_ps(a);
bm = _mm256_loadu_ps(b);
normA = _mm256_add_ps(normA, _mm256_mul_ps(am, am));
normB = _mm256_add_ps(normB, _mm256_mul_ps(bm, bm));
sum = _mm256_add_ps(sum, _mm256_mul_ps(am, bm));
a += 8;
b += 8;
}
__m128 am128 = _mm_add_ps(_mm256_extractf128_ps(normA, 0), _mm256_extractf128_ps(normA, 1));
__m128 bm128 = _mm_add_ps(_mm256_extractf128_ps(normB, 0), _mm256_extractf128_ps(normB, 1));
__m128 s128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
#else
__m128 am128 = _mm_setzero_ps();
__m128 bm128 = _mm_setzero_ps();
__m128 s128 = _mm_setzero_ps();
__m128 am, bm;
while (a < last) {
am = _mm_loadu_ps(a);
bm = _mm_loadu_ps(b);
am128 = _mm_add_ps(am128, _mm_mul_ps(am, am));
bm128 = _mm_add_ps(bm128, _mm_mul_ps(bm, bm));
s128 = _mm_add_ps(s128, _mm_mul_ps(am, bm));
a += 4;
b += 4;
}
#endif
__attribute__((aligned(32))) float f[4];
_mm_store_ps(f, am128);
double na = f[0] + f[1] + f[2] + f[3];
_mm_store_ps(f, bm128);
double nb = f[0] + f[1] + f[2] + f[3];
_mm_store_ps(f, s128);
double s = f[0] + f[1] + f[2] + f[3];
double cosine = s / sqrt(na * nb);
return cosine;
}
inline static double
compareCosine(const unsigned char* a, const unsigned char* b, size_t size) {
double normA = 0.0;
double normB = 0.0;
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
normA += (double)a[loc] * (double)a[loc];
normB += (double)b[loc] * (double)b[loc];
sum += (double)a[loc] * (double)b[loc];
}
double cosine = sum / sqrt(normA * normB);
return cosine;
}
#endif // #if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE>
inline static double
compareAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double cosine = compareCosine(a, b, size);
if (cosine >= 1.0) {
return 0.0;
} else if (cosine <= -1.0) {
return acos(-1.0);
} else {
return acos(cosine);
}
}
template <typename OBJECT_TYPE>
inline static double
compareNormalizedAngleDistance(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double cosine = compareDotProduct(a, b, size);
if (cosine >= 1.0) {
return 0.0;
} else if (cosine <= -1.0) {
return acos(-1.0);
} else {
return acos(cosine);
}
}
template <typename OBJECT_TYPE>
inline static double
compareCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
return 1.0 - compareCosine(a, b, size);
}
template <typename OBJECT_TYPE>
inline static double
compareNormalizedCosineSimilarity(const OBJECT_TYPE* a, const OBJECT_TYPE* b, size_t size) {
double v = 1.0 - compareDotProduct(a, b, size);
return v < 0.0 ? 0.0 : v;
}
class L1Uint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL1((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class L2Uint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL2((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class HammingUint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareHammingDistance((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class JaccardUint8 {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareJaccardDistance((const uint8_t*)a, (const uint8_t*)b, size);
}
};
class SparseJaccardFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareSparseJaccardDistance((const float*)a, (const float*)b, size);
}
};
class L2Float {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
#if defined(NGT_NO_AVX)
return PrimitiveComparator::compareL2<float, double>((const float*)a, (const float*)b, size);
#else
return PrimitiveComparator::compareL2((const float*)a, (const float*)b, size);
#endif
}
};
class L1Float {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareL1((const float*)a, (const float*)b, size);
}
};
class CosineSimilarityFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareCosineSimilarity((const float*)a, (const float*)b, size);
}
};
class AngleFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareAngleDistance((const float*)a, (const float*)b, size);
}
};
class NormalizedCosineSimilarityFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareNormalizedCosineSimilarity((const float*)a, (const float*)b, size);
}
};
class NormalizedAngleFloat {
public:
inline static double
compare(const void* a, const void* b, size_t size) {
return PrimitiveComparator::compareNormalizedAngleDistance((const float*)a, (const float*)b, size);
}
};
};
} // namespace NGT

View File

@ -1,40 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/SharedMemoryAllocator.h"
void* operator
new(size_t size, SharedMemoryAllocator &allocator)
{
void *addr = allocator.allocate(size);
#ifdef MEMORY_ALLOCATOR_INFO
std::cerr << "new:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
#endif
return addr;
}
void* operator
new[](size_t size, SharedMemoryAllocator &allocator)
{
void *addr = allocator.allocate(size);
#ifdef MEMORY_ALLOCATOR_INFO
std::cerr << "new[]:" << size << " " << addr << " " << allocator.getTotalSize() << std::endl;
#endif
return addr;
}

View File

@ -1,209 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/defines.h"
#include "NGT/MmapManager.h"
#include <unistd.h>
#include <cstdlib>
#include <cstring>
#include <string>
#include <iostream>
#include <vector>
#include <exception>
#include <cassert>
#define MMAP_MANAGER
///////////////////////////////////////////////////////////////////////
class SharedMemoryAllocator {
public:
enum GetMemorySizeType {
GetTotalMemorySize = 0,
GetAllocatedMemorySize = 1,
GetFreedMemorySize = 2
};
SharedMemoryAllocator():isValid(false) {
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocatorSiglton::constructor" << std::endl;
#endif
}
SharedMemoryAllocator(const SharedMemoryAllocator& a){}
SharedMemoryAllocator& operator=(const SharedMemoryAllocator& a){ return *this; }
public:
void* allocate(size_t size) {
if (isValid == false) {
std::cerr << "SharedMemoryAllocator::allocate: Fatal error! " << std::endl;
assert(isValid);
}
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::allocate: size=" << size << std::endl;
std::cerr << "SharedMemoryAllocator::allocate: before " << getTotalSize() << ":" << getAllocatedSize() << ":" << getFreedSize() << std::endl;
#endif
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
if(!isValid){
return NULL;
}
off_t file_offset = mmanager->alloc(size, true);
if (file_offset == -1) {
std::cerr << "Fatal Error: Allocating memory size is too big for this settings." << std::endl;
std::cerr << " Max allocation size should be enlarged." << std::endl;
abort();
}
void *p = mmanager->getAbsAddr(file_offset);
std::memset(p, 0, size);
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::allocate: end" <<std::endl;
#endif
return p;
#else
void *ptr = std::malloc(size);
std::memset(ptr, 0, size);
return ptr;
#endif
}
void free(void *ptr) {
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::free: ptr=" << ptr << std::endl;
#endif
if (ptr == 0) {
std::cerr << "SharedMemoryAllocator::free: ptr is invalid! ptr=" << ptr << std::endl;
}
if (ptr == 0) {
return;
}
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
off_t file_offset = mmanager->getRelAddr(ptr);
mmanager->free(file_offset);
#else
std::free(ptr);
#endif
}
void *construct(const std::string &filePath, size_t memorysize = 0) {
file = filePath; // debug
#ifdef SMA_TRACE
std::cerr << "ObjectSharedMemoryAllocator::construct: file " << filePath << std::endl;
#endif
void *hook = 0;
#ifdef MMAP_MANAGER
mmanager = new MemoryManager::MmapManager();
// msize is the maximum allocated size (M byte) at once.
size_t msize = memorysize;
if (msize == 0) {
msize = NGT_SHARED_MEMORY_MAX_SIZE;
}
size_t bsize = msize * 1048576 / sysconf(_SC_PAGESIZE) + 1; // 1048576=1M
uint64_t size = bsize * sysconf(_SC_PAGESIZE);
MemoryManager::init_option_st option;
MemoryManager::MmapManager::setDefaultOptionValue(option);
option.use_expand = true;
option.reuse_type = MemoryManager::REUSE_DATA_CLASSIFY;
bool create = true;
if(!mmanager->init(filePath, size, &option)){
#ifdef SMA_TRACE
std::cerr << "SMA: info. already existed." << std::endl;
#endif
create = false;
} else {
#ifdef SMA_TRACE
std::cerr << "SMA::construct: msize=" << msize << ":" << memorysize << std::endl;
#endif
}
if(!mmanager->openMemory(filePath)){
std::cerr << "SMA: open error" << std::endl;
return 0;
}
if (!create) {
#ifdef SMA_TRACE
std::cerr << "SMA: get hook to initialize data structure" << std::endl;
#endif
hook = mmanager->getEntryHook();
assert(hook != 0);
}
#endif
isValid = true;
#ifdef SMA_TRACE
std::cerr << "SharedMemoryAllocator::construct: " << filePath << " total="
<< getTotalSize() << " allocated=" << getAllocatedSize() << " freed="
<< getFreedSize() << " (" << (double)getFreedSize() / (double)getTotalSize() << ") " << std::endl;
#endif
return hook;
}
void destruct() {
if (!isValid) {
return;
}
isValid = false;
#ifdef MMAP_MANAGER
mmanager->closeMemory();
delete mmanager;
#endif
};
void setEntry(void *entry) {
#ifdef MMAP_MANAGER
mmanager->setEntryHook(entry);
#endif
}
void *getAddr(off_t oft) {
if (oft == 0) {
return 0;
}
assert(oft > 0);
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
return mmanager->getAbsAddr(oft);
#else
return (void*)oft;
#endif
}
off_t getOffset(void *adr) {
if (adr == 0) {
return 0;
}
#if defined(MMAP_MANAGER) && !defined(NOT_USE_MMAP_ALLOCATOR)
return mmanager->getRelAddr(adr);
#else
return (off_t)adr;
#endif
}
size_t getMemorySize(GetMemorySizeType t) {
switch (t) {
case GetTotalMemorySize : return getTotalSize();
case GetAllocatedMemorySize : return getAllocatedSize();
case GetFreedMemorySize : return getFreedSize();
}
return getTotalSize();
}
size_t getTotalSize() { return mmanager->getTotalSize(); }
size_t getAllocatedSize() { return mmanager->getUseSize(); }
size_t getFreedSize() { return mmanager->getFreeSize(); }
bool isValid;
std::string file;
#ifdef MMAP_MANAGER
MemoryManager::MmapManager *mmanager;
#endif
};
/////////////////////////////////////////////////////////////////////////
void* operator new(size_t size, SharedMemoryAllocator &allocator);
void* operator new[](size_t size, SharedMemoryAllocator &allocator);

View File

@ -1,128 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <pthread.h>
#include "Thread.h"
using namespace std;
using namespace NGT;
namespace NGT {
class ThreadInfo {
public:
pthread_t threadid;
pthread_attr_t threadAttr;
};
class ThreadMutex {
public:
pthread_mutex_t mutex;
pthread_cond_t condition;
};
}
Thread::Thread() {
threadInfo = new ThreadInfo;
threadInfo->threadid = 0;
threadNo = -1;
isTerminate = false;
}
Thread::~Thread() {
if (threadInfo != 0) {
delete threadInfo;
}
}
ThreadMutex *
Thread::constructThreadMutex()
{
return new ThreadMutex;
}
void
Thread::destructThreadMutex(ThreadMutex *t)
{
if (t != 0) {
pthread_mutex_destroy(&(t->mutex));
pthread_cond_destroy(&(t->condition));
delete t;
}
}
int
Thread::start()
{
pthread_attr_init(&(threadInfo->threadAttr));
size_t stackSize = 0;
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
if (stackSize < 0xa00000) { // 64bit stack size
stackSize *= 4;
}
pthread_attr_setstacksize(&(threadInfo->threadAttr), stackSize);
pthread_attr_getstacksize(&(threadInfo->threadAttr), &stackSize);
return pthread_create(&(threadInfo->threadid), &(threadInfo->threadAttr), Thread::startThread, this);
}
int
Thread::join()
{
return pthread_join(threadInfo->threadid, 0);
}
void
Thread::lock(ThreadMutex &m)
{
pthread_mutex_lock(&m.mutex);
}
void
Thread::unlock(ThreadMutex &m)
{
pthread_mutex_unlock(&m.mutex);
}
void
Thread::signal(ThreadMutex &m)
{
pthread_cond_signal(&m.condition);
}
void
Thread::wait(ThreadMutex &m)
{
if (pthread_cond_wait(&m.condition, &m.mutex) != 0) {
cerr << "waitForSignalFromThread: internal error" << endl;
NGTThrowException("waitForSignalFromThread: internal error");
}
}
void
Thread::broadcast(ThreadMutex &m)
{
pthread_cond_broadcast(&m.condition);
}
void
Thread::mutexInit(ThreadMutex &m)
{
if (pthread_mutex_init(&m.mutex, NULL) != 0) {
NGTThrowException("Thread::mutexInit: Cannot initialize mutex");
}
if (pthread_cond_init(&m.condition, NULL) != 0) {
NGTThrowException("Thread::mutexInit: Cannot initialize condition");
}
}

View File

@ -1,291 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Common.h"
#include <cstdio>
#include <cstdlib>
#include <sys/time.h>
#include <unistd.h>
#include <iostream>
#include <deque>
namespace NGT {
void * evaluate_responce(void *);
class ThreadTerminationException : public Exception {
public:
ThreadTerminationException(const std::string &file, size_t line, std::stringstream &m) { set(file, line, m.str()); }
ThreadTerminationException(const std::string &file, size_t line, const std::string &m) { set(file, line, m); }
};
class ThreadInfo;
class ThreadMutex;
class Thread
{
public:
Thread();
virtual ~Thread();
virtual int start();
virtual int join();
static ThreadMutex *constructThreadMutex();
static void destructThreadMutex(ThreadMutex *t);
static void mutexInit(ThreadMutex &m);
static void lock(ThreadMutex &m);
static void unlock(ThreadMutex &m);
static void signal(ThreadMutex &m);
static void wait(ThreadMutex &m);
static void broadcast(ThreadMutex &m);
protected:
virtual int run() {
return 0;
}
private:
static void* startThread(void *thread) {
if (thread == 0) {
return 0;
}
Thread* p = (Thread*)thread;
p->run();
return thread;
}
public:
int threadNo;
bool isTerminate;
protected:
ThreadInfo *threadInfo;
};
template <class JOB, class SHARED_DATA, class THREAD>
class ThreadPool {
public:
class JobQueue : public std::deque<JOB> {
public:
JobQueue() {
threadMutex = Thread::constructThreadMutex();
Thread::mutexInit(*threadMutex);
}
~JobQueue() {
Thread::destructThreadMutex(threadMutex);
}
bool isDeficient() { return std::deque<JOB>::size() <= requestSize; }
bool isEmpty() { return std::deque<JOB>::size() == 0; }
bool isFull() { return std::deque<JOB>::size() >= maxSize; }
void setRequestSize(int s) { requestSize = s; }
void setMaxSize(int s) { maxSize = s; }
void lock() { Thread::lock(*threadMutex); }
void unlock() { Thread::unlock(*threadMutex); }
void signal() { Thread::signal(*threadMutex); }
void wait() { Thread::wait(*threadMutex); }
void wait(JobQueue &q) { wait(*q.threadMutex); }
void broadcast() { Thread::broadcast(*threadMutex); }
unsigned int requestSize;
unsigned int maxSize;
ThreadMutex *threadMutex;
};
class InputJobQueue : public JobQueue {
public:
InputJobQueue() {
isTerminate = false;
underPushing = false;
pushedSize = 0;
}
void popFront(JOB &d) {
JobQueue::lock();
while (JobQueue::isEmpty()) {
if (isTerminate) {
JobQueue::unlock();
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
}
JobQueue::wait();
}
d = std::deque<JOB>::front();
std::deque<JOB>::pop_front();
JobQueue::unlock();
return;
}
void popFront(std::deque<JOB> &d, size_t s) {
JobQueue::lock();
while (JobQueue::isEmpty()) {
if (isTerminate) {
JobQueue::unlock();
NGTThrowSpecificException("Thread::termination", ThreadTerminationException);
}
JobQueue::wait();
}
for (size_t i = 0; i < s; i++) {
d.push_back(std::deque<JOB>::front());
std::deque<JOB>::pop_front();
if (JobQueue::isEmpty()) {
break;
}
}
JobQueue::unlock();
return;
}
void pushBack(JOB &data) {
JobQueue::lock();
if (!underPushing) {
underPushing = true;
pushedSize = 0;
}
pushedSize++;
std::deque<JOB>::push_back(data);
JobQueue::unlock();
JobQueue::signal();
}
void pushBackEnd() {
underPushing = false;
}
void terminate() {
JobQueue::lock();
if (underPushing || !JobQueue::isEmpty()) {
JobQueue::unlock();
NGTThrowException("Thread::teminate:Under pushing!");
}
isTerminate = true;
JobQueue::unlock();
JobQueue::broadcast();
}
bool isTerminate;
bool underPushing;
size_t pushedSize;
};
class OutputJobQueue : public JobQueue {
public:
void waitForFull() {
JobQueue::wait();
JobQueue::unlock();
}
void pushBack(JOB &data) {
JobQueue::lock();
std::deque<JOB>::push_back(data);
if (!JobQueue::isFull()) {
JobQueue::unlock();
return;
}
JobQueue::unlock();
JobQueue::signal();
}
};
class SharedData {
public:
SharedData():isAvailable(false) {
inputJobs.requestSize = 5;
inputJobs.maxSize = 50;
}
SHARED_DATA sharedData;
InputJobQueue inputJobs;
OutputJobQueue outputJobs;
bool isAvailable;
};
class Thread : public THREAD {
public:
SHARED_DATA &getSharedData() {
if (threadPool->sharedData.isAvailable) {
return threadPool->sharedData.sharedData;
} else {
NGTThrowException("Thread::getSharedData: Shared data is unavailable. No set yet.");
}
}
InputJobQueue &getInputJobQueue() {
return threadPool->sharedData.inputJobs;
}
OutputJobQueue &getOutputJobQueue() {
return threadPool->sharedData.outputJobs;
}
ThreadPool *threadPool;
};
ThreadPool(int s) {
size = s;
threads = new Thread[s];
}
~ThreadPool() {
delete[] threads;
}
void setSharedData(SHARED_DATA d) {
sharedData.sharedData = d;
sharedData.isAvailable = true;
}
void create() {
for (unsigned int i = 0; i < size; i++) {
threads[i].threadPool = this;
threads[i].threadNo = i;
threads[i].start();
}
}
void pushInputQueue(JOB &data) {
if (!sharedData.inputJobs.underPushing) {
sharedData.outputJobs.lock();
}
sharedData.inputJobs.pushBack(data);
}
void waitForFinish() {
sharedData.inputJobs.pushBackEnd();
sharedData.outputJobs.setMaxSize(sharedData.inputJobs.pushedSize);
sharedData.inputJobs.pushedSize = 0;
sharedData.outputJobs.waitForFull();
}
void terminate() {
sharedData.inputJobs.terminate();
for (unsigned int i = 0; i < size; i++) {
threads[i].join();
}
}
InputJobQueue &getInputJobQueue() { return sharedData.inputJobs; }
OutputJobQueue &getOutputJobQueue() { return sharedData.outputJobs; }
SharedData sharedData; // shared data
Thread *threads; // thread set
unsigned int size; // thread size
};
}

View File

@ -1,564 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/defines.h"
#include "NGT/Tree.h"
#include "NGT/Node.h"
#include <vector>
using namespace std;
using namespace NGT;
void
DVPTree::insert(InsertContainer &iobj) {
SearchContainer q(iobj.object);
q.mode = SearchContainer::SearchLeaf;
q.vptree = this;
q.radius = 0.0;
search(q);
iobj.vptree = this;
assert(q.nodeID.getType() == Node::ID::Leaf);
LeafNode *ln = (LeafNode*)getNode(q.nodeID);
insert(iobj, ln);
return;
}
void
DVPTree::insert(InsertContainer &iobj, LeafNode *leafNode)
{
LeafNode &leaf = *leafNode;
size_t fsize = leaf.getObjectSize();
if (fsize != 0) {
NGT::ObjectSpace::Comparator &comparator = objectSpace->getComparator();
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = comparator(iobj.object, leaf.getPivot(*objectSpace));
#else
Distance d = comparator(iobj.object, leaf.getPivot());
#endif
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistance *objects = leaf.getObjectIDs(leafNodes.allocator);
#else
NGT::ObjectDistance *objects = leaf.getObjectIDs();
#endif
for (size_t i = 0; i < fsize; i++) {
if (objects[i].distance == d) {
Distance idd = 0.0;
ObjectID loid;
try {
loid = objects[i].id;
idd = comparator(iobj.object, *getObjectRepository().get(loid));
} catch (Exception &e) {
stringstream msg;
msg << "LeafNode::insert: Cannot find object which belongs to a leaf node. id="
<< objects[i].id << ":" << e.what() << endl;
NGTThrowException(msg.str());
}
if (idd == 0.0) {
if (loid == iobj.id) {
stringstream msg;
msg << "DVPTree::insert:already existed. " << iobj.id;
NGTThrowException(msg);
}
return;
}
}
}
}
if (leaf.getObjectSize() >= leafObjectsSize) {
split(iobj, leaf);
} else {
insertObject(iobj, leaf);
}
return;
}
Node::ID
DVPTree::split(InsertContainer &iobj, LeafNode &leaf)
{
Node::Objects *fs = getObjects(leaf, iobj);
int pv = DVPTree::MaxVariance;
switch (splitMode) {
case DVPTree::MaxVariance:
pv = LeafNode::selectPivotByMaxVariance(iobj, *fs);
break;
case DVPTree::MaxDistance:
pv = LeafNode::selectPivotByMaxDistance(iobj, *fs);
break;
}
LeafNode::splitObjects(iobj, *fs, pv);
Node::ID nid = recombineNodes(iobj, *fs, leaf);
delete fs;
return nid;
}
Node::ID
DVPTree::recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf)
{
LeafNode *ln[internalChildrenSize];
Node::ID targetParent = leaf.parent;
Node::ID targetId = leaf.id;
ln[0] = &leaf;
ln[0]->objectSize = 0;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
for (size_t i = 1; i < internalChildrenSize; i++) {
ln[i] = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
}
#else
for (size_t i = 1; i < internalChildrenSize; i++) {
ln[i] = new LeafNode;
}
#endif
InternalNode *in = createInternalNode();
Node::ID inid = in->id;
try {
if (targetParent.getID() != 0) {
InternalNode &pnode = *(InternalNode*)getNode(targetParent);
for (size_t i = 0; i < internalChildrenSize; i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (pnode.getChildren(internalNodes.allocator)[i] == targetId) {
pnode.getChildren(internalNodes.allocator)[i] = inid;
#else
if (pnode.getChildren()[i] == targetId) {
pnode.getChildren()[i] = inid;
#endif
break;
}
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, internalNodes.allocator);
#else
in->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
in->parent = targetParent;
int fsize = fs.size();
int cid = fs[0].clusterID;
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = fs[0].id;
fid.distance = 0.0;
ln[cid]->objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize].id = fs[0].id;
ln[cid]->getObjectIDs(leafNodes.allocator)[ln[cid]->objectSize++].distance = 0.0;
#else
ln[cid]->getObjectIDs()[ln[cid]->objectSize].id = fs[0].id;
ln[cid]->getObjectIDs()[ln[cid]->objectSize++].distance = 0.0;
#endif
#endif
if (fs[0].leafDistance == Node::Object::Pivot) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
#else
ln[cid]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
} else {
NGTThrowException("recombineNodes: internal error : illegal pivot.");
}
ln[cid]->parent = inid;
int maxClusterID = cid;
for (int i = 1; i < fsize; i++) {
int clusterID = fs[i].clusterID;
if (clusterID > maxClusterID) {
maxClusterID = clusterID;
}
Distance ld;
if (fs[i].leafDistance == Node::Object::Pivot) {
// pivot
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace, leafNodes.allocator);
#else
ln[clusterID]->setPivot(*getObjectRepository().get(fs[i].id), *objectSpace);
#endif
ld = 0.0;
} else {
ld = fs[i].leafDistance;
}
#ifdef NGT_NODE_USE_VECTOR
fid.id = fs[i].id;
fid.distance = ld;
ln[clusterID]->objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize].id = fs[i].id;
ln[clusterID]->getObjectIDs(leafNodes.allocator)[ln[clusterID]->objectSize++].distance = ld;
#else
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize].id = fs[i].id;
ln[clusterID]->getObjectIDs()[ln[clusterID]->objectSize++].distance = ld;
#endif
#endif
ln[clusterID]->parent = inid;
if (clusterID != cid) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getBorders(internalNodes.allocator)[cid] = fs[i].distance;
#else
in->getBorders()[cid] = fs[i].distance;
#endif
cid = fs[i].clusterID;
}
}
// When the number of the children is less than the expected,
// proper values are set to the empty children.
for (size_t i = maxClusterID + 1; i < internalChildrenSize; i++) {
ln[i]->parent = inid;
// dummy
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace, leafNodes.allocator);
#else
ln[i]->setPivot(*getObjectRepository().get(fs[0].id), *objectSpace);
#endif
if (i < (internalChildrenSize - 1)) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getBorders(internalNodes.allocator)[i] = FLT_MAX;
#else
in->getBorders()[i] = FLT_MAX;
#endif
}
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getChildren(internalNodes.allocator)[0] = targetId;
#else
in->getChildren()[0] = targetId;
#endif
for (size_t i = 1; i < internalChildrenSize; i++) {
insertNode(ln[i]);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in->getChildren(internalNodes.allocator)[i] = ln[i]->id;
#else
in->getChildren()[i] = ln[i]->id;
#endif
}
} catch(Exception &e) {
throw e;
}
return inid;
}
void
DVPTree::insertObject(InsertContainer &ic, LeafNode &leaf) {
if (leaf.getObjectSize() == 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace, leafNodes.allocator);
#else
leaf.setPivot(*getObjectRepository().get(ic.id), *objectSpace);
#endif
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = ic.id;
fid.distance = 0;
leaf.objectIDs.push_back(fid);
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = 0;
#else
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
leaf.getObjectIDs()[leaf.objectSize++].distance = 0;
#endif
#endif
} else {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot(*objectSpace));
#else
Distance d = objectSpace->getComparator()(ic.object, leaf.getPivot());
#endif
#ifdef NGT_NODE_USE_VECTOR
LeafNode::ObjectIDs fid;
fid.id = ic.id;
fid.distance = d;
leaf.objectIDs.push_back(fid);
std::sort(leaf.objectIDs.begin(), leaf.objectIDs.end(), LeafNode::ObjectIDs());
#else
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize].id = ic.id;
leaf.getObjectIDs(leafNodes.allocator)[leaf.objectSize++].distance = d;
#else
leaf.getObjectIDs()[leaf.objectSize].id = ic.id;
leaf.getObjectIDs()[leaf.objectSize++].distance = d;
#endif
#endif
}
}
Node::Objects *
DVPTree::getObjects(LeafNode &n, Container &iobj)
{
int size = n.getObjectSize() + 1;
Node::Objects *fs = new Node::Objects(size);
for (size_t i = 0; i < n.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs(leafNodes.allocator)[i].id);
(*fs)[i].id = n.getObjectIDs(leafNodes.allocator)[i].id;
#else
(*fs)[i].object = getObjectRepository().get(n.getObjectIDs()[i].id);
(*fs)[i].id = n.getObjectIDs()[i].id;
#endif
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
(*fs)[n.getObjectSize()].object = getObjectRepository().get(iobj.id);
#else
(*fs)[n.getObjectSize()].object = &iobj.object;
#endif
(*fs)[n.getObjectSize()].id = iobj.id;
return fs;
}
void
DVPTree::removeEmptyNodes(InternalNode &inode) {
int csize = internalChildrenSize;
InternalNode *target = &inode;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Node::ID *children = target->getChildren(internalNodes.allocator);
#else
Node::ID *children = target->getChildren();
#endif
for(;;) {
for (int i = 0; i < csize; i++) {
if (children[i].getType() == Node::ID::Internal) {
return;
}
LeafNode &ln = *static_cast<LeafNode*>(getNode(children[i]));
if (ln.getObjectSize() != 0) {
return;
}
}
for (int i = 0; i < csize; i++) {
removeNode(children[i]);
}
if (target->parent.getID() == 0) {
removeNode(target->id);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
LeafNode *root = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
#else
LeafNode *root = new LeafNode;
#endif
insertNode(root);
if (root->id.getID() != 1) {
NGTThrowException("Root id Error");
}
return;
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
LeafNode *ln = new(leafNodes.allocator) LeafNode(leafNodes.allocator);
#else
LeafNode *ln = new LeafNode;
#endif
ln->parent = target->parent;
insertNode(ln);
InternalNode &in = *(InternalNode*)getNode(ln->parent);
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
in.updateChild(*this, target->id, ln->id, internalNodes.allocator);
#else
in.updateChild(*this, target->id, ln->id);
#endif
removeNode(target->id);
target = &in;
}
return;
}
void
DVPTree::search(SearchContainer &sc, InternalNode &node, UncheckedNode &uncheckedNode)
{
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance d = objectSpace->getComparator()(sc.object, node.getPivot(*objectSpace));
#else
Distance d = objectSpace->getComparator()(sc.object, node.getPivot());
#endif
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
sc.distanceComputationCount++;
#endif
int bsize = internalChildrenSize - 1;
vector<ObjectDistance> regions;
regions.reserve(internalChildrenSize);
ObjectDistance child;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance *borders = node.getBorders(internalNodes.allocator);
#else
Distance *borders = node.getBorders();
#endif
int mid;
for (mid = 0; mid < bsize; mid++) {
if (d < borders[mid]) {
child.id = mid;
child.distance = 0.0;
regions.push_back(child);
if (d + sc.radius < borders[mid]) {
break;
} else {
continue;
}
} else {
if (d < borders[mid] + sc.radius) {
child.id = mid;
child.distance = d - borders[mid];
regions.push_back(child);
continue;
} else {
continue;
}
}
}
if (mid == bsize) {
if (d >= borders[mid - 1]) {
child.id = mid;
child.distance = 0.0;
regions.push_back(child);
} else {
child.id = mid;
child.distance = borders[mid - 1] - d;
regions.push_back(child);
}
}
sort(regions.begin(), regions.end());
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Node::ID *children = node.getChildren(internalNodes.allocator);
#else
Node::ID *children = node.getChildren();
#endif
vector<ObjectDistance>::iterator i;
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
if (children[regions.front().id].getType() == Node::ID::Leaf) {
sc.nodeID.setRaw(children[regions.front().id].get());
assert(uncheckedNode.empty());
} else {
uncheckedNode.push(children[regions.front().id]);
}
} else {
for (i = regions.begin(); i != regions.end(); i++) {
uncheckedNode.push(children[i->id]);
}
}
}
void
DVPTree::search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode)
{
DVPTree::SearchContainer &q = (DVPTree::SearchContainer&)so;
if (node.getObjectSize() == 0) {
return;
}
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
Distance pq = objectSpace->getComparator()(q.object, node.getPivot(*objectSpace));
#else
Distance pq = objectSpace->getComparator()(q.object, node.getPivot());
#endif
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
so.distanceComputationCount++;
#endif
ObjectDistance r;
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
NGT::ObjectDistance *objects = node.getObjectIDs(leafNodes.allocator);
#else
NGT::ObjectDistance *objects = node.getObjectIDs();
#endif
for (size_t i = 0; i < node.getObjectSize(); i++) {
if ((objects[i].distance <= pq + q.radius) &&
(objects[i].distance >= pq - q.radius)) {
Distance d = 0;
try {
d = objectSpace->getComparator()(q.object, *q.vptree->getObjectRepository().get(objects[i].id));
#ifdef NGT_DISTANCE_COMPUTATION_COUNT
so.distanceComputationCount++;
#endif
} catch(...) {
NGTThrowException("VpTree::LeafNode::search: Internal fatal error : Cannot get object");
}
if (d <= q.radius) {
r.id = objects[i].id;
r.distance = d;
so.getResult().push_back(r);
std::sort(so.getResult().begin(), so.getResult().end());
if (so.getResult().size() > q.size) {
so.getResult().resize(q.size);
}
}
}
}
}
void
DVPTree::search(SearchContainer &sc) {
((SearchContainer&)sc).vptree = this;
Node *root = getRootNode();
assert(root != 0);
if (sc.mode == DVPTree::SearchContainer::SearchLeaf) {
if (root->id.getType() == Node::ID::Leaf) {
sc.nodeID.setRaw(root->id.get());
return;
}
}
UncheckedNode uncheckedNode;
uncheckedNode.push(root->id);
while (!uncheckedNode.empty()) {
Node::ID nodeid = uncheckedNode.top();
uncheckedNode.pop();
Node *cnode = getNode(nodeid);
if (cnode == 0) {
cerr << "Error! child node is null. but continue." << endl;
continue;
}
if (cnode->id.getType() == Node::ID::Internal) {
search(sc, (InternalNode&)*cnode, uncheckedNode);
} else if (cnode->id.getType() == Node::ID::Leaf) {
search(sc, (LeafNode&)*cnode, uncheckedNode);
} else {
cerr << "Tree: Inner fatal error!: Node type error!" << endl;
abort();
}
}
}

View File

@ -1,513 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include "NGT/Common.h"
#include "NGT/Node.h"
#include "NGT/defines.h"
#include "faiss/utils/BitsetView.h"
#include <sstream>
#include <string>
#include <vector>
#include <stack>
#include <set>
namespace NGT {
class DVPTree {
public:
enum SplitMode {
MaxDistance = 0,
MaxVariance = 1
};
typedef std::vector<Node::ID> IDVector;
class Container : public NGT::Container {
public:
Container(Object &f, ObjectID i):NGT::Container(f, i) {}
DVPTree *vptree;
};
class SearchContainer : public NGT::SearchContainer {
public:
enum Mode {
SearchLeaf = 0,
SearchObject = 1
};
SearchContainer(Object &f, ObjectID i):NGT::SearchContainer(f, i) {}
SearchContainer(Object &f):NGT::SearchContainer(f, 0) {}
DVPTree *vptree;
Mode mode;
Node::ID nodeID;
};
class InsertContainer : public Container {
public:
InsertContainer(Object &f, ObjectID i):Container(f, i) {}
};
class RemoveContainer : public Container {
public:
RemoveContainer(Object &f, ObjectID i):Container(f, i) {}
};
DVPTree() {
leafObjectsSize = LeafNode::LeafObjectsSizeMax;
internalChildrenSize = InternalNode::InternalChildrenSizeMax;
splitMode = MaxVariance;
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
insertNode(new LeafNode);
#endif
}
virtual ~DVPTree() {
#ifndef NGT_SHARED_MEMORY_ALLOCATOR
deleteAll();
#endif
}
void deleteAll() {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leafNodes[i]->deletePivot(*objectSpace, leafNodes.allocator);
#else
leafNodes[i]->deletePivot(*objectSpace);
#endif
delete leafNodes[i];
}
}
leafNodes.clear();
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
internalNodes[i]->deletePivot(*objectSpace, internalNodes.allocator);
#else
internalNodes[i]->deletePivot(*objectSpace);
#endif
delete internalNodes[i];
}
}
internalNodes.clear();
}
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
void open(const std::string &f, size_t sharedMemorySize) {
// If no file, then create a new file.
leafNodes.open(f + "l", sharedMemorySize);
internalNodes.open(f + "i", sharedMemorySize);
if (leafNodes.size() == 0) {
if (internalNodes.size() != 0) {
NGTThrowException("Tree::Open: Internal error. Internal and leaf are inconsistent.");
}
LeafNode *ln = leafNodes.allocate();
insertNode(ln);
}
}
#endif // NGT_SHARED_MEMORY_ALLOCATOR
void insert(InsertContainer &iobj);
void insert(InsertContainer &iobj, LeafNode *n);
Node::ID split(InsertContainer &iobj, LeafNode &leaf);
Node::ID recombineNodes(InsertContainer &ic, Node::Objects &fs, LeafNode &leaf);
void insertObject(InsertContainer &obj, LeafNode &leaf);
typedef std::stack<Node::ID> UncheckedNode;
void search(SearchContainer &so);
void search(SearchContainer &so, InternalNode &node, UncheckedNode &uncheckedNode);
void search(SearchContainer &so, LeafNode &node, UncheckedNode &uncheckedNode);
bool searchObject(ObjectID id) {
LeafNode &ln = getLeaf(id);
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
if (ln.getObjectIDs(leafNodes.allocator)[i].id == id) {
#else
if (ln.getObjectIDs()[i].id == id) {
#endif
return true;
}
}
return false;
}
LeafNode &getLeaf(ObjectID id) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
Object *qobject = objectSpace->allocateObject(*getObjectRepository().get(id));
SearchContainer q(*qobject);
#else
SearchContainer q(*getObjectRepository().get(id));
#endif
q.mode = SearchContainer::SearchLeaf;
q.vptree = this;
q.radius = 0.0;
q.size = 1;
search(q);
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
objectSpace->deleteObject(qobject);
#endif
return *(LeafNode*)getNode(q.nodeID);
}
void replace(ObjectID id, ObjectID replacedId) { remove(id, replacedId); }
// remove the specified object.
void remove(ObjectID id, ObjectID replaceId = 0) {
LeafNode &ln = getLeaf(id);
try {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
ln.removeObject(id, replaceId, leafNodes.allocator);
#else
ln.removeObject(id, replaceId);
#endif
} catch(Exception &err) {
std::stringstream msg;
msg << "VpTree::remove: Inner error. Cannot remove object. leafNode=" << ln.id.getID() << ":" << err.what();
NGTThrowException(msg);
}
if (ln.getObjectSize() == 0) {
if (ln.parent.getID() != 0) {
InternalNode &inode = *(InternalNode*)getNode(ln.parent);
removeEmptyNodes(inode);
}
}
return;
}
void removeNaively(ObjectID id, ObjectID replaceId = 0) {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
try {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
leafNodes[i]->removeObject(id, replaceId, leafNodes.allocator);
#else
leafNodes[i]->removeObject(id, replaceId);
#endif
break;
} catch(...) {}
}
}
}
Node *getRootNode() {
size_t nid = 1;
Node *root;
try {
root = internalNodes.get(nid);
} catch(Exception &err) {
try {
root = leafNodes.get(nid);
} catch(Exception &e) {
std::stringstream msg;
msg << "VpTree::getRootNode: Inner error. Cannot get a leaf root node. " << nid << ":" << e.what();
NGTThrowException(msg);
}
}
return root;
}
InternalNode *createInternalNode() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
InternalNode *n = new(internalNodes.allocator) InternalNode(internalChildrenSize, internalNodes.allocator);
#else
InternalNode *n = new InternalNode(internalChildrenSize);
#endif
insertNode(n);
return n;
}
void
removeNode(Node::ID id) {
size_t idx = id.getID();
if (id.getType() == Node::ID::Leaf) {
leafNodes.remove(idx);
} else {
internalNodes.remove(idx);
}
}
void removeEmptyNodes(InternalNode &node);
Node::Objects * getObjects(LeafNode &n, Container &iobj);
// for milvus
void
getObjectIDsFromLeaf(Node::ID nid, ObjectDistances& rl, const faiss::BitsetView bitset) {
LeafNode& ln = *(LeafNode*)getNode(nid);
rl.clear();
ObjectDistance r;
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
#else
r.id = ln.getObjectIDs()[i].id;
r.distance = ln.getObjectIDs()[i].distance;
#endif
if (!bitset.empty() && bitset.test(r.id - 1)) {
continue;
}
rl.push_back(r);
}
return;
}
void getObjectIDsFromLeaf(Node::ID nid, ObjectDistances &rl) {
LeafNode &ln = *(LeafNode*)getNode(nid);
rl.clear();
ObjectDistance r;
for (size_t i = 0; i < ln.getObjectSize(); i++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
r.id = ln.getObjectIDs(leafNodes.allocator)[i].id;
r.distance = ln.getObjectIDs(leafNodes.allocator)[i].distance;
#else
r.id = ln.getObjectIDs()[i].id;
r.distance = ln.getObjectIDs()[i].distance;
#endif
rl.push_back(r);
}
return;
}
void
insertNode(LeafNode *n) {
size_t id = leafNodes.insert(n);
n->id.setID(id);
n->id.setType(Node::ID::Leaf);
}
// replace
void replaceNode(LeafNode *n) {
int id = n->id.getID();
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
leafNodes.set(id, n);
#else
leafNodes[id] = n;
#endif
}
void
insertNode(InternalNode *n)
{
size_t id = internalNodes.insert(n);
n->id.setID(id);
n->id.setType(Node::ID::Internal);
}
Node *getNode(Node::ID &id) {
Node *n = 0;
Node::NodeID idx = id.getID();
if (id.getType() == Node::ID::Leaf) {
n = leafNodes.get(idx);
} else {
n = internalNodes.get(idx);
}
return n;
}
void getAllLeafNodeIDs(std::vector<Node::ID> &leafIDs) {
leafIDs.clear();
Node *root = getRootNode();
if (root->id.getType() == Node::ID::Leaf) {
leafIDs.push_back(root->id);
return;
}
UncheckedNode uncheckedNode;
uncheckedNode.push(root->id);
while (!uncheckedNode.empty()) {
Node::ID nodeid = uncheckedNode.top();
uncheckedNode.pop();
Node *cnode = getNode(nodeid);
if (cnode->id.getType() == Node::ID::Internal) {
InternalNode &inode = static_cast<InternalNode&>(*cnode);
for (size_t ci = 0; ci < internalChildrenSize; ci++) {
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
uncheckedNode.push(inode.getChildren(internalNodes.allocator)[ci]);
#else
uncheckedNode.push(inode.getChildren()[ci]);
#endif
}
} else if (cnode->id.getType() == Node::ID::Leaf) {
leafIDs.push_back(static_cast<LeafNode&>(*cnode).id);
} else {
std::cerr << "Tree: Inner fatal error!: Node type error!" << std::endl;
abort();
}
}
}
// for milvus
void serialize(std::stringstream & os)
{
leafNodes.serialize(os, objectSpace);
internalNodes.serialize(os, objectSpace);
}
void serialize(std::ofstream &os) {
leafNodes.serialize(os, objectSpace);
internalNodes.serialize(os, objectSpace);
}
void deserialize(std::ifstream &is) {
leafNodes.deserialize(is, objectSpace);
internalNodes.deserialize(is, objectSpace);
}
void deserialize(std::stringstream & is)
{
leafNodes.deserialize(is, objectSpace);
internalNodes.deserialize(is, objectSpace);
}
void serializeAsText(std::ofstream &os) {
leafNodes.serializeAsText(os, objectSpace);
internalNodes.serializeAsText(os, objectSpace);
}
void deserializeAsText(std::ifstream &is) {
leafNodes.deserializeAsText(is, objectSpace);
internalNodes.deserializeAsText(is, objectSpace);
}
void show() {
std::cout << "Show tree " << std::endl;
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
std::cout << i << ":";
(*leafNodes[i]).show();
}
}
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
std::cout << i << ":";
(*internalNodes[i]).show();
}
}
}
bool verify(size_t objCount, std::vector<uint8_t> &status) {
std::cerr << "Started verifying internal nodes. size=" << internalNodes.size() << "..." << std::endl;
bool valid = true;
for (size_t i = 0; i < internalNodes.size(); i++) {
if (internalNodes[i] != 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes, internalNodes.allocator);
#else
valid = valid && (*internalNodes[i]).verify(internalNodes, leafNodes);
#endif
}
}
std::cerr << "Started verifying leaf nodes. size=" << leafNodes.size() << " ..." << std::endl;
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
valid = valid && (*leafNodes[i]).verify(objCount, status, leafNodes.allocator);
#else
valid = valid && (*leafNodes[i]).verify(objCount, status);
#endif
}
}
return valid;
}
void deleteInMemory() {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
assert(0);
#else
for (std::vector<NGT::LeafNode*>::iterator i = leafNodes.begin(); i != leafNodes.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
leafNodes.clear();
for (std::vector<NGT::InternalNode*>::iterator i = internalNodes.begin(); i != internalNodes.end(); i++) {
if ((*i) != 0) {
delete (*i);
}
}
internalNodes.clear();
#endif
}
ObjectRepository &getObjectRepository() { return objectSpace->getRepository(); }
size_t getSharedMemorySize(std::ostream &os, SharedMemoryAllocator::GetMemorySizeType t) {
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
size_t isize = internalNodes.getAllocator().getMemorySize(t);
os << "internal=" << isize << std::endl;
size_t lsize = leafNodes.getAllocator().getMemorySize(t);
os << "leaf=" << lsize << std::endl;
return isize + lsize;
#else
return 0;
#endif
}
void getAllObjectIDs(std::set<ObjectID> &ids) {
for (size_t i = 0; i < leafNodes.size(); i++) {
if (leafNodes[i] != 0) {
LeafNode &ln = *leafNodes[i];
#if defined(NGT_SHARED_MEMORY_ALLOCATOR)
auto objs = ln.getObjectIDs(leafNodes.allocator);
#else
auto objs = ln.getObjectIDs();
#endif
for (size_t idx = 0; idx < ln.objectSize; ++idx) {
ids.insert(objs[idx].id);
}
}
}
}
virtual int64_t memSize() { return sizeof(size_t) * 2 + sizeof(splitMode) + name.size() + leafNodes.memSize() + internalNodes.memSize(); }
public:
size_t internalChildrenSize;
size_t leafObjectsSize;
SplitMode splitMode;
std::string name;
#ifdef NGT_SHARED_MEMORY_ALLOCATOR
PersistentRepository<LeafNode> leafNodes;
PersistentRepository<InternalNode> internalNodes;
#else
Repository<LeafNode> leafNodes;
Repository<InternalNode> internalNodes;
#endif
ObjectSpace *objectSpace;
};
} // namespace DVPTree

View File

@ -1,58 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "NGT/Version.h"
void
NGT::Version::get(std::ostream &os)
{
os << " Version:" << NGT::Version::getVersion() << std::endl;
os << " Built date:" << NGT::Version::getBuildDate() << std::endl;
os << " The last git tag:" << Version::getGitTag() << std::endl;
os << " The last git commit hash:" << Version::getGitHash() << std::endl;
os << " The last git commit date:" << Version::getGitDate() << std::endl;
}
const std::string
NGT::Version::getVersion()
{
return NGT_VERSION;
}
const std::string
NGT::Version::getBuildDate()
{
return NGT_BUILD_DATE;
}
const std::string
NGT::Version::getGitHash()
{
return NGT_GIT_HASH;
}
const std::string
NGT::Version::getGitDate()
{
return NGT_GIT_DATE;
}
const std::string
NGT::Version::getGitTag()
{
return NGT_GIT_TAG;
}

View File

@ -1,61 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <ostream>
#include <string>
#ifndef NGT_VERSION
#define NGT_VERSION "-"
#endif
#ifndef NGT_BUILD_DATE
#define NGT_BUILD_DATE "-"
#endif
#ifndef NGT_GIT_HASH
#define NGT_GIT_HASH "-"
#endif
#ifndef NGT_GIT_DATE
#define NGT_GIT_DATE "-"
#endif
#ifndef NGT_GIT_TAG
#define NGT_GIT_TAG "-"
#endif
namespace NGT {
class Version {
public:
static void
get(std::ostream& os);
static const std::string
getVersion();
static const std::string
getBuildDate();
static const std::string
getGitHash();
static const std::string
getGitDate();
static const std::string
getGitTag();
static const std::string
get();
};
}; // namespace NGT
#ifdef NGT_VERSION_FOR_HEADER
#include "Version.cpp"
#endif

View File

@ -1,29 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <string>
void (*NGT_LOG_TRACE_)(const std::string&);
void (*NGT_LOG_DEBUG_)(const std::string&);
void (*NGT_LOG_INFO_)(const std::string&);
void (*NGT_LOG_WARNING_)(const std::string&);
void (*NGT_LOG_FATAL_)(const std::string&);
void (*NGT_LOG_ERROR_)(const std::string&);

View File

@ -1,73 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <string>
// Begin of cmake defines
#if 0
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
#endif
// End of cmake defines
//////////////////////////////////////////////////////////////////////////
// Release Definitions for OSS
//#define NGT_DISTANCE_COMPUTATION_COUNT
#define NGT_CREATION_EDGE_SIZE 10
#define NGT_EXPLORATION_COEFFICIENT 1.1
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
#define NGT_COMPACT_VECTOR
#define NGT_GRAPH_READ_ONLY_GRAPH
#ifdef NGT_LARGE_DATASET
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
#else
#define NGT_GRAPH_CHECK_VECTOR
#endif
#if defined(NGT_AVX_DISABLED)
#define NGT_NO_AVX
#else
#if defined(__AVX512F__) && defined(__AVX512DQ__)
#define NGT_AVX512
#elif defined(__AVX2__)
#define NGT_AVX2
#else
#define NGT_NO_AVX
#endif
#endif
extern void (*NGT_LOG_TRACE_)(const std::string&);
extern void (*NGT_LOG_DEBUG_)(const std::string&);
extern void (*NGT_LOG_INFO_)(const std::string&);
extern void (*NGT_LOG_WARNING_)(const std::string&);
extern void (*NGT_LOG_FATAL_)(const std::string&);
extern void (*NGT_LOG_ERROR_)(const std::string&);

View File

@ -1,58 +0,0 @@
//
// Copyright (C) 2015-2020 Yahoo Japan Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
// Begin of cmake defines
#cmakedefine NGT_SHARED_MEMORY_ALLOCATOR // use shared memory for indexes
#cmakedefine NGT_GRAPH_CHECK_VECTOR // use vector to check whether accessed
#cmakedefine NGT_AVX_DISABLED // not use avx to compare
#cmakedefine NGT_LARGE_DATASET // more than 10M objects
#cmakedefine NGT_DISTANCE_COMPUTATION_COUNT // count # of distance computations
// End of cmake defines
//////////////////////////////////////////////////////////////////////////
// Release Definitions for OSS
//#define NGT_DISTANCE_COMPUTATION_COUNT
#define NGT_CREATION_EDGE_SIZE 10
#define NGT_EXPLORATION_COEFFICIENT 1.1
#define NGT_INSERTION_EXPLORATION_COEFFICIENT 1.1
#define NGT_SHARED_MEMORY_MAX_SIZE 1024 // MB
#define NGT_FORCED_REMOVE // When errors occur due to the index inconsistency, ignore them.
#define NGT_COMPACT_VECTOR
#define NGT_GRAPH_READ_ONLY_GRAPH
#ifdef NGT_LARGE_DATASET
#define NGT_GRAPH_CHECK_HASH_BASED_BOOLEAN_SET
#else
#define NGT_GRAPH_CHECK_VECTOR
#endif
#if defined(NGT_AVX_DISABLED)
#define NGT_NO_AVX
#else
#if defined(__AVX512F__) && defined(__AVX512DQ__)
#define NGT_AVX512
#elif defined(__AVX2__)
#define NGT_AVX2
#else
#define NGT_NO_AVX
#endif
#endif

0
internal/core/src/index/thirdparty/annoy/examples/s_compile_cpp.sh vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/build.sh vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/Clustering.cpp vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/Clustering.h vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/build-aux/config.guess vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/build-aux/config.sub vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/build-aux/install-sh vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/build.sh vendored Executable file → Normal file
View File

0
internal/core/src/index/thirdparty/faiss/configure vendored Executable file → Normal file
View File

View File

@ -1 +0,0 @@
.

View File

@ -0,0 +1 @@
.

View File

@ -19,6 +19,7 @@ add_library(milvus_indexbuilder SHARED
${INDEXBUILDER_FILES}
)
find_library(TBB NAMES tbb)
set(PLATFORM_LIBS dl)
if (MSYS)
set(PLATFORM_LIBS )
@ -26,12 +27,12 @@ endif ()
# link order matters
target_link_libraries(milvus_indexbuilder
knowhere
milvus_config
milvus_common
milvus_utils
milvus_proto
knowhere
tbb
${TBB}
log
${PLATFORM_LIBS}
pthread

View File

@ -10,8 +10,10 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <string>
#ifndef __APPLE__
#include <malloc.h>
#include "index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h"
#endif
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "indexbuilder/IndexWrapper.h"
#include "indexbuilder/index_c.h"
@ -47,7 +49,9 @@ void
DeleteIndex(CIndex index) {
auto cIndex = (milvus::indexbuilder::IndexWrapper*)index;
delete cIndex;
#ifndef __APPLE__
malloc_trim(0);
#endif
}
CStatus

View File

@ -16,7 +16,7 @@
#include <tuple>
#include <vector>
#include "index/knowhere/knowhere/index/IndexType.h"
#include "knowhere/index/IndexType.h"
namespace milvus::indexbuilder {

View File

@ -21,9 +21,3 @@ set(LOG_FILES ${MILVUS_ENGINE_SRC}/log/Log.cpp
add_library(log STATIC ${LOG_FILES})
set_target_properties(log PROPERTIES RULE_LAUNCH_COMPILE "")
set_target_properties(log PROPERTIES RULE_LAUNCH_LINK "")
if(MSYS)
target_link_libraries( log PRIVATE )
else()
target_link_libraries( log PRIVATE fiu )
endif()

View File

@ -46,7 +46,13 @@ LogOut(const char* pattern, ...) {
void
SetThreadName(const std::string& name) {
// Note: the name cannot exceed 16 bytes
#ifdef __APPLE__
pthread_setname_np(name.c_str());
#elif __linux__
pthread_setname_np(pthread_self(), name.c_str());
#else
#error "Unsupported SetThreadName";
#endif
}
std::string
@ -84,17 +90,26 @@ get_system_boottime() {
int64_t
get_thread_starttime() {
#ifdef __APPLE__
uint64_t tid;
pthread_threadid_np(NULL, &tid);
#elif __linux__
int64_t tid = gettid();
#else
#error "Unsupported SetThreadName";
#endif
int64_t pid = getpid();
char filename[256];
snprintf(filename, sizeof(filename), "/proc/%ld/task/%ld/stat", pid, tid);
snprintf(filename, sizeof(filename), "/proc/%lld/task/%lld/stat", (long long)pid, (long long)tid); // NOLINT
int64_t val = 0;
char comm[16], state;
FILE* thread_stat = fopen(filename, "r");
auto ret = fscanf(thread_stat, "%ld %s %s ", &val, comm, &state);
auto ret = fscanf(thread_stat, "%lld %s %s ", (long long*)&val, comm, &state); // NOLINT
for (auto i = 4; i < 23; i++) {
ret = fscanf(thread_stat, "%ld ", &val);
ret = fscanf(thread_stat, "%lld ", (long long*)&val); // NOLINT
if (i == 22) {
break;
}

View File

@ -14,7 +14,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fiu/fiu-local.h>
#include <libgen.h>
#include <cctype>
#include <string>
@ -152,22 +151,16 @@ LogMgr::Filename(const std::string& filename) {
LogMgr&
LogMgr::Level(std::unordered_map<std::string, bool>& enables) {
fiu_do_on("LogMgr.Level.trace_enable_to_false", enables["trace"] = false);
enable(el_config_, el::Level::Trace, enables["trace"]);
fiu_do_on("LogMgr.Level.info_enable_to_false", enables["info"] = false);
enable(el_config_, el::Level::Info, enables["info"]);
fiu_do_on("LogMgr.Level.debug_enable_to_false", enables["debug"] = false);
enable(el_config_, el::Level::Debug, enables["debug"]);
fiu_do_on("LogMgr.Level.warning_enable_to_false", enables["warning"] = false);
enable(el_config_, el::Level::Warning, enables["warning"]);
fiu_do_on("LogMgr.Level.error_enable_to_false", enables["error"] = false);
enable(el_config_, el::Level::Error, enables["error"]);
fiu_do_on("LogMgr.Level.fatal_enable_to_false", enables["fatal"] = false);
enable(el_config_, el::Level::Fatal, enables["fatal"]);
return *this;
@ -183,7 +176,6 @@ LogMgr::To(bool log_to_stdout, bool log_to_file) {
LogMgr&
LogMgr::Rotate(int64_t max_log_file_size, int64_t log_rotate_num) {
fiu_do_on("LogMgr.Rotate.set_max_log_size_small_than_min", max_log_file_size = MAX_LOG_FILE_SIZE_MIN - 1);
if (max_log_file_size < MAX_LOG_FILE_SIZE_MIN || max_log_file_size > MAX_LOG_FILE_SIZE_MAX) {
std::string msg = "max_log_file_size must in range[" + std::to_string(MAX_LOG_FILE_SIZE_MIN) + ", " +
std::to_string(MAX_LOG_FILE_SIZE_MAX) + "], now is " + std::to_string(max_log_file_size);
@ -197,7 +189,6 @@ LogMgr::Rotate(int64_t max_log_file_size, int64_t log_rotate_num) {
// set delete_exceeds = 0 means disable throw away log file even they reach certain limit.
if (log_rotate_num != 0) {
fiu_do_on("LogMgr.Rotate.delete_exceeds_small_than_min", log_rotate_num = LOG_ROTATE_NUM_MIN - 1);
if (log_rotate_num < LOG_ROTATE_NUM_MIN || log_rotate_num > LOG_ROTATE_NUM_MAX) {
std::string msg = "log_rotate_num must in range[" + std::to_string(LOG_ROTATE_NUM_MIN) + ", " +
std::to_string(LOG_ROTATE_NUM_MAX) + "], now is " + std::to_string(log_rotate_num);

View File

@ -12,8 +12,15 @@
#include <string>
#include <vector>
#include <faiss/utils/distances.h>
#ifdef __APPLE__
#include "knowhere/index/vector_index/impl/bruteforce/BruteForce.h"
#include "knowhere/common/Heap.h"
#elif __linux__
#include <faiss/utils/BinaryDistance.h>
#include <faiss/utils/distances.h>
#else
#error "Unsupported OS environment.";
#endif
#include "SearchBruteForce.h"
#include "SubSearchResult.h"
@ -35,6 +42,7 @@ raw_search(MetricType metric_type,
float* D,
idx_t* labels,
const BitsetView bitset) {
#ifdef __linux__
using namespace faiss; // NOLINT
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
float_maxheap_array_t res = {size_t(n), size_t(k), labels, D};
@ -62,6 +70,9 @@ raw_search(MetricType metric_type,
std::string("binary search not support metric type: ") + segcore::MetricTypeToString(metric_type);
PanicInfo(msg);
}
#else
PanicInfo("Unsupported brute force for binary search on current OS environment!");
#endif
}
SubSearchResult
@ -100,18 +111,29 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset,
SubSearchResult sub_qr(num_queries, topk, metric_type, round_decimal);
auto query_data = reinterpret_cast<const float*>(dataset.query_data);
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
#ifdef __APPLE__
if (metric_type == MetricType::METRIC_L2) {
knowhere::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(),
sub_qr.get_distances()};
knowhere::knn_L2sqr_sse(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
} else {
knowhere::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(),
sub_qr.get_distances()};
knowhere::knn_inner_product_sse(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
}
#elif __linux__
if (metric_type == MetricType::METRIC_L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
sub_qr.round_values();
return sub_qr;
} else {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_ids(), sub_qr.get_distances()};
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
sub_qr.round_values();
return sub_qr;
}
#else
#error "Unsupported OS environment!";
#endif
sub_qr.round_values();
return sub_qr;
}
SubSearchResult

View File

@ -11,8 +11,6 @@
#pragma once
#include <faiss/utils/BinaryDistance.h>
#include "common/Schema.h"
#include "query/SubSearchResult.h"
#include "query/helper.h"

View File

@ -28,7 +28,7 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
visit(BinaryVectorANNS& node) override;
void
visit(RetrievePlanNode& node);
visit(RetrievePlanNode& node) override;
public:
using RetType = SearchResult;

View File

@ -25,7 +25,7 @@ class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor {
visit(BinaryVectorANNS& node) override;
void
visit(RetrievePlanNode& node);
visit(RetrievePlanNode& node) override;
public:
explicit ExtractInfoPlanNodeVisitor(ExtractedPlanInfo& plan_info) : plan_info_(plan_info) {

View File

@ -32,6 +32,7 @@ add_library(milvus_segcore SHARED
${SEGCORE_FILES}
)
find_library(TBB NAMES tbb)
set(PLATFORM_LIBS dl)
if (MSYS)
set(PLATFORM_LIBS )
@ -41,7 +42,7 @@ target_link_libraries(milvus_segcore
${PLATFORM_LIBS}
log
pthread
tbb
${TBB}
${OpenMP_CXX_FLAGS}
knowhere
milvus_common

View File

@ -16,7 +16,6 @@
#include "AckResponder.h"
#include "common/Schema.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "segcore/Record.h"
namespace milvus::segcore {

View File

@ -13,7 +13,9 @@
#include <thread>
#include "common/SystemProperty.h"
#ifdef __linux__
#include "knowhere/index/vector_index/IndexIVF.h"
#endif
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "segcore/FieldIndexing.h"
@ -21,6 +23,7 @@ namespace milvus::segcore {
void
VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
#ifdef __linux__
AssertInfo(field_meta_.get_data_type() == DataType::VECTOR_FLOAT, "Data type of vector field is not VECTOR_FLOAT");
auto dim = field_meta_.get_dim();
@ -39,6 +42,9 @@ VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vec
indexing->AddWithoutIds(dataset, conf);
data_[chunk_id] = std::move(indexing);
}
#else
throw std::runtime_error("Unsupported BuildIndexRange on current platform!");
#endif
}
knowhere::Config

View File

@ -21,7 +21,7 @@
#include "InsertRecord.h"
#include "common/Schema.h"
#include "knowhere/index/structured_index_simple/StructuredIndexSort.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "segcore/SegcoreConfig.h"
namespace milvus::segcore {
@ -36,6 +36,7 @@ class FieldIndexing {
FieldIndexing(const FieldIndexing&) = delete;
FieldIndexing&
operator=(const FieldIndexing&) = delete;
virtual ~FieldIndexing() = default;
// Do this in parallel
virtual void

View File

@ -119,7 +119,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
enable_small_index_ = false;
}
ssize_t
int64_t
get_row_count() const override {
return record_.ack_responder_.GetAck();
}

View File

@ -11,7 +11,7 @@
#pragma once
#include "index/thirdparty/faiss/MetricType.h"
#include <knowhere/common/MetricType.h>
namespace milvus::segcore {
static inline bool

View File

@ -13,8 +13,7 @@
#include <string>
#include <exception>
#include <stdexcept>
#include "index/thirdparty/faiss/MetricType.h"
#include <knowhere/common/MetricType.h>
namespace milvus::segcore {

View File

@ -9,8 +9,11 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <iostream>
#ifndef __APPLE__
#include <malloc.h>
#endif
#include <iostream>
#include "segcore/collection_c.h"
#include "segcore/Collection.h"
@ -25,7 +28,9 @@ void
DeleteCollection(CCollection collection) {
auto col = (milvus::segcore::Collection*)collection;
delete col;
#ifndef __APPLE__
malloc_trim(0);
#endif
}
const char*

View File

@ -11,8 +11,8 @@
#include "common/LoadInfo.h"
#include "exceptions/EasyAssert.h"
#include "index/knowhere/knowhere/common/BinarySet.h"
#include "index/knowhere/knowhere/index/vector_index/VecIndexFactory.h"
#include "knowhere/common/BinarySet.h"
#include "knowhere/index/vector_index/VecIndexFactory.h"
#include "segcore/load_index_c.h"
CStatus

View File

@ -15,8 +15,4 @@ aux_source_directory( ${MILVUS_ENGINE_SRC}/utils UTILS_FILES )
add_library( milvus_utils STATIC ${UTILS_FILES} )
if(MSYS)
target_link_libraries( milvus_utils PRIVATE milvus_exceptions)
else()
target_link_libraries( milvus_utils PRIVATE fiu milvus_exceptions )
endif()

View File

@ -29,9 +29,7 @@ set( VALUE_SRCS config/ConfigInit.cpp
ValueType.cpp
)
set( VALUE_LIBS yaml-cpp
fiu
)
set( VALUE_LIBS yaml-cpp)
create_library(
TARGET value
@ -41,20 +39,17 @@ create_library(
if ( BUILD_UNIT_TEST )
create_library(
TARGET value-fiu
TARGET value-test
SRCS ${VALUE_SRCS}
LIBS ${VALUE_LIBS}
DEFS FIU_ENABLE
)
target_compile_definitions(value-fiu PRIVATE FIU_ENABLE)
LIBS ${VALUE_LIBS})
set(GTEST_LIBS gtest gtest_main gmock gmock_main)
create_executable(
TARGET ConfigMgrTest
SRCS config/ConfigMgrTest
LIBS value-fiu ${GTEST_LIBS}
DEFS FIU_ENABLE
LIBS value-test ${GTEST_LIBS}
DEFS ""
)
add_test ( NAME ConfigMgrTest
@ -65,7 +60,7 @@ if ( BUILD_UNIT_TEST )
TARGET ServerConfigTest
SRCS config/ServerConfigTest
LIBS value-fiu ${GTEST_LIBS}
DEFS FIU_ENABLE
DEFS ""
)
add_test ( NAME ServerConfigTest
@ -75,8 +70,8 @@ if ( BUILD_UNIT_TEST )
create_executable(
TARGET ValueTypeTest
SRCS ValueTypeTest1 ValueTypeTest2
LIBS value-fiu ${GTEST_LIBS}
DEFS FIU_ENABLE
LIBS value-test ${GTEST_LIBS}
DEFS ""
)
add_test ( NAME ValueTypeTest

View File

@ -1,34 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <fiu-control.h>
#include <fiu/fiu-local.h>
#include <gtest/gtest.h>
#include "ConfigMgr.h"
#include "value/config/ServerConfig.h"
namespace milvus {
// TODO: need a safe directory for testing
// TEST(ConfigMgrTest, set_version) {
// ConfigMgr::GetInstance().Init();
// ConfigMgr::GetInstance().LoadMemory(R"(
// version: 0.1
//)");
// ConfigMgr::GetInstance().FilePath() = "/tmp/milvus_unittest_configmgr.yaml";
//
// ASSERT_EQ(ConfigMgr::GetInstance().Get("version"), "0.1");
// ConfigMgr::GetInstance().Set("version", "100.0");
// ASSERT_EQ(ConfigMgr::GetInstance().Get("version"), "100.0");
//}
} // namespace milvus

View File

@ -9,8 +9,6 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <fiu/fiu-local.h>
#include <iostream>
#include <sstream>
#include <string>
@ -45,7 +43,6 @@ ParseGPUDevices(const std::string& str) {
std::string device;
while (std::getline(ss, device, ',')) {
fiu_do_on("ParseGPUDevices.invalid_format", device = "");
if (device.length() < 4) {
/* Invalid format string */
return {};

View File

@ -9,15 +9,11 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include <fiu-control.h>
#include <fiu/fiu-local.h>
#include <gtest/gtest.h>
#include "value/config/ServerConfig.h"
TEST(ServerConfigTest, parse_invalid_devices) {
fiu_init(0);
fiu_enable("ParseGPUDevices.invalid_format", 1, nullptr, 0);
auto collections = milvus::ParseGPUDevices("gpu0,gpu1");
ASSERT_EQ(collections.size(), 0);
}

View File

@ -38,10 +38,15 @@ message( STATUS "Thirdparty downloaded file path: ${THIRDPARTY_DOWNLOAD_PATH}" )
set( THREADS_PREFER_PTHREAD_FLAG ON )
find_package( Threads REQUIRED )
add_subdirectory( knowhere )
# ****************************** Thirdparty googletest ***************************************
if ( MILVUS_BUILD_TESTS )
if ( MILVUS_BUILD_TESTS)
add_subdirectory( gtest )
add_subdirectory( google_benchmark )
endif()
if ( MILVUS_BUILD_TESTS AND LINUX )
add_subdirectory( profiler )
endif()
@ -55,5 +60,4 @@ if ( MILVUS_WITH_OPENTRACING )
endif()
add_subdirectory( protobuf )
add_subdirectory( boost_ext )
add_subdirectory( fiu )
add_subdirectory( boost_ext )

View File

@ -1,66 +0,0 @@
#-------------------------------------------------------------------------------
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
#-------------------------------------------------------------------------------
if(NOT DEFINED FIU_VERSION)
set(FIU_VERSION 1.00)
endif()
if ( DEFINED ENV{KNOWHERE_FIU_URL} )
set( FIU_SOURCE_URL "$ENV{MILVUS_FIU_URL}" )
else ()
set( FIU_SOURCE_URL "https://github.com/albertito/libfiu/archive/${FIU_VERSION}.tar.gz" )
endif ()
macro( build_fiu )
message( STATUS "Building FIU-${FIU_VERSION} from source" )
ExternalProject_Add(
fiu_ep
PREFIX ${CMAKE_BINARY_DIR}/3rdparty_download/fiu-subbuild
DOWNLOAD_DIR ${THIRDPARTY_DOWNLOAD_PATH}
INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}
URL ${FIU_SOURCE_URL}
URL_MD5 "75f9d076daf964c9410611701f07c61b"
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
BUILD_COMMAND ${MAKE}
INSTALL_COMMAND ${MAKE} "PREFIX=<INSTALL_DIR>" install
${EP_LOG_OPTIONS}
)
ExternalProject_Get_Property( fiu_ep INSTALL_DIR )
if( NOT IS_DIRECTORY ${INSTALL_DIR}/include )
file( MAKE_DIRECTORY "${INSTALL_DIR}/include" )
endif()
add_library( fiu SHARED IMPORTED )
set_target_properties( fiu
PROPERTIES
IMPORTED_GLOBAL TRUE
IMPORTED_LOCATION ${INSTALL_DIR}/lib/libfiu.so
INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include )
add_dependencies(fiu fiu_ep)
endmacro()
if (WIN32)
# nothing
message("skip building fiu on windows")
else ()
build_fiu()
install( FILES ${INSTALL_DIR}/lib/libfiu.so
${INSTALL_DIR}/lib/libfiu.so.0
${INSTALL_DIR}/lib/libfiu.so.1.00
DESTINATION lib )
get_target_property( var fiu INTERFACE_INCLUDE_DIRECTORIES )
message( STATUS ${var} )
set_directory_properties( PROPERTY INCLUDE_DIRECTORIES ${var} )
endif ()

Some files were not shown because too many files have changed in this diff Show More