diff --git a/CHANGELOG.md b/CHANGELOG.md index eda2e7fda2..1f2696e194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Please mark all change in change log and use the ticket from JIRA. -# Milvus 0.5.0 (TODO) +# Milvus 0.5.0 (2019-10-21) ## Bug - MS-568 - Fix gpuresource free error @@ -26,6 +26,15 @@ Please mark all change in change log and use the ticket from JIRA. - MS-653 - When config check fail, Milvus close without message - MS-654 - Describe index timeout when building index - MS-658 - Fix SQ8 Hybrid can't search +- MS-665 - IVF_SQ8H search crash when no GPU resource in search_resources +- \#9 - Change default gpu_cache_capacity to 4 +- \#20 - C++ sdk example get grpc error +- \#23 - Add unittest to improve code coverage +- \#31 - make clang-format failed after run build.sh -l +- \#39 - Create SQ8H index hang if using github server version +- \#30 - Some troubleshoot messages in Milvus do not provide enough information +- \#48 - Config unittest failed +- \#59 - Topk result is incorrect for small dataset ## Improvement - MS-552 - Add and change the easylogging library @@ -47,6 +56,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-626 - Refactor DataObj to support cache any type data - MS-648 - Improve unittest - MS-655 - Upgrade SPTAG +- \#42 - Put union of index_build_device and search resources to gpu_pool ## New Feature - MS-614 - Preload table at startup @@ -68,6 +78,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-624 - Re-organize project directory for open-source - MS-635 - Add compile option to support customized faiss - MS-660 - add ubuntu_build_deps.sh +- \#18 - Add all test cases # Milvus 0.4.0 (2019-09-12) diff --git a/README.md b/README.md index 153969bdde..3d1979be4a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ ![Milvuslogo](https://github.com/milvus-io/docs/blob/master/assets/milvus_logo.png) + ![LICENSE](https://img.shields.io/badge/license-Apache--2.0-brightgreen) ![Language](https://img.shields.io/badge/language-C%2B%2B-blue) +[![codebeat badge](https://codebeat.co/badges/e030a4f6-b126-4475-a938-4723d54ec3a7?style=plastic)](https://codebeat.co/projects/github-com-jinhai-cn-milvus-master) - [Slack Community](https://join.slack.com/t/milvusio/shared_invite/enQtNzY1OTQ0NDI3NjMzLWNmYmM1NmNjOTQ5MGI5NDhhYmRhMGU5M2NhNzhhMDMzY2MzNDdlYjM5ODQ5MmE3ODFlYzU3YjJkNmVlNDQ2ZTk) - [Twitter](https://twitter.com/milvusio) @@ -54,6 +56,7 @@ Keep up-to-date with newest releases and latest updates by reading Milvus [relea You can track system performance on Prometheus-based GUI monitor dashboards. ## Architecture + ![Milvus_arch](https://github.com/milvus-io/docs/blob/master/assets/milvus_arch.png) ## Get started diff --git a/ci/function/file_transfer.groovy b/ci/function/file_transfer.groovy new file mode 100644 index 0000000000..bebae14832 --- /dev/null +++ b/ci/function/file_transfer.groovy @@ -0,0 +1,10 @@ +def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) { + if (protocol == "ftp") { + ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [ + [configName: "${remoteIP}", transfers: [ + [asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true + ] + ] + } +} +return this diff --git a/ci/jenkins/Jenkinsfile b/ci/jenkins/Jenkinsfile new file mode 100644 index 0000000000..fbdf3a3096 --- /dev/null +++ b/ci/jenkins/Jenkinsfile @@ -0,0 +1,152 @@ +pipeline { + agent none + + options { + timestamps() + } + + parameters{ + choice choices: ['Release', 'Debug'], description: '', name: 'BUILD_TYPE' + string defaultValue: 'cf1434e7-5a4b-4d25-82e8-88d667aef9e5', description: 'GIT CREDENTIALS ID', name: 'GIT_CREDENTIALS_ID', trim: true + string defaultValue: 'registry.zilliz.com', description: 'DOCKER REGISTRY URL', name: 'DOKCER_REGISTRY_URL', trim: true + string defaultValue: 'ba070c98-c8cc-4f7c-b657-897715f359fc', description: 'DOCKER CREDENTIALS ID', name: 'DOCKER_CREDENTIALS_ID', trim: true + string defaultValue: 'http://192.168.1.202/artifactory/milvus', description: 'JFROG ARTFACTORY URL', name: 'JFROG_ARTFACTORY_URL', trim: true + string defaultValue: '1a527823-d2b7-44fd-834b-9844350baf14', description: 'JFROG CREDENTIALS ID', name: 'JFROG_CREDENTIALS_ID', trim: true + } + + environment { + PROJECT_NAME = "milvus" + LOWER_BUILD_TYPE = params.BUILD_TYPE.toLowerCase() + SEMVER = "${BRANCH_NAME}" + JOBNAMES = env.JOB_NAME.split('/') + PIPELINE_NAME = "${JOBNAMES[0]}" + } + + stages { + stage("Ubuntu 18.04") { + environment { + OS_NAME = "ubuntu18.04" + PACKAGE_VERSION = VersionNumber([ + versionNumberString : '${SEMVER}-${LOWER_BUILD_TYPE}-ubuntu18.04-x86_64-${BUILD_DATE_FORMATTED, "yyyyMMdd"}-${BUILDS_TODAY}' + ]); + DOCKER_VERSION = "${SEMVER}-${OS_NAME}-${LOWER_BUILD_TYPE}" + } + + stages { + stage("Run Build") { + agent { + kubernetes { + label 'build' + defaultContainer 'jnlp' + yamlFile 'ci/jenkins/pod/milvus-build-env-pod.yaml' + } + } + + stages { + stage('Build') { + steps { + container('milvus-build-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/build.groovy" + } + } + } + } + stage('Code Coverage') { + steps { + container('milvus-build-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/coverage.groovy" + } + } + } + } + stage('Upload Package') { + steps { + container('milvus-build-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/package.groovy" + } + } + } + } + } + } + + stage("Publish docker images") { + agent { + kubernetes { + label 'publish' + defaultContainer 'jnlp' + yamlFile 'ci/jenkins/pod/docker-pod.yaml' + } + } + + stages { + stage('Publish') { + steps { + container('publish-images'){ + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/publishImages.groovy" + } + } + } + } + } + } + + stage("Deploy to Development") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yamlFile 'ci/jenkins/pod/testEnvironment.yaml' + } + } + + stages { + stage("Deploy to Dev") { + steps { + container('milvus-test-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/deploySingle2Dev.groovy" + } + } + } + } + + stage("Dev Test") { + steps { + container('milvus-test-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/singleDevTest.groovy" + } + } + } + } + + stage ("Cleanup Dev") { + steps { + container('milvus-test-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy" + } + } + } + } + } + post { + unsuccessful { + container('milvus-test-env') { + script { + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy" + } + } + } + } + } + } + } + } +} + diff --git a/ci/jenkins/jenkinsfile/build.groovy b/ci/jenkins/jenkinsfile/build.groovy new file mode 100644 index 0000000000..14d0414f4f --- /dev/null +++ b/ci/jenkins/jenkinsfile/build.groovy @@ -0,0 +1,9 @@ +timeout(time: 60, unit: 'MINUTES') { + dir ("ci/jenkins/scripts") { + sh "./build.sh -l" + withCredentials([usernamePassword(credentialsId: "${params.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' && export JFROG_USER_NAME='${USERNAME}' && export JFROG_PASSWORD='${PASSWORD}' && ./build.sh -t ${params.BUILD_TYPE} -o /opt/milvus -d /opt/milvus -j -u -c" + } + } +} + diff --git a/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy b/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy new file mode 100644 index 0000000000..6e85a678be --- /dev/null +++ b/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy @@ -0,0 +1,9 @@ +try { + sh "helm del --purge ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu" +} catch (exc) { + def helmResult = sh script: "helm status ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu", returnStatus: true + if (!helmResult) { + sh "helm del --purge ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu" + } + throw exc +} diff --git a/ci/jenkins/jenkinsfile/coverage.groovy b/ci/jenkins/jenkinsfile/coverage.groovy new file mode 100644 index 0000000000..7c3b16c029 --- /dev/null +++ b/ci/jenkins/jenkinsfile/coverage.groovy @@ -0,0 +1,10 @@ +timeout(time: 60, unit: 'MINUTES') { + dir ("ci/jenkins/scripts") { + sh "./coverage.sh -o /opt/milvus -u root -p 123456 -t \$POD_IP" + // Set some env variables so codecov detection script works correctly + withCredentials([[$class: 'StringBinding', credentialsId: "${env.PIPELINE_NAME}-codecov-token", variable: 'CODECOV_TOKEN']]) { + sh 'curl -s https://codecov.io/bash | bash -s - -f output_new.info || echo "Codecov did not collect coverage reports"' + } + } +} + diff --git a/ci/jenkins/jenkinsfile/deploySingle2Dev.groovy b/ci/jenkins/jenkinsfile/deploySingle2Dev.groovy new file mode 100644 index 0000000000..2ab13486a6 --- /dev/null +++ b/ci/jenkins/jenkinsfile/deploySingle2Dev.groovy @@ -0,0 +1,14 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo update' + dir ('milvus-helm') { + checkout([$class: 'GitSCM', branches: [[name: "0.5.0"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_CREDENTIALS_ID}", url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.0:refs/remotes/origin/0.5.0"]]]) + dir ("milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu -f ci/values.yaml --namespace milvus ." + } + } +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu" + throw exc +} diff --git a/ci/jenkins/jenkinsfile/package.groovy b/ci/jenkins/jenkinsfile/package.groovy new file mode 100644 index 0000000000..edd6ce88da --- /dev/null +++ b/ci/jenkins/jenkinsfile/package.groovy @@ -0,0 +1,9 @@ +timeout(time: 5, unit: 'MINUTES') { + sh "tar -zcvf ./${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz -C /opt/ milvus" + withCredentials([usernamePassword(credentialsId: "${params.JFROG_CREDENTIALS_ID}", usernameVariable: 'JFROG_USERNAME', passwordVariable: 'JFROG_PASSWORD')]) { + def uploadStatus = sh(returnStatus: true, script: "curl -u${JFROG_USERNAME}:${JFROG_PASSWORD} -T ./${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz ${params.JFROG_ARTFACTORY_URL}/milvus/package/${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz") + if (uploadStatus != 0) { + error("\" ${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz \" upload to \" ${params.JFROG_ARTFACTORY_URL}/milvus/package/${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz \" failed!") + } + } +} diff --git a/ci/jenkins/jenkinsfile/publishImages.groovy b/ci/jenkins/jenkinsfile/publishImages.groovy new file mode 100644 index 0000000000..62df0c73bf --- /dev/null +++ b/ci/jenkins/jenkinsfile/publishImages.groovy @@ -0,0 +1,47 @@ +container('publish-images') { + timeout(time: 15, unit: 'MINUTES') { + dir ("docker/deploy/${OS_NAME}") { + def binaryPackage = "${PROJECT_NAME}-${PACKAGE_VERSION}.tar.gz" + + withCredentials([usernamePassword(credentialsId: "${params.JFROG_CREDENTIALS_ID}", usernameVariable: 'JFROG_USERNAME', passwordVariable: 'JFROG_PASSWORD')]) { + def downloadStatus = sh(returnStatus: true, script: "curl -u${JFROG_USERNAME}:${JFROG_PASSWORD} -O ${params.JFROG_ARTFACTORY_URL}/milvus/package/${binaryPackage}") + + if (downloadStatus != 0) { + error("\" Download \" ${params.JFROG_ARTFACTORY_URL}/milvus/package/${binaryPackage} \" failed!") + } + } + sh "tar zxvf ${binaryPackage}" + def imageName = "${PROJECT_NAME}/engine:${DOCKER_VERSION}" + + try { + def isExistSourceImage = sh(returnStatus: true, script: "docker inspect --type=image ${imageName} 2>&1 > /dev/null") + if (isExistSourceImage == 0) { + def removeSourceImageStatus = sh(returnStatus: true, script: "docker rmi ${imageName}") + } + + def customImage = docker.build("${imageName}") + + def isExistTargeImage = sh(returnStatus: true, script: "docker inspect --type=image ${params.DOKCER_REGISTRY_URL}/${imageName} 2>&1 > /dev/null") + if (isExistTargeImage == 0) { + def removeTargeImageStatus = sh(returnStatus: true, script: "docker rmi ${params.DOKCER_REGISTRY_URL}/${imageName}") + } + + docker.withRegistry("https://${params.DOKCER_REGISTRY_URL}", "${params.DOCKER_CREDENTIALS_ID}") { + customImage.push() + } + } catch (exc) { + throw exc + } finally { + def isExistSourceImage = sh(returnStatus: true, script: "docker inspect --type=image ${imageName} 2>&1 > /dev/null") + if (isExistSourceImage == 0) { + def removeSourceImageStatus = sh(returnStatus: true, script: "docker rmi ${imageName}") + } + + def isExistTargeImage = sh(returnStatus: true, script: "docker inspect --type=image ${params.DOKCER_REGISTRY_URL}/${imageName} 2>&1 > /dev/null") + if (isExistTargeImage == 0) { + def removeTargeImageStatus = sh(returnStatus: true, script: "docker rmi ${params.DOKCER_REGISTRY_URL}/${imageName}") + } + } + } + } +} diff --git a/ci/jenkins/jenkinsfile/singleDevTest.groovy b/ci/jenkins/jenkinsfile/singleDevTest.groovy new file mode 100644 index 0000000000..ae57ffd42b --- /dev/null +++ b/ci/jenkins/jenkinsfile/singleDevTest.groovy @@ -0,0 +1,22 @@ +timeout(time: 30, unit: 'MINUTES') { + dir ("tests/milvus_python_test") { + sh 'python3 -m pip install -r requirements.txt' + sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --level=1 --ip ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu-milvus-gpu-engine.milvus.svc.cluster.local" + } + // mysql database backend test + load "${env.WORKSPACE}/ci/jenkins/jenkinsfile/cleanupSingleDev.groovy" + + if (!fileExists('milvus-helm')) { + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "0.5.0"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_CREDENTIALS_ID}", url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/0.5.0:refs/remotes/origin/0.5.0"]]]) + } + } + dir ("milvus-helm") { + dir ("milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu -f ci/db_backend/mysql_values.yaml --namespace milvus ." + } + } + dir ("tests/milvus_python_test") { + sh "pytest . --alluredir=\"test_out/dev/single/mysql\" --level=1 --ip ${env.PIPELINE_NAME}-${env.BUILD_NUMBER}-single-gpu-milvus-gpu-engine.milvus.svc.cluster.local" + } +} diff --git a/ci/jenkins/pod/docker-pod.yaml b/ci/jenkins/pod/docker-pod.yaml new file mode 100644 index 0000000000..a4fed0bcad --- /dev/null +++ b/ci/jenkins/pod/docker-pod.yaml @@ -0,0 +1,23 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-images + image: registry.zilliz.com/library/docker:v1.0.0 + securityContext: + privileged: true + command: + - cat + tty: true + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock + diff --git a/ci/jenkins/pod/milvus-build-env-pod.yaml b/ci/jenkins/pod/milvus-build-env-pod.yaml new file mode 100644 index 0000000000..bb4499711f --- /dev/null +++ b/ci/jenkins/pod/milvus-build-env-pod.yaml @@ -0,0 +1,35 @@ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.5.0-ubuntu18.04 + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + command: + - cat + tty: true + resources: + limits: + memory: "32Gi" + cpu: "8.0" + nvidia.com/gpu: 1 + requests: + memory: "16Gi" + cpu: "4.0" + - name: milvus-mysql + image: mysql:5.6 + env: + - name: MYSQL_ROOT_PASSWORD + value: 123456 + ports: + - containerPort: 3306 + name: mysql diff --git a/ci/jenkins/pod/testEnvironment.yaml b/ci/jenkins/pod/testEnvironment.yaml new file mode 100644 index 0000000000..174277dfcc --- /dev/null +++ b/ci/jenkins/pod/testEnvironment.yaml @@ -0,0 +1,22 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: test-env +spec: + containers: + - name: milvus-test-env + image: registry.zilliz.com/milvus/milvus-test-env:v0.1 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config + diff --git a/ci/jenkins/scripts/build.sh b/ci/jenkins/scripts/build.sh new file mode 100755 index 0000000000..2ccdf4a618 --- /dev/null +++ b/ci/jenkins/scripts/build.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +SOURCE="${BASH_SOURCE[0]}" +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" + +CMAKE_BUILD_DIR="${SCRIPTS_DIR}/../../../core/cmake_build" +BUILD_TYPE="Debug" +BUILD_UNITTEST="OFF" +INSTALL_PREFIX="/opt/milvus" +BUILD_COVERAGE="OFF" +DB_PATH="/opt/milvus" +PROFILING="OFF" +USE_JFROG_CACHE="OFF" +RUN_CPPLINT="OFF" +CUSTOMIZATION="OFF" # default use ori faiss +CUDA_COMPILER=/usr/local/cuda/bin/nvcc + +CUSTOMIZED_FAISS_URL="${FAISS_URL:-NONE}" +wget -q --method HEAD ${CUSTOMIZED_FAISS_URL} +if [ $? -eq 0 ]; then + CUSTOMIZATION="ON" +else + CUSTOMIZATION="OFF" +fi + +while getopts "o:d:t:ulcgjhx" arg +do + case $arg in + o) + INSTALL_PREFIX=$OPTARG + ;; + d) + DB_PATH=$OPTARG + ;; + t) + BUILD_TYPE=$OPTARG # BUILD_TYPE + ;; + u) + echo "Build and run unittest cases" ; + BUILD_UNITTEST="ON"; + ;; + l) + RUN_CPPLINT="ON" + ;; + c) + BUILD_COVERAGE="ON" + ;; + g) + PROFILING="ON" + ;; + j) + USE_JFROG_CACHE="ON" + ;; + x) + CUSTOMIZATION="OFF" # force use ori faiss + ;; + h) # help + echo " + +parameter: +-o: install prefix(default: /opt/milvus) +-d: db data path(default: /opt/milvus) +-t: build type(default: Debug) +-u: building unit test options(default: OFF) +-l: run cpplint, clang-format and clang-tidy(default: OFF) +-c: code coverage(default: OFF) +-g: profiling(default: OFF) +-j: use jfrog cache build directory(default: OFF) +-h: help + +usage: +./build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} [-u] [-l] [-r] [-c] [-g] [-j] [-h] + " + exit 0 + ;; + ?) + echo "ERROR! unknown argument" + exit 1 + ;; + esac +done + +if [[ ! -d ${CMAKE_BUILD_DIR} ]]; then + mkdir ${CMAKE_BUILD_DIR} +fi + +cd ${CMAKE_BUILD_DIR} + +# remove make cache since build.sh -l use default variables +# force update the variables each time +make rebuild_cache + +CMAKE_CMD="cmake \ +-DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ +-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ +-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \ +-DBUILD_COVERAGE=${BUILD_COVERAGE} \ +-DMILVUS_DB_PATH=${DB_PATH} \ +-DMILVUS_ENABLE_PROFILING=${PROFILING} \ +-DUSE_JFROG_CACHE=${USE_JFROG_CACHE} \ +-DCUSTOMIZATION=${CUSTOMIZATION} \ +-DFAISS_URL=${CUSTOMIZED_FAISS_URL} \ +.." +echo ${CMAKE_CMD} +${CMAKE_CMD} + +if [[ ${RUN_CPPLINT} == "ON" ]]; then + # cpplint check + make lint + if [ $? -ne 0 ]; then + echo "ERROR! cpplint check failed" + exit 1 + fi + echo "cpplint check passed!" + + # clang-format check + make check-clang-format + if [ $? -ne 0 ]; then + echo "ERROR! clang-format check failed" + exit 1 + fi + echo "clang-format check passed!" + +# # clang-tidy check +# make check-clang-tidy +# if [ $? -ne 0 ]; then +# echo "ERROR! clang-tidy check failed" +# rm -f CMakeCache.txt +# exit 1 +# fi +# echo "clang-tidy check passed!" +else + # compile and build + make -j8 || exit 1 + make install || exit 1 +fi diff --git a/ci/jenkins/scripts/coverage.sh b/ci/jenkins/scripts/coverage.sh new file mode 100755 index 0000000000..ecbb2dfbe9 --- /dev/null +++ b/ci/jenkins/scripts/coverage.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +SOURCE="${BASH_SOURCE[0]}" +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +SCRIPTS_DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" + +INSTALL_PREFIX="/opt/milvus" +CMAKE_BUILD_DIR="${SCRIPTS_DIR}/../../../core/cmake_build" +MYSQL_USER_NAME=root +MYSQL_PASSWORD=123456 +MYSQL_HOST='127.0.0.1' +MYSQL_PORT='3306' + +while getopts "o:u:p:t:h" arg +do + case $arg in + o) + INSTALL_PREFIX=$OPTARG + ;; + u) + MYSQL_USER_NAME=$OPTARG + ;; + p) + MYSQL_PASSWORD=$OPTARG + ;; + t) + MYSQL_HOST=$OPTARG + ;; + h) # help + echo " + +parameter: +-o: milvus install prefix(default: /opt/milvus) +-u: mysql account +-p: mysql password +-t: mysql host +-h: help + +usage: +./coverage.sh -o \${INSTALL_PREFIX} -u \${MYSQL_USER} -p \${MYSQL_PASSWORD} -t \${MYSQL_HOST} [-h] + " + exit 0 + ;; + ?) + echo "ERROR! unknown argument" + exit 1 + ;; + esac +done + +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${INSTALL_PREFIX}/lib + +LCOV_CMD="lcov" +# LCOV_GEN_CMD="genhtml" + +FILE_INFO_BASE="base.info" +FILE_INFO_MILVUS="server.info" +FILE_INFO_OUTPUT="output.info" +FILE_INFO_OUTPUT_NEW="output_new.info" +DIR_LCOV_OUTPUT="lcov_out" + +DIR_GCNO="${CMAKE_BUILD_DIR}" +DIR_UNITTEST="${INSTALL_PREFIX}/unittest" + +# delete old code coverage info files +rm -rf lcov_out +rm -f FILE_INFO_BASE FILE_INFO_MILVUS FILE_INFO_OUTPUT FILE_INFO_OUTPUT_NEW + +MYSQL_DB_NAME=milvus_`date +%s%N` + +function mysql_exc() +{ + cmd=$1 + mysql -h${MYSQL_HOST} -u${MYSQL_USER_NAME} -p${MYSQL_PASSWORD} -e "${cmd}" + if [ $? -ne 0 ]; then + echo "mysql $cmd run failed" + fi +} + +mysql_exc "CREATE DATABASE IF NOT EXISTS ${MYSQL_DB_NAME};" +mysql_exc "GRANT ALL PRIVILEGES ON ${MYSQL_DB_NAME}.* TO '${MYSQL_USER_NAME}'@'%';" +mysql_exc "FLUSH PRIVILEGES;" +mysql_exc "USE ${MYSQL_DB_NAME};" + +# get baseline +${LCOV_CMD} -c -i -d ${DIR_GCNO} -o "${FILE_INFO_BASE}" +if [ $? -ne 0 ]; then + echo "gen baseline coverage run failed" + exit -1 +fi + +for test in `ls ${DIR_UNITTEST}`; do + echo $test + case ${test} in + test_db) + # set run args for test_db + args="mysql://${MYSQL_USER_NAME}:${MYSQL_PASSWORD}@${MYSQL_HOST}:${MYSQL_PORT}/${MYSQL_DB_NAME}" + ;; + *test_*) + args="" + ;; + esac + # run unittest + ${DIR_UNITTEST}/${test} "${args}" + if [ $? -ne 0 ]; then + echo ${args} + echo ${DIR_UNITTEST}/${test} "run failed" + fi +done + +mysql_exc "DROP DATABASE IF EXISTS ${MYSQL_DB_NAME};" + +# gen code coverage +${LCOV_CMD} -d ${DIR_GCNO} -o "${FILE_INFO_MILVUS}" -c +# merge coverage +${LCOV_CMD} -a ${FILE_INFO_BASE} -a ${FILE_INFO_MILVUS} -o "${FILE_INFO_OUTPUT}" + +# remove third party from tracefiles +${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \ + "/usr/*" \ + "*/boost/*" \ + "*/cmake_build/*_ep-prefix/*" \ + "*/src/index/cmake_build*" \ + "*/src/index/thirdparty*" \ + "*/src/grpc*" \ + "*/src/metrics/MetricBase.h" \ + "*/src/server/Server.cpp" \ + "*/src/server/DBWrapper.cpp" \ + "*/src/server/grpc_impl/GrpcServer.cpp" \ + "*/src/utils/easylogging++.h" \ + "*/src/utils/easylogging++.cc" + +# gen html report +# ${LCOV_GEN_CMD} "${FILE_INFO_OUTPUT_NEW}" --output-directory ${DIR_LCOV_OUTPUT}/ diff --git a/ci/jenkinsfile/cleanup_dev.groovy b/ci/jenkinsfile/cleanup_dev.groovy new file mode 100644 index 0000000000..2e9332fa6e --- /dev/null +++ b/ci/jenkinsfile/cleanup_dev.groovy @@ -0,0 +1,13 @@ +try { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } +} catch (exc) { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } + throw exc +} + diff --git a/ci/jenkinsfile/cleanup_staging.groovy b/ci/jenkinsfile/cleanup_staging.groovy new file mode 100644 index 0000000000..2e9332fa6e --- /dev/null +++ b/ci/jenkinsfile/cleanup_staging.groovy @@ -0,0 +1,13 @@ +try { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } +} catch (exc) { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } + throw exc +} + diff --git a/ci/jenkinsfile/cluster_cleanup_dev.groovy b/ci/jenkinsfile/cluster_cleanup_dev.groovy new file mode 100644 index 0000000000..e57988fefe --- /dev/null +++ b/ci/jenkinsfile/cluster_cleanup_dev.groovy @@ -0,0 +1,13 @@ +try { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster" + } +} catch (exc) { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster" + } + throw exc +} + diff --git a/ci/jenkinsfile/cluster_deploy2dev.groovy b/ci/jenkinsfile/cluster_deploy2dev.groovy new file mode 100644 index 0000000000..7f2a584256 --- /dev/null +++ b/ci/jenkinsfile/cluster_deploy2dev.groovy @@ -0,0 +1,24 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus' + sh 'helm repo update' + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + dir ("milvus/milvus-cluster") { + sh "helm install --wait --timeout 300 --set roServers.image.tag=${DOCKER_VERSION} --set woServers.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP -f ci/values.yaml --name ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster --namespace milvus-cluster --version 0.5.0 . " + } + } + /* + timeout(time: 2, unit: 'MINUTES') { + waitUntil { + def result = sh script: "nc -z -w 3 ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster-milvus-cluster-proxy.milvus-cluster.svc.cluster.local 19530", returnStatus: true + return !result + } + } + */ +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster" + throw exc +} + diff --git a/ci/jenkinsfile/cluster_dev_test.groovy b/ci/jenkinsfile/cluster_dev_test.groovy new file mode 100644 index 0000000000..4a15b926cf --- /dev/null +++ b/ci/jenkinsfile/cluster_dev_test.groovy @@ -0,0 +1,12 @@ +timeout(time: 25, unit: 'MINUTES') { + try { + dir ("${PROJECT_NAME}_test") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:Test/milvus_test.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + sh 'python3 -m pip install -r requirements_cluster.txt' + sh "pytest . --alluredir=cluster_test_out --ip ${env.JOB_NAME}-${env.BUILD_NUMBER}-cluster-milvus-cluster-proxy.milvus-cluster.svc.cluster.local" + } + } catch (exc) { + echo 'Milvus Cluster Test Failed !' + throw exc + } +} diff --git a/ci/jenkinsfile/deploy2dev.groovy b/ci/jenkinsfile/deploy2dev.groovy new file mode 100644 index 0000000000..b00c8fa335 --- /dev/null +++ b/ci/jenkinsfile/deploy2dev.groovy @@ -0,0 +1,16 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus' + sh 'helm repo update' + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-1 --version 0.5.0 ." + } + } +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + throw exc +} + diff --git a/ci/jenkinsfile/deploy2staging.groovy b/ci/jenkinsfile/deploy2staging.groovy new file mode 100644 index 0000000000..42ccfda71a --- /dev/null +++ b/ci/jenkinsfile/deploy2staging.groovy @@ -0,0 +1,16 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus' + sh 'helm repo update' + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.repository=\"zilliz.azurecr.cn/milvus/engine\" --set engine.image.tag=${DOCKER_VERSION} --set expose.type=loadBalancer --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-1 --version 0.5.0 ." + } + } +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + throw exc +} + diff --git a/ci/jenkinsfile/dev_test.groovy b/ci/jenkinsfile/dev_test.groovy new file mode 100644 index 0000000000..f9df9b4065 --- /dev/null +++ b/ci/jenkinsfile/dev_test.groovy @@ -0,0 +1,28 @@ +timeout(time: 30, unit: 'MINUTES') { + try { + dir ("${PROJECT_NAME}_test") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:Test/milvus_test.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com' + sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --level=1 --ip ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-1.svc.cluster.local" + } + // mysql database backend test + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + + if (!fileExists('milvus-helm')) { + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + } + } + dir ("milvus-helm") { + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/db_backend/mysql_values.yaml --namespace milvus-2 --version 0.5.0 ." + } + } + dir ("${PROJECT_NAME}_test") { + sh "pytest . --alluredir=\"test_out/dev/single/mysql\" --level=1 --ip ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-2.svc.cluster.local" + } + } catch (exc) { + echo 'Milvus Test Failed !' + throw exc + } +} diff --git a/ci/jenkinsfile/dev_test_all.groovy b/ci/jenkinsfile/dev_test_all.groovy new file mode 100644 index 0000000000..b11ea755b9 --- /dev/null +++ b/ci/jenkinsfile/dev_test_all.groovy @@ -0,0 +1,29 @@ +timeout(time: 60, unit: 'MINUTES') { + try { + dir ("${PROJECT_NAME}_test") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:Test/milvus_test.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com' + sh "pytest . --alluredir=\"test_out/dev/single/sqlite\" --ip ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-1.svc.cluster.local" + } + + // mysql database backend test + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + + if (!fileExists('milvus-helm')) { + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + } + } + dir ("milvus-helm") { + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${DOCKER_VERSION} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/db_backend/mysql_values.yaml --namespace milvus-2 --version 0.4.0 ." + } + } + dir ("${PROJECT_NAME}_test") { + sh "pytest . --alluredir=\"test_out/dev/single/mysql\" --ip ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-2.svc.cluster.local" + } + } catch (exc) { + echo 'Milvus Test Failed !' + throw exc + } +} diff --git a/ci/jenkinsfile/milvus_build.groovy b/ci/jenkinsfile/milvus_build.groovy new file mode 100644 index 0000000000..92fd364bb9 --- /dev/null +++ b/ci/jenkinsfile/milvus_build.groovy @@ -0,0 +1,30 @@ +container('milvus-build-env') { + timeout(time: 120, unit: 'MINUTES') { + gitlabCommitStatus(name: 'Build Engine') { + dir ("milvus_engine") { + try { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [[$class: 'SubmoduleOption',disableSubmodules: false,parentCredentials: true,recursiveSubmodules: true,reference: '',trackingSubmodules: false]], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + + dir ("core") { + sh "git config --global user.email \"test@zilliz.com\"" + sh "git config --global user.name \"test\"" + withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh "./build.sh -l" + sh "rm -rf cmake_build" + sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' \ + && export JFROG_USER_NAME='${USERNAME}' \ + && export JFROG_PASSWORD='${PASSWORD}' \ + && export FAISS_URL='http://192.168.1.105:6060/jinhai/faiss/-/archive/branch-0.2.1/faiss-branch-0.2.1.tar.gz' \ + && ./build.sh -t ${params.BUILD_TYPE} -d /opt/milvus -j -u -c" + + sh "./coverage.sh -u root -p 123456 -t \$POD_IP" + } + } + } catch (exc) { + updateGitlabCommitStatus name: 'Build Engine', state: 'failed' + throw exc + } + } + } + } +} diff --git a/ci/jenkinsfile/milvus_build_no_ut.groovy b/ci/jenkinsfile/milvus_build_no_ut.groovy new file mode 100644 index 0000000000..1dd3361106 --- /dev/null +++ b/ci/jenkinsfile/milvus_build_no_ut.groovy @@ -0,0 +1,28 @@ +container('milvus-build-env') { + timeout(time: 120, unit: 'MINUTES') { + gitlabCommitStatus(name: 'Build Engine') { + dir ("milvus_engine") { + try { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [[$class: 'SubmoduleOption',disableSubmodules: false,parentCredentials: true,recursiveSubmodules: true,reference: '',trackingSubmodules: false]], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + + dir ("core") { + sh "git config --global user.email \"test@zilliz.com\"" + sh "git config --global user.name \"test\"" + withCredentials([usernamePassword(credentialsId: "${params.JFROG_USER}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + sh "./build.sh -l" + sh "rm -rf cmake_build" + sh "export JFROG_ARTFACTORY_URL='${params.JFROG_ARTFACTORY_URL}' \ + && export JFROG_USER_NAME='${USERNAME}' \ + && export JFROG_PASSWORD='${PASSWORD}' \ + && export FAISS_URL='http://192.168.1.105:6060/jinhai/faiss/-/archive/branch-0.2.1/faiss-branch-0.2.1.tar.gz' \ + && ./build.sh -t ${params.BUILD_TYPE} -j -d /opt/milvus" + } + } + } catch (exc) { + updateGitlabCommitStatus name: 'Build Engine', state: 'failed' + throw exc + } + } + } + } +} diff --git a/ci/jenkinsfile/nightly_publish_docker.groovy b/ci/jenkinsfile/nightly_publish_docker.groovy new file mode 100644 index 0000000000..8c6121bec8 --- /dev/null +++ b/ci/jenkinsfile/nightly_publish_docker.groovy @@ -0,0 +1,38 @@ +container('publish-docker') { + timeout(time: 15, unit: 'MINUTES') { + gitlabCommitStatus(name: 'Publish Engine Docker') { + try { + dir ("${PROJECT_NAME}_build") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:build/milvus_build.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + dir ("docker/deploy/ubuntu16.04/free_version") { + sh "curl -O -u anonymous: ftp://192.168.1.126/data/${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}/${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz" + sh "tar zxvf ${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz" + try { + def customImage = docker.build("${PROJECT_NAME}/engine:${DOCKER_VERSION}") + docker.withRegistry('https://registry.zilliz.com', "${params.DOCKER_PUBLISH_USER}") { + customImage.push() + } + docker.withRegistry('https://zilliz.azurecr.cn', "${params.AZURE_DOCKER_PUBLISH_USER}") { + customImage.push() + } + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'success' + echo "Docker Pull Command: docker pull registry.zilliz.com/${PROJECT_NAME}/engine:${DOCKER_VERSION}" + } + } catch (exc) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'canceled' + throw exc + } finally { + sh "docker rmi ${PROJECT_NAME}/engine:${DOCKER_VERSION}" + } + } + } + } catch (exc) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'failed' + echo 'Publish docker failed!' + throw exc + } + } + } +} + diff --git a/ci/jenkinsfile/notify.groovy b/ci/jenkinsfile/notify.groovy new file mode 100644 index 0000000000..0a257b8cd8 --- /dev/null +++ b/ci/jenkinsfile/notify.groovy @@ -0,0 +1,15 @@ +def notify() { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } +} +return this + diff --git a/ci/jenkinsfile/packaged_milvus.groovy b/ci/jenkinsfile/packaged_milvus.groovy new file mode 100644 index 0000000000..1d30e21910 --- /dev/null +++ b/ci/jenkinsfile/packaged_milvus.groovy @@ -0,0 +1,44 @@ +container('milvus-build-env') { + timeout(time: 5, unit: 'MINUTES') { + dir ("milvus_engine") { + dir ("core") { + gitlabCommitStatus(name: 'Packaged Engine') { + if (fileExists('milvus')) { + try { + sh "tar -zcvf ./${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz ./milvus" + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz", "${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Download Milvus Engine Binary Viewer \"http://192.168.1.126:8080/${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}/${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz\"" + } + } catch (exc) { + updateGitlabCommitStatus name: 'Packaged Engine', state: 'failed' + throw exc + } + } else { + updateGitlabCommitStatus name: 'Packaged Engine', state: 'failed' + error("Milvus binary directory don't exists!") + } + } + + gitlabCommitStatus(name: 'Packaged Engine lcov') { + if (fileExists('lcov_out')) { + try { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("lcov_out/", "${PROJECT_NAME}/lcov/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus lcov out Viewer \"http://192.168.1.126:8080/${PROJECT_NAME}/lcov/${JOB_NAME}-${BUILD_ID}/lcov_out/\"" + } + } catch (exc) { + updateGitlabCommitStatus name: 'Packaged Engine lcov', state: 'failed' + throw exc + } + } else { + updateGitlabCommitStatus name: 'Packaged Engine lcov', state: 'failed' + error("Milvus lcov out directory don't exists!") + } + } + } + } + } +} diff --git a/ci/jenkinsfile/packaged_milvus_no_ut.groovy b/ci/jenkinsfile/packaged_milvus_no_ut.groovy new file mode 100644 index 0000000000..bc68be374a --- /dev/null +++ b/ci/jenkinsfile/packaged_milvus_no_ut.groovy @@ -0,0 +1,26 @@ +container('milvus-build-env') { + timeout(time: 5, unit: 'MINUTES') { + dir ("milvus_engine") { + dir ("core") { + gitlabCommitStatus(name: 'Packaged Engine') { + if (fileExists('milvus')) { + try { + sh "tar -zcvf ./${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz ./milvus" + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz", "${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Download Milvus Engine Binary Viewer \"http://192.168.1.126:8080/${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}/${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz\"" + } + } catch (exc) { + updateGitlabCommitStatus name: 'Packaged Engine', state: 'failed' + throw exc + } + } else { + updateGitlabCommitStatus name: 'Packaged Engine', state: 'failed' + error("Milvus binary directory don't exists!") + } + } + } + } + } +} diff --git a/ci/jenkinsfile/publish_docker.groovy b/ci/jenkinsfile/publish_docker.groovy new file mode 100644 index 0000000000..ef31eba9a4 --- /dev/null +++ b/ci/jenkinsfile/publish_docker.groovy @@ -0,0 +1,35 @@ +container('publish-docker') { + timeout(time: 15, unit: 'MINUTES') { + gitlabCommitStatus(name: 'Publish Engine Docker') { + try { + dir ("${PROJECT_NAME}_build") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:build/milvus_build.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + dir ("docker/deploy/ubuntu16.04/free_version") { + sh "curl -O -u anonymous: ftp://192.168.1.126/data/${PROJECT_NAME}/engine/${JOB_NAME}-${BUILD_ID}/${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz" + sh "tar zxvf ${PROJECT_NAME}-engine-${PACKAGE_VERSION}.tar.gz" + try { + def customImage = docker.build("${PROJECT_NAME}/engine:${DOCKER_VERSION}") + docker.withRegistry('https://registry.zilliz.com', "${params.DOCKER_PUBLISH_USER}") { + customImage.push() + } + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'success' + echo "Docker Pull Command: docker pull registry.zilliz.com/${PROJECT_NAME}/engine:${DOCKER_VERSION}" + } + } catch (exc) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'canceled' + throw exc + } finally { + sh "docker rmi ${PROJECT_NAME}/engine:${DOCKER_VERSION}" + } + } + } + } catch (exc) { + updateGitlabCommitStatus name: 'Publish Engine Docker', state: 'failed' + echo 'Publish docker failed!' + throw exc + } + } + } +} + diff --git a/ci/jenkinsfile/staging_test.groovy b/ci/jenkinsfile/staging_test.groovy new file mode 100644 index 0000000000..dcf1787103 --- /dev/null +++ b/ci/jenkinsfile/staging_test.groovy @@ -0,0 +1,31 @@ +timeout(time: 40, unit: 'MINUTES') { + try { + dir ("${PROJECT_NAME}_test") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:Test/milvus_test.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + sh 'python3 -m pip install -r requirements.txt' + def service_ip = sh (script: "kubectl get svc --namespace milvus-1 ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine --template \"{{range .status.loadBalancer.ingress}}{{.ip}}{{end}}\"",returnStdout: true).trim() + sh "pytest . --alluredir=\"test_out/staging/single/sqlite\" --ip ${service_ip}" + } + + // mysql database backend test + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_staging.groovy" + + if (!fileExists('milvus-helm')) { + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${SEMVER}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${SEMVER}:refs/remotes/origin/${SEMVER}"]]]) + } + } + dir ("milvus-helm") { + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.repository=\"zilliz.azurecr.cn/milvus/engine\" --set engine.image.tag=${DOCKER_VERSION} --set expose.type=loadBalancer --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/db_backend/mysql_values.yaml --namespace milvus-2 --version 0.5.0 ." + } + } + dir ("${PROJECT_NAME}_test") { + def service_ip = sh (script: "kubectl get svc --namespace milvus-2 ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine --template \"{{range .status.loadBalancer.ingress}}{{.ip}}{{end}}\"",returnStdout: true).trim() + sh "pytest . --alluredir=\"test_out/staging/single/mysql\" --ip ${service_ip}" + } + } catch (exc) { + echo 'Milvus Test Failed !' + throw exc + } +} diff --git a/ci/jenkinsfile/upload_dev_cluster_test_out.groovy b/ci/jenkinsfile/upload_dev_cluster_test_out.groovy new file mode 100644 index 0000000000..6bbd8a649f --- /dev/null +++ b/ci/jenkinsfile/upload_dev_cluster_test_out.groovy @@ -0,0 +1,14 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("${PROJECT_NAME}_test") { + if (fileExists('cluster_test_out')) { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("cluster_test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\"" + } + } else { + error("Milvus Dev Test Out directory don't exists!") + } + } +} + diff --git a/ci/jenkinsfile/upload_dev_test_out.groovy b/ci/jenkinsfile/upload_dev_test_out.groovy new file mode 100644 index 0000000000..017b887334 --- /dev/null +++ b/ci/jenkinsfile/upload_dev_test_out.groovy @@ -0,0 +1,13 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("${PROJECT_NAME}_test") { + if (fileExists('test_out/dev')) { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("test_out/dev/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\"" + } + } else { + error("Milvus Dev Test Out directory don't exists!") + } + } +} diff --git a/ci/jenkinsfile/upload_staging_test_out.groovy b/ci/jenkinsfile/upload_staging_test_out.groovy new file mode 100644 index 0000000000..1f1e66ab1b --- /dev/null +++ b/ci/jenkinsfile/upload_staging_test_out.groovy @@ -0,0 +1,13 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("${PROJECT_NAME}_test") { + if (fileExists('test_out/staging')) { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("test_out/staging/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\"" + } + } else { + error("Milvus Dev Test Out directory don't exists!") + } + } +} diff --git a/ci/main_jenkinsfile b/ci/main_jenkinsfile new file mode 100644 index 0000000000..0c3fc32e5b --- /dev/null +++ b/ci/main_jenkinsfile @@ -0,0 +1,396 @@ +pipeline { + agent none + + options { + timestamps() + } + + environment { + PROJECT_NAME = "milvus" + LOWER_BUILD_TYPE = BUILD_TYPE.toLowerCase() + SEMVER = "${env.gitlabSourceBranch == null ? params.ENGINE_BRANCH.substring(params.ENGINE_BRANCH.lastIndexOf('/') + 1) : env.gitlabSourceBranch}" + GITLAB_AFTER_COMMIT = "${env.gitlabAfter == null ? null : env.gitlabAfter}" + SUFFIX_VERSION_NAME = "${env.gitlabAfter == null ? null : env.gitlabAfter.substring(0, 6)}" + DOCKER_VERSION_STR = "${env.gitlabAfter == null ? "${SEMVER}-${LOWER_BUILD_TYPE}" : "${SEMVER}-${LOWER_BUILD_TYPE}-${SUFFIX_VERSION_NAME}"}" + } + + stages { + stage("Ubuntu 16.04") { + environment { + PACKAGE_VERSION = VersionNumber([ + versionNumberString : '${SEMVER}-${LOWER_BUILD_TYPE}-${BUILD_DATE_FORMATTED, "yyyyMMdd"}' + ]); + + DOCKER_VERSION = VersionNumber([ + versionNumberString : '${DOCKER_VERSION_STR}' + ]); + } + + stages { + stage("Run Build") { + agent { + kubernetes { + cloud 'build-kubernetes' + label 'build' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" + - name: milvus-mysql + image: mysql:5.6 + env: + - name: MYSQL_ROOT_PASSWORD + value: 123456 + ports: + - containerPort: 3306 + name: mysql +""" + } + } + stages { + stage('Build') { + steps { + gitlabCommitStatus(name: 'Build') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/milvus_build.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/packaged_milvus.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Build', state: 'canceled' + echo "Milvus Build aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Build', state: 'failed' + echo "Milvus Build failure !" + } + } + } + } + + stage("Publish docker and helm") { + agent { + kubernetes { + label 'publish' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-docker + image: registry.zilliz.com/library/zilliz_docker:v1.0.0 + securityContext: + privileged: true + command: + - cat + tty: true + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock +""" + } + } + stages { + stage('Publish Docker') { + steps { + gitlabCommitStatus(name: 'Publish Docker') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/publish_docker.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'canceled' + echo "Milvus Publish Docker aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'failed' + echo "Milvus Publish Docker failure !" + } + } + } + } + + stage("Deploy to Development") { + parallel { + stage("Single Node") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: test +spec: + containers: + - name: milvus-testframework + image: registry.zilliz.com/milvus/milvus-test:v0.2 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config +""" + } + } + + stages { + stage("Deploy to Dev") { + steps { + gitlabCommitStatus(name: 'Deloy to Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/deploy2dev.groovy" + } + } + } + } + } + stage("Dev Test") { + steps { + gitlabCommitStatus(name: 'Deloy Test') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/dev_test.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_test_out.groovy" + } + } + } + } + } + stage ("Cleanup Dev") { + steps { + gitlabCommitStatus(name: 'Cleanup Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + success { + script { + echo "Milvus Single Node CI/CD success !" + } + } + aborted { + script { + echo "Milvus Single Node CI/CD aborted !" + } + } + failure { + script { + echo "Milvus Single Node CI/CD failure !" + } + } + } + } + +// stage("Cluster") { +// agent { +// kubernetes { +// label 'dev-test' +// defaultContainer 'jnlp' +// yaml """ +// apiVersion: v1 +// kind: Pod +// metadata: +// labels: +// app: milvus +// componet: test +// spec: +// containers: +// - name: milvus-testframework +// image: registry.zilliz.com/milvus/milvus-test:v0.2 +// command: +// - cat +// tty: true +// volumeMounts: +// - name: kubeconf +// mountPath: /root/.kube/ +// readOnly: true +// volumes: +// - name: kubeconf +// secret: +// secretName: test-cluster-config +// """ +// } +// } +// stages { +// stage("Deploy to Dev") { +// steps { +// gitlabCommitStatus(name: 'Deloy to Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_deploy2dev.groovy" +// } +// } +// } +// } +// } +// stage("Dev Test") { +// steps { +// gitlabCommitStatus(name: 'Deloy Test') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_dev_test.groovy" +// load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_cluster_test_out.groovy" +// } +// } +// } +// } +// } +// stage ("Cleanup Dev") { +// steps { +// gitlabCommitStatus(name: 'Cleanup Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// } +// } +// } +// post { +// always { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// success { +// script { +// echo "Milvus Cluster CI/CD success !" +// } +// } +// aborted { +// script { +// echo "Milvus Cluster CI/CD aborted !" +// } +// } +// failure { +// script { +// echo "Milvus Cluster CI/CD failure !" +// } +// } +// } +// } + } + } + } + } + } + + post { + always { + script { + if (env.gitlabAfter != null) { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } + } + } + } + + success { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'success' + echo "Milvus CI/CD success !" + } + } + + aborted { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'canceled' + echo "Milvus CI/CD aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'failed' + echo "Milvus CI/CD failure !" + } + } + } +} + diff --git a/ci/main_jenkinsfile_no_ut b/ci/main_jenkinsfile_no_ut new file mode 100644 index 0000000000..e8d7dae75a --- /dev/null +++ b/ci/main_jenkinsfile_no_ut @@ -0,0 +1,396 @@ +pipeline { + agent none + + options { + timestamps() + } + + environment { + PROJECT_NAME = "milvus" + LOWER_BUILD_TYPE = BUILD_TYPE.toLowerCase() + SEMVER = "${env.gitlabSourceBranch == null ? params.ENGINE_BRANCH.substring(params.ENGINE_BRANCH.lastIndexOf('/') + 1) : env.gitlabSourceBranch}" + GITLAB_AFTER_COMMIT = "${env.gitlabAfter == null ? null : env.gitlabAfter}" + SUFFIX_VERSION_NAME = "${env.gitlabAfter == null ? null : env.gitlabAfter.substring(0, 6)}" + DOCKER_VERSION_STR = "${env.gitlabAfter == null ? "${SEMVER}-${LOWER_BUILD_TYPE}" : "${SEMVER}-${LOWER_BUILD_TYPE}-${SUFFIX_VERSION_NAME}"}" + } + + stages { + stage("Ubuntu 16.04") { + environment { + PACKAGE_VERSION = VersionNumber([ + versionNumberString : '${SEMVER}-${LOWER_BUILD_TYPE}-${BUILD_DATE_FORMATTED, "yyyyMMdd"}' + ]); + + DOCKER_VERSION = VersionNumber([ + versionNumberString : '${DOCKER_VERSION_STR}' + ]); + } + + stages { + stage("Run Build") { + agent { + kubernetes { + cloud 'build-kubernetes' + label 'build' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" + - name: milvus-mysql + image: mysql:5.6 + env: + - name: MYSQL_ROOT_PASSWORD + value: 123456 + ports: + - containerPort: 3306 + name: mysql +""" + } + } + stages { + stage('Build') { + steps { + gitlabCommitStatus(name: 'Build') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/milvus_build_no_ut.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/packaged_milvus_no_ut.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Build', state: 'canceled' + echo "Milvus Build aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Build', state: 'failed' + echo "Milvus Build failure !" + } + } + } + } + + stage("Publish docker and helm") { + agent { + kubernetes { + label 'publish' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-docker + image: registry.zilliz.com/library/zilliz_docker:v1.0.0 + securityContext: + privileged: true + command: + - cat + tty: true + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock +""" + } + } + stages { + stage('Publish Docker') { + steps { + gitlabCommitStatus(name: 'Publish Docker') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/publish_docker.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'canceled' + echo "Milvus Publish Docker aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'failed' + echo "Milvus Publish Docker failure !" + } + } + } + } + + stage("Deploy to Development") { + parallel { + stage("Single Node") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: test +spec: + containers: + - name: milvus-testframework + image: registry.zilliz.com/milvus/milvus-test:v0.2 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config +""" + } + } + + stages { + stage("Deploy to Dev") { + steps { + gitlabCommitStatus(name: 'Deloy to Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/deploy2dev.groovy" + } + } + } + } + } + stage("Dev Test") { + steps { + gitlabCommitStatus(name: 'Deloy Test') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/dev_test.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_test_out.groovy" + } + } + } + } + } + stage ("Cleanup Dev") { + steps { + gitlabCommitStatus(name: 'Cleanup Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + success { + script { + echo "Milvus Single Node CI/CD success !" + } + } + aborted { + script { + echo "Milvus Single Node CI/CD aborted !" + } + } + failure { + script { + echo "Milvus Single Node CI/CD failure !" + } + } + } + } + +// stage("Cluster") { +// agent { +// kubernetes { +// label 'dev-test' +// defaultContainer 'jnlp' +// yaml """ +// apiVersion: v1 +// kind: Pod +// metadata: +// labels: +// app: milvus +// componet: test +// spec: +// containers: +// - name: milvus-testframework +// image: registry.zilliz.com/milvus/milvus-test:v0.2 +// command: +// - cat +// tty: true +// volumeMounts: +// - name: kubeconf +// mountPath: /root/.kube/ +// readOnly: true +// volumes: +// - name: kubeconf +// secret: +// secretName: test-cluster-config +// """ +// } +// } +// stages { +// stage("Deploy to Dev") { +// steps { +// gitlabCommitStatus(name: 'Deloy to Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_deploy2dev.groovy" +// } +// } +// } +// } +// } +// stage("Dev Test") { +// steps { +// gitlabCommitStatus(name: 'Deloy Test') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_dev_test.groovy" +// load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_cluster_test_out.groovy" +// } +// } +// } +// } +// } +// stage ("Cleanup Dev") { +// steps { +// gitlabCommitStatus(name: 'Cleanup Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// } +// } +// } +// post { +// always { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// success { +// script { +// echo "Milvus Cluster CI/CD success !" +// } +// } +// aborted { +// script { +// echo "Milvus Cluster CI/CD aborted !" +// } +// } +// failure { +// script { +// echo "Milvus Cluster CI/CD failure !" +// } +// } +// } +// } + } + } + } + } + } + + post { + always { + script { + if (env.gitlabAfter != null) { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } + } + } + } + + success { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'success' + echo "Milvus CI/CD success !" + } + } + + aborted { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'canceled' + echo "Milvus CI/CD aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'failed' + echo "Milvus CI/CD failure !" + } + } + } +} + diff --git a/ci/nightly_main_jenkinsfile b/ci/nightly_main_jenkinsfile new file mode 100644 index 0000000000..add9e00fb4 --- /dev/null +++ b/ci/nightly_main_jenkinsfile @@ -0,0 +1,478 @@ +pipeline { + agent none + + options { + timestamps() + } + + environment { + PROJECT_NAME = "milvus" + LOWER_BUILD_TYPE = BUILD_TYPE.toLowerCase() + SEMVER = "${env.gitlabSourceBranch == null ? params.ENGINE_BRANCH.substring(params.ENGINE_BRANCH.lastIndexOf('/') + 1) : env.gitlabSourceBranch}" + GITLAB_AFTER_COMMIT = "${env.gitlabAfter == null ? null : env.gitlabAfter}" + SUFFIX_VERSION_NAME = "${env.gitlabAfter == null ? null : env.gitlabAfter.substring(0, 6)}" + DOCKER_VERSION_STR = "${env.gitlabAfter == null ? '${SEMVER}-${LOWER_BUILD_TYPE}-${BUILD_DATE_FORMATTED, \"yyyyMMdd\"}' : '${SEMVER}-${LOWER_BUILD_TYPE}-${SUFFIX_VERSION_NAME}'}" + } + + stages { + stage("Ubuntu 16.04") { + environment { + PACKAGE_VERSION = VersionNumber([ + versionNumberString : '${SEMVER}-${LOWER_BUILD_TYPE}-${BUILD_DATE_FORMATTED, "yyyyMMdd"}' + ]); + + DOCKER_VERSION = VersionNumber([ + versionNumberString : '${DOCKER_VERSION_STR}' + ]); + } + + stages { + stage("Run Build") { + agent { + kubernetes { + cloud 'build-kubernetes' + label 'build' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + name: milvus-build-env + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.13 + command: + - cat + tty: true + resources: + limits: + memory: "28Gi" + cpu: "10.0" + nvidia.com/gpu: 1 + requests: + memory: "14Gi" + cpu: "5.0" +""" + } + } + stages { + stage('Build') { + steps { + gitlabCommitStatus(name: 'Build') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/milvus_build.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/packaged_milvus.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Build', state: 'canceled' + echo "Milvus Build aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Build', state: 'failed' + echo "Milvus Build failure !" + } + } + } + } + + stage("Publish docker and helm") { + agent { + kubernetes { + label 'publish' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-docker + image: registry.zilliz.com/library/zilliz_docker:v1.0.0 + securityContext: + privileged: true + command: + - cat + tty: true + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock +""" + } + } + stages { + stage('Publish Docker') { + steps { + gitlabCommitStatus(name: 'Publish Docker') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/nightly_publish_docker.groovy" + } + } + } + } + } + post { + aborted { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'canceled' + echo "Milvus Publish Docker aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'Publish Docker', state: 'failed' + echo "Milvus Publish Docker failure !" + } + } + } + } + + stage("Deploy to Development") { + parallel { + stage("Single Node") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: test +spec: + containers: + - name: milvus-testframework + image: registry.zilliz.com/milvus/milvus-test:v0.2 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config +""" + } + } + + stages { + stage("Deploy to Dev") { + steps { + gitlabCommitStatus(name: 'Deloy to Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/deploy2dev.groovy" + } + } + } + } + } + stage("Dev Test") { + steps { + gitlabCommitStatus(name: 'Deloy Test') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/dev_test_all.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_test_out.groovy" + } + } + } + } + } + stage ("Cleanup Dev") { + steps { + gitlabCommitStatus(name: 'Cleanup Dev') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_dev.groovy" + } + } + } + success { + script { + echo "Milvus Deploy to Dev Single Node CI/CD success !" + } + } + aborted { + script { + echo "Milvus Deploy to Dev Single Node CI/CD aborted !" + } + } + failure { + script { + echo "Milvus Deploy to Dev Single Node CI/CD failure !" + } + } + } + } + +// stage("Cluster") { +// agent { +// kubernetes { +// label 'dev-test' +// defaultContainer 'jnlp' +// yaml """ +// apiVersion: v1 +// kind: Pod +// metadata: +// labels: +// app: milvus +// componet: test +// spec: +// containers: +// - name: milvus-testframework +// image: registry.zilliz.com/milvus/milvus-test:v0.2 +// command: +// - cat +// tty: true +// volumeMounts: +// - name: kubeconf +// mountPath: /root/.kube/ +// readOnly: true +// volumes: +// - name: kubeconf +// secret: +// secretName: test-cluster-config +// """ +// } +// } +// stages { +// stage("Deploy to Dev") { +// steps { +// gitlabCommitStatus(name: 'Deloy to Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_deploy2dev.groovy" +// } +// } +// } +// } +// } +// stage("Dev Test") { +// steps { +// gitlabCommitStatus(name: 'Deloy Test') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_dev_test.groovy" +// load "${env.WORKSPACE}/ci/jenkinsfile/upload_dev_cluster_test_out.groovy" +// } +// } +// } +// } +// } +// stage ("Cleanup Dev") { +// steps { +// gitlabCommitStatus(name: 'Cleanup Dev') { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// } +// } +// } +// post { +// always { +// container('milvus-testframework') { +// script { +// load "${env.WORKSPACE}/ci/jenkinsfile/cluster_cleanup_dev.groovy" +// } +// } +// } +// success { +// script { +// echo "Milvus Deploy to Dev Cluster CI/CD success !" +// } +// } +// aborted { +// script { +// echo "Milvus Deploy to Dev Cluster CI/CD aborted !" +// } +// } +// failure { +// script { +// echo "Milvus Deploy to Dev Cluster CI/CD failure !" +// } +// } +// } +// } + } + } + + stage("Deploy to Staging") { + parallel { + stage("Single Node") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: test +spec: + containers: + - name: milvus-testframework + image: registry.zilliz.com/milvus/milvus-test:v0.2 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: aks-gpu-cluster-config +""" + } + } + + stages { + stage("Deploy to Staging") { + steps { + gitlabCommitStatus(name: 'Deloy to Staging') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/deploy2staging.groovy" + } + } + } + } + } + stage("Staging Test") { + steps { + gitlabCommitStatus(name: 'Staging Test') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/staging_test.groovy" + load "${env.WORKSPACE}/ci/jenkinsfile/upload_staging_test_out.groovy" + } + } + } + } + } + stage ("Cleanup Staging") { + steps { + gitlabCommitStatus(name: 'Cleanup Staging') { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_staging.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework') { + script { + load "${env.WORKSPACE}/ci/jenkinsfile/cleanup_staging.groovy" + } + } + } + success { + script { + echo "Milvus Deploy to Staging Single Node CI/CD success !" + } + } + aborted { + script { + echo "Milvus Deploy to Staging Single Node CI/CD aborted !" + } + } + failure { + script { + echo "Milvus Deploy to Staging Single Node CI/CD failure !" + } + } + } + } + } + } + } + } + } + + post { + always { + script { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } + } + } + + success { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'success' + echo "Milvus CI/CD success !" + } + } + + aborted { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'canceled' + echo "Milvus CI/CD aborted !" + } + } + + failure { + script { + updateGitlabCommitStatus name: 'CI/CD', state: 'failed' + echo "Milvus CI/CD failure !" + } + } + } +} + diff --git a/ci/pod_containers/milvus-engine-build.yaml b/ci/pod_containers/milvus-engine-build.yaml new file mode 100644 index 0000000000..cd5352ffef --- /dev/null +++ b/ci/pod_containers/milvus-engine-build.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: build-env +spec: + containers: + - name: milvus-build-env + image: registry.zilliz.com/milvus/milvus-build-env:v0.9 + command: + - cat + tty: true diff --git a/ci/pod_containers/milvus-testframework.yaml b/ci/pod_containers/milvus-testframework.yaml new file mode 100644 index 0000000000..7a98fbca8e --- /dev/null +++ b/ci/pod_containers/milvus-testframework.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: testframework +spec: + containers: + - name: milvus-testframework + image: registry.zilliz.com/milvus/milvus-test:v0.1 + command: + - cat + tty: true diff --git a/ci/pod_containers/publish-docker.yaml b/ci/pod_containers/publish-docker.yaml new file mode 100644 index 0000000000..268afb1331 --- /dev/null +++ b/ci/pod_containers/publish-docker.yaml @@ -0,0 +1,22 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: publish + componet: docker +spec: + containers: + - name: publish-docker + image: registry.zilliz.com/library/zilliz_docker:v1.0.0 + securityContext: + privileged: true + command: + - cat + tty: true + volumeMounts: + - name: docker-sock + mountPath: /var/run/docker.sock + volumes: + - name: docker-sock + hostPath: + path: /var/run/docker.sock diff --git a/codecov.yaml b/codecov.yaml new file mode 100644 index 0000000000..debe315ac0 --- /dev/null +++ b/codecov.yaml @@ -0,0 +1,14 @@ +#Configuration File for CodeCov +coverage: + precision: 2 + round: down + range: "70...100" + + status: + project: on + patch: yes + changes: no + +comment: + layout: "header, diff, changes, tree" + behavior: default diff --git a/core/build.sh b/core/build.sh index c0bd9c5221..cd6f65201b 100755 --- a/core/build.sh +++ b/core/build.sh @@ -1,11 +1,12 @@ #!/bin/bash +BUILD_OUTPUT_DIR="cmake_build" BUILD_TYPE="Debug" BUILD_UNITTEST="OFF" INSTALL_PREFIX=$(pwd)/milvus MAKE_CLEAN="OFF" BUILD_COVERAGE="OFF" -DB_PATH="/opt/milvus" +DB_PATH="/tmp/milvus" PROFILING="OFF" USE_JFROG_CACHE="OFF" RUN_CPPLINT="OFF" @@ -40,8 +41,8 @@ do RUN_CPPLINT="ON" ;; r) - if [[ -d cmake_build ]]; then - rm ./cmake_build -r + if [[ -d ${BUILD_OUTPUT_DIR} ]]; then + rm ./${BUILD_OUTPUT_DIR} -r MAKE_CLEAN="ON" fi ;; @@ -62,7 +63,7 @@ do parameter: -p: install prefix(default: $(pwd)/milvus) --d: db path(default: /opt/milvus) +-d: db data path(default: /tmp/milvus) -t: build type(default: Debug) -u: building unit test options(default: OFF) -l: run cpplint, clang-format and clang-tidy(default: OFF) @@ -84,11 +85,15 @@ usage: esac done -if [[ ! -d cmake_build ]]; then - mkdir cmake_build +if [[ ! -d ${BUILD_OUTPUT_DIR} ]]; then + mkdir ${BUILD_OUTPUT_DIR} fi -cd cmake_build +cd ${BUILD_OUTPUT_DIR} + +# remove make cache since build.sh -l use default variables +# force update the variables each time +make rebuild_cache CMAKE_CMD="cmake \ -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ diff --git a/core/cmake/DefineOptions.cmake b/core/cmake/DefineOptions.cmake index 1b0646c2fa..7aae177f0b 100644 --- a/core/cmake/DefineOptions.cmake +++ b/core/cmake/DefineOptions.cmake @@ -56,7 +56,7 @@ define_option(MILVUS_VERBOSE_THIRDPARTY_BUILD "Show output from ExternalProjects rather than just logging to files" ON) define_option(MILVUS_BOOST_VENDORED "Use vendored Boost instead of existing Boost. \ -Note that this requires linking Boost statically" ON) +Note that this requires linking Boost statically" OFF) define_option(MILVUS_BOOST_HEADER_ONLY "Use only BOOST headers" OFF) diff --git a/core/conf/server_config.template b/core/conf/server_config.template index 2f2f699e09..7abfb8b055 100644 --- a/core/conf/server_config.template +++ b/core/conf/server_config.template @@ -16,7 +16,6 @@ db_config: insert_buffer_size: 4 # GB, maximum insert buffer size allowed # sum of insert_buffer_size and cpu_cache_capacity cannot exceed total memory - build_index_gpu: 0 # gpu id used for building index preload_table: # preload data at startup, '*' means load all tables, empty value means no preload # you can specify preload tables like this: table1,table2,table3 @@ -30,6 +29,8 @@ metric_config: cache_config: cpu_cache_capacity: 16 # GB, CPU memory used for cache cpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered + gpu_cache_capacity: 4 # GB, GPU memory used for cache + gpu_cache_threshold: 0.85 # percentage of data that will be kept when cache cleanup is triggered cache_insert_data: false # whether to load inserted data into cache engine_config: @@ -37,6 +38,6 @@ engine_config: # if nq >= use_blas_threshold, use OpenBlas, slower with stable response times resource_config: - resource_pool: - - cpu + search_resources: # define the GPUs used for search computation, valid value: gpux - gpu0 + index_build_device: gpu0 # GPU used for building index \ No newline at end of file diff --git a/core/coverage.sh b/core/coverage.sh index 55c70d7e8d..74f9f4219d 100755 --- a/core/coverage.sh +++ b/core/coverage.sh @@ -114,15 +114,15 @@ ${LCOV_CMD} -r "${FILE_INFO_OUTPUT}" -o "${FILE_INFO_OUTPUT_NEW}" \ "/usr/*" \ "*/boost/*" \ "*/cmake_build/*_ep-prefix/*" \ - "src/index/cmake_build*" \ - "src/index/thirdparty*" \ - "src/grpc*"\ - "src/metrics/MetricBase.h"\ - "src/server/Server.cpp"\ - "src/server/DBWrapper.cpp"\ - "src/server/grpc_impl/GrpcServer.cpp"\ - "src/utils/easylogging++.h"\ - "src/utils/easylogging++.cc"\ + "*/src/index/cmake_build*" \ + "*/src/index/thirdparty*" \ + "*/src/grpc*" \ + "*/src/metrics/MetricBase.h" \ + "*/src/server/Server.cpp" \ + "*/src/server/DBWrapper.cpp" \ + "*/src/server/grpc_impl/GrpcServer.cpp" \ + "*/src/utils/easylogging++.h" \ + "*/src/utils/easylogging++.cc" # gen html report ${LCOV_GEN_CMD} "${FILE_INFO_OUTPUT_NEW}" --output-directory ${DIR_LCOV_OUTPUT}/ diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt index 5f87d1d219..b0228bd090 100644 --- a/core/src/CMakeLists.txt +++ b/core/src/CMakeLists.txt @@ -96,9 +96,9 @@ set(prometheus_lib ) set(boost_lib - boost_system_static - boost_filesystem_static - boost_serialization_static + libboost_system.a + libboost_filesystem.a + libboost_serialization.a ) set(cuda_lib diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 7d10ab5c6b..ecd6ff0650 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -65,7 +65,7 @@ ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& : location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) { index_ = CreatetVecIndex(EngineType::FAISS_IDMAP); if (!index_) { - throw Exception(DB_ERROR, "Could not create VecIndex"); + throw Exception(DB_ERROR, "Unsupported index type"); } TempMetaConf temp_conf; @@ -111,7 +111,7 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { break; } default: { - ENGINE_LOG_ERROR << "Invalid engine type"; + ENGINE_LOG_ERROR << "Unsupported index type"; return nullptr; } } @@ -124,6 +124,11 @@ ExecutionEngineImpl::HybridLoad() const { return; } + if (index_->GetType() == IndexType::FAISS_IDMAP) { + ENGINE_LOG_WARNING << "HybridLoad with type FAISS_IDMAP, ignore"; + return; + } + const std::string key = location_ + ".quantizer"; std::vector gpus = scheduler::get_gpu_pool(); @@ -164,6 +169,9 @@ ExecutionEngineImpl::HybridLoad() const { quantizer_conf->mode = 1; quantizer_conf->gpu_id = best_device_id; auto quantizer = index_->LoadQuantizer(quantizer_conf); + if (quantizer == nullptr) { + ENGINE_LOG_ERROR << "quantizer is nullptr"; + } index_->SetQuantizer(quantizer); auto cache_quantizer = std::make_shared(quantizer); cache::GpuCacheMgr::GetInstance(best_device_id)->InsertItem(key, cache_quantizer); @@ -175,6 +183,9 @@ ExecutionEngineImpl::HybridUnset() const { if (index_type_ != EngineType::FAISS_IVFSQ8H) { return; } + if (index_->GetType() == IndexType::FAISS_IDMAP) { + return; + } index_->UnsetQuantizer(); } @@ -373,7 +384,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t auto to_index = CreatetVecIndex(engine_type); if (!to_index) { - throw Exception(DB_ERROR, "Could not create VecIndex"); + throw Exception(DB_ERROR, "Unsupported index type"); } TempMetaConf temp_conf; @@ -397,6 +408,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t Status ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t nprobe, float* distances, int64_t* labels, bool hybrid) { +#if 0 if (index_type_ == EngineType::FAISS_IVFSQ8H) { if (!hybrid) { const std::string key = location_ + ".quantizer"; @@ -449,6 +461,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr } } } +#endif if (index_ == nullptr) { ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to search"; @@ -501,7 +514,7 @@ ExecutionEngineImpl::GpuCache(uint64_t gpu_id) { Status ExecutionEngineImpl::Init() { server::Config& config = server::Config::GetInstance(); - Status s = config.GetDBConfigBuildIndexGPU(gpu_num_); + Status s = config.GetResourceConfigIndexBuildDevice(gpu_num_); if (!s.ok()) { return s; } diff --git a/core/src/index/cmake/ThirdPartyPackagesCore.cmake b/core/src/index/cmake/ThirdPartyPackagesCore.cmake index ee1d88ee32..99f52dc284 100644 --- a/core/src/index/cmake/ThirdPartyPackagesCore.cmake +++ b/core/src/index/cmake/ThirdPartyPackagesCore.cmake @@ -243,7 +243,8 @@ if(CUSTOMIZATION) # set(FAISS_MD5 "57da9c4f599cc8fa4260488b1c96e1cc") # commit-id 6dbdf75987c34a2c853bd172ea0d384feea8358c branch-0.2.0 # set(FAISS_MD5 "21deb1c708490ca40ecb899122c01403") # commit-id 643e48f479637fd947e7b93fa4ca72b38ecc9a39 branch-0.2.0 # set(FAISS_MD5 "072db398351cca6e88f52d743bbb9fa0") # commit-id 3a2344d04744166af41ef1a74449d68a315bfe17 branch-0.2.1 - set(FAISS_MD5 "c89ea8e655f5cdf58f42486f13614714") # commit-id 9c28a1cbb88f41fa03b03d7204106201ad33276b branch-0.2.1 + # set(FAISS_MD5 "c89ea8e655f5cdf58f42486f13614714") # commit-id 9c28a1cbb88f41fa03b03d7204106201ad33276b branch-0.2.1 + set(FAISS_MD5 "87fdd86351ffcaf3f80dc26ade63c44b") # commit-id 841a156e67e8e22cd8088e1b58c00afbf2efc30b branch-0.2.1 endif() else() set(FAISS_SOURCE_URL "https://github.com/facebookresearch/faiss/archive/v1.5.3.tar.gz") diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp index 0c4856f2b6..fba2e11e2e 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp @@ -24,17 +24,21 @@ #include #include #include +#include #include #include #include #include "knowhere/adapter/VectorAdapter.h" #include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" #include "knowhere/index/vector_index/IndexGPUIVF.h" #include "knowhere/index/vector_index/IndexIVF.h" namespace knowhere { +using stdclock = std::chrono::high_resolution_clock; + IndexModelPtr IVF::Train(const DatasetPtr& dataset, const Config& config) { auto build_cfg = std::dynamic_pointer_cast(config); @@ -216,7 +220,15 @@ IVF::GenGraph(const int64_t& k, Graph& graph, const DatasetPtr& dataset, const C void IVF::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& cfg) { auto params = GenParams(cfg); + stdclock::time_point before = stdclock::now(); faiss::ivflib::search_with_parameters(index_.get(), n, (float*)data, k, distances, labels, params.get()); + stdclock::time_point after = stdclock::now(); + double search_cost = (std::chrono::duration(after - before)).count(); + KNOWHERE_LOG_DEBUG << "IVF search cost: " << search_cost + << ", quantization cost: " << faiss::indexIVF_stats.quantization_time + << ", data search cost: " << faiss::indexIVF_stats.search_time; + faiss::indexIVF_stats.quantization_time = 0; + faiss::indexIVF_stats.search_time = 0; } VectorIndexPtr diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp index 34c81991c9..fe5bf0990a 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQHybrid.cpp @@ -189,6 +189,8 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { if (quantizer_conf->mode != 2) { KNOWHERE_THROW_MSG("mode only support 2 in this func"); } + } else { + KNOWHERE_THROW_MSG("conf error"); } // if (quantizer_conf->gpu_id != gpu_id_) { // KNOWHERE_THROW_MSG("quantizer and data must on the same gpu card"); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp index d74c6bc562..837629c6eb 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp @@ -63,7 +63,7 @@ FaissGpuResourceMgr::InitResource() { mutex_cache_.emplace(device_id, std::make_unique()); - // std::cout << "Device Id: " << device_id << std::endl; + // std::cout << "Device Id: " << DEVICEID << std::endl; auto& device_param = device.second; auto& bq = idle_map_[device_id]; @@ -119,7 +119,7 @@ void FaissGpuResourceMgr::Dump() { for (auto& item : idle_map_) { auto& bq = item.second; - std::cout << "device_id: " << item.first << ", resource count:" << bq.Size(); + std::cout << "DEVICEID: " << item.first << ", resource count:" << bq.Size(); } } diff --git a/core/src/index/unittest/CMakeLists.txt b/core/src/index/unittest/CMakeLists.txt index 1812bc2503..8a5e089486 100644 --- a/core/src/index/unittest/CMakeLists.txt +++ b/core/src/index/unittest/CMakeLists.txt @@ -73,9 +73,17 @@ target_link_libraries(test_kdt SPTAGLibStatic ${depend_libs} ${unittest_libs} ${basic_libs}) +add_executable(test_gpuresource test_gpuresource.cpp ${util_srcs} ${ivf_srcs}) +target_link_libraries(test_gpuresource ${depend_libs} ${unittest_libs} ${basic_libs}) + +add_executable(test_customized_index test_customized_index.cpp ${util_srcs} ${ivf_srcs}) +target_link_libraries(test_customized_index ${depend_libs} ${unittest_libs} ${basic_libs}) + install(TARGETS test_ivf DESTINATION unittest) install(TARGETS test_idmap DESTINATION unittest) install(TARGETS test_kdt DESTINATION unittest) +install(TARGETS test_gpuresource DESTINATION unittest) +install(TARGETS test_customized_index DESTINATION unittest) #add_subdirectory(faiss_ori) add_subdirectory(test_nsg) diff --git a/core/src/index/unittest/Helper.h b/core/src/index/unittest/Helper.h new file mode 100644 index 0000000000..d11a484c03 --- /dev/null +++ b/core/src/index/unittest/Helper.h @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" + +constexpr int DEVICEID = 0; +constexpr int64_t DIM = 128; +constexpr int64_t NB = 10000; +constexpr int64_t NQ = 10; +constexpr int64_t K = 10; +constexpr int64_t PINMEM = 1024 * 1024 * 200; +constexpr int64_t TEMPMEM = 1024 * 1024 * 300; +constexpr int64_t RESNUM = 2; + +knowhere::IVFIndexPtr +IndexFactory(const std::string& type) { + if (type == "IVF") { + return std::make_shared(); + } else if (type == "IVFPQ") { + return std::make_shared(); + } else if (type == "GPUIVF") { + return std::make_shared(DEVICEID); + } else if (type == "GPUIVFPQ") { + return std::make_shared(DEVICEID); + } else if (type == "IVFSQ") { + return std::make_shared(); + } else if (type == "GPUIVFSQ") { + return std::make_shared(DEVICEID); + } else if (type == "IVFSQHybrid") { + return std::make_shared(DEVICEID); + } +} + +enum class ParameterType { + ivf, + ivfpq, + ivfsq, +}; + +class ParamGenerator { + public: + static ParamGenerator& + GetInstance() { + static ParamGenerator instance; + return instance; + } + + knowhere::Config + Gen(const ParameterType& type) { + if (type == ParameterType::ivf) { + auto tempconf = std::make_shared(); + tempconf->d = DIM; + tempconf->gpu_id = DEVICEID; + tempconf->nlist = 100; + tempconf->nprobe = 4; + tempconf->k = K; + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } else if (type == ParameterType::ivfpq) { + auto tempconf = std::make_shared(); + tempconf->d = DIM; + tempconf->gpu_id = DEVICEID; + tempconf->nlist = 100; + tempconf->nprobe = 4; + tempconf->k = K; + tempconf->m = 4; + tempconf->nbits = 8; + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } else if (type == ParameterType::ivfsq) { + auto tempconf = std::make_shared(); + tempconf->d = DIM; + tempconf->gpu_id = DEVICEID; + tempconf->nlist = 100; + tempconf->nprobe = 4; + tempconf->k = K; + tempconf->nbits = 8; + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } + } +}; + +#include + +class TestGpuIndexBase : public ::testing::Test { + protected: + void + SetUp() override { + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); + } + + void + TearDown() override { + knowhere::FaissGpuResourceMgr::GetInstance().Free(); + } +}; diff --git a/core/src/index/unittest/test_customized_index.cpp b/core/src/index/unittest/test_customized_index.cpp new file mode 100644 index 0000000000..1e0b1d932d --- /dev/null +++ b/core/src/index/unittest/test_customized_index.cpp @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +class SingleIndexTest : public DataGen, public TestGpuIndexBase { + protected: + void + SetUp() override { + TestGpuIndexBase::SetUp(); + Generate(DIM, NB, NQ); + k = K; + } + + void + TearDown() override { + TestGpuIndexBase::TearDown(); + } + + protected: + std::string index_type; + knowhere::IVFIndexPtr index_ = nullptr; +}; + +#ifdef CUSTOMIZATION +TEST_F(SingleIndexTest, IVFSQHybrid) { + assert(!xb.empty()); + + index_type = "IVFSQHybrid"; + index_ = IndexFactory(index_type); + auto conf = ParamGenerator::GetInstance().Gen(ParameterType::ivfsq); + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dimension(), dim); + + auto binaryset = index_->Serialize(); + { + // copy cpu to gpu + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + { + for (int i = 0; i < 3; ++i) { + auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf); + auto result = gpu_idx->Search(query_dataset, conf); + AssertAnns(result, nq, conf->k); + // PrintResult(result, nq, k); + } + } + } + + { + // quantization already in gpu, only copy data + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); + auto gpu_idx = pair.first; + auto quantization = pair.second; + + auto result = gpu_idx->Search(query_dataset, conf); + AssertAnns(result, nq, conf->k); + // PrintResult(result, nq, k); + + auto quantizer_conf = std::make_shared(); + quantizer_conf->mode = 2; // only copy data + quantizer_conf->gpu_id = DEVICEID; + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = std::make_shared(DEVICEID); + hybrid_idx->Load(binaryset); + + auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf); + auto result = new_idx->Search(query_dataset, conf); + AssertAnns(result, nq, conf->k); + // PrintResult(result, nq, k); + } + } + + { + // quantization already in gpu, only set quantization + auto cpu_idx = std::make_shared(DEVICEID); + cpu_idx->Load(binaryset); + + auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf); + auto quantization = pair.second; + + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = std::make_shared(DEVICEID); + hybrid_idx->Load(binaryset); + + hybrid_idx->SetQuantizer(quantization); + auto result = hybrid_idx->Search(query_dataset, conf); + AssertAnns(result, nq, conf->k); + // PrintResult(result, nq, k); + hybrid_idx->UnsetQuantizer(); + } + } +} + +#endif diff --git a/core/src/index/unittest/test_gpuresource.cpp b/core/src/index/unittest/test_gpuresource.cpp new file mode 100644 index 0000000000..e3485d26ba --- /dev/null +++ b/core/src/index/unittest/test_gpuresource.cpp @@ -0,0 +1,309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/IndexGPUIVF.h" +#include "knowhere/index/vector_index/IndexGPUIVFPQ.h" +#include "knowhere/index/vector_index/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/IndexIVFPQ.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +class GPURESTEST : public DataGen, public TestGpuIndexBase { + protected: + void + SetUp() override { + TestGpuIndexBase::SetUp(); + Generate(DIM, NB, NQ); + + k = K; + elems = nq * k; + ids = (int64_t*)malloc(sizeof(int64_t) * elems); + dis = (float*)malloc(sizeof(float) * elems); + } + + void + TearDown() override { + delete ids; + delete dis; + TestGpuIndexBase::TearDown(); + } + + protected: + std::string index_type; + knowhere::IVFIndexPtr index_ = nullptr; + + int64_t* ids = nullptr; + float* dis = nullptr; + int64_t elems = 0; +}; + +TEST_F(GPURESTEST, copyandsearch) { + // search and copy at the same time + printf("==================\n"); + + index_type = "GPUIVF"; + index_ = IndexFactory(index_type); + + auto conf = ParamGenerator::GetInstance().Gen(ParameterType::ivf); + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + auto result = index_->Search(query_dataset, conf); + AssertAnns(result, nq, k); + + auto cpu_idx = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config()); + cpu_idx->Seal(); + auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + + constexpr int64_t search_count = 50; + constexpr int64_t load_count = 15; + auto search_func = [&] { + // TimeRecorder tc("search&load"); + for (int i = 0; i < search_count; ++i) { + search_idx->Search(query_dataset, conf); + // if (i > search_count - 6 || i == 0) + // tc.RecordSection("search once"); + } + // tc.ElapseFromBegin("search finish"); + }; + auto load_func = [&] { + // TimeRecorder tc("search&load"); + for (int i = 0; i < load_count; ++i) { + knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + // if (i > load_count -5 || i < 5) + // tc.RecordSection("Copy to gpu"); + } + // tc.ElapseFromBegin("load finish"); + }; + + knowhere::TimeRecorder tc("Basic"); + knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + tc.RecordSection("Copy to gpu once"); + search_idx->Search(query_dataset, conf); + tc.RecordSection("Search once"); + search_func(); + tc.RecordSection("Search total cost"); + load_func(); + tc.RecordSection("Copy total cost"); + + std::thread search_thread(search_func); + std::thread load_thread(load_func); + search_thread.join(); + load_thread.join(); + tc.RecordSection("Copy&Search total"); +} + +TEST_F(GPURESTEST, trainandsearch) { + index_type = "GPUIVF"; + index_ = IndexFactory(index_type); + + auto conf = ParamGenerator::GetInstance().Gen(ParameterType::ivf); + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + auto new_index = IndexFactory(index_type); + new_index->set_index_model(model); + new_index->Add(base_dataset, conf); + auto cpu_idx = knowhere::cloner::CopyGpuToCpu(new_index, knowhere::Config()); + cpu_idx->Seal(); + auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + + constexpr int train_count = 5; + constexpr int search_count = 200; + auto train_stage = [&] { + for (int i = 0; i < train_count; ++i) { + auto model = index_->Train(base_dataset, conf); + auto test_idx = IndexFactory(index_type); + test_idx->set_index_model(model); + test_idx->Add(base_dataset, conf); + } + }; + auto search_stage = [&](knowhere::VectorIndexPtr& search_idx) { + for (int i = 0; i < search_count; ++i) { + auto result = search_idx->Search(query_dataset, conf); + AssertAnns(result, nq, k); + } + }; + + // TimeRecorder tc("record"); + // train_stage(); + // tc.RecordSection("train cost"); + // search_stage(search_idx); + // tc.RecordSection("search cost"); + + { + // search and build parallel + std::thread search_thread(search_stage, std::ref(search_idx)); + std::thread train_thread(train_stage); + train_thread.join(); + search_thread.join(); + } + { + // build parallel + std::thread train_1(train_stage); + std::thread train_2(train_stage); + train_1.join(); + train_2.join(); + } + { + // search parallel + auto search_idx_2 = knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + std::thread search_1(search_stage, std::ref(search_idx)); + std::thread search_2(search_stage, std::ref(search_idx_2)); + search_1.join(); + search_2.join(); + } +} + +#ifdef CompareToOriFaiss +TEST_F(GPURESTEST, gpu_ivf_resource_test) { + assert(!xb.empty()); + + { + index_ = std::make_shared(-1); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), -1); + std::dynamic_pointer_cast(index_)->SetGpuDevice(DEVICEID); + ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), DEVICEID); + + auto conf = ParamGenerator::GetInstance().Gen(ParameterType::ivfsq); + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dimension(), dim); + + // knowhere::TimeRecorder tc("knowere GPUIVF"); + for (int i = 0; i < search_count; ++i) { + index_->Search(query_dataset, conf); + if (i > search_count - 6 || i < 5) + // tc.RecordSection("search once"); + } + // tc.ElapseFromBegin("search all"); + } + knowhere::FaissGpuResourceMgr::GetInstance().Dump(); + + // { + // // ori faiss IVF-Search + // faiss::gpu::StandardGpuResources res; + // faiss::gpu::GpuIndexIVFFlatConfig idx_config; + // idx_config.device = DEVICEID; + // faiss::gpu::GpuIndexIVFFlat device_index(&res, dim, 1638, faiss::METRIC_L2, idx_config); + // device_index.train(nb, xb.data()); + // device_index.add(nb, xb.data()); + // + // knowhere::TimeRecorder tc("ori IVF"); + // for (int i = 0; i < search_count; ++i) { + // device_index.search(nq, xq.data(), k, dis, ids); + // if (i > search_count - 6 || i < 5) + // tc.RecordSection("search once"); + // } + // tc.ElapseFromBegin("search all"); + // } +} + +TEST_F(GPURESTEST, gpuivfsq) { + { + // knowhere gpu ivfsq + index_type = "GPUIVFSQ"; + index_ = IndexFactory(index_type); + + auto conf = std::make_shared(); + conf->nlist = 1638; + conf->d = dim; + conf->gpu_id = DEVICEID; + conf->metric_type = knowhere::METRICTYPE::L2; + conf->k = k; + conf->nbits = 8; + conf->nprobe = 1; + + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + index_->set_index_model(model); + index_->Add(base_dataset, conf); + // auto result = index_->Search(query_dataset, conf); + // AssertAnns(result, nq, k); + + auto cpu_idx = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config()); + cpu_idx->Seal(); + + knowhere::TimeRecorder tc("knowhere GPUSQ8"); + auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); + tc.RecordSection("Copy to gpu"); + for (int i = 0; i < search_count; ++i) { + search_idx->Search(query_dataset, conf); + if (i > search_count - 6 || i < 5) + tc.RecordSection("search once"); + } + tc.ElapseFromBegin("search all"); + } + + { + // Ori gpuivfsq Test + const char* index_description = "IVF1638,SQ8"; + faiss::Index* ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2); + + faiss::gpu::StandardGpuResources res; + auto device_index = faiss::gpu::index_cpu_to_gpu(&res, DEVICEID, ori_index); + device_index->train(nb, xb.data()); + device_index->add(nb, xb.data()); + + auto cpu_index = faiss::gpu::index_gpu_to_cpu(device_index); + auto idx = dynamic_cast(cpu_index); + if (idx != nullptr) { + idx->to_readonly(); + } + delete device_index; + delete ori_index; + + faiss::gpu::GpuClonerOptions option; + option.allInGpu = true; + + knowhere::TimeRecorder tc("ori GPUSQ8"); + faiss::Index* search_idx = faiss::gpu::index_cpu_to_gpu(&res, DEVICEID, cpu_index, &option); + tc.RecordSection("Copy to gpu"); + for (int i = 0; i < search_count; ++i) { + search_idx->search(nq, xq.data(), k, dis, ids); + if (i > search_count - 6 || i < 5) + tc.RecordSection("search once"); + } + tc.ElapseFromBegin("search all"); + delete cpu_index; + delete search_idx; + } +} +#endif diff --git a/core/src/index/unittest/test_idmap.cpp b/core/src/index/unittest/test_idmap.cpp index 077f7328f7..d1ff3ee046 100644 --- a/core/src/index/unittest/test_idmap.cpp +++ b/core/src/index/unittest/test_idmap.cpp @@ -23,54 +23,28 @@ #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/helpers/Cloner.h" +#include "Helper.h" #include "unittest/utils.h" -static int device_id = 0; -class IDMAPTest : public DataGen, public ::testing::Test { +class IDMAPTest : public DataGen, public TestGpuIndexBase { protected: void SetUp() override { - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 300, 2); + TestGpuIndexBase::SetUp(); + Init_with_default(); index_ = std::make_shared(); } void TearDown() override { - knowhere::FaissGpuResourceMgr::GetInstance().Free(); + TestGpuIndexBase::TearDown(); } protected: knowhere::IDMAPPtr index_ = nullptr; }; -void -AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); - } -} - -void -PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - auto dists = result->array()[1]; - - std::stringstream ss_id; - std::stringstream ss_dist; - for (auto i = 0; i < 10; i++) { - for (auto j = 0; j < k; ++j) { - ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; - ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; - } - ss_id << std::endl; - ss_dist << std::endl; - } - std::cout << "id\n" << ss_id.str() << std::endl; - std::cout << "dist\n" << ss_dist.str() << std::endl; -} - TEST_F(IDMAPTest, idmap_basic) { ASSERT_TRUE(!xb.empty()); @@ -87,7 +61,7 @@ TEST_F(IDMAPTest, idmap_basic) { ASSERT_TRUE(index_->GetRawIds() != nullptr); auto result = index_->Search(query_dataset, conf); AssertAnns(result, nq, k); - PrintResult(result, nq, k); + // PrintResult(result, nq, k); index_->Seal(); auto binaryset = index_->Serialize(); @@ -95,7 +69,7 @@ TEST_F(IDMAPTest, idmap_basic) { new_index->Load(binaryset); auto re_result = index_->Search(query_dataset, conf); AssertAnns(re_result, nq, k); - PrintResult(re_result, nq, k); + // PrintResult(re_result, nq, k); } TEST_F(IDMAPTest, idmap_serialize) { @@ -118,7 +92,7 @@ TEST_F(IDMAPTest, idmap_serialize) { index_->Add(base_dataset, knowhere::Config()); auto re_result = index_->Search(query_dataset, conf); AssertAnns(re_result, nq, k); - PrintResult(re_result, nq, k); + // PrintResult(re_result, nq, k); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dimension(), dim); auto binaryset = index_->Serialize(); @@ -138,7 +112,7 @@ TEST_F(IDMAPTest, idmap_serialize) { EXPECT_EQ(index_->Dimension(), dim); auto result = index_->Search(query_dataset, conf); AssertAnns(result, nq, k); - PrintResult(result, nq, k); + // PrintResult(result, nq, k); } } @@ -169,7 +143,7 @@ TEST_F(IDMAPTest, copy_test) { { // cpu to gpu - auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, device_id, conf); + auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); auto clone_result = clone_index->Search(query_dataset, conf); AssertAnns(clone_result, nq, k); ASSERT_THROW({ std::static_pointer_cast(clone_index)->GetRawVectors(); }, @@ -194,9 +168,9 @@ TEST_F(IDMAPTest, copy_test) { ASSERT_TRUE(std::static_pointer_cast(host_index)->GetRawIds() != nullptr); // gpu to gpu - auto device_index = knowhere::cloner::CopyCpuToGpu(index_, device_id, conf); + auto device_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf); auto new_device_index = - std::static_pointer_cast(device_index)->CopyGpuToGpu(device_id, conf); + std::static_pointer_cast(device_index)->CopyGpuToGpu(DEVICEID, conf); auto device_result = new_device_index->Search(query_dataset, conf); AssertAnns(device_result, nq, k); } diff --git a/core/src/index/unittest/test_ivf.cpp b/core/src/index/unittest/test_ivf.cpp index 4aaddb108b..fae27b0dd3 100644 --- a/core/src/index/unittest/test_ivf.cpp +++ b/core/src/index/unittest/test_ivf.cpp @@ -35,99 +35,25 @@ #include "knowhere/index/vector_index/IndexIVFSQHybrid.h" #include "knowhere/index/vector_index/helpers/Cloner.h" +#include "unittest/Helper.h" #include "unittest/utils.h" using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -constexpr int device_id = 0; -constexpr int64_t DIM = 128; -constexpr int64_t NB = 1000000 / 100; -constexpr int64_t NQ = 10; -constexpr int64_t K = 10; - -knowhere::IVFIndexPtr -IndexFactory(const std::string& type) { - if (type == "IVF") { - return std::make_shared(); - } else if (type == "IVFPQ") { - return std::make_shared(); - } else if (type == "GPUIVF") { - return std::make_shared(device_id); - } else if (type == "GPUIVFPQ") { - return std::make_shared(device_id); - } else if (type == "IVFSQ") { - return std::make_shared(); - } else if (type == "GPUIVFSQ") { - return std::make_shared(device_id); - } else if (type == "IVFSQHybrid") { - return std::make_shared(device_id); - } -} - -enum class ParameterType { - ivf, - ivfpq, - ivfsq, - nsg, -}; - -class ParamGenerator { - public: - static ParamGenerator& - GetInstance() { - static ParamGenerator instance; - return instance; - } - - knowhere::Config - Gen(const ParameterType& type) { - if (type == ParameterType::ivf) { - auto tempconf = std::make_shared(); - tempconf->d = DIM; - tempconf->gpu_id = device_id; - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->k = K; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } else if (type == ParameterType::ivfpq) { - auto tempconf = std::make_shared(); - tempconf->d = DIM; - tempconf->gpu_id = device_id; - tempconf->nlist = 25; - tempconf->nprobe = 4; - tempconf->k = K; - tempconf->m = 4; - tempconf->nbits = 8; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } else if (type == ParameterType::ivfsq) { - auto tempconf = std::make_shared(); - tempconf->d = DIM; - tempconf->gpu_id = device_id; - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->k = K; - tempconf->nbits = 8; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - } -}; - class IVFTest : public DataGen, public TestWithParam<::std::tuple> { protected: void SetUp() override { + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); + ParameterType parameter_type; std::tie(index_type, parameter_type) = GetParam(); // Init_with_default(); Generate(DIM, NB, NQ); index_ = IndexFactory(index_type); conf = ParamGenerator::GetInstance().Gen(parameter_type); - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 600, 2); } void @@ -140,7 +66,7 @@ class IVFTest : public DataGen, public TestWithParam<::std::tuple gpu_idx{"GPUIVFSQ"}; auto finder = std::find(gpu_idx.cbegin(), gpu_idx.cend(), index_type); if (finder != gpu_idx.cend()) { - return knowhere::cloner::CopyCpuToGpu(index_, device_id, knowhere::Config()); + return knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, knowhere::Config()); } return index_; } @@ -162,33 +88,6 @@ INSTANTIATE_TEST_CASE_P(IVFParameters, IVFTest, #endif std::make_tuple("GPUIVFSQ", ParameterType::ivfsq))); -void -AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); - } -} - -void -PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - auto dists = result->array()[1]; - - std::stringstream ss_id; - std::stringstream ss_dist; - for (auto i = 0; i < 10; i++) { - for (auto j = 0; j < k; ++j) { - ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; - ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; - } - ss_id << std::endl; - ss_dist << std::endl; - } - std::cout << "id\n" << ss_id.str() << std::endl; - std::cout << "dist\n" << ss_dist.str() << std::endl; -} - TEST_P(IVFTest, ivf_basic) { assert(!xb.empty()); @@ -207,85 +106,6 @@ TEST_P(IVFTest, ivf_basic) { // PrintResult(result, nq, k); } -TEST_P(IVFTest, hybrid) { - if (index_type != "IVFSQHybrid") { - return; - } - assert(!xb.empty()); - - auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); - index_->set_preprocessor(preprocessor); - - auto model = index_->Train(base_dataset, conf); - index_->set_index_model(model); - index_->Add(base_dataset, conf); - EXPECT_EQ(index_->Count(), nb); - EXPECT_EQ(index_->Dimension(), dim); - - // auto new_idx = ChooseTodo(); - // auto result = new_idx->Search(query_dataset, conf); - // AssertAnns(result, nq, conf->k); - - { - auto hybrid_1_idx = std::make_shared(device_id); - - auto binaryset = index_->Serialize(); - hybrid_1_idx->Load(binaryset); - - auto quantizer_conf = std::make_shared(); - quantizer_conf->mode = 1; - quantizer_conf->gpu_id = device_id; - auto q = hybrid_1_idx->LoadQuantizer(quantizer_conf); - hybrid_1_idx->SetQuantizer(q); - auto result = hybrid_1_idx->Search(query_dataset, conf); - AssertAnns(result, nq, conf->k); - PrintResult(result, nq, k); - hybrid_1_idx->UnsetQuantizer(); - } - - { - auto hybrid_2_idx = std::make_shared(device_id); - - auto binaryset = index_->Serialize(); - hybrid_2_idx->Load(binaryset); - - auto quantizer_conf = std::make_shared(); - quantizer_conf->mode = 1; - quantizer_conf->gpu_id = device_id; - auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf); - quantizer_conf->mode = 2; - auto gpu_idx = hybrid_2_idx->LoadData(q, quantizer_conf); - - auto result = gpu_idx->Search(query_dataset, conf); - AssertAnns(result, nq, conf->k); - PrintResult(result, nq, k); - } -} - -// TEST_P(IVFTest, gpu_to_cpu) { -// if (index_type.find("GPU") == std::string::npos) { return; } -// -// // else -// assert(!xb.empty()); -// -// auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); -// index_->set_preprocessor(preprocessor); -// -// auto model = index_->Train(base_dataset, conf); -// index_->set_index_model(model); -// index_->Add(base_dataset, conf); -// EXPECT_EQ(index_->Count(), nb); -// EXPECT_EQ(index_->Dimension(), dim); -// auto result = index_->Search(query_dataset, conf); -// AssertAnns(result, nq, k); -// -// if (auto device_index = std::dynamic_pointer_cast(index_)) { -// auto host_index = device_index->Copy_index_gpu_to_cpu(); -// auto result = host_index->Search(query_dataset, conf); -// AssertAnns(result, nq, k); -// } -//} - TEST_P(IVFTest, ivf_serialize) { auto serialize = [](const std::string& filename, knowhere::BinaryPtr& bin, uint8_t* ret) { FileIOWriter writer(filename); @@ -423,7 +243,7 @@ TEST_P(IVFTest, clone_test) { auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); if (finder != support_idx_vec.cend()) { EXPECT_NO_THROW({ - auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, device_id, knowhere::Config()); + auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, knowhere::Config()); auto clone_result = clone_index->Search(query_dataset, conf); AssertEqual(result, clone_result); std::cout << "clone C <=> G [" << index_type << "] success" << std::endl; @@ -432,7 +252,7 @@ TEST_P(IVFTest, clone_test) { EXPECT_THROW( { std::cout << "clone C <=> G [" << index_type << "] failed" << std::endl; - auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, device_id, knowhere::Config()); + auto clone_index = knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, knowhere::Config()); }, knowhere::KnowhereException); } @@ -440,9 +260,7 @@ TEST_P(IVFTest, clone_test) { } #ifdef CUSTOMIZATION -TEST_P(IVFTest, seal_test) { - // FaissGpuResourceMgr::GetInstance().InitDevice(device_id); - +TEST_P(IVFTest, gpu_seal_test) { std::vector support_idx_vec{"GPUIVF", "GPUIVFSQ", "IVFSQHybrid"}; auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); if (finder == support_idx_vec.cend()) { @@ -466,309 +284,13 @@ TEST_P(IVFTest, seal_test) { auto cpu_idx = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config()); knowhere::TimeRecorder tc("CopyToGpu"); - knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); + knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); auto without_seal = tc.RecordSection("Without seal"); cpu_idx->Seal(); tc.RecordSection("seal cost"); - knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); + knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, knowhere::Config()); auto with_seal = tc.RecordSection("With seal"); ASSERT_GE(without_seal, with_seal); } + #endif - -class GPURESTEST : public DataGen, public ::testing::Test { - protected: - void - SetUp() override { - Generate(128, 1000000, 1000); - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(device_id, 1024 * 1024 * 200, 1024 * 1024 * 300, 2); - - k = 100; - elems = nq * k; - ids = (int64_t*)malloc(sizeof(int64_t) * elems); - dis = (float*)malloc(sizeof(float) * elems); - } - - void - TearDown() override { - delete ids; - delete dis; - knowhere::FaissGpuResourceMgr::GetInstance().Free(); - } - - protected: - std::string index_type; - knowhere::IVFIndexPtr index_ = nullptr; - - int64_t* ids = nullptr; - float* dis = nullptr; - int64_t elems = 0; -}; - -const int search_count = 18; -const int load_count = 3; - -TEST_F(GPURESTEST, gpu_ivf_resource_test) { - assert(!xb.empty()); - - { - index_ = std::make_shared(-1); - ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), -1); - std::dynamic_pointer_cast(index_)->SetGpuDevice(device_id); - ASSERT_EQ(std::dynamic_pointer_cast(index_)->GetGpuDevice(), device_id); - - auto conf = std::make_shared(); - conf->nlist = 1638; - conf->d = dim; - conf->gpu_id = device_id; - conf->metric_type = knowhere::METRICTYPE::L2; - conf->k = k; - conf->nprobe = 1; - - auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); - index_->set_preprocessor(preprocessor); - auto model = index_->Train(base_dataset, conf); - index_->set_index_model(model); - index_->Add(base_dataset, conf); - EXPECT_EQ(index_->Count(), nb); - EXPECT_EQ(index_->Dimension(), dim); - - knowhere::TimeRecorder tc("knowere GPUIVF"); - for (int i = 0; i < search_count; ++i) { - index_->Search(query_dataset, conf); - if (i > search_count - 6 || i < 5) - tc.RecordSection("search once"); - } - tc.ElapseFromBegin("search all"); - } - knowhere::FaissGpuResourceMgr::GetInstance().Dump(); - - { - // IVF-Search - faiss::gpu::StandardGpuResources res; - faiss::gpu::GpuIndexIVFFlatConfig idx_config; - idx_config.device = device_id; - faiss::gpu::GpuIndexIVFFlat device_index(&res, dim, 1638, faiss::METRIC_L2, idx_config); - device_index.train(nb, xb.data()); - device_index.add(nb, xb.data()); - - knowhere::TimeRecorder tc("ori IVF"); - for (int i = 0; i < search_count; ++i) { - device_index.search(nq, xq.data(), k, dis, ids); - if (i > search_count - 6 || i < 5) - tc.RecordSection("search once"); - } - tc.ElapseFromBegin("search all"); - } -} - -#ifdef CUSTOMIZATION -TEST_F(GPURESTEST, gpuivfsq) { - { - // knowhere gpu ivfsq - index_type = "GPUIVFSQ"; - index_ = IndexFactory(index_type); - - auto conf = std::make_shared(); - conf->nlist = 1638; - conf->d = dim; - conf->gpu_id = device_id; - conf->metric_type = knowhere::METRICTYPE::L2; - conf->k = k; - conf->nbits = 8; - conf->nprobe = 1; - - auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); - index_->set_preprocessor(preprocessor); - auto model = index_->Train(base_dataset, conf); - index_->set_index_model(model); - index_->Add(base_dataset, conf); - // auto result = index_->Search(query_dataset, conf); - // AssertAnns(result, nq, k); - - auto cpu_idx = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config()); - cpu_idx->Seal(); - - knowhere::TimeRecorder tc("knowhere GPUSQ8"); - auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - tc.RecordSection("Copy to gpu"); - for (int i = 0; i < search_count; ++i) { - search_idx->Search(query_dataset, conf); - if (i > search_count - 6 || i < 5) - tc.RecordSection("search once"); - } - tc.ElapseFromBegin("search all"); - } - - { - // Ori gpuivfsq Test - const char* index_description = "IVF1638,SQ8"; - faiss::Index* ori_index = faiss::index_factory(dim, index_description, faiss::METRIC_L2); - - faiss::gpu::StandardGpuResources res; - auto device_index = faiss::gpu::index_cpu_to_gpu(&res, device_id, ori_index); - device_index->train(nb, xb.data()); - device_index->add(nb, xb.data()); - - auto cpu_index = faiss::gpu::index_gpu_to_cpu(device_index); - auto idx = dynamic_cast(cpu_index); - if (idx != nullptr) { - idx->to_readonly(); - } - delete device_index; - delete ori_index; - - faiss::gpu::GpuClonerOptions option; - option.allInGpu = true; - - knowhere::TimeRecorder tc("ori GPUSQ8"); - faiss::Index* search_idx = faiss::gpu::index_cpu_to_gpu(&res, device_id, cpu_index, &option); - tc.RecordSection("Copy to gpu"); - for (int i = 0; i < search_count; ++i) { - search_idx->search(nq, xq.data(), k, dis, ids); - if (i > search_count - 6 || i < 5) - tc.RecordSection("search once"); - } - tc.ElapseFromBegin("search all"); - delete cpu_index; - delete search_idx; - } -} -#endif - -TEST_F(GPURESTEST, copyandsearch) { - // search and copy at the same time - printf("==================\n"); - - index_type = "GPUIVF"; - index_ = IndexFactory(index_type); - - auto conf = std::make_shared(); - conf->nlist = 1638; - conf->d = dim; - conf->gpu_id = device_id; - conf->metric_type = knowhere::METRICTYPE::L2; - conf->k = k; - conf->nbits = 8; - conf->nprobe = 1; - - auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); - index_->set_preprocessor(preprocessor); - auto model = index_->Train(base_dataset, conf); - index_->set_index_model(model); - index_->Add(base_dataset, conf); - // auto result = index_->Search(query_dataset, conf); - // AssertAnns(result, nq, k); - - auto cpu_idx = knowhere::cloner::CopyGpuToCpu(index_, knowhere::Config()); - cpu_idx->Seal(); - - auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - - auto search_func = [&] { - // TimeRecorder tc("search&load"); - for (int i = 0; i < search_count; ++i) { - search_idx->Search(query_dataset, conf); - // if (i > search_count - 6 || i == 0) - // tc.RecordSection("search once"); - } - // tc.ElapseFromBegin("search finish"); - }; - auto load_func = [&] { - // TimeRecorder tc("search&load"); - for (int i = 0; i < load_count; ++i) { - knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - // if (i > load_count -5 || i < 5) - // tc.RecordSection("Copy to gpu"); - } - // tc.ElapseFromBegin("load finish"); - }; - - knowhere::TimeRecorder tc("basic"); - knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - tc.RecordSection("Copy to gpu once"); - search_idx->Search(query_dataset, conf); - tc.RecordSection("search once"); - search_func(); - tc.RecordSection("only search total"); - load_func(); - tc.RecordSection("only copy total"); - - std::thread search_thread(search_func); - std::thread load_thread(load_func); - search_thread.join(); - load_thread.join(); - tc.RecordSection("Copy&search total"); -} - -TEST_F(GPURESTEST, TrainAndSearch) { - index_type = "GPUIVF"; - index_ = IndexFactory(index_type); - - auto conf = std::make_shared(); - conf->nlist = 1638; - conf->d = dim; - conf->gpu_id = device_id; - conf->metric_type = knowhere::METRICTYPE::L2; - conf->k = k; - conf->nbits = 8; - conf->nprobe = 1; - - auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); - index_->set_preprocessor(preprocessor); - auto model = index_->Train(base_dataset, conf); - auto new_index = IndexFactory(index_type); - new_index->set_index_model(model); - new_index->Add(base_dataset, conf); - auto cpu_idx = knowhere::cloner::CopyGpuToCpu(new_index, knowhere::Config()); - cpu_idx->Seal(); - auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - - constexpr int train_count = 1; - constexpr int search_count = 5000; - auto train_stage = [&] { - for (int i = 0; i < train_count; ++i) { - auto model = index_->Train(base_dataset, conf); - auto test_idx = IndexFactory(index_type); - test_idx->set_index_model(model); - test_idx->Add(base_dataset, conf); - } - }; - auto search_stage = [&](knowhere::VectorIndexPtr& search_idx) { - for (int i = 0; i < search_count; ++i) { - auto result = search_idx->Search(query_dataset, conf); - AssertAnns(result, nq, k); - } - }; - - // TimeRecorder tc("record"); - // train_stage(); - // tc.RecordSection("train cost"); - // search_stage(search_idx); - // tc.RecordSection("search cost"); - - { - // search and build parallel - std::thread search_thread(search_stage, std::ref(search_idx)); - std::thread train_thread(train_stage); - train_thread.join(); - search_thread.join(); - } - { - // build parallel - std::thread train_1(train_stage); - std::thread train_2(train_stage); - train_1.join(); - train_2.join(); - } - { - // search parallel - auto search_idx_2 = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); - std::thread search_1(search_stage, std::ref(search_idx)); - std::thread search_2(search_stage, std::ref(search_idx_2)); - search_1.join(); - search_2.join(); - } -} - -// TODO(lxj): Add exception test diff --git a/core/src/index/unittest/test_kdt.cpp b/core/src/index/unittest/test_kdt.cpp index f9058cc4d2..5400881875 100644 --- a/core/src/index/unittest/test_kdt.cpp +++ b/core/src/index/unittest/test_kdt.cpp @@ -52,33 +52,6 @@ class KDTTest : public DataGen, public ::testing::Test { std::shared_ptr index_ = nullptr; }; -void -AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); - } -} - -void -PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - auto dists = result->array()[1]; - - std::stringstream ss_id; - std::stringstream ss_dist; - for (auto i = 0; i < 10; i++) { - for (auto j = 0; j < k; ++j) { - ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; - ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; - } - ss_id << std::endl; - ss_dist << std::endl; - } - std::cout << "id\n" << ss_id.str() << std::endl; - std::cout << "dist\n" << ss_dist.str() << std::endl; -} - // TODO(lxj): add test about count() and dimension() TEST_F(KDTTest, kdt_basic) { assert(!xb.empty()); diff --git a/core/src/index/unittest/test_nsg/test_nsg.cpp b/core/src/index/unittest/test_nsg/test_nsg.cpp index 657387f219..11b9becce4 100644 --- a/core/src/index/unittest/test_nsg/test_nsg.cpp +++ b/core/src/index/unittest/test_nsg/test_nsg.cpp @@ -30,19 +30,19 @@ using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -constexpr int64_t DEVICE_ID = 1; +constexpr int64_t DEVICEID = 0; class NSGInterfaceTest : public DataGen, public ::testing::Test { protected: void SetUp() override { // Init_with_default(); - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, 1024 * 1024 * 200, 1024 * 1024 * 600, 2); + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, 1024 * 1024 * 200, 1024 * 1024 * 600, 2); Generate(256, 1000000 / 100, 1); index_ = std::make_shared(); auto tmp_conf = std::make_shared(); - tmp_conf->gpu_id = DEVICE_ID; + tmp_conf->gpu_id = DEVICEID; tmp_conf->knng = 20; tmp_conf->nprobe = 8; tmp_conf->nlist = 163; @@ -69,14 +69,6 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test { knowhere::Config search_conf; }; -void -AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) { - auto ids = result->array()[0]; - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); - } -} - TEST_F(NSGInterfaceTest, basic_test) { assert(!xb.empty()); diff --git a/core/src/index/unittest/utils.cpp b/core/src/index/unittest/utils.cpp index cdfc56b1cb..d4a59bafbb 100644 --- a/core/src/index/unittest/utils.cpp +++ b/core/src/index/unittest/utils.cpp @@ -17,6 +17,7 @@ #include "unittest/utils.h" +#include #include #include #include @@ -147,3 +148,30 @@ generate_query_dataset(int64_t nb, int64_t dim, float* xb) { auto dataset = std::make_shared(std::move(tensors), tensor_schema); return dataset; } + +void +AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k) { + auto ids = result->array()[0]; + for (auto i = 0; i < nq; i++) { + EXPECT_EQ(i, *(ids->data()->GetValues(1, i * k))); + } +} + +void +PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) { + auto ids = result->array()[0]; + auto dists = result->array()[1]; + + std::stringstream ss_id; + std::stringstream ss_dist; + for (auto i = 0; i < 10; i++) { + for (auto j = 0; j < k; ++j) { + ss_id << *(ids->data()->GetValues(1, i * k + j)) << " "; + ss_dist << *(dists->data()->GetValues(1, i * k + j)) << " "; + } + ss_id << std::endl; + ss_dist << std::endl; + } + std::cout << "id\n" << ss_id.str() << std::endl; + std::cout << "dist\n" << ss_dist.str() << std::endl; +} diff --git a/core/src/index/unittest/utils.h b/core/src/index/unittest/utils.h index acc3e89183..b39cf9ea14 100644 --- a/core/src/index/unittest/utils.h +++ b/core/src/index/unittest/utils.h @@ -68,6 +68,12 @@ generate_dataset(int64_t nb, int64_t dim, float* xb, int64_t* ids); knowhere::DatasetPtr generate_query_dataset(int64_t nb, int64_t dim, float* xb); +void +AssertAnns(const knowhere::DatasetPtr& result, const int& nq, const int& k); + +void +PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k); + struct FileIOWriter { std::fstream fs; std::string name; diff --git a/core/src/main.cpp b/core/src/main.cpp index 769483acd3..d1c9ba6dfd 100644 --- a/core/src/main.cpp +++ b/core/src/main.cpp @@ -50,7 +50,7 @@ print_banner() { std::cout << " / /|_/ // // /_| |/ / /_/ /\\ \\ " << std::endl; std::cout << " /_/ /_/___/____/___/\\____/___/ " << std::endl; std::cout << std::endl; - std::cout << "Welcome to use Milvus by Zilliz!" << std::endl; + std::cout << "Welcome to Milvus!" << std::endl; std::cout << "Milvus " << BUILD_TYPE << " version: v" << MILVUS_VERSION << ", built at " << BUILD_TIME << std::endl; std::cout << std::endl; } diff --git a/core/src/scheduler/SchedInst.cpp b/core/src/scheduler/SchedInst.cpp index 0053332746..f3f293a0f3 100644 --- a/core/src/scheduler/SchedInst.cpp +++ b/core/src/scheduler/SchedInst.cpp @@ -50,29 +50,35 @@ load_simple_config() { std::string mode; config.GetResourceConfigMode(mode); std::vector pool; - config.GetResourceConfigPool(pool); + config.GetResourceConfigSearchResources(pool); // get resources - bool use_cpu_to_compute = false; - for (auto& resource : pool) { - if (resource == "cpu") { - use_cpu_to_compute = true; - break; - } - } auto gpu_ids = get_gpu_pool(); + int32_t build_gpu_id; + config.GetResourceConfigIndexBuildDevice(build_gpu_id); + // create and connect ResMgrInst::GetInstance()->Add(ResourceFactory::Create("disk", "DISK", 0, true, false)); auto io = Connection("io", 500); - ResMgrInst::GetInstance()->Add(ResourceFactory::Create("cpu", "CPU", 0, true, use_cpu_to_compute)); + ResMgrInst::GetInstance()->Add(ResourceFactory::Create("cpu", "CPU", 0, true, true)); ResMgrInst::GetInstance()->Connect("disk", "cpu", io); auto pcie = Connection("pcie", 12000); + bool find_build_gpu_id = false; for (auto& gpu_id : gpu_ids) { ResMgrInst::GetInstance()->Add(ResourceFactory::Create(std::to_string(gpu_id), "GPU", gpu_id, true, true)); ResMgrInst::GetInstance()->Connect("cpu", std::to_string(gpu_id), pcie); + if (build_gpu_id == gpu_id) { + find_build_gpu_id = true; + } + } + + if (not find_build_gpu_id) { + ResMgrInst::GetInstance()->Add( + ResourceFactory::Create(std::to_string(build_gpu_id), "GPU", build_gpu_id, true, true)); + ResMgrInst::GetInstance()->Connect("cpu", std::to_string(build_gpu_id), pcie); } } diff --git a/core/src/scheduler/Utils.cpp b/core/src/scheduler/Utils.cpp index 18f6fc249d..998e545ba5 100644 --- a/core/src/scheduler/Utils.cpp +++ b/core/src/scheduler/Utils.cpp @@ -48,7 +48,7 @@ get_gpu_pool() { server::Config& config = server::Config::GetInstance(); std::vector pool; - Status s = config.GetResourceConfigPool(pool); + Status s = config.GetResourceConfigSearchResources(pool); if (!s.ok()) { SERVER_LOG_ERROR << s.message(); } diff --git a/core/src/scheduler/action/PushTaskToNeighbour.cpp b/core/src/scheduler/action/PushTaskToNeighbour.cpp index 4e7dbf984f..c64e81dcfa 100644 --- a/core/src/scheduler/action/PushTaskToNeighbour.cpp +++ b/core/src/scheduler/action/PushTaskToNeighbour.cpp @@ -184,11 +184,11 @@ Action::SpecifiedResourceLabelTaskScheduler(ResourceMgrWPtr res_mgr, ResourcePtr // get build index gpu resource server::Config& config = server::Config::GetInstance(); int32_t build_index_gpu; - Status stat = config.GetDBConfigBuildIndexGPU(build_index_gpu); + Status stat = config.GetResourceConfigIndexBuildDevice(build_index_gpu); bool find_gpu_res = false; - for (uint64_t i = 0; i < compute_resources.size(); ++i) { - if (res_mgr.lock()->GetResource(ResourceType::GPU, build_index_gpu) != nullptr) { + if (res_mgr.lock()->GetResource(ResourceType::GPU, build_index_gpu) != nullptr) { + for (uint64_t i = 0; i < compute_resources.size(); ++i) { if (compute_resources[i]->name() == res_mgr.lock()->GetResource(ResourceType::GPU, build_index_gpu)->name()) { find_gpu_res = true; diff --git a/core/src/scheduler/optimizer/LargeSQ8HPass.cpp b/core/src/scheduler/optimizer/LargeSQ8HPass.cpp index 62d0e57902..8368a90000 100644 --- a/core/src/scheduler/optimizer/LargeSQ8HPass.cpp +++ b/core/src/scheduler/optimizer/LargeSQ8HPass.cpp @@ -26,48 +26,48 @@ namespace milvus { namespace scheduler { -bool -LargeSQ8HPass::Run(const TaskPtr& task) { - if (task->Type() != TaskType::SearchTask) { - return false; - } - - auto search_task = std::static_pointer_cast(task); - if (search_task->file_->engine_type_ != (int)engine::EngineType::FAISS_IVFSQ8H) { - return false; - } - - auto search_job = std::static_pointer_cast(search_task->job_.lock()); - - // TODO: future, Index::IVFSQ8H, if nq < threshold set cpu, else set gpu - if (search_job->nq() < 100) { - return false; - } - - std::vector gpus = scheduler::get_gpu_pool(); - std::vector all_free_mem; - for (auto& gpu : gpus) { - auto cache = cache::GpuCacheMgr::GetInstance(gpu); - auto free_mem = cache->CacheCapacity() - cache->CacheUsage(); - all_free_mem.push_back(free_mem); - } - - auto max_e = std::max_element(all_free_mem.begin(), all_free_mem.end()); - auto best_index = std::distance(all_free_mem.begin(), max_e); - auto best_device_id = gpus[best_index]; - - ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource(ResourceType::GPU, best_device_id); - if (not res_ptr) { - SERVER_LOG_ERROR << "GpuResource " << best_device_id << " invalid."; - // TODO: throw critical error and exit - return false; - } - - auto label = std::make_shared(std::weak_ptr(res_ptr)); - task->label() = label; - - return true; -} +// bool +// LargeSQ8HPass::Run(const TaskPtr& task) { +// if (task->Type() != TaskType::SearchTask) { +// return false; +// } +// +// auto search_task = std::static_pointer_cast(task); +// if (search_task->file_->engine_type_ != (int)engine::EngineType::FAISS_IVFSQ8H) { +// return false; +// } +// +// auto search_job = std::static_pointer_cast(search_task->job_.lock()); +// +// // TODO: future, Index::IVFSQ8H, if nq < threshold set cpu, else set gpu +// if (search_job->nq() < 100) { +// return false; +// } +// +// std::vector gpus = scheduler::get_gpu_pool(); +// std::vector all_free_mem; +// for (auto& gpu : gpus) { +// auto cache = cache::GpuCacheMgr::GetInstance(gpu); +// auto free_mem = cache->CacheCapacity() - cache->CacheUsage(); +// all_free_mem.push_back(free_mem); +// } +// +// auto max_e = std::max_element(all_free_mem.begin(), all_free_mem.end()); +// auto best_index = std::distance(all_free_mem.begin(), max_e); +// auto best_device_id = gpus[best_index]; +// +// ResourcePtr res_ptr = ResMgrInst::GetInstance()->GetResource(ResourceType::GPU, best_device_id); +// if (not res_ptr) { +// SERVER_LOG_ERROR << "GpuResource " << best_device_id << " invalid."; +// // TODO: throw critical error and exit +// return false; +// } +// +// auto label = std::make_shared(std::weak_ptr(res_ptr)); +// task->label() = label; +// +// return true; +// } } // namespace scheduler } // namespace milvus diff --git a/core/src/scheduler/optimizer/LargeSQ8HPass.h b/core/src/scheduler/optimizer/LargeSQ8HPass.h index 49e658002f..3335a37cc7 100644 --- a/core/src/scheduler/optimizer/LargeSQ8HPass.h +++ b/core/src/scheduler/optimizer/LargeSQ8HPass.h @@ -37,8 +37,8 @@ class LargeSQ8HPass : public Pass { LargeSQ8HPass() = default; public: - bool - Run(const TaskPtr& task) override; + // bool + // Run(const TaskPtr& task) override; }; using LargeSQ8HPassPtr = std::shared_ptr; diff --git a/core/src/scheduler/optimizer/Optimizer.cpp b/core/src/scheduler/optimizer/Optimizer.cpp index c5fa311a27..46f24ea712 100644 --- a/core/src/scheduler/optimizer/Optimizer.cpp +++ b/core/src/scheduler/optimizer/Optimizer.cpp @@ -20,12 +20,12 @@ namespace milvus { namespace scheduler { -void -Optimizer::Init() { - for (auto& pass : pass_list_) { - pass->Init(); - } -} +// void +// Optimizer::Init() { +// for (auto& pass : pass_list_) { +// pass->Init(); +// } +// } bool Optimizer::Run(const TaskPtr& task) { diff --git a/core/src/scheduler/optimizer/Optimizer.h b/core/src/scheduler/optimizer/Optimizer.h index 68b519e115..bfabbf7de3 100644 --- a/core/src/scheduler/optimizer/Optimizer.h +++ b/core/src/scheduler/optimizer/Optimizer.h @@ -38,8 +38,8 @@ class Optimizer { explicit Optimizer(std::vector pass_list) : pass_list_(std::move(pass_list)) { } - void - Init(); + // void + // Init(); bool Run(const TaskPtr& task); diff --git a/core/src/scheduler/optimizer/Pass.h b/core/src/scheduler/optimizer/Pass.h index 959c3ea5ee..016b05e457 100644 --- a/core/src/scheduler/optimizer/Pass.h +++ b/core/src/scheduler/optimizer/Pass.h @@ -34,9 +34,9 @@ namespace scheduler { class Pass { public: - virtual void - Init() { - } + // virtual void + // Init() { + // } virtual bool Run(const TaskPtr& task) = 0; diff --git a/core/src/scheduler/task/BuildIndexTask.cpp b/core/src/scheduler/task/BuildIndexTask.cpp index 25d3d73a7b..d8602c141e 100644 --- a/core/src/scheduler/task/BuildIndexTask.cpp +++ b/core/src/scheduler/task/BuildIndexTask.cpp @@ -55,9 +55,6 @@ XBuildIndexTask::Load(milvus::scheduler::LoadType type, uint8_t device_id) { } else if (type == LoadType::CPU2GPU) { stat = to_index_engine_->CopyToIndexFileToGpu(device_id); type_str = "CPU2GPU"; - } else if (type == LoadType::GPU2CPU) { - stat = to_index_engine_->CopyToCpu(); - type_str = "GPU2CPU"; } else { error_msg = "Wrong load type"; stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); @@ -137,6 +134,7 @@ XBuildIndexTask::Execute() { ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete"; + build_index_job->BuildIndexDone(to_index_id_); to_index_engine_ = nullptr; return; } @@ -151,6 +149,7 @@ XBuildIndexTask::Execute() { std::cout << "ERROR: failed to build index, index file is too large or gpu memory is not enough" << std::endl; + build_index_job->BuildIndexDone(to_index_id_); build_index_job->GetStatus() = Status(DB_ERROR, msg); to_index_engine_ = nullptr; return; @@ -161,6 +160,9 @@ XBuildIndexTask::Execute() { meta_ptr->HasTable(file_->table_id_, has_table); if (!has_table) { meta_ptr->DeleteTableFiles(file_->table_id_); + + build_index_job->BuildIndexDone(to_index_id_); + build_index_job->GetStatus() = Status(DB_ERROR, "Table has been deleted, discard index file."); to_index_engine_ = nullptr; return; } @@ -180,6 +182,7 @@ XBuildIndexTask::Execute() { std::cout << "ERROR: failed to persist index file: " << table_file.location_ << ", possible out of disk space" << std::endl; + build_index_job->BuildIndexDone(to_index_id_); build_index_job->GetStatus() = Status(DB_ERROR, msg); to_index_engine_ = nullptr; return; @@ -199,8 +202,9 @@ XBuildIndexTask::Execute() { ENGINE_LOG_DEBUG << "New index file " << table_file.file_id_ << " of size " << index->PhysicalSize() << " bytes" << " from file " << origin_file.file_id_; - - // index->Cache(); + if (build_index_job->options().insert_cache_immediately_) { + index->Cache(); + } } else { // failed to update meta, mark the new file as to_delete, don't delete old file origin_file.file_type_ = engine::meta::TableFileSchema::TO_INDEX; diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index b7a1e211d2..1bf1caff76 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -253,7 +253,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, const s if (result[i].empty()) { result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); - uint64_t input_k_multi_i = input_k * i; + uint64_t input_k_multi_i = topk * i; for (auto k = 0; k < input_k; ++k) { uint64_t idx = input_k_multi_i + k; auto& result_buf_item = result_buf[k]; @@ -266,7 +266,7 @@ XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, const s result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); size_t buf_k = 0, src_k = 0, tar_k = 0; uint64_t src_idx; - uint64_t input_k_multi_i = input_k * i; + uint64_t input_k_multi_i = topk * i; while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { src_idx = input_k_multi_i + src_k; auto& result_buf_item = result_buf[buf_k]; @@ -307,71 +307,71 @@ XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, const s } } -void -XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, - const std::vector& src_ids, const std::vector& src_distance, - uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { - if (src_ids.empty() || src_distance.empty()) { - return; - } - - uint64_t output_k = std::min(topk, tar_input_k + src_input_k); - std::vector id_buf(nq * output_k, -1); - std::vector dist_buf(nq * output_k, 0.0); - - uint64_t buf_k, src_k, tar_k; - uint64_t src_idx, tar_idx, buf_idx; - uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; - - for (uint64_t i = 0; i < nq; i++) { - src_input_k_multi_i = src_input_k * i; - tar_input_k_multi_i = tar_input_k * i; - buf_k_multi_i = output_k * i; - buf_k = src_k = tar_k = 0; - while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { - src_idx = src_input_k_multi_i + src_k; - tar_idx = tar_input_k_multi_i + tar_k; - buf_idx = buf_k_multi_i + buf_k; - if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || - (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { - id_buf[buf_idx] = src_ids[src_idx]; - dist_buf[buf_idx] = src_distance[src_idx]; - src_k++; - } else { - id_buf[buf_idx] = tar_ids[tar_idx]; - dist_buf[buf_idx] = tar_distance[tar_idx]; - tar_k++; - } - buf_k++; - } - - if (buf_k < output_k) { - if (src_k < src_input_k) { - while (buf_k < output_k && src_k < src_input_k) { - src_idx = src_input_k_multi_i + src_k; - buf_idx = buf_k_multi_i + buf_k; - id_buf[buf_idx] = src_ids[src_idx]; - dist_buf[buf_idx] = src_distance[src_idx]; - src_k++; - buf_k++; - } - } else { - while (buf_k < output_k && tar_k < tar_input_k) { - tar_idx = tar_input_k_multi_i + tar_k; - buf_idx = buf_k_multi_i + buf_k; - id_buf[buf_idx] = tar_ids[tar_idx]; - dist_buf[buf_idx] = tar_distance[tar_idx]; - tar_k++; - buf_k++; - } - } - } - } - - tar_ids.swap(id_buf); - tar_distance.swap(dist_buf); - tar_input_k = output_k; -} +// void +// XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, +// const std::vector& src_ids, const std::vector& src_distance, +// uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { +// if (src_ids.empty() || src_distance.empty()) { +// return; +// } +// +// uint64_t output_k = std::min(topk, tar_input_k + src_input_k); +// std::vector id_buf(nq * output_k, -1); +// std::vector dist_buf(nq * output_k, 0.0); +// +// uint64_t buf_k, src_k, tar_k; +// uint64_t src_idx, tar_idx, buf_idx; +// uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; +// +// for (uint64_t i = 0; i < nq; i++) { +// src_input_k_multi_i = src_input_k * i; +// tar_input_k_multi_i = tar_input_k * i; +// buf_k_multi_i = output_k * i; +// buf_k = src_k = tar_k = 0; +// while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { +// src_idx = src_input_k_multi_i + src_k; +// tar_idx = tar_input_k_multi_i + tar_k; +// buf_idx = buf_k_multi_i + buf_k; +// if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || +// (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { +// id_buf[buf_idx] = src_ids[src_idx]; +// dist_buf[buf_idx] = src_distance[src_idx]; +// src_k++; +// } else { +// id_buf[buf_idx] = tar_ids[tar_idx]; +// dist_buf[buf_idx] = tar_distance[tar_idx]; +// tar_k++; +// } +// buf_k++; +// } +// +// if (buf_k < output_k) { +// if (src_k < src_input_k) { +// while (buf_k < output_k && src_k < src_input_k) { +// src_idx = src_input_k_multi_i + src_k; +// buf_idx = buf_k_multi_i + buf_k; +// id_buf[buf_idx] = src_ids[src_idx]; +// dist_buf[buf_idx] = src_distance[src_idx]; +// src_k++; +// buf_k++; +// } +// } else { +// while (buf_k < output_k && tar_k < tar_input_k) { +// tar_idx = tar_input_k_multi_i + tar_k; +// buf_idx = buf_k_multi_i + buf_k; +// id_buf[buf_idx] = tar_ids[tar_idx]; +// dist_buf[buf_idx] = tar_distance[tar_idx]; +// tar_k++; +// buf_k++; +// } +// } +// } +// } +// +// tar_ids.swap(id_buf); +// tar_distance.swap(dist_buf); +// tar_input_k = output_k; +//} } // namespace scheduler } // namespace milvus diff --git a/core/src/scheduler/task/SearchTask.h b/core/src/scheduler/task/SearchTask.h index 6a7381e0e6..bbc8b5bd8f 100644 --- a/core/src/scheduler/task/SearchTask.h +++ b/core/src/scheduler/task/SearchTask.h @@ -42,10 +42,10 @@ class XSearchTask : public Task { MergeTopkToResultSet(const std::vector& input_ids, const std::vector& input_distance, uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); - static void - MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, - const std::vector& src_ids, const std::vector& src_distance, uint64_t src_input_k, - uint64_t nq, uint64_t topk, bool ascending); + // static void + // MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, + // const std::vector& src_ids, const std::vector& src_distance, uint64_t + // src_input_k, uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; diff --git a/core/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/core/src/sdk/examples/grpcsimple/src/ClientTest.cpp index ce511714b2..069283200f 100644 --- a/core/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/core/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -40,8 +40,10 @@ constexpr int64_t BATCH_ROW_COUNT = 100000; constexpr int64_t NQ = 5; constexpr int64_t TOP_K = 10; constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different -constexpr int64_t ADD_VECTOR_LOOP = 1; +constexpr int64_t ADD_VECTOR_LOOP = 5; constexpr int64_t SECONDS_EACH_HOUR = 3600; +constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::gpu_ivfsq8; +constexpr int32_t N_LIST = 15000; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl; @@ -311,8 +313,8 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::cout << "Wait until create all index done" << std::endl; milvus::IndexParam index; index.table_name = TABLE_NAME; - index.index_type = milvus::IndexType::gpu_ivfsq8; - index.nlist = 16384; + index.index_type = INDEX_TYPE; + index.nlist = N_LIST; milvus::Status stat = conn->CreateIndex(index); std::cout << "CreateIndex function call status: " << stat.message() << std::endl; @@ -344,8 +346,8 @@ ClientTest::Test(const std::string& address, const std::string& port) { { // delete by range milvus::Range rg; - rg.start_value = CurrentTmDate(-2); - rg.end_value = CurrentTmDate(-3); + rg.start_value = CurrentTmDate(-3); + rg.end_value = CurrentTmDate(-2); milvus::Status stat = conn->DeleteByRange(rg, TABLE_NAME); std::cout << "DeleteByRange function call status: " << stat.message() << std::endl; diff --git a/core/src/sdk/interface/Status.cpp b/core/src/sdk/interface/Status.cpp index a5e89556f2..9ccbabfd20 100644 --- a/core/src/sdk/interface/Status.cpp +++ b/core/src/sdk/interface/Status.cpp @@ -88,7 +88,7 @@ Status::MoveFrom(Status& s) { std::string Status::message() const { if (state_ == nullptr) { - return ""; + return "OK"; } std::string msg; diff --git a/core/src/server/Config.cpp b/core/src/server/Config.cpp index 189070eb2b..7de84cbccc 100644 --- a/core/src/server/Config.cpp +++ b/core/src/server/Config.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -129,12 +130,6 @@ Config::ValidateConfig() { return s; } - int32_t db_build_index_gpu; - s = GetDBConfigBuildIndexGPU(db_build_index_gpu); - if (!s.ok()) { - return s; - } - /* metric config */ bool metric_enable_monitor; s = GetMetricConfigEnableMonitor(metric_enable_monitor); @@ -205,8 +200,14 @@ Config::ValidateConfig() { return s; } - std::vector resource_pool; - s = GetResourceConfigPool(resource_pool); + std::vector search_resources; + s = GetResourceConfigSearchResources(search_resources); + if (!s.ok()) { + return s; + } + + int32_t resource_index_build_device; + s = GetResourceConfigIndexBuildDevice(resource_index_build_device); if (!s.ok()) { return s; } @@ -270,11 +271,6 @@ Config::ResetDefaultConfig() { return s; } - s = SetDBConfigBuildIndexGPU(CONFIG_DB_BUILD_INDEX_GPU_DEFAULT); - if (!s.ok()) { - return s; - } - /* metric config */ s = SetMetricConfigEnableMonitor(CONFIG_METRIC_ENABLE_MONITOR_DEFAULT); if (!s.ok()) { @@ -334,6 +330,11 @@ Config::ResetDefaultConfig() { return s; } + s = SetResourceConfigIndexBuildDevice(CONFIG_RESOURCE_INDEX_BUILD_DEVICE_DEFAULT); + if (!s.ok()) { + return s; + } + return Status::OK(); } @@ -459,19 +460,6 @@ Config::CheckDBConfigInsertBufferSize(const std::string& value) { return Status::OK(); } -Status -Config::CheckDBConfigBuildIndexGPU(const std::string& value) { - if (!ValidationUtil::ValidateStringIsNumber(value).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config build_index_gpu: " + value); - } else { - int32_t gpu_index = std::stoi(value); - if (!ValidationUtil::ValidateGpuIndex(gpu_index).ok()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid DB config build_index_gpu: " + value); - } - } - return Status::OK(); -} - Status Config::CheckMetricConfigEnableMonitor(const std::string& value) { if (!ValidationUtil::ValidateStringIsBool(value).ok()) { @@ -544,7 +532,7 @@ Config::CheckCacheConfigGpuCacheCapacity(const std::string& value) { } else { uint64_t gpu_cache_capacity = std::stoi(value) * GB; int gpu_index; - Status s = GetDBConfigBuildIndexGPU(gpu_index); + Status s = GetResourceConfigIndexBuildDevice(gpu_index); if (!s.ok()) { return s; } @@ -616,9 +604,38 @@ Config::CheckResourceConfigMode(const std::string& value) { } Status -Config::CheckResourceConfigPool(const std::vector& value) { +CheckGpuDevice(const std::string& value) { + const std::regex pat("gpu(\\d+)"); + std::cmatch m; + if (!std::regex_match(value.c_str(), m, pat)) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid gpu device: " + value); + } + + int32_t gpu_index = std::stoi(value.substr(3)); + if (!ValidationUtil::ValidateGpuIndex(gpu_index).ok()) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid gpu device: " + value); + } + return Status::OK(); +} + +Status +Config::CheckResourceConfigSearchResources(const std::vector& value) { if (value.empty()) { - return Status(SERVER_INVALID_ARGUMENT, "Invalid resource config pool"); + return Status(SERVER_INVALID_ARGUMENT, "Empty resource config search_resources"); + } + + for (auto& gpu_device : value) { + if (!CheckGpuDevice(gpu_device).ok()) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid resource config search_resources: " + gpu_device); + } + } + return Status::OK(); +} + +Status +Config::CheckResourceConfigIndexBuildDevice(const std::string& value) { + if (!CheckGpuDevice(value).ok()) { + return Status(SERVER_INVALID_ARGUMENT, "Invalid resource config index_build_device: " + value); } return Status::OK(); } @@ -739,18 +756,6 @@ Config::GetDBConfigInsertBufferSize(int32_t& value) { return Status::OK(); } -Status -Config::GetDBConfigBuildIndexGPU(int32_t& value) { - std::string str = GetConfigStr(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, CONFIG_DB_BUILD_INDEX_GPU_DEFAULT); - Status s = CheckDBConfigBuildIndexGPU(str); - if (!s.ok()) { - return s; - } - - value = std::stoi(str); - return Status::OK(); -} - Status Config::GetDBConfigPreloadTable(std::string& value) { value = GetConfigStr(CONFIG_DB, CONFIG_DB_PRELOAD_TABLE); @@ -880,10 +885,23 @@ Config::GetResourceConfigMode(std::string& value) { } Status -Config::GetResourceConfigPool(std::vector& value) { +Config::GetResourceConfigSearchResources(std::vector& value) { ConfigNode resource_config = GetConfigNode(CONFIG_RESOURCE); - value = resource_config.GetSequence(CONFIG_RESOURCE_POOL); - return CheckResourceConfigPool(value); + value = resource_config.GetSequence(CONFIG_RESOURCE_SEARCH_RESOURCES); + return CheckResourceConfigSearchResources(value); +} + +Status +Config::GetResourceConfigIndexBuildDevice(int32_t& value) { + std::string str = + GetConfigStr(CONFIG_RESOURCE, CONFIG_RESOURCE_INDEX_BUILD_DEVICE, CONFIG_RESOURCE_INDEX_BUILD_DEVICE_DEFAULT); + Status s = CheckResourceConfigIndexBuildDevice(str); + if (!s.ok()) { + return s; + } + + value = std::stoi(str.substr(3)); + return Status::OK(); } /////////////////////////////////////////////////////////////////////////////// @@ -999,17 +1017,6 @@ Config::SetDBConfigInsertBufferSize(const std::string& value) { return Status::OK(); } -Status -Config::SetDBConfigBuildIndexGPU(const std::string& value) { - Status s = CheckDBConfigBuildIndexGPU(value); - if (!s.ok()) { - return s; - } - - SetConfigValueInMem(CONFIG_DB, CONFIG_DB_BUILD_INDEX_GPU, value); - return Status::OK(); -} - /* metric config */ Status Config::SetMetricConfigEnableMonitor(const std::string& value) { @@ -1135,5 +1142,16 @@ Config::SetResourceConfigMode(const std::string& value) { return Status::OK(); } +Status +Config::SetResourceConfigIndexBuildDevice(const std::string& value) { + Status s = CheckResourceConfigIndexBuildDevice(value); + if (!s.ok()) { + return s; + } + + SetConfigValueInMem(CONFIG_DB, CONFIG_RESOURCE_INDEX_BUILD_DEVICE, value); + return Status::OK(); +} + } // namespace server } // namespace milvus diff --git a/core/src/server/Config.h b/core/src/server/Config.h index 9c754256a2..4cab25a1c6 100644 --- a/core/src/server/Config.h +++ b/core/src/server/Config.h @@ -53,8 +53,6 @@ static const char* CONFIG_DB_ARCHIVE_DAYS_THRESHOLD = "archive_days_threshold"; static const char* CONFIG_DB_ARCHIVE_DAYS_THRESHOLD_DEFAULT = "0"; static const char* CONFIG_DB_INSERT_BUFFER_SIZE = "insert_buffer_size"; static const char* CONFIG_DB_INSERT_BUFFER_SIZE_DEFAULT = "4"; -static const char* CONFIG_DB_BUILD_INDEX_GPU = "build_index_gpu"; -static const char* CONFIG_DB_BUILD_INDEX_GPU_DEFAULT = "0"; static const char* CONFIG_DB_PRELOAD_TABLE = "preload_table"; /* cache config */ @@ -62,7 +60,7 @@ static const char* CONFIG_CACHE = "cache_config"; static const char* CONFIG_CACHE_CPU_CACHE_CAPACITY = "cpu_cache_capacity"; static const char* CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT = "16"; static const char* CONFIG_CACHE_GPU_CACHE_CAPACITY = "gpu_cache_capacity"; -static const char* CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT = "0"; +static const char* CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT = "4"; static const char* CONFIG_CACHE_CPU_CACHE_THRESHOLD = "cpu_mem_threshold"; static const char* CONFIG_CACHE_CPU_CACHE_THRESHOLD_DEFAULT = "0.85"; static const char* CONFIG_CACHE_GPU_CACHE_THRESHOLD = "gpu_mem_threshold"; @@ -91,7 +89,9 @@ static const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0"; static const char* CONFIG_RESOURCE = "resource_config"; static const char* CONFIG_RESOURCE_MODE = "mode"; static const char* CONFIG_RESOURCE_MODE_DEFAULT = "simple"; -static const char* CONFIG_RESOURCE_POOL = "resource_pool"; +static const char* CONFIG_RESOURCE_SEARCH_RESOURCES = "search_resources"; +static const char* CONFIG_RESOURCE_INDEX_BUILD_DEVICE = "index_build_device"; +static const char* CONFIG_RESOURCE_INDEX_BUILD_DEVICE_DEFAULT = "gpu0"; class Config { public: @@ -140,8 +140,6 @@ class Config { CheckDBConfigArchiveDaysThreshold(const std::string& value); Status CheckDBConfigInsertBufferSize(const std::string& value); - Status - CheckDBConfigBuildIndexGPU(const std::string& value); /* metric config */ Status @@ -173,7 +171,9 @@ class Config { Status CheckResourceConfigMode(const std::string& value); Status - CheckResourceConfigPool(const std::vector& value); + CheckResourceConfigSearchResources(const std::vector& value); + Status + CheckResourceConfigIndexBuildDevice(const std::string& value); std::string GetConfigStr(const std::string& parent_key, const std::string& child_key, const std::string& default_value = ""); @@ -203,8 +203,6 @@ class Config { Status GetDBConfigInsertBufferSize(int32_t& value); Status - GetDBConfigBuildIndexGPU(int32_t& value); - Status GetDBConfigPreloadTable(std::string& value); /* metric config */ @@ -237,7 +235,9 @@ class Config { Status GetResourceConfigMode(std::string& value); Status - GetResourceConfigPool(std::vector& value); + GetResourceConfigSearchResources(std::vector& value); + Status + GetResourceConfigIndexBuildDevice(int32_t& value); public: /* server config */ @@ -263,8 +263,6 @@ class Config { SetDBConfigArchiveDaysThreshold(const std::string& value); Status SetDBConfigInsertBufferSize(const std::string& value); - Status - SetDBConfigBuildIndexGPU(const std::string& value); /* metric config */ Status @@ -295,6 +293,8 @@ class Config { /* resource config */ Status SetResourceConfigMode(const std::string& value); + Status + SetResourceConfigIndexBuildDevice(const std::string& value); private: std::unordered_map> config_map_; diff --git a/core/src/server/grpc_impl/GrpcRequestTask.cpp b/core/src/server/grpc_impl/GrpcRequestTask.cpp index 02cb24175a..86ff23b3d0 100644 --- a/core/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/core/src/server/grpc_impl/GrpcRequestTask.cpp @@ -113,6 +113,14 @@ ConvertTimeRangeToDBDates(const std::vector<::milvus::grpc::Range>& range_array, return Status::OK(); } + +std::string +TableNotExistMsg(const std::string& table_name) { + return "Table " + table_name + + " not exist. Use milvus.has_table to verify whether the table exists. You also can check if the table name " + "exists."; +} + } // namespace //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -255,7 +263,7 @@ CreateIndexTask::OnExecute() { } if (!has_table) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_)); } auto& grpc_index = index_param_->index(); @@ -348,7 +356,7 @@ DropTableTask::OnExecute() { status = DBWrapper::DB()->DescribeTable(table_info); if (!status.ok()) { if (status.code() == DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_)); } else { return status; } @@ -420,12 +428,14 @@ InsertTask::OnExecute() { return status; } if (insert_param_->row_record_array().empty()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); + return Status(SERVER_INVALID_ROWRECORD_ARRAY, + "The vector array is empty. Make sure you have entered vector records."); } if (!insert_param_->row_id_array().empty()) { if (insert_param_->row_id_array().size() != insert_param_->row_record_array_size()) { - return Status(SERVER_ILLEGAL_VECTOR_ID, "Size of vector ids is not equal to row record array size"); + return Status(SERVER_ILLEGAL_VECTOR_ID, + "The size of vector ID array must be equal to the size of the vector."); } } @@ -435,7 +445,7 @@ InsertTask::OnExecute() { status = DBWrapper::DB()->DescribeTable(table_info); if (!status.ok()) { if (status.code() == DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + insert_param_->table_name() + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(insert_param_->table_name())); } else { return status; } @@ -447,13 +457,14 @@ InsertTask::OnExecute() { // user already provided id before, all insert action require user id if ((table_info.flag_ & engine::meta::FLAG_MASK_HAS_USERID) != 0 && !user_provide_ids) { return Status(SERVER_ILLEGAL_VECTOR_ID, - "Table vector ids are user defined, please provide id for this batch"); + "Table vector IDs are user-defined. Please provide IDs for all vectors of this table."); } // user didn't provided id before, no need to provide user id if ((table_info.flag_ & engine::meta::FLAG_MASK_NO_USERID) != 0 && user_provide_ids) { - return Status(SERVER_ILLEGAL_VECTOR_ID, - "Table vector ids are auto generated, no need to provide id for this batch"); + return Status( + SERVER_ILLEGAL_VECTOR_ID, + "Table vector IDs are auto-generated. All vectors of this table must use auto-generated IDs."); } rc.RecordSection("check validation"); @@ -470,13 +481,13 @@ InsertTask::OnExecute() { // TODO(yk): change to one dimension array or use multiple-thread to copy the data for (size_t i = 0; i < insert_param_->row_record_array_size(); i++) { if (insert_param_->row_record_array(i).vector_data().empty()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array data is empty"); + return Status(SERVER_INVALID_ROWRECORD_ARRAY, + "The vector dimension must be equal to the table dimension."); } uint64_t vec_dim = insert_param_->row_record_array(i).vector_data().size(); if (vec_dim != table_info.dimension_) { ErrorCode error_code = SERVER_INVALID_VECTOR_DIMENSION; - std::string error_msg = "Invalid row record dimension: " + std::to_string(vec_dim) + - " vs. table dimension:" + std::to_string(table_info.dimension_); + std::string error_msg = "The vector dimension must be equal to the table dimension."; return Status(error_code, error_msg); } memcpy(&vec_f[i * table_info.dimension_], insert_param_->row_record_array(i).vector_data().data(), @@ -569,7 +580,7 @@ SearchTask::OnExecute() { status = DBWrapper::DB()->DescribeTable(table_info); if (!status.ok()) { if (status.code() == DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_)); } else { return status; } @@ -587,7 +598,8 @@ SearchTask::OnExecute() { } if (search_param_->query_record_array().empty()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty"); + return Status(SERVER_INVALID_ROWRECORD_ARRAY, + "The vector array is empty. Make sure you have entered vector records."); } // step 4: check date range, and convert to db dates @@ -609,13 +621,13 @@ SearchTask::OnExecute() { std::vector vec_f(record_array_size * table_info.dimension_, 0); for (size_t i = 0; i < record_array_size; i++) { if (search_param_->query_record_array(i).vector_data().empty()) { - return Status(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array data is empty"); + return Status(SERVER_INVALID_ROWRECORD_ARRAY, + "The vector dimension must be equal to the table dimension."); } uint64_t query_vec_dim = search_param_->query_record_array(i).vector_data().size(); if (query_vec_dim != table_info.dimension_) { ErrorCode error_code = SERVER_INVALID_VECTOR_DIMENSION; - std::string error_msg = "Invalid row record dimension: " + std::to_string(query_vec_dim) + - " vs. table dimension:" + std::to_string(table_info.dimension_); + std::string error_msg = "The vector dimension must be equal to the table dimension."; return Status(error_code, error_msg); } @@ -707,7 +719,7 @@ CountTableTask::OnExecute() { status = DBWrapper::DB()->GetTableRowCount(table_name_, row_count); if (!status.ok()) { if (status.code(), DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_)); } else { return status; } @@ -779,7 +791,7 @@ DeleteByRangeTask::OnExecute() { status = DBWrapper::DB()->DescribeTable(table_info); if (!status.ok()) { if (status.code(), DB_NOT_FOUND) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name)); } else { return status; } @@ -917,7 +929,7 @@ DropIndexTask::OnExecute() { } if (!has_table) { - return Status(SERVER_TABLE_NOT_EXIST, "Table " + table_name_ + " not exists"); + return Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(table_name_)); } // step 2: check table existence diff --git a/core/src/utils/Status.cpp b/core/src/utils/Status.cpp index ad97717cf7..9ac0279810 100644 --- a/core/src/utils/Status.cpp +++ b/core/src/utils/Status.cpp @@ -88,7 +88,7 @@ Status::MoveFrom(Status& s) { std::string Status::message() const { if (state_ == nullptr) { - return ""; + return "OK"; } std::string msg; diff --git a/core/src/utils/ValidationUtil.cpp b/core/src/utils/ValidationUtil.cpp index b982a31f5e..68088d6c93 100644 --- a/core/src/utils/ValidationUtil.cpp +++ b/core/src/utils/ValidationUtil.cpp @@ -37,14 +37,15 @@ Status ValidationUtil::ValidateTableName(const std::string& table_name) { // Table name shouldn't be empty. if (table_name.empty()) { - std::string msg = "Empty table name"; + std::string msg = "Table name should not be empty."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TABLE_NAME, msg); } + std::string invalid_msg = "Invalid table name: " + table_name + ". "; // Table name size shouldn't exceed 16384. if (table_name.size() > TABLE_NAME_SIZE_LIMIT) { - std::string msg = "Table name size exceed the limitation"; + std::string msg = invalid_msg + "The length of a table name must be less than 255 characters."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TABLE_NAME, msg); } @@ -52,7 +53,7 @@ ValidationUtil::ValidateTableName(const std::string& table_name) { // Table name first character should be underscore or character. char first_char = table_name[0]; if (first_char != '_' && std::isalpha(first_char) == 0) { - std::string msg = "Table name first character isn't underscore or character"; + std::string msg = invalid_msg + "The first character of a table name must be an underscore or letter."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TABLE_NAME, msg); } @@ -61,7 +62,7 @@ ValidationUtil::ValidateTableName(const std::string& table_name) { for (int64_t i = 1; i < table_name_size; ++i) { char name_char = table_name[i]; if (name_char != '_' && std::isalnum(name_char) == 0) { - std::string msg = "Table name character isn't underscore or alphanumber"; + std::string msg = invalid_msg + "Table name can only contain numbers, letters, and underscores."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TABLE_NAME, msg); } @@ -72,12 +73,9 @@ ValidationUtil::ValidateTableName(const std::string& table_name) { Status ValidationUtil::ValidateTableDimension(int64_t dimension) { - if (dimension <= 0) { - std::string msg = "Dimension value should be greater than 0"; - SERVER_LOG_ERROR << msg; - return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); - } else if (dimension > TABLE_DIMENSION_LIMIT) { - std::string msg = "Table dimension excceed the limitation: " + std::to_string(TABLE_DIMENSION_LIMIT); + if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) { + std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " + + "The table dimension must be within the range of 1 ~ 16384."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); } else { @@ -89,18 +87,29 @@ Status ValidationUtil::ValidateTableIndexType(int32_t index_type) { int engine_type = static_cast(engine::EngineType(index_type)); if (engine_type <= 0 || engine_type > static_cast(engine::EngineType::MAX_VALUE)) { - std::string msg = "Invalid index type: " + std::to_string(index_type); + std::string msg = "Invalid index type: " + std::to_string(index_type) + ". " + + "Make sure the index type is in IndexType list."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_TYPE, msg); } +#ifndef CUSTOMIZATION + // special case, hybird index only available in customize faiss library + if (engine_type == static_cast(engine::EngineType::FAISS_IVFSQ8H)) { + std::string msg = "Unsupported index type: " + std::to_string(index_type); + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_INDEX_TYPE, msg); + } +#endif + return Status::OK(); } Status ValidationUtil::ValidateTableIndexNlist(int32_t nlist) { if (nlist <= 0) { - std::string msg = "nlist value should be greater than 0"; + std::string msg = + "Invalid index nlist: " + std::to_string(nlist) + ". " + "The index nlist must be greater than 0."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_NLIST, msg); } @@ -111,7 +120,9 @@ ValidationUtil::ValidateTableIndexNlist(int32_t nlist) { Status ValidationUtil::ValidateTableIndexFileSize(int64_t index_file_size) { if (index_file_size <= 0 || index_file_size > INDEX_FILE_SIZE_LIMIT) { - std::string msg = "Invalid index file size: " + std::to_string(index_file_size); + std::string msg = "Invalid index file size: " + std::to_string(index_file_size) + ". " + + "The index file size must be within the range of 1 ~ " + + std::to_string(INDEX_FILE_SIZE_LIMIT) + "."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_FILE_SIZE, msg); } @@ -123,7 +134,8 @@ Status ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { if (metric_type != static_cast(engine::MetricType::L2) && metric_type != static_cast(engine::MetricType::IP)) { - std::string msg = "Invalid metric type: " + std::to_string(metric_type); + std::string msg = "Invalid index metric type: " + std::to_string(metric_type) + ". " + + "Make sure the metric type is either MetricType.L2 or MetricType.IP."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_INDEX_METRIC_TYPE, msg); } @@ -133,7 +145,8 @@ ValidationUtil::ValidateTableIndexMetricType(int32_t metric_type) { Status ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchema& table_schema) { if (top_k <= 0 || top_k > 2048) { - std::string msg = "Invalid top k value: " + std::to_string(top_k) + ", rational range [1, 2048]"; + std::string msg = + "Invalid topk: " + std::to_string(top_k) + ". " + "The topk must be within the range of 1 ~ 2048."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_TOPK, msg); } @@ -144,8 +157,8 @@ ValidationUtil::ValidateSearchTopk(int64_t top_k, const engine::meta::TableSchem Status ValidationUtil::ValidateSearchNprobe(int64_t nprobe, const engine::meta::TableSchema& table_schema) { if (nprobe <= 0 || nprobe > table_schema.nlist_) { - std::string msg = "Invalid nprobe value: " + std::to_string(nprobe) + ", rational range [1, " + - std::to_string(table_schema.nlist_) + "]"; + std::string msg = "Invalid nprobe: " + std::to_string(nprobe) + ". " + + "The nprobe must be within the range of 1 ~ index nlist."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_NPROBE, msg); } diff --git a/core/src/wrapper/KnowhereResource.cpp b/core/src/wrapper/KnowhereResource.cpp index d291bb9299..650ae727c1 100644 --- a/core/src/wrapper/KnowhereResource.cpp +++ b/core/src/wrapper/KnowhereResource.cpp @@ -45,7 +45,7 @@ KnowhereResource::Initialize() { server::Config& config = server::Config::GetInstance(); int32_t build_index_gpu; - s = config.GetDBConfigBuildIndexGPU(build_index_gpu); + s = config.GetResourceConfigIndexBuildDevice(build_index_gpu); if (!s.ok()) return s; @@ -53,7 +53,7 @@ KnowhereResource::Initialize() { // get search gpu resource std::vector pool; - s = config.GetResourceConfigPool(pool); + s = config.GetResourceConfigSearchResources(pool); if (!s.ok()) return s; diff --git a/core/src/wrapper/VecIndex.h b/core/src/wrapper/VecIndex.h index 05da9ccc03..36104b2107 100644 --- a/core/src/wrapper/VecIndex.h +++ b/core/src/wrapper/VecIndex.h @@ -25,6 +25,7 @@ #include "knowhere/common/BinarySet.h" #include "knowhere/common/Config.h" #include "knowhere/index/vector_index/Quantizer.h" +#include "utils/Log.h" #include "utils/Status.h" namespace milvus { @@ -101,6 +102,7 @@ class VecIndex : public cache::DataObj { //////////////// virtual knowhere::QuantizerPtr LoadQuantizer(const Config& conf) { + ENGINE_LOG_ERROR << "LoadQuantizer virtual funciton called."; return nullptr; } diff --git a/core/ubuntu_build_deps.sh b/core/ubuntu_build_deps.sh index 06f05fa49f..ed9eb9dee5 100755 --- a/core/ubuntu_build_deps.sh +++ b/core/ubuntu_build_deps.sh @@ -1,5 +1,5 @@ #!/bin/bash -sudo apt-get install -y gfortran libmysqlclient-dev mysql-client libcurl4-openssl-dev +sudo apt-get install -y gfortran libmysqlclient-dev mysql-client libcurl4-openssl-dev libboost-system-dev libboost-filesystem-dev libboost-serialization-dev sudo ln -s /usr/lib/x86_64-linux-gnu/libmysqlclient.so /usr/lib/x86_64-linux-gnu/libmysqlclient_r.so diff --git a/core/unittest/CMakeLists.txt b/core/unittest/CMakeLists.txt index fe40e76afa..258fd76a8e 100644 --- a/core/unittest/CMakeLists.txt +++ b/core/unittest/CMakeLists.txt @@ -92,8 +92,8 @@ set(common_files set(unittest_libs sqlite - boost_system_static - boost_filesystem_static + libboost_system.a + libboost_filesystem.a lz4 mysqlpp yaml-cpp diff --git a/core/unittest/db/test_db.cpp b/core/unittest/db/test_db.cpp index 9e2730a8dd..5e6ecc2ac4 100644 --- a/core/unittest/db/test_db.cpp +++ b/core/unittest/db/test_db.cpp @@ -308,6 +308,12 @@ TEST_F(DBTest, SEARCH_TEST) { ASSERT_TRUE(stat.ok()); } + { + milvus::engine::QueryResults large_nq_results; + stat = db_->Query(TABLE_NAME, k, 200, 10, xq.data(), large_nq_results); + ASSERT_TRUE(stat.ok()); + } + {//search by specify index file milvus::engine::meta::DatesT dates; std::vector file_ids = {"1", "2", "3", "4", "5", "6"}; @@ -315,6 +321,8 @@ TEST_F(DBTest, SEARCH_TEST) { stat = db_->Query(TABLE_NAME, file_ids, k, nq, 10, xq.data(), dates, results); ASSERT_TRUE(stat.ok()); } + + #endif } @@ -412,6 +420,16 @@ TEST_F(DBTest, INDEX_TEST) { stat = db_->CreateIndex(table_info.table_id_, index); ASSERT_TRUE(stat.ok()); + index.engine_type_ = (int) milvus::engine::EngineType::FAISS_IVFFLAT; + stat = db_->CreateIndex(table_info.table_id_, index); + ASSERT_TRUE(stat.ok()); + +#ifdef CUSTOMIZATION + index.engine_type_ = (int)milvus::engine::EngineType::FAISS_IVFSQ8H; + stat = db_->CreateIndex(table_info.table_id_, index); + ASSERT_TRUE(stat.ok()); +#endif + milvus::engine::TableIndex index_out; stat = db_->DescribeIndex(table_info.table_id_, index_out); ASSERT_TRUE(stat.ok()); diff --git a/core/unittest/db/test_engine.cpp b/core/unittest/db/test_engine.cpp index 137612bcab..147de5399c 100644 --- a/core/unittest/db/test_engine.cpp +++ b/core/unittest/db/test_engine.cpp @@ -108,15 +108,16 @@ TEST_F(EngineTest, ENGINE_IMPL_TEST) { ASSERT_EQ(engine_ptr->Dimension(), dimension); ASSERT_EQ(engine_ptr->Count(), ids.size()); -// status = engine_ptr->CopyToGpu(0); -// //ASSERT_TRUE(status.ok()); -// -// auto new_engine = engine_ptr->Clone(); -// ASSERT_EQ(new_engine->Dimension(), dimension); -// ASSERT_EQ(new_engine->Count(), ids.size()); -// status = new_engine->CopyToCpu(); -// //ASSERT_TRUE(status.ok()); -// -// auto engine_build = new_engine->BuildIndex("/tmp/milvus_index_2", engine::EngineType::FAISS_IVFSQ8); -// //ASSERT_TRUE(status.ok()); + status = engine_ptr->CopyToGpu(0, true); + status = engine_ptr->CopyToGpu(0, false); + //ASSERT_TRUE(status.ok()); + + auto new_engine = engine_ptr->Clone(); + ASSERT_EQ(new_engine->Dimension(), dimension); + ASSERT_EQ(new_engine->Count(), ids.size()); + status = new_engine->CopyToCpu(); + //ASSERT_TRUE(status.ok()); + + auto engine_build = new_engine->BuildIndex("/tmp/milvus_index_2", milvus::engine::EngineType::FAISS_IVFSQ8); + //ASSERT_TRUE(status.ok()); } diff --git a/core/unittest/db/test_search.cpp b/core/unittest/db/test_search.cpp index dc393b7a26..b8cf08b3e2 100644 --- a/core/unittest/db/test_search.cpp +++ b/core/unittest/db/test_search.cpp @@ -30,6 +30,7 @@ namespace ms = milvus::scheduler; void BuildResult(std::vector& output_ids, std::vector& output_distance, + uint64_t input_k, uint64_t topk, uint64_t nq, bool ascending) { @@ -39,9 +40,15 @@ BuildResult(std::vector& output_ids, output_distance.resize(nq * topk); for (uint64_t i = 0; i < nq; i++) { - for (uint64_t j = 0; j < topk; j++) { + //insert valid items + for (uint64_t j = 0; j < input_k; j++) { output_ids[i * topk + j] = (int64_t)(drand48() * 100000); - output_distance[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); + output_distance[i * topk + j] = ascending ? (j + drand48()) : ((input_k - j) + drand48()); + } + //insert invalid items + for (uint64_t j = input_k; j < topk; j++) { + output_ids[i * topk + j] = -1; + output_distance[i * topk + j] = -1.0; } } } @@ -83,23 +90,32 @@ CheckTopkResult(const std::vector& input_ids_1, ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); - uint64_t input_k1 = input_ids_1.size() / nq; - uint64_t input_k2 = input_ids_2.size() / nq; - for (int64_t i = 0; i < nq; i++) { std::vector - src_vec(input_distance_1.begin() + i * input_k1, input_distance_1.begin() + (i + 1) * input_k1); + src_vec(input_distance_1.begin() + i * topk, input_distance_1.begin() + (i + 1) * topk); src_vec.insert(src_vec.end(), - input_distance_2.begin() + i * input_k2, - input_distance_2.begin() + (i + 1) * input_k2); + input_distance_2.begin() + i * topk, + input_distance_2.begin() + (i + 1) * topk); if (ascending) { std::sort(src_vec.begin(), src_vec.end()); } else { std::sort(src_vec.begin(), src_vec.end(), std::greater()); } - uint64_t n = std::min(topk, input_k1 + input_k2); + //erase invalid items + std::vector::iterator iter; + for (iter = src_vec.begin(); iter != src_vec.end();) { + if (*iter < 0.0) + iter = src_vec.erase(iter); + else + ++iter; + } + + uint64_t n = std::min(topk, result[i].size()); for (uint64_t j = 0; j < n; j++) { + if (result[i][j].first < 0) { + continue; + } if (src_vec[j] != result[i][j].second) { std::cout << src_vec[j] << " " << result[i][j].second << std::endl; } @@ -110,12 +126,13 @@ CheckTopkResult(const std::vector& input_ids_1, } // namespace -void MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { +void +MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { std::vector ids1, ids2; std::vector dist1, dist2; ms::ResultSet result; - BuildResult(ids1, dist1, topk_1, nq, ascending); - BuildResult(ids2, dist2, topk_2, nq, ascending); + BuildResult(ids1, dist1, topk_1, topk, nq, ascending); + BuildResult(ids2, dist2, topk_2, topk, nq, ascending); ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result); ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result); @@ -134,70 +151,72 @@ TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, false); /* test3, id1/dist1 small topk */ - MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, true); - MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, false); + MergeTopkToResultSetTest(TOP_K / 2, TOP_K, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K / 2, TOP_K, NQ, TOP_K, false); /* test4, id1/dist1 small topk, id2/dist2 small topk */ - MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); - MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); + MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K / 2, TOP_K / 3, NQ, TOP_K, false); } -void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { - std::vector ids1, ids2; - std::vector dist1, dist2; - ms::ResultSet result; - BuildResult(ids1, dist1, topk_1, nq, ascending); - BuildResult(ids2, dist2, topk_2, nq, ascending); - uint64_t result_topk = std::min(topk, topk_1 + topk_2); - ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending); - if (ids1.size() != result_topk * nq) { - std::cout << ids1.size() << " " << result_topk * nq << std::endl; - } - ASSERT_TRUE(ids1.size() == result_topk * nq); - ASSERT_TRUE(dist1.size() == result_topk * nq); - for (uint64_t i = 0; i < nq; i++) { - for (uint64_t k = 1; k < result_topk; k++) { - if (ascending) { - if (dist1[i * result_topk + k] < dist1[i * result_topk + k - 1]) { - std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl; - } - ASSERT_TRUE(dist1[i * result_topk + k] >= dist1[i * result_topk + k - 1]); - } else { - if (dist1[i * result_topk + k] > dist1[i * result_topk + k - 1]) { - std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl; - } - ASSERT_TRUE(dist1[i * result_topk + k] <= dist1[i * result_topk + k - 1]); - } - } - } -} +//void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { +// std::vector ids1, ids2; +// std::vector dist1, dist2; +// ms::ResultSet result; +// BuildResult(ids1, dist1, topk_1, topk, nq, ascending); +// BuildResult(ids2, dist2, topk_2, topk, nq, ascending); +// uint64_t result_topk = std::min(topk, topk_1 + topk_2); +// ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending); +// if (ids1.size() != result_topk * nq) { +// std::cout << ids1.size() << " " << result_topk * nq << std::endl; +// } +// ASSERT_TRUE(ids1.size() == result_topk * nq); +// ASSERT_TRUE(dist1.size() == result_topk * nq); +// for (uint64_t i = 0; i < nq; i++) { +// for (uint64_t k = 1; k < result_topk; k++) { +// float f0 = dist1[i * topk + k - 1]; +// float f1 = dist1[i * topk + k]; +// if (ascending) { +// if (f1 < f0) { +// std::cout << f0 << " " << f1 << std::endl; +// } +// ASSERT_TRUE(f1 >= f0); +// } else { +// if (f1 > f0) { +// std::cout << f0 << " " << f1 << std::endl; +// } +// ASSERT_TRUE(f1 <= f0); +// } +// } +// } +//} -TEST(DBSearchTest, MERGE_ARRAY_TEST) { - uint64_t NQ = 15; - uint64_t TOP_K = 64; +//TEST(DBSearchTest, MERGE_ARRAY_TEST) { +// uint64_t NQ = 15; +// uint64_t TOP_K = 64; +// +// /* test1, id1/dist1 valid, id2/dist2 empty */ +// MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, false); +// MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, true); +// MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, false); - /* test1, id1/dist1 valid, id2/dist2 empty */ - MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, false); - MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, true); - MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, false); - - /* test2, id1/dist1 valid, id2/dist2 valid */ - MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, false); - - /* test3, id1/dist1 small topk */ - MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, false); - MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, false); - - /* test4, id1/dist1 small topk, id2/dist2 small topk */ - MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); - MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, true); - MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, false); -} +// /* test2, id1/dist1 valid, id2/dist2 valid */ +// MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, false); +// +// /* test3, id1/dist1 small topk */ +// MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, false); +// MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, false); +// +// /* test4, id1/dist1 small topk, id2/dist2 small topk */ +// MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); +// MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, true); +// MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, false); +//} TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t index_file_num = 478; /* sift1B dataset, index files num */ @@ -206,8 +225,8 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { std::vector thread_vec = {4, 8}; std::vector nq_vec = {1, 10, 100}; std::vector topk_vec = {1, 4, 16, 64}; - int32_t NQ = nq_vec[nq_vec.size()-1]; - int32_t TOPK = topk_vec[topk_vec.size()-1]; + int32_t NQ = nq_vec[nq_vec.size() - 1]; + int32_t TOPK = topk_vec[topk_vec.size() - 1]; std::vector> id_vec; std::vector> dist_vec; @@ -217,7 +236,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { /* generate testing data */ for (i = 0; i < index_file_num; i++) { - BuildResult(input_ids, input_distance, TOPK, NQ, ascending); + BuildResult(input_ids, input_distance, TOPK, TOPK, NQ, ascending); id_vec.push_back(input_ids); dist_vec.push_back(input_distance); } @@ -237,7 +256,7 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { } std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " + - std::to_string(nq) + " " + std::to_string(top_k); + std::to_string(nq) + " " + std::to_string(top_k); milvus::TimeRecorder rc1(str1); /////////////////////////////////////////////////////////////////////////////////////// @@ -255,114 +274,114 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { rc1.RecordSection("reduce done"); - /////////////////////////////////////////////////////////////////////////////////////// - /* method-2 */ - std::vector> id_vec_2(index_file_num); - std::vector> dist_vec_2(index_file_num); - std::vector k_vec_2(index_file_num); - for (i = 0; i < index_file_num; i++) { - CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); - k_vec_2[i] = top_k; - } - - std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " + - std::to_string(nq) + " " + std::to_string(top_k); - milvus::TimeRecorder rc2(str2); - - for (step = 1; step < index_file_num; step *= 2) { - for (i = 0; i + step < index_file_num; i += step * 2) { - ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], - id_vec_2[i + step], dist_vec_2[i + step], k_vec_2[i + step], - nq, top_k, ascending); - } - } - ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], - dist_vec_2[0], - k_vec_2[0], - nq, - top_k, - ascending, - final_result_2); - ASSERT_EQ(final_result_2.size(), nq); - - rc2.RecordSection("reduce done"); - - for (i = 0; i < nq; i++) { - ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); - for (k = 0; k < final_result[i].size(); k++) { - if (final_result[i][k].first != final_result_2[i][k].first) { - std::cout << i << " " << k << std::endl; - } - ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); - ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - /* method-3 parallel */ - std::vector> id_vec_3(index_file_num); - std::vector> dist_vec_3(index_file_num); - std::vector k_vec_3(index_file_num); - for (i = 0; i < index_file_num; i++) { - CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); - k_vec_3[i] = top_k; - } - - std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " + - std::to_string(nq) + " " + std::to_string(top_k); - milvus::TimeRecorder rc3(str3); - - for (step = 1; step < index_file_num; step *= 2) { - for (i = 0; i + step < index_file_num; i += step * 2) { - threads_list.push_back( - threadPool.enqueue(ms::XSearchTask::MergeTopkArray, - std::ref(id_vec_3[i]), - std::ref(dist_vec_3[i]), - std::ref(k_vec_3[i]), - std::ref(id_vec_3[i + step]), - std::ref(dist_vec_3[i + step]), - std::ref(k_vec_3[i + step]), - nq, - top_k, - ascending)); - } - - while (threads_list.size() > 0) { - int nready = 0; - for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { - auto &p = *it; - std::chrono::milliseconds span(0); - if (p.wait_for(span) == std::future_status::ready) { - threads_list.erase(it++); - ++nready; - } else { - ++it; - } - } - - if (nready == 0) { - std::this_thread::yield(); - } - } - } - ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], - dist_vec_3[0], - k_vec_3[0], - nq, - top_k, - ascending, - final_result_3); - ASSERT_EQ(final_result_3.size(), nq); - - rc3.RecordSection("reduce done"); - - for (i = 0; i < nq; i++) { - ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); - for (k = 0; k < final_result[i].size(); k++) { - ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); - ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); - } - } +// /////////////////////////////////////////////////////////////////////////////////////// +// /* method-2 */ +// std::vector> id_vec_2(index_file_num); +// std::vector> dist_vec_2(index_file_num); +// std::vector k_vec_2(index_file_num); +// for (i = 0; i < index_file_num; i++) { +// CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); +// k_vec_2[i] = top_k; +// } +// +// std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " + +// std::to_string(nq) + " " + std::to_string(top_k); +// milvus::TimeRecorder rc2(str2); +// +// for (step = 1; step < index_file_num; step *= 2) { +// for (i = 0; i + step < index_file_num; i += step * 2) { +// ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], +// id_vec_2[i + step], dist_vec_2[i + step], k_vec_2[i + step], +// nq, top_k, ascending); +// } +// } +// ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], +// dist_vec_2[0], +// k_vec_2[0], +// nq, +// top_k, +// ascending, +// final_result_2); +// ASSERT_EQ(final_result_2.size(), nq); +// +// rc2.RecordSection("reduce done"); +// +// for (i = 0; i < nq; i++) { +// ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); +// for (k = 0; k < final_result[i].size(); k++) { +// if (final_result[i][k].first != final_result_2[i][k].first) { +// std::cout << i << " " << k << std::endl; +// } +// ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); +// ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); +// } +// } +// +// /////////////////////////////////////////////////////////////////////////////////////// +// /* method-3 parallel */ +// std::vector> id_vec_3(index_file_num); +// std::vector> dist_vec_3(index_file_num); +// std::vector k_vec_3(index_file_num); +// for (i = 0; i < index_file_num; i++) { +// CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); +// k_vec_3[i] = top_k; +// } +// +// std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " + +// std::to_string(nq) + " " + std::to_string(top_k); +// milvus::TimeRecorder rc3(str3); +// +// for (step = 1; step < index_file_num; step *= 2) { +// for (i = 0; i + step < index_file_num; i += step * 2) { +// threads_list.push_back( +// threadPool.enqueue(ms::XSearchTask::MergeTopkArray, +// std::ref(id_vec_3[i]), +// std::ref(dist_vec_3[i]), +// std::ref(k_vec_3[i]), +// std::ref(id_vec_3[i + step]), +// std::ref(dist_vec_3[i + step]), +// std::ref(k_vec_3[i + step]), +// nq, +// top_k, +// ascending)); +// } +// +// while (threads_list.size() > 0) { +// int nready = 0; +// for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { +// auto &p = *it; +// std::chrono::milliseconds span(0); +// if (p.wait_for(span) == std::future_status::ready) { +// threads_list.erase(it++); +// ++nready; +// } else { +// ++it; +// } +// } +// +// if (nready == 0) { +// std::this_thread::yield(); +// } +// } +// } +// ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], +// dist_vec_3[0], +// k_vec_3[0], +// nq, +// top_k, +// ascending, +// final_result_3); +// ASSERT_EQ(final_result_3.size(), nq); +// +// rc3.RecordSection("reduce done"); +// +// for (i = 0; i < nq; i++) { +// ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); +// for (k = 0; k < final_result[i].size(); k++) { +// ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); +// ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); +// } +// } } } } diff --git a/core/unittest/scheduler/task_test.cpp b/core/unittest/scheduler/task_test.cpp index 07e85c723c..8ea39edef9 100644 --- a/core/unittest/scheduler/task_test.cpp +++ b/core/unittest/scheduler/task_test.cpp @@ -17,6 +17,7 @@ #include "scheduler/task/SearchTask.h" +#include "scheduler/task/BuildIndexTask.h" #include @@ -26,6 +27,11 @@ namespace scheduler { TEST(TaskTest, INVALID_INDEX) { auto search_task = std::make_shared(nullptr, nullptr); search_task->Load(LoadType::TEST, 10); + + auto build_task = std::make_shared(nullptr, nullptr); + build_task->Load(LoadType::TEST, 10); + + build_task->Execute(); } } // namespace scheduler diff --git a/core/unittest/server/CMakeLists.txt b/core/unittest/server/CMakeLists.txt index 4420e2a1a7..180dcfa6d5 100644 --- a/core/unittest/server/CMakeLists.txt +++ b/core/unittest/server/CMakeLists.txt @@ -67,11 +67,3 @@ target_link_libraries(test_server ) install(TARGETS test_server DESTINATION unittest) - -configure_file(appendix/server_config.yaml - "${CMAKE_CURRENT_BINARY_DIR}/milvus/conf/server_config.yaml" - COPYONLY) - -configure_file(appendix/log_config.conf - "${CMAKE_CURRENT_BINARY_DIR}/milvus/conf/log_config.conf" - COPYONLY) diff --git a/core/unittest/server/test_config.cpp b/core/unittest/server/test_config.cpp index a6c6be64c4..f3adf8a2c3 100644 --- a/core/unittest/server/test_config.cpp +++ b/core/unittest/server/test_config.cpp @@ -22,28 +22,27 @@ #include "utils/CommonUtil.h" #include "utils/ValidationUtil.h" #include "server/Config.h" +#include "server/utils.h" namespace { -static const char *CONFIG_FILE_PATH = "./milvus/conf/server_config.yaml"; -static const char *LOG_FILE_PATH = "./milvus/conf/log_config.conf"; - static constexpr uint64_t KB = 1024; static constexpr uint64_t MB = KB * 1024; static constexpr uint64_t GB = MB * 1024; } // namespace -TEST(ConfigTest, CONFIG_TEST) { +TEST_F(ConfigTest, CONFIG_TEST) { milvus::server::ConfigMgr *config_mgr = milvus::server::YamlConfigMgr::GetInstance(); milvus::Status s = config_mgr->LoadConfigFile(""); ASSERT_FALSE(s.ok()); - s = config_mgr->LoadConfigFile(LOG_FILE_PATH); + std::string config_path(CONFIG_PATH); + s = config_mgr->LoadConfigFile(config_path+ INVALID_CONFIG_FILE); ASSERT_FALSE(s.ok()); - s = config_mgr->LoadConfigFile(CONFIG_FILE_PATH); + s = config_mgr->LoadConfigFile(config_path + VALID_CONFIG_FILE); ASSERT_TRUE(s.ok()); config_mgr->Print(); @@ -99,9 +98,10 @@ TEST(ConfigTest, CONFIG_TEST) { ASSERT_TRUE(seqs.empty()); } -TEST(ConfigTest, SERVER_CONFIG_TEST) { +TEST_F(ConfigTest, SERVER_CONFIG_TEST) { + std::string config_path(CONFIG_PATH); milvus::server::Config &config = milvus::server::Config::GetInstance(); - milvus::Status s = config.LoadConfigFile(CONFIG_FILE_PATH); + milvus::Status s = config.LoadConfigFile(config_path + VALID_CONFIG_FILE); ASSERT_TRUE(s.ok()); s = config.ValidateConfig(); diff --git a/core/unittest/server/test_rpc.cpp b/core/unittest/server/test_rpc.cpp index 7d3e0b5511..09a56699ea 100644 --- a/core/unittest/server/test_rpc.cpp +++ b/core/unittest/server/test_rpc.cpp @@ -405,12 +405,12 @@ TEST_F(RpcHandlerTest, DELETE_BY_RANGE_TEST) { handler->DeleteByRange(&context, &request, &status); request.set_table_name(TABLE_NAME); - request.mutable_range()->set_start_value(CurrentTmDate(-2)); - request.mutable_range()->set_end_value(CurrentTmDate(-3)); + request.mutable_range()->set_start_value(CurrentTmDate(-3)); + request.mutable_range()->set_end_value(CurrentTmDate(-2)); ::grpc::Status grpc_status = handler->DeleteByRange(&context, &request, &status); int error_code = status.error_code(); - ASSERT_EQ(error_code, ::milvus::grpc::ErrorCode::SUCCESS); +// ASSERT_EQ(error_code, ::milvus::grpc::ErrorCode::SUCCESS); request.mutable_range()->set_start_value("test6"); grpc_status = handler->DeleteByRange(&context, &request, &status); diff --git a/core/unittest/server/util_test.cpp b/core/unittest/server/test_util.cpp similarity index 99% rename from core/unittest/server/util_test.cpp rename to core/unittest/server/test_util.cpp index 395839a8c0..24482740bc 100644 --- a/core/unittest/server/util_test.cpp +++ b/core/unittest/server/test_util.cpp @@ -275,6 +275,11 @@ TEST(ValidationUtilTest, VALIDATE_INDEX_TEST) { ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableIndexType((int)milvus::engine::EngineType::INVALID).code(), milvus::SERVER_INVALID_INDEX_TYPE); for (int i = 1; i <= (int)milvus::engine::EngineType::MAX_VALUE; i++) { +#ifndef CUSTOMIZATION + if (i == (int)milvus::engine::EngineType::FAISS_IVFSQ8H) { + continue; + } +#endif ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableIndexType(i).code(), milvus::SERVER_SUCCESS); } ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableIndexType( diff --git a/core/unittest/server/utils.cpp b/core/unittest/server/utils.cpp new file mode 100644 index 0000000000..4c03da6ad9 --- /dev/null +++ b/core/unittest/server/utils.cpp @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "server/utils.h" +#include "utils/CommonUtil.h" + +#include +#include +#include +#include + +namespace { + +static const char + * VALID_CONFIG_STR = "# Default values are used when you make no changes to the following parameters.\n" + "\n" + "server_config:\n" + " address: 0.0.0.0 # milvus server ip address (IPv4)\n" + " port: 19530 # port range: 1025 ~ 65534\n" + " deploy_mode: single \n" + " time_zone: UTC+8\n" + "\n" + "db_config:\n" + " primary_path: /tmp/milvus # path used to store data and meta\n" + " secondary_path: # path used to store data only, split by semicolon\n" + "\n" + " backend_url: sqlite://:@:/ \n" + "\n" + " insert_buffer_size: 4 # GB, maximum insert buffer size allowed\n" + " preload_table: \n" + "\n" + "metric_config:\n" + " enable_monitor: false # enable monitoring or not\n" + " collector: prometheus # prometheus\n" + " prometheus_config:\n" + " port: 8080 # port prometheus uses to fetch metrics\n" + "\n" + "cache_config:\n" + " cpu_cache_capacity: 16 # GB, CPU memory used for cache\n" + " cpu_cache_threshold: 0.85 \n" + " gpu_cache_capacity: 4 # GB, GPU memory used for cache\n" + " gpu_cache_threshold: 0.85 \n" + " cache_insert_data: false # whether to load inserted data into cache\n" + "\n" + "engine_config:\n" + " use_blas_threshold: 20 \n" + "\n" + "resource_config:\n" + " search_resources: \n" + " - gpu0\n" + " index_build_device: gpu0 # GPU used for building index"; + +static const char* INVALID_CONFIG_STR = "*INVALID*"; + +void +WriteToFile(const std::string& file_path, const char* content) { + std::fstream fs(file_path.c_str(), std::ios_base::out); + + //write data to file + fs << content; + fs.close(); +} + +} // namespace + + +void +ConfigTest::SetUp() { + std::string config_path(CONFIG_PATH); + milvus::server::CommonUtil::CreateDirectory(config_path); + WriteToFile(config_path + VALID_CONFIG_FILE, VALID_CONFIG_STR); + WriteToFile(config_path+ INVALID_CONFIG_FILE, INVALID_CONFIG_STR); +} + +void +ConfigTest::TearDown() { + std::string config_path(CONFIG_PATH); + milvus::server::CommonUtil::DeleteDirectory(config_path); +} diff --git a/core/unittest/server/utils.h b/core/unittest/server/utils.h new file mode 100644 index 0000000000..2efc5e4120 --- /dev/null +++ b/core/unittest/server/utils.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + +#pragma once + +#include +#include + +static const char *CONFIG_PATH = "/tmp/milvus_test/"; +static const char *VALID_CONFIG_FILE = "valid_config.yaml"; +static const char *INVALID_CONFIG_FILE = "invalid_config.conf"; + +class ConfigTest : public ::testing::Test { + protected: + void SetUp() override; + void TearDown() override; +}; diff --git a/core/unittest/wrapper/CMakeLists.txt b/core/unittest/wrapper/CMakeLists.txt index 8eae47b3d4..156d89b241 100644 --- a/core/unittest/wrapper/CMakeLists.txt +++ b/core/unittest/wrapper/CMakeLists.txt @@ -33,10 +33,19 @@ set(util_files add_executable(test_wrapper ${test_files} ${wrapper_files} - ${util_files}) + ${util_files} + ${common_files}) target_link_libraries(test_wrapper knowhere ${unittest_libs}) -install(TARGETS test_wrapper DESTINATION unittest) \ No newline at end of file +install(TARGETS test_wrapper DESTINATION unittest) + +configure_file(appendix/server_config.yaml + "${CMAKE_CURRENT_BINARY_DIR}/milvus/conf/server_config.yaml" + COPYONLY) + +configure_file(appendix/log_config.conf + "${CMAKE_CURRENT_BINARY_DIR}/milvus/conf/log_config.conf" + COPYONLY) \ No newline at end of file diff --git a/core/unittest/server/appendix/log_config.conf b/core/unittest/wrapper/appendix/log_config.conf similarity index 100% rename from core/unittest/server/appendix/log_config.conf rename to core/unittest/wrapper/appendix/log_config.conf diff --git a/core/unittest/server/appendix/server_config.yaml b/core/unittest/wrapper/appendix/server_config.yaml similarity index 100% rename from core/unittest/server/appendix/server_config.yaml rename to core/unittest/wrapper/appendix/server_config.yaml diff --git a/core/unittest/wrapper/test_hybrid_index.cpp b/core/unittest/wrapper/test_hybrid_index.cpp new file mode 100644 index 0000000000..757d5b2098 --- /dev/null +++ b/core/unittest/wrapper/test_hybrid_index.cpp @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "wrapper/VecIndex.h" +#include "wrapper/utils.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#include +#include "knowhere/index/vector_index/IndexIVFSQHybrid.h" + +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::Combine; + +class KnowhereHybrid + : public DataGenBase, public ::testing::Test { + protected: + void SetUp() override { + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); + + dim = 128; + nb = 10000; + nq = 100; + k = 100; + GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); + } + + void TearDown() override { + knowhere::FaissGpuResourceMgr::GetInstance().Free(); + } + + protected: + milvus::engine::IndexType index_type; + milvus::engine::VecIndexPtr index_ = nullptr; + knowhere::Config conf; +}; + +#ifdef CUSTOMIZATION +TEST_F(KnowhereHybrid, test_interface) { + assert(!xb.empty()); + + index_type = milvus::engine::IndexType::FAISS_IVFSQ8_HYBRID; + index_ = GetVecIndexFactory(index_type); + conf = ParamGenerator::GetInstance().Gen(index_type); + + auto elems = nq * k; + std::vector res_ids(elems); + std::vector res_dis(elems); + + conf->gpu_id = DEVICEID; + conf->d = dim; + conf->k = k; + index_->BuildAll(nb, xb.data(), ids.data(), conf); + index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); + AssertResult(res_ids, res_dis); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dimension(), dim); + + auto binaryset = index_->Serialize(); + { + // cpu -> gpu + auto cpu_idx = GetVecIndexFactory(index_type); + cpu_idx->Load(binaryset); + { + for (int i = 0; i < 2; ++i) { + auto gpu_idx = cpu_idx->CopyToGpu(DEVICEID, conf); + gpu_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); + AssertResult(res_ids, res_dis); + } + } + } + + { + // quantization already in gpu, only copy data + auto cpu_idx = GetVecIndexFactory(index_type); + cpu_idx->Load(binaryset); + + auto pair = cpu_idx->CopyToGpuWithQuantizer(DEVICEID, conf); + auto gpu_idx = pair.first; + auto quantization = pair.second; + + gpu_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); + AssertResult(res_ids, res_dis); + + auto quantizer_conf = std::make_shared(); + quantizer_conf->mode = 2; + quantizer_conf->gpu_id = DEVICEID; + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = GetVecIndexFactory(index_type); + hybrid_idx->Load(binaryset); + + hybrid_idx->LoadData(quantization, quantizer_conf); + hybrid_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); + AssertResult(res_ids, res_dis); + } + } + + { + // quantization already in gpu, only set quantization + auto cpu_idx = GetVecIndexFactory(index_type); + cpu_idx->Load(binaryset); + + auto pair = cpu_idx->CopyToGpuWithQuantizer(DEVICEID, conf); + auto quantization = pair.second; + + for (int i = 0; i < 2; ++i) { + auto hybrid_idx = GetVecIndexFactory(index_type); + hybrid_idx->Load(binaryset); + + hybrid_idx->SetQuantizer(quantization); + hybrid_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); + AssertResult(res_ids, res_dis); + hybrid_idx->UnsetQuantizer(); + } + } +} + +#endif diff --git a/core/unittest/wrapper/test_knowhere.cpp b/core/unittest/wrapper/test_knowhere.cpp new file mode 100644 index 0000000000..e9b93fb63e --- /dev/null +++ b/core/unittest/wrapper/test_knowhere.cpp @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "wrapper/KnowhereResource.h" +#include "server/Config.h" + +#include + +namespace { + +static const char* CONFIG_FILE_PATH = "./milvus/conf/server_config.yaml"; +static const char* LOG_FILE_PATH = "./milvus/conf/log_config.conf"; + +} // namespace + +TEST(KnowhereTest, KNOWHERE_RESOURCE_TEST) { + milvus::server::Config &config = milvus::server::Config::GetInstance(); + milvus::Status s = config.LoadConfigFile(CONFIG_FILE_PATH); + ASSERT_TRUE(s.ok()); + + milvus::engine::KnowhereResource::Initialize(); + milvus::engine::KnowhereResource::Finalize(); +} diff --git a/core/unittest/wrapper/test_wrapper.cpp b/core/unittest/wrapper/test_wrapper.cpp index 7accef649c..f112fc7e65 100644 --- a/core/unittest/wrapper/test_wrapper.cpp +++ b/core/unittest/wrapper/test_wrapper.cpp @@ -25,150 +25,36 @@ INITIALIZE_EASYLOGGINGPP -namespace { - -namespace ms = milvus::engine; -namespace kw = knowhere; - -} // namespace - using ::testing::TestWithParam; using ::testing::Values; using ::testing::Combine; -constexpr int64_t DIM = 128; -constexpr int64_t NB = 100000; -constexpr int64_t DEVICE_ID = 0; - -class ParamGenerator { - public: - static ParamGenerator& GetInstance() { - static ParamGenerator instance; - return instance; - } - - knowhere::Config Gen(const milvus::engine::IndexType& type) { - switch (type) { - case milvus::engine::IndexType::FAISS_IDMAP: { - auto tempconf = std::make_shared(); - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - case milvus::engine::IndexType::FAISS_IVFFLAT_CPU: - case milvus::engine::IndexType::FAISS_IVFFLAT_GPU: - case milvus::engine::IndexType::FAISS_IVFFLAT_MIX: { - auto tempconf = std::make_shared(); - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - case milvus::engine::IndexType::FAISS_IVFSQ8_CPU: - case milvus::engine::IndexType::FAISS_IVFSQ8_GPU: - case milvus::engine::IndexType::FAISS_IVFSQ8_MIX: { - auto tempconf = std::make_shared(); - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->nbits = 8; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - case milvus::engine::IndexType::FAISS_IVFPQ_CPU: - case milvus::engine::IndexType::FAISS_IVFPQ_GPU: { - auto tempconf = std::make_shared(); - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->nbits = 8; - tempconf->m = 8; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - case milvus::engine::IndexType::NSG_MIX: { - auto tempconf = std::make_shared(); - tempconf->nlist = 100; - tempconf->nprobe = 16; - tempconf->search_length = 8; - tempconf->knng = 200; - tempconf->search_length = 40; // TODO(linxj): be 20 when search - tempconf->out_degree = 60; - tempconf->candidate_pool_size = 200; - tempconf->metric_type = knowhere::METRICTYPE::L2; - return tempconf; - } - } - } -}; - class KnowhereWrapperTest - : public TestWithParam<::std::tuple> { + : public DataGenBase, + public TestWithParam<::std::tuple> { protected: void SetUp() override { - knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID, - 1024 * 1024 * 200, - 1024 * 1024 * 300, - 2); + knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); std::string generator_type; std::tie(index_type, generator_type, dim, nb, nq, k) = GetParam(); - - auto generator = std::make_shared(); - generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); + GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis); index_ = GetVecIndexFactory(index_type); - conf = ParamGenerator::GetInstance().Gen(index_type); conf->k = k; conf->d = dim; - conf->gpu_id = DEVICE_ID; + conf->gpu_id = DEVICEID; } void TearDown() override { knowhere::FaissGpuResourceMgr::GetInstance().Free(); } - void AssertResult(const std::vector& ids, const std::vector& dis) { - EXPECT_EQ(ids.size(), nq * k); - EXPECT_EQ(dis.size(), nq * k); - - for (auto i = 0; i < nq; i++) { - EXPECT_EQ(ids[i * k], gt_ids[i * k]); - //EXPECT_EQ(dis[i * k], gt_dis[i * k]); - } - - int match = 0; - for (int i = 0; i < nq; ++i) { - for (int j = 0; j < k; ++j) { - for (int l = 0; l < k; ++l) { - if (ids[i * nq + j] == gt_ids[i * nq + l]) match++; - } - } - } - - auto precision = float(match) / (nq * k); - EXPECT_GT(precision, 0.5); - std::cout << std::endl << "Precision: " << precision - << ", match: " << match - << ", total: " << nq * k - << std::endl; - } - protected: milvus::engine::IndexType index_type; - knowhere::Config conf; - - int dim = DIM; - int nb = NB; - int nq = 10; - int k = 10; - std::vector xb; - std::vector xq; - std::vector ids; - milvus::engine::VecIndexPtr index_ = nullptr; - - // Ground Truth - std::vector gt_ids; - std::vector gt_dis; + knowhere::Config conf; }; INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, @@ -220,7 +106,7 @@ TEST_P(KnowhereWrapperTest, TO_GPU_TEST) { AssertResult(res_ids, res_dis); { - auto dev_idx = index_->CopyToGpu(DEVICE_ID); + auto dev_idx = index_->CopyToGpu(DEVICEID); for (int i = 0; i < 10; ++i) { dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); } @@ -232,7 +118,7 @@ TEST_P(KnowhereWrapperTest, TO_GPU_TEST) { write_index(index_, file_location); auto new_index = milvus::engine::read_index(file_location); - auto dev_idx = new_index->CopyToGpu(DEVICE_ID); + auto dev_idx = new_index->CopyToGpu(DEVICEID); for (int i = 0; i < 10; ++i) { dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), conf); } @@ -240,10 +126,6 @@ TEST_P(KnowhereWrapperTest, TO_GPU_TEST) { } } -//TEST_P(KnowhereWrapperTest, TO_CPU_TEST) { -// // dev -//} - TEST_P(KnowhereWrapperTest, SERIALIZE_TEST) { EXPECT_EQ(index_->GetType(), index_type); @@ -283,7 +165,13 @@ TEST_P(KnowhereWrapperTest, SERIALIZE_TEST) { } } -// TODO(linxj): add exception test -//TEST_P(KnowhereWrapperTest, exception_test) { -//} +#include "wrapper/ConfAdapter.h" +TEST(whatever, test_config) { + milvus::engine::TempMetaConf conf; + auto nsg_conf = std::make_shared(); + nsg_conf->Match(conf); + nsg_conf->MatchSearch(conf, milvus::engine::IndexType::FAISS_IVFPQ_GPU); + auto pq_conf = std::make_shared(); + pq_conf->Match(conf); +} diff --git a/core/unittest/wrapper/utils.cpp b/core/unittest/wrapper/utils.cpp index f2bb83b482..445b7a2de6 100644 --- a/core/unittest/wrapper/utils.cpp +++ b/core/unittest/wrapper/utils.cpp @@ -16,6 +16,7 @@ // under the License. +#include #include #include "wrapper/utils.h" @@ -59,3 +60,30 @@ DataGenBase::GenData(const int &dim, gt_dis.resize(nq * k); GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data()); } + +void +DataGenBase::AssertResult(const std::vector& ids, const std::vector& dis) { + EXPECT_EQ(ids.size(), nq * k); + EXPECT_EQ(dis.size(), nq * k); + + for (auto i = 0; i < nq; i++) { + EXPECT_EQ(ids[i * k], gt_ids[i * k]); + //EXPECT_EQ(dis[i * k], gt_dis[i * k]); + } + + int match = 0; + for (int i = 0; i < nq; ++i) { + for (int j = 0; j < k; ++j) { + for (int l = 0; l < k; ++l) { + if (ids[i * nq + j] == gt_ids[i * nq + l]) match++; + } + } + } + + auto precision = float(match) / (nq * k); + EXPECT_GT(precision, 0.5); + std::cout << std::endl << "Precision: " << precision + << ", match: " << match + << ", total: " << nq * k + << std::endl; +} diff --git a/core/unittest/wrapper/utils.h b/core/unittest/wrapper/utils.h index ff4ce9c23a..5a614543c9 100644 --- a/core/unittest/wrapper/utils.h +++ b/core/unittest/wrapper/utils.h @@ -24,24 +24,110 @@ #include #include +#include "wrapper/VecIndex.h" +#include "wrapper/utils.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + class DataGenBase; using DataGenPtr = std::shared_ptr; +constexpr int64_t DIM = 128; +constexpr int64_t NB = 100000; +constexpr int64_t NQ = 10; +constexpr int64_t DEVICEID = 0; +constexpr int64_t PINMEM = 1024 * 1024 * 200; +constexpr int64_t TEMPMEM = 1024 * 1024 * 300; +constexpr int64_t RESNUM = 2; + class DataGenBase { public: - virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, int64_t *ids, - const int &k, int64_t *gt_ids, float *gt_dis); + virtual void GenData(const int& dim, const int& nb, const int& nq, float* xb, float* xq, int64_t* ids, + const int& k, int64_t* gt_ids, float* gt_dis); - virtual void GenData(const int &dim, - const int &nb, - const int &nq, - std::vector &xb, - std::vector &xq, - std::vector &ids, - const int &k, - std::vector >_ids, - std::vector >_dis); + virtual void GenData(const int& dim, + const int& nb, + const int& nq, + std::vector& xb, + std::vector& xq, + std::vector& ids, + const int& k, + std::vector& gt_ids, + std::vector& gt_dis); + + void AssertResult(const std::vector& ids, const std::vector& dis); + + int dim = DIM; + int nb = NB; + int nq = NQ; + int k = 10; + std::vector xb; + std::vector xq; + std::vector ids; + + // Ground Truth + std::vector gt_ids; + std::vector gt_dis; +}; + +class ParamGenerator { + public: + static ParamGenerator& GetInstance() { + static ParamGenerator instance; + return instance; + } + + knowhere::Config Gen(const milvus::engine::IndexType& type) { + switch (type) { + case milvus::engine::IndexType::FAISS_IDMAP: { + auto tempconf = std::make_shared(); + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } + case milvus::engine::IndexType::FAISS_IVFFLAT_CPU: + case milvus::engine::IndexType::FAISS_IVFFLAT_GPU: + case milvus::engine::IndexType::FAISS_IVFFLAT_MIX: { + auto tempconf = std::make_shared(); + tempconf->nlist = 100; + tempconf->nprobe = 16; + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } + case milvus::engine::IndexType::FAISS_IVFSQ8_HYBRID: + case milvus::engine::IndexType::FAISS_IVFSQ8_CPU: + case milvus::engine::IndexType::FAISS_IVFSQ8_GPU: + case milvus::engine::IndexType::FAISS_IVFSQ8_MIX: { + auto tempconf = std::make_shared(); + tempconf->nlist = 100; + tempconf->nprobe = 16; + tempconf->nbits = 8; + tempconf->metric_type = knowhere::METRICTYPE::L2; + return tempconf; + } +// case milvus::engine::IndexType::FAISS_IVFPQ_CPU: +// case milvus::engine::IndexType::FAISS_IVFPQ_GPU: { +// auto tempconf = std::make_shared(); +// tempconf->nlist = 100; +// tempconf->nprobe = 16; +// tempconf->nbits = 8; +// tempconf->m = 8; +// tempconf->metric_type = knowhere::METRICTYPE::L2; +// return tempconf; +// } +// case milvus::engine::IndexType::NSG_MIX: { +// auto tempconf = std::make_shared(); +// tempconf->nlist = 100; +// tempconf->nprobe = 16; +// tempconf->search_length = 8; +// tempconf->knng = 200; +// tempconf->search_length = 40; // TODO(linxj): be 20 when search +// tempconf->out_degree = 60; +// tempconf->candidate_pool_size = 200; +// tempconf->metric_type = knowhere::METRICTYPE::L2; +// return tempconf; +// } + } + } }; diff --git a/docker/build_env/ubuntu16.04/Dockerfile b/docker/build_env/ubuntu16.04/Dockerfile new file mode 100644 index 0000000000..a0ccecce5f --- /dev/null +++ b/docker/build_env/ubuntu16.04/Dockerfile @@ -0,0 +1,25 @@ +FROM nvidia/cuda:10.1-devel-ubuntu16.04 + +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && apt-get install -y --no-install-recommends wget && \ + wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \ + wget -P /tmp https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \ + apt-key add /tmp/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \ + sh -c 'echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list' && \ + apt-get update && apt-get install -y --no-install-recommends \ + git flex bison gfortran \ + curl libtool automake libboost1.58-all-dev libssl-dev pkg-config libcurl4-openssl-dev \ + clang-format-6.0 clang-tidy-6.0 \ + lcov mysql-client libmysqlclient-dev intel-mkl-gnu-2019.4-243 intel-mkl-core-2019.4-243 && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/lib/x86_64-linux-gnu/libmysqlclient.so /usr/lib/x86_64-linux-gnu/libmysqlclient_r.so + +RUN sh -c 'echo export LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2019.4.243/linux/mkl/lib/intel64:\$LD_LIBRARY_PATH > /etc/profile.d/mkl.sh' + +COPY docker-entrypoint.sh /app/docker-entrypoint.sh +ENTRYPOINT [ "/app/docker-entrypoint.sh" ] +CMD [ "start" ] + diff --git a/docker/build_env/ubuntu16.04/docker-entrypoint.sh b/docker/build_env/ubuntu16.04/docker-entrypoint.sh new file mode 100755 index 0000000000..1e85e7e9e1 --- /dev/null +++ b/docker/build_env/ubuntu16.04/docker-entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +if [ "$1" = 'start' ]; then + tail -f /dev/null +fi + +exec "$@" + diff --git a/docker/build_env/ubuntu18.04/Dockerfile b/docker/build_env/ubuntu18.04/Dockerfile new file mode 100644 index 0000000000..e7c528f48e --- /dev/null +++ b/docker/build_env/ubuntu18.04/Dockerfile @@ -0,0 +1,25 @@ +FROM nvidia/cuda:10.1-devel-ubuntu18.04 + +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && apt-get install -y --no-install-recommends wget && \ + wget -qO- "https://cmake.org/files/v3.14/cmake-3.14.3-Linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \ + wget -P /tmp https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \ + apt-key add /tmp/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB && \ + sh -c 'echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list' && \ + apt-get update && apt-get install -y --no-install-recommends \ + git flex bison gfortran \ + curl libtool automake libboost-all-dev libssl-dev pkg-config libcurl4-openssl-dev \ + clang-format-6.0 clang-tidy-6.0 \ + lcov mysql-client libmysqlclient-dev intel-mkl-gnu-2019.4-243 intel-mkl-core-2019.4-243 && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/lib/x86_64-linux-gnu/libmysqlclient.so /usr/lib/x86_64-linux-gnu/libmysqlclient_r.so + +RUN sh -c 'echo export LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2019.4.243/linux/mkl/lib/intel64:\$LD_LIBRARY_PATH > /etc/profile.d/mkl.sh' + +COPY docker-entrypoint.sh /app/docker-entrypoint.sh +ENTRYPOINT [ "/app/docker-entrypoint.sh" ] +CMD [ "start" ] + diff --git a/docker/build_env/ubuntu18.04/docker-entrypoint.sh b/docker/build_env/ubuntu18.04/docker-entrypoint.sh new file mode 100755 index 0000000000..1e85e7e9e1 --- /dev/null +++ b/docker/build_env/ubuntu18.04/docker-entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +if [ "$1" = 'start' ]; then + tail -f /dev/null +fi + +exec "$@" + diff --git a/docker/deploy/ubuntu16.04/Dockerfile b/docker/deploy/ubuntu16.04/Dockerfile new file mode 100644 index 0000000000..c5ca0ab03e --- /dev/null +++ b/docker/deploy/ubuntu16.04/Dockerfile @@ -0,0 +1,23 @@ +FROM nvidia/cuda:10.1-devel-ubuntu16.04 + +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN rm -rf /etc/apt/sources.list.d/nvidia-ml.list && rm -rf /etc/apt/sources.list.d/cuda.list + +RUN apt-get update && apt-get install -y --no-install-recommends \ + gfortran libsqlite3-dev libmysqlclient-dev libcurl4-openssl-dev python3 && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/lib/x86_64-linux-gnu/libmysqlclient.so /usr/lib/x86_64-linux-gnu/libmysqlclient_r.so + +COPY ./docker-entrypoint.sh /opt +COPY ./milvus /opt/milvus +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/milvus/lib" + +ENTRYPOINT [ "/opt/docker-entrypoint.sh" ] + +CMD [ "start" ] + +EXPOSE 19530 + diff --git a/docker/deploy/ubuntu16.04/docker-entrypoint.sh b/docker/deploy/ubuntu16.04/docker-entrypoint.sh new file mode 100755 index 0000000000..446c174d74 --- /dev/null +++ b/docker/deploy/ubuntu16.04/docker-entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +if [ "$1" == 'start' ]; then + cd /opt/milvus/scripts && ./start_server.sh +fi + +exec "$@" + diff --git a/docker/deploy/ubuntu18.04/Dockerfile b/docker/deploy/ubuntu18.04/Dockerfile new file mode 100644 index 0000000000..0d16ae46e1 --- /dev/null +++ b/docker/deploy/ubuntu18.04/Dockerfile @@ -0,0 +1,23 @@ +FROM nvidia/cuda:10.1-devel-ubuntu18.04 + +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN rm -rf /etc/apt/sources.list.d/nvidia-ml.list && rm -rf /etc/apt/sources.list.d/cuda.list + +RUN apt-get update && apt-get install -y --no-install-recommends \ + gfortran libsqlite3-dev libmysqlclient-dev libcurl4-openssl-dev python3 && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/lib/x86_64-linux-gnu/libmysqlclient.so /usr/lib/x86_64-linux-gnu/libmysqlclient_r.so + +COPY ./docker-entrypoint.sh /opt +COPY ./milvus /opt/milvus +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/milvus/lib" + +ENTRYPOINT [ "/opt/docker-entrypoint.sh" ] + +CMD [ "start" ] + +EXPOSE 19530 + diff --git a/docker/deploy/ubuntu18.04/docker-entrypoint.sh b/docker/deploy/ubuntu18.04/docker-entrypoint.sh new file mode 100755 index 0000000000..446c174d74 --- /dev/null +++ b/docker/deploy/ubuntu18.04/docker-entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +if [ "$1" == 'start' ]; then + cd /opt/milvus/scripts && ./start_server.sh +fi + +exec "$@" + diff --git a/docker/docker-compose-monitor.yml b/docker/docker-compose-monitor.yml index 81e7084af9..ac417674be 100644 --- a/docker/docker-compose-monitor.yml +++ b/docker/docker-compose-monitor.yml @@ -42,15 +42,15 @@ services: milvus_server: runtime: nvidia - image: registry.zilliz.com/milvus/engine:branch-0.4.0-release + image: milvusdb/milvus:latest restart: always links: - prometheus environment: WEB_APP: host.docker.internal volumes: - - ../cpp/conf/server_config.yaml:/opt/milvus/conf/server_config.yaml - - ../cpp/conf/log_config.conf:/opt/milvus/conf/log_config.conf + - ../core/conf/server_config.yaml:/opt/milvus/conf/server_config.yaml + - ../core/conf/log_config.conf:/opt/milvus/conf/log_config.conf ports: - "8080:8080" - "19530:19530" diff --git a/tests/milvus-java-test/.gitignore b/tests/milvus-java-test/.gitignore new file mode 100644 index 0000000000..3a553813a8 --- /dev/null +++ b/tests/milvus-java-test/.gitignore @@ -0,0 +1,4 @@ +target/ +.idea/ +test-output/ +lib/* diff --git a/tests/milvus-java-test/README.md b/tests/milvus-java-test/README.md new file mode 100644 index 0000000000..eba21bb4e4 --- /dev/null +++ b/tests/milvus-java-test/README.md @@ -0,0 +1,29 @@ +# Requirements + +- jdk-1.8 +- testng + +# How to use this Test Project + +1. package and install + +```shell +mvn clean install +``` + +2. start or deploy your milvus server +3. run tests + +```shell +java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h 127.0.0.1 +``` + +4. get test report + +```shell +firefox test-output/index.html +``` + +# Contribution getting started + +Add test cases under testng framework \ No newline at end of file diff --git a/tests/milvus-java-test/bin/run.sh b/tests/milvus-java-test/bin/run.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/milvus-java-test/ci/function/file_transfer.groovy b/tests/milvus-java-test/ci/function/file_transfer.groovy new file mode 100644 index 0000000000..bebae14832 --- /dev/null +++ b/tests/milvus-java-test/ci/function/file_transfer.groovy @@ -0,0 +1,10 @@ +def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) { + if (protocol == "ftp") { + ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [ + [configName: "${remoteIP}", transfers: [ + [asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true + ] + ] + } +} +return this diff --git a/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy new file mode 100644 index 0000000000..2e9332fa6e --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy @@ -0,0 +1,13 @@ +try { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } +} catch (exc) { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } + throw exc +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy new file mode 100644 index 0000000000..6650b7c219 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy @@ -0,0 +1,16 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus' + sh 'helm repo update' + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]]) + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${IMAGE_TAG} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-sdk-test --version 0.3.1 ." + } + } +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + throw exc +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy new file mode 100644 index 0000000000..662c93bc77 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy @@ -0,0 +1,13 @@ +timeout(time: 30, unit: 'MINUTES') { + try { + dir ("milvus-java-test") { + sh "mvn clean install" + sh "java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-sdk-test.svc.cluster.local" + } + + } catch (exc) { + echo 'Milvus-SDK-Java Integration Test Failed !' + throw exc + } +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/notify.groovy b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy new file mode 100644 index 0000000000..0a257b8cd8 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy @@ -0,0 +1,15 @@ +def notify() { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } +} +return this + diff --git a/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy new file mode 100644 index 0000000000..7e106c0296 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy @@ -0,0 +1,13 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("${PROJECT_NAME}_test") { + if (fileExists('test_out')) { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\"" + } + } else { + error("Milvus Dev Test Out directory don't exists!") + } + } +} diff --git a/tests/milvus-java-test/ci/main_jenkinsfile b/tests/milvus-java-test/ci/main_jenkinsfile new file mode 100644 index 0000000000..5df9d61ccb --- /dev/null +++ b/tests/milvus-java-test/ci/main_jenkinsfile @@ -0,0 +1,110 @@ +pipeline { + agent none + + options { + timestamps() + } + + environment { + SRC_BRANCH = "master" + IMAGE_TAG = "${params.IMAGE_TAG}-release" + HELM_BRANCH = "${params.IMAGE_TAG}" + TEST_URL = "git@192.168.1.105:Test/milvus-java-test.git" + TEST_BRANCH = "${params.IMAGE_TAG}" + } + + stages { + stage("Setup env") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ + apiVersion: v1 + kind: Pod + metadata: + labels: + app: milvus + componet: test + spec: + containers: + - name: milvus-testframework-java + image: registry.zilliz.com/milvus/milvus-java-test:v0.1 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config + """ + } + } + + stages { + stage("Deploy Server") { + steps { + gitlabCommitStatus(name: 'Deloy Server') { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/deploy_server.groovy" + } + } + } + } + } + stage("Integration Test") { + steps { + gitlabCommitStatus(name: 'Integration Test') { + container('milvus-testframework-java') { + script { + print "In integration test stage" + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/integration_test.groovy" + } + } + } + } + } + stage ("Cleanup Env") { + steps { + gitlabCommitStatus(name: 'Cleanup Env') { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy" + } + } + } + success { + script { + echo "Milvus java-sdk test success !" + } + } + aborted { + script { + echo "Milvus java-sdk test aborted !" + } + } + failure { + script { + echo "Milvus java-sdk test failed !" + } + } + } + } + } +} diff --git a/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml new file mode 100644 index 0000000000..1381e1454f --- /dev/null +++ b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: testframework-java +spec: + containers: + - name: milvus-testframework-java + image: maven:3.6.2-jdk-8 + command: + - cat + tty: true diff --git a/tests/milvus-java-test/milvus-java-test.iml b/tests/milvus-java-test/milvus-java-test.iml new file mode 100644 index 0000000000..78b2cc53b2 --- /dev/null +++ b/tests/milvus-java-test/milvus-java-test.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/tests/milvus-java-test/pom.xml b/tests/milvus-java-test/pom.xml new file mode 100644 index 0000000000..4da715e292 --- /dev/null +++ b/tests/milvus-java-test/pom.xml @@ -0,0 +1,137 @@ + + + 4.0.0 + + milvus + MilvusSDkJavaTest + 1.0-SNAPSHOT + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + package + + copy-dependencies + + + lib + false + true + + + + + + + + + UTF-8 + 1.23.0 + 3.9.0 + 3.9.0 + 1.8 + 1.8 + + + + + + + + + + + + + + + + + oss.sonatype.org-snapshot + http://oss.sonatype.org/content/repositories/snapshots + + false + + + true + + + + + + + org.apache.commons + commons-lang3 + 3.4 + + + + commons-cli + commons-cli + 1.3 + + + + org.testng + testng + 6.10 + + + + junit + junit + 4.9 + + + + + + + + + + io.milvus + milvus-sdk-java + 0.2.0-SNAPSHOT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/milvus-java-test/src/main/java/com/MainClass.java b/tests/milvus-java-test/src/main/java/com/MainClass.java new file mode 100644 index 0000000000..8928843f01 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/MainClass.java @@ -0,0 +1,146 @@ +package com; + +import io.milvus.client.*; +import org.apache.commons.cli.*; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.SkipException; +import org.testng.TestNG; +import org.testng.annotations.DataProvider; +import org.testng.xml.XmlClass; +import org.testng.xml.XmlSuite; +import org.testng.xml.XmlTest; + +import java.util.ArrayList; +import java.util.List; + +public class MainClass { + private static String host = "127.0.0.1"; + private static String port = "19530"; + private int index_file_size = 50; + public int dimension = 128; + + public static void setHost(String host) { + MainClass.host = host; + } + + public static void setPort(String port) { + MainClass.port = port; + } + + @DataProvider(name="DefaultConnectArgs") + public static Object[][] defaultConnectArgs(){ + return new Object[][]{{host, port}}; + } + + @DataProvider(name="ConnectInstance") + public Object[][] connectInstance() throws ConnectFailedException { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + String tableName = RandomStringUtils.randomAlphabetic(10); + return new Object[][]{{client, tableName}}; + } + + @DataProvider(name="DisConnectInstance") + public Object[][] disConnectInstance() throws ConnectFailedException { + // Generate connection instance + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + String tableName = RandomStringUtils.randomAlphabetic(10); + return new Object[][]{{client, tableName}}; + } + + @DataProvider(name="Table") + public Object[][] provideTable() throws ConnectFailedException { + Object[][] tables = new Object[2][2]; + MetricType[] metricTypes = { MetricType.L2, MetricType.IP }; + for (int i = 0; i < metricTypes.length; ++i) { + String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10); + // Generate connection instance + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(metricTypes[i]) + .build(); + Response res = client.createTable(tableSchema); + if (!res.ok()) { + System.out.println(res.getMessage()); + throw new SkipException("Table created failed"); + } + tables[i] = new Object[]{client, tableName}; + } + return tables; + } + + public static void main(String[] args) { + CommandLineParser parser = new DefaultParser(); + Options options = new Options(); + options.addOption("h", "host", true, "milvus-server hostname/ip"); + options.addOption("p", "port", true, "milvus-server port"); + try { + CommandLine cmd = parser.parse(options, args); + String host = cmd.getOptionValue("host"); + if (host != null) { + setHost(host); + } + String port = cmd.getOptionValue("port"); + if (port != null) { + setPort(port); + } + System.out.println("Host: "+host+", Port: "+port); + } + catch(ParseException exp) { + System.err.println("Parsing failed. Reason: " + exp.getMessage() ); + } + +// TestListenerAdapter tla = new TestListenerAdapter(); +// TestNG testng = new TestNG(); +// testng.setTestClasses(new Class[] { TestPing.class }); +// testng.setTestClasses(new Class[] { TestConnect.class }); +// testng.addListener(tla); +// testng.run(); + + XmlSuite suite = new XmlSuite(); + suite.setName("TmpSuite"); + + XmlTest test = new XmlTest(suite); + test.setName("TmpTest"); + List classes = new ArrayList(); + + classes.add(new XmlClass("com.TestPing")); + classes.add(new XmlClass("com.TestAddVectors")); + classes.add(new XmlClass("com.TestConnect")); + classes.add(new XmlClass("com.TestDeleteVectors")); + classes.add(new XmlClass("com.TestIndex")); + classes.add(new XmlClass("com.TestSearchVectors")); + classes.add(new XmlClass("com.TestTable")); + classes.add(new XmlClass("com.TestTableCount")); + + test.setXmlClasses(classes) ; + + List suites = new ArrayList(); + suites.add(suite); + TestNG tng = new TestNG(); + tng.setXmlSuites(suites); + tng.run(); + + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestAddVectors.java b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java new file mode 100644 index 0000000000..215f526179 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java @@ -0,0 +1,150 @@ +package com; + +import io.milvus.client.InsertParam; +import io.milvus.client.InsertResponse; +import io.milvus.client.MilvusClient; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TestAddVectors { + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + String tableNameNew = tableName + "_"; + InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.currentThread().sleep(1000); + // Assert table row count + Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 200000; +// List> vectors = gen_vectors(nb); +// System.out.println(new Date()); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build(); +// InsertResponse res = client.insert(insertParam); +// assert(!res.getResponse().ok()); +// } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException { + int nb = 500000; + List> vectors = gen_vectors(nb); + System.out.println(new Date()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + // Add vectors with ids + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.currentThread().sleep(2000); + // Assert table row count + Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb); + } + + // TODO: MS-628 + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) { + int nb = 10; + List> vectors = gen_vectors(nb); + // Add vectors with ids + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb+1) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) { + int nb = 10000; + List> vectors = gen_vectors(nb); + vectors.get(0).add((float) 0); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) { + int nb = 10000; + List> vectors = gen_vectors(nb); + vectors.set(0, new ArrayList<>()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100000; + int loops = 10; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = null; + for (int i = 0; i < loops; ++i ) { + long startTime = System.currentTimeMillis(); + res = client.insert(insertParam); + long endTime = System.currentTimeMillis(); + System.out.println("Total execution time: " + (endTime-startTime) + "ms"); + } + Thread.currentThread().sleep(1000); + // Assert table row count + Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb * loops); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestConnect.java b/tests/milvus-java-test/src/main/java/com/TestConnect.java new file mode 100644 index 0000000000..8f6d556f8b --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestConnect.java @@ -0,0 +1,89 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +public class TestConnect { + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect(String host, String port) throws ConnectFailedException { + System.out.println("Host: "+host+", Port: "+port); + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + Response res = client.connect(connectParam); + assert(res.ok()); + assert(client.isConnected()); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect_repeat(String host, String port) { + MilvusGrpcClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + Response res = null; + try { + res = client.connect(connectParam); + res = client.connect(connectParam); + } catch (ConnectFailedException e) { + e.printStackTrace(); + } + assert (res.ok()); + assert(client.isConnected()); + } + + @Test(dataProvider="InvalidConnectArgs") + public void test_connect_invalid_connect_args(String ip, String port) { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(ip) + .withPort(port) + .build(); + Response res = null; + try { + res = client.connect(connectParam); + } catch (ConnectFailedException e) { + e.printStackTrace(); + } + Assert.assertEquals(res, null); + assert(!client.isConnected()); + } + + // TODO: MS-615 + @DataProvider(name="InvalidConnectArgs") + public Object[][] generate_invalid_connect_args() { + String port = "19530"; + String ip = ""; + return new Object[][]{ + {"1.1.1.1", port}, + {"255.255.0.0", port}, + {"1.2.2", port}, + {"中文", port}, + {"www.baidu.com", "100000"}, + {"127.0.0.1", "100000"}, + {"www.baidu.com", "80"}, + }; + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_disconnect(MilvusClient client, String tableName){ + assert(!client.isConnected()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_disconnect_repeatably(MilvusClient client, String tableName){ + Response res = null; + try { + res = client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + assert(!res.ok()); + assert(!client.isConnected()); + } +} diff --git a/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java new file mode 100644 index 0000000000..d5fde0e570 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java @@ -0,0 +1,116 @@ +package com; + +import java.util.*; + +public class TestDeleteVectors { + int index_file_size = 50; + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + public static Date getDeltaDate(int delta) { + Date today = new Date(); + Calendar c = Calendar.getInstance(); + c.setTime(today); + c.add(Calendar.DAY_OF_MONTH, delta); + return c.getTime(); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 10000; +// List> vectors = gen_vectors(nb); +// // Add vectors +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// InsertResponse res = client.insert(insertParam); +// assert(res.getResponse().ok()); +// Thread.sleep(1000); +// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(res_delete.ok()); +// Thread.sleep(1000); +// // Assert table row count +// Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0); +// } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { +// String tableNameNew = tableName + "_"; +// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build(); +// Response res_delete = client.deleteByRange(param); +// assert(!res_delete.ok()); +// } + +// @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) +// public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException { +// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(!res_delete.ok()); +// } +// +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException { +// DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(res_delete.ok()); +// } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 100; +// List> vectors = gen_vectors(nb); +// // Add vectors +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// InsertResponse res = client.insert(insertParam); +// assert(res.getResponse().ok()); +// Thread.sleep(1000); +// DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(!res_delete.ok()); +// } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 100; +// List> vectors = gen_vectors(nb); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// InsertResponse res = client.insert(insertParam); +// assert(res.getResponse().ok()); +// DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(!res_delete.ok()); +// } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 100; +// List> vectors = gen_vectors(nb); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// InsertResponse res = client.insert(insertParam); +// assert(res.getResponse().ok()); +// Thread.sleep(1000); +// DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2)); +// DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); +// Response res_delete = client.deleteByRange(param); +// assert(res_delete.ok()); +// Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb); +// } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestIndex.java b/tests/milvus-java-test/src/main/java/com/TestIndex.java new file mode 100644 index 0000000000..eaf0c8dc10 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestIndex.java @@ -0,0 +1,324 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.LinkedList; +import java.util.List; +import java.util.Random; + +public class TestIndex { + int index_file_size = 10; + int dimension = 128; + int n_list = 1024; + int default_n_list = 16384; + int nb = 100000; + IndexType indexType = IndexType.IVF_SQ8; + IndexType defaultIndexType = IndexType.FLAT; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_repeatably(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_create_index_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException { +// int nb = 500000; +// IndexType indexType = IndexType.IVF_SQ8; +// List> vectors = gen_vectors(nb); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// client.insert(insertParam); +// Index index = new Index.Builder().withIndexType(indexType) +// .withNList(n_list) +// .build(); +// System.out.println(new Date()); +// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(1).build(); +// Response res_create = client.createIndex(createIndexParam); +// assert(!res_create.ok()); +// } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFSQ8H(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8H; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_with_no_vector(MilvusClient client, String tableName) { + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableNameNew).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_create_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_invalid_n_list(MilvusClient client, String tableName) throws InterruptedException { + int n_list = 0; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_alter_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + // Create another index + IndexType indexTypeNew = IndexType.IVFLAT; + int n_list_new = n_list + 1; + Index index_new = new Index.Builder().withIndexType(indexTypeNew) + .withNList(n_list_new) + .build(); + CreateIndexParam createIndexParamNew = new CreateIndexParam.Builder(tableName).withIndex(index_new).build(); + Response res_create_new = client.createIndex(createIndexParamNew); + assert(res_create_new.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res_create.ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list_new); + Assert.assertEquals(index1.getIndexType(), indexTypeNew); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + DescribeIndexResponse res = client.describeIndex(tableNameNew); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_describe_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + DescribeIndexResponse res = client.describeIndex(tableName); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(defaultIndexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + Response res_drop = client.dropIndex(tableName); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_repeatably(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(defaultIndexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + Response res_drop = client.dropIndex(tableName); + res_drop = client.dropIndex(tableName); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + Response res_drop = client.dropIndex(tableNameNew); + assert(!res_drop.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_drop_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + Response res_drop = client.dropIndex(tableName); + assert(!res_drop.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_no_index_created(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Response res_drop = client.dropIndex(tableName); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableName); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestMix.java b/tests/milvus-java-test/src/main/java/com/TestMix.java new file mode 100644 index 0000000000..7c33da7094 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestMix.java @@ -0,0 +1,225 @@ +package com; + +import io.milvus.client.*; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +public class TestMix { + private int dimension = 128; + private int nb = 100000; + int n_list = 8192; + int n_probe = 20; + int top_k = 10; + double epsilon = 0.001; + int index_file_size = 20; + + public List normalize(List w2v){ + float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum); + final float norm = (float) Math.sqrt(squareSum); + w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList()); + return w2v; + } + + public List> gen_vectors(int nb, boolean norm) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + List vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + if (norm == true) { + vector = normalize(vector); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 10; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect_threads(String host, String port) throws ConnectFailedException { + int thread_num = 100; + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + try { + client.connect(connectParam); + } catch (ConnectFailedException e) { + e.printStackTrace(); + } + assert(client.isConnected()); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 10; + List> vectors = gen_vectors(nb,false); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + assert (res_insert.getResponse().ok()); + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + + Thread.sleep(2000); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 50; + List> vectors = gen_vectors(nb,false); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + assert (res_insert.getResponse().ok()); + }); + } + executor.awaitQuiescence(300, TimeUnit.SECONDS); + executor.shutdown(); + Thread.sleep(2000); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 50; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + assert (res_insert.getResponse().ok()); + try { + TimeUnit.SECONDS.sleep(1); + } catch (InterruptedException e) { + e.printStackTrace(); + } + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + double distance = res.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + }); + } + executor.awaitQuiescence(300, TimeUnit.SECONDS); + executor.shutdown(); + Thread.sleep(2000); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableName); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_create_insert_delete_threads(String host, String port) { + int thread_num = 100; + List> vectors = gen_vectors(nb,false); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + try { + client.connect(connectParam); + } catch (ConnectFailedException e) { + e.printStackTrace(); + } + assert(client.isConnected()); + String tableName = RandomStringUtils.randomAlphabetic(10); + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.IP) + .build(); + client.createTable(tableSchema); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Response response = client.dropTable(tableName); + Assert.assertTrue(response.ok()); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestPing.java b/tests/milvus-java-test/src/main/java/com/TestPing.java new file mode 100644 index 0000000000..1ed462e47d --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestPing.java @@ -0,0 +1,25 @@ +package com; + +import io.milvus.client.*; +import org.testng.annotations.Test; + +public class TestPing { + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_server_status(String host, String port) throws ConnectFailedException { + System.out.println("Host: "+host+", Port: "+port); + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + Response res = client.getServerStatus(); + assert (res.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){ + Response res = client.getServerStatus(); + assert (!res.ok()); + } +} \ No newline at end of file diff --git a/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java new file mode 100644 index 0000000000..de69a1c065 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java @@ -0,0 +1,470 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TestSearchVectors { + int index_file_size = 10; + int dimension = 128; + int n_list = 1024; + int default_n_list = 16384; + int nb = 100000; + int n_probe = 20; + int top_k = 10; + double epsilon = 0.001; + IndexType indexType = IndexType.IVF_SQ8; + IndexType defaultIndexType = IndexType.FLAT; + + + public List normalize(List w2v){ + float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum); + final float norm = (float) Math.sqrt(squareSum); + w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList()); + return w2v; + } + + public List> gen_vectors(int nb, boolean norm) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + List vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + if (norm == true) { + vector = normalize(vector); + } + xb.add(vector); + } + return xb; + } + + public static Date getDeltaDate(int delta) { + Date today = new Date(); + Calendar c = Calendar.getInstance(); + c.setTime(today); + c.add(Calendar.DAY_OF_MONTH, delta); + return c.getTime(); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + int nq = 5; + int nb = 100; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + SearchParam searchParam = new SearchParam.Builder(tableNameNew, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_ids_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(2000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + double distance = res_search.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int nb = 1000; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(default_n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getResultDistancesList(); + for (int i = 0; i < nq; i++) { + double distance = res_search.get(i).get(0); + System.out.println(distance); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_search_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException { +// IndexType indexType = IndexType.FLAT; +// int nb = 100000; +// int nq = 1000; +// int top_k = 2048; +// List> vectors = gen_vectors(nb, false); +// List> vectors = gen_vectors(nb, false); +// List> queryVectors = vectors.subList(0,nq); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// client.insert(insertParam); +// Thread.sleep(1000); +// SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withTimeout(1).build(); +// System.out.println(new Date()); +// SearchResponse res_search = client.search(searchParam); +// assert (!res_search.getResponse().ok()); +// } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_FLAT_big_data_size(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nb = 100000; + int nq = 2000; + int top_k = 2048; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + System.out.println(new Date()); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + double distance = res_search.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_invalid_n_probe(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int n_probe_new = 0; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe_new).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_invalid_top_k(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int top_k_new = 0; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k_new).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_search_invalid_query_vectors(MilvusClient client, String tableName) throws InterruptedException { +// IndexType indexType = IndexType.IVF_SQ8; +// int nq = 5; +// List> vectors = gen_vectors(nb, false); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// client.insert(insertParam); +// Index index = new Index.Builder().withIndexType(indexType) +// .withNList(n_list) +// .build(); +// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); +// client.createIndex(createIndexParam); +// TableParam tableParam = new TableParam.Builder(tableName).build(); +// SearchParam searchParam = new SearchParam.Builder(tableName, null).withNProbe(n_probe).withTopK(top_k).build(); +// SearchResponse res_search = client.search(searchParam); +// assert (!res_search.getResponse().ok()); +// } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), nq); + Assert.assertEquals(res.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), nq); + Assert.assertEquals(res.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range_no_result(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range_no_result(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range_invalid(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range_invalid(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestTable.java b/tests/milvus-java-test/src/main/java/com/TestTable.java new file mode 100644 index 0000000000..e722db23df --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestTable.java @@ -0,0 +1,142 @@ +package com; + + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.List; + +public class TestTable { + int index_file_size = 50; + int dimension = 128; + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + Response res = client.createTable(tableSchema); + assert(res.ok()); + Assert.assertEquals(res.ok(), true); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_disconnect(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + Response res = client.createTable(tableSchema); + assert(!res.ok()); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_repeatably(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + Response res = client.createTable(tableSchema); + Assert.assertEquals(res.ok(), true); + Response res_new = client.createTable(tableSchema); + Assert.assertEquals(res_new.ok(), false); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_wrong_params(MilvusClient client, String tableName){ + Integer dimension = 0; + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + Response res = client.createTable(tableSchema); + System.out.println(res.toString()); + Assert.assertEquals(res.ok(), false); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_show_tables(MilvusClient client, String tableName){ + Integer tableNum = 10; + ShowTablesResponse res = null; + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName+"_"+Integer.toString(i); + TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + client.createTable(tableSchema); + List tableNames = client.showTables().getTableNames(); + Assert.assertTrue(tableNames.contains(tableNameNew)); + } + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_show_tables_without_connect(MilvusClient client, String tableName){ + ShowTablesResponse res = client.showTables(); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_table(MilvusClient client, String tableName) throws InterruptedException { + Response res = client.dropTable(tableName); + assert(res.ok()); + Thread.currentThread().sleep(1000); + List tableNames = client.showTables().getTableNames(); + Assert.assertFalse(tableNames.contains(tableName)); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + Response res = client.dropTable(tableName+"_"); + assert(!res.ok()); + List tableNames = client.showTables().getTableNames(); + Assert.assertTrue(tableNames.contains(tableName)); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_drop_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + Response res = client.dropTable(tableName); + assert(!res.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_table(MilvusClient client, String tableName) throws InterruptedException { + DescribeTableResponse res = client.describeTable(tableName); + assert(res.getResponse().ok()); + TableSchema tableSchema = res.getTableSchema().get(); + Assert.assertEquals(tableSchema.getDimension(), dimension); + Assert.assertEquals(tableSchema.getTableName(), tableName); + Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size); + Assert.assertEquals(tableSchema.getMetricType().name(), tableName.substring(0,2)); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_describe_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + DescribeTableResponse res = client.describeTable(tableName); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_has_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + HasTableResponse res = client.hasTable(tableName+"_"); + assert(res.getResponse().ok()); + Assert.assertFalse(res.hasTable()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_has_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + HasTableResponse res = client.hasTable(tableName); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_has_table(MilvusClient client, String tableName) throws InterruptedException { + HasTableResponse res = client.hasTable(tableName); + assert(res.getResponse().ok()); + Assert.assertTrue(res.hasTable()); + } + + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestTableCount.java b/tests/milvus-java-test/src/main/java/com/TestTableCount.java new file mode 100644 index 0000000000..5cda18e812 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestTableCount.java @@ -0,0 +1,83 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +public class TestTableCount { + int index_file_size = 50; + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + ArrayList vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_no_vectors(MilvusClient client, String tableName) { + Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_table_not_existed(MilvusClient client, String tableName) { + GetTableRowCountResponse res = client.getTableRowCount(tableName+"_"); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_table_count_without_connect(MilvusClient client, String tableName) { + GetTableRowCountResponse res = client.getTableRowCount(tableName+"_"); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();; + client.insert(insertParam); + Thread.currentThread().sleep(2000); + Assert.assertEquals(client.getTableRowCount(tableName).getTableRowCount(), nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_multi_tables(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + Integer tableNum = 10; + GetTableRowCountResponse res = null; + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName + "_" + Integer.toString(i); + TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + client.createTable(tableSchema); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build(); + client.insert(insertParam); + } + Thread.currentThread().sleep(1000); + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName + "_" + Integer.toString(i); + res = client.getTableRowCount(tableNameNew); + Assert.assertEquals(res.getTableRowCount(), nb); + } + } + +} + + diff --git a/tests/milvus-java-test/testng.xml b/tests/milvus-java-test/testng.xml new file mode 100644 index 0000000000..19520f7eab --- /dev/null +++ b/tests/milvus-java-test/testng.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/tests/milvus_ann_acc/.gitignore b/tests/milvus_ann_acc/.gitignore new file mode 100644 index 0000000000..f250cab9fe --- /dev/null +++ b/tests/milvus_ann_acc/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +logs/ diff --git a/tests/milvus_ann_acc/README.md b/tests/milvus_ann_acc/README.md new file mode 100644 index 0000000000..f5ab9d8168 --- /dev/null +++ b/tests/milvus_ann_acc/README.md @@ -0,0 +1,21 @@ +# Requirements + +- python 3.6+ +- pip install -r requirements.txt + +# How to use this Test Project + +This project is used to test search accuracy based on the given datasets (https://github.com/erikbern/ann-benchmarks#data-sets) + +1. start your milvus server +2. update your test configuration in test.py +3. run command + +```shell +python test.py +``` + +# Contribution getting started + +- Follow PEP-8 for naming and black for formatting. + diff --git a/tests/milvus_ann_acc/client.py b/tests/milvus_ann_acc/client.py new file mode 100644 index 0000000000..de4ef17cb6 --- /dev/null +++ b/tests/milvus_ann_acc/client.py @@ -0,0 +1,149 @@ +import pdb +import random +import logging +import json +import time, datetime +from multiprocessing import Process +import numpy +import sklearn.preprocessing +from milvus import Milvus, IndexType, MetricType + +logger = logging.getLogger("milvus_ann_acc.client") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 + + +def time_wrapper(func): + """ + This decorator prints the execution time for the decorated function. + """ + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) + return result + return wrapper + + +class MilvusClient(object): + def __init__(self, table_name=None, ip=None, port=None): + self._milvus = Milvus() + self._table_name = table_name + try: + if not ip: + self._milvus.connect( + host = SERVER_HOST_DEFAULT, + port = SERVER_PORT_DEFAULT) + else: + self._milvus.connect( + host = ip, + port = port) + except Exception as e: + raise e + + def __str__(self): + return 'Milvus table %s' % self._table_name + + def check_status(self, status): + if not status.OK(): + logger.error(status.message) + raise Exception("Status not ok") + + def create_table(self, table_name, dimension, index_file_size, metric_type): + if not self._table_name: + self._table_name = table_name + if metric_type == "l2": + metric_type = MetricType.L2 + elif metric_type == "ip": + metric_type = MetricType.IP + else: + logger.error("Not supported metric_type: %s" % metric_type) + self._metric_type = metric_type + create_param = {'table_name': table_name, + 'dimension': dimension, + 'index_file_size': index_file_size, + "metric_type": metric_type} + status = self._milvus.create_table(create_param) + self.check_status(status) + + @time_wrapper + def insert(self, X, ids): + if self._metric_type == MetricType.IP: + logger.info("Set normalize for metric_type: Inner Product") + X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') + X = X.astype(numpy.float32) + status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids) + self.check_status(status) + return status, result + + @time_wrapper + def create_index(self, index_type, nlist): + if index_type == "flat": + index_type = IndexType.FLAT + elif index_type == "ivf_flat": + index_type = IndexType.IVFLAT + elif index_type == "ivf_sq8": + index_type = IndexType.IVF_SQ8 + elif index_type == "ivf_sq8h": + index_type = IndexType.IVF_SQ8H + elif index_type == "mix_nsg": + index_type = IndexType.MIX_NSG + index_params = { + "index_type": index_type, + "nlist": nlist, + } + logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) + status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600) + self.check_status(status) + + def describe_index(self): + return self._milvus.describe_index(self._table_name) + + def drop_index(self): + logger.info("Drop index: %s" % self._table_name) + return self._milvus.drop_index(self._table_name) + + @time_wrapper + def query(self, X, top_k, nprobe): + if self._metric_type == MetricType.IP: + logger.info("Set normalize for metric_type: Inner Product") + X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') + X = X.astype(numpy.float32) + status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist()) + self.check_status(status) + # logger.info(results[0]) + ids = [] + for result in results: + tmp_ids = [] + for item in result: + tmp_ids.append(item.id) + ids.append(tmp_ids) + return ids + + def count(self): + return self._milvus.get_table_row_count(self._table_name)[1] + + def delete(self, timeout=60): + logger.info("Start delete table: %s" % self._table_name) + self._milvus.delete_table(self._table_name) + i = 0 + while i < timeout: + if self.count(): + time.sleep(1) + i = i + 1 + else: + break + if i >= timeout: + logger.error("Delete table timeout") + + def describe(self): + return self._milvus.describe_table(self._table_name) + + def exists_table(self): + return self._milvus.has_table(self._table_name) + + @time_wrapper + def preload_table(self): + return self._milvus.preload_table(self._table_name) diff --git a/tests/milvus_ann_acc/config.yaml b/tests/milvus_ann_acc/config.yaml new file mode 100644 index 0000000000..e2ac2c1bfb --- /dev/null +++ b/tests/milvus_ann_acc/config.yaml @@ -0,0 +1,17 @@ +datasets: + sift-128-euclidean: + cpu_cache_size: 16 + gpu_cache_size: 5 + index_file_size: [1024] + nytimes-16-angular: + cpu_cache_size: 16 + gpu_cache_size: 5 + index_file_size: [1024] + +index: + index_types: ['flat', 'ivf_flat', 'ivf_sq8'] + nlists: [8092, 16384] + +search: + nprobes: [1, 8, 32] + top_ks: [10] diff --git a/tests/milvus_ann_acc/main.py b/tests/milvus_ann_acc/main.py new file mode 100644 index 0000000000..308e8246c7 --- /dev/null +++ b/tests/milvus_ann_acc/main.py @@ -0,0 +1,26 @@ + +import argparse + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--dataset', + metavar='NAME', + help='the dataset to load training points from', + default='glove-100-angular', + choices=DATASETS.keys()) + parser.add_argument( + "-k", "--count", + default=10, + type=positive_int, + help="the number of near neighbours to search for") + parser.add_argument( + '--definitions', + metavar='FILE', + help='load algorithm definitions from FILE', + default='algos.yaml') + parser.add_argument( + '--image-tag', + default=None, + help='pull image first') \ No newline at end of file diff --git a/tests/milvus_ann_acc/requirements.txt b/tests/milvus_ann_acc/requirements.txt new file mode 100644 index 0000000000..8c10e71b1f --- /dev/null +++ b/tests/milvus_ann_acc/requirements.txt @@ -0,0 +1,4 @@ +numpy==1.16.3 +pymilvus>=0.2.0 +scikit-learn==0.19.1 +h5py==2.7.1 diff --git a/tests/milvus_ann_acc/test.py b/tests/milvus_ann_acc/test.py new file mode 100644 index 0000000000..c4fbc33195 --- /dev/null +++ b/tests/milvus_ann_acc/test.py @@ -0,0 +1,132 @@ +import os +import pdb +import time +import random +import sys +import h5py +import numpy +import logging +from logging import handlers + +from client import MilvusClient + +LOG_FOLDER = "logs" +logger = logging.getLogger("milvus_ann_acc") + +formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s') +if not os.path.exists(LOG_FOLDER): + os.system('mkdir -p %s' % LOG_FOLDER) +fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10) +fileTimeHandler.suffix = "%Y%m%d.log" +fileTimeHandler.setFormatter(formatter) +logging.basicConfig(level=logging.DEBUG) +fileTimeHandler.setFormatter(formatter) +logger.addHandler(fileTimeHandler) + + +def get_dataset_fn(dataset_name): + file_path = "/test/milvus/ann_hdf5/" + if not os.path.exists(file_path): + raise Exception("%s not exists" % file_path) + return os.path.join(file_path, '%s.hdf5' % dataset_name) + + +def get_dataset(dataset_name): + hdf5_fn = get_dataset_fn(dataset_name) + hdf5_f = h5py.File(hdf5_fn) + return hdf5_f + + +def parse_dataset_name(dataset_name): + data_type = dataset_name.split("-")[0] + dimension = int(dataset_name.split("-")[1]) + metric = dataset_name.split("-")[-1] + # metric = dataset.attrs['distance'] + # dimension = len(dataset["train"][0]) + if metric == "euclidean": + metric_type = "l2" + elif metric == "angular": + metric_type = "ip" + return ("ann"+data_type, dimension, metric_type) + + +def get_table_name(dataset_name, index_file_size): + data_type, dimension, metric_type = parse_dataset_name(dataset_name) + dataset = get_dataset(dataset_name) + table_size = len(dataset["train"]) + table_size = str(table_size // 1000000)+"m" + table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type + return table_name + + +def main(dataset_name, index_file_size, nlist=16384, force=False): + top_k = 10 + nprobes = [32, 128] + + dataset = get_dataset(dataset_name) + table_name = get_table_name(dataset_name, index_file_size) + m = MilvusClient(table_name) + if m.exists_table(): + if force is True: + logger.info("Re-create table: %s" % table_name) + m.delete() + time.sleep(10) + else: + logger.info("Table name: %s existed" % table_name) + return + data_type, dimension, metric_type = parse_dataset_name(dataset_name) + m.create_table(table_name, dimension, index_file_size, metric_type) + print(m.describe()) + vectors = numpy.array(dataset["train"]) + query_vectors = numpy.array(dataset["test"]) + # m.insert(vectors) + + interval = 100000 + loops = len(vectors) // interval + 1 + + for i in range(loops): + start = i*interval + end = min((i+1)*interval, len(vectors)) + tmp_vectors = vectors[start:end] + if start < end: + m.insert(tmp_vectors, ids=[i for i in range(start, end)]) + + time.sleep(60) + print(m.count()) + + for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]: + m.create_index(index_type, nlist) + print(m.describe_index()) + if m.count() != len(vectors): + return + m.preload_table() + true_ids = numpy.array(dataset["neighbors"]) + for nprobe in nprobes: + print("nprobe: %s" % nprobe) + sum_radio = 0.0; avg_radio = 0.0 + result_ids = m.query(query_vectors, top_k, nprobe) + # print(result_ids[:10]) + for index, result_item in enumerate(result_ids): + if len(set(true_ids[index][:top_k])) != len(set(result_item)): + logger.info("Error happened") + # logger.info(query_vectors[index]) + # logger.info(true_ids[index][:top_k], result_item) + tmp = set(true_ids[index][:top_k]).intersection(set(result_item)) + sum_radio = sum_radio + (len(tmp) / top_k) + avg_radio = round(sum_radio / len(result_ids), 4) + logger.info(avg_radio) + m.drop_index() + + +if __name__ == "__main__": + print("glove-25-angular") + # main("sift-128-euclidean", 1024, force=True) + for index_file_size in [50, 1024]: + print("Index file size: %d" % index_file_size) + main("glove-25-angular", index_file_size, force=True) + + print("sift-128-euclidean") + for index_file_size in [50, 1024]: + print("Index file size: %d" % index_file_size) + main("sift-128-euclidean", index_file_size, force=True) + # m = MilvusClient() \ No newline at end of file diff --git a/tests/milvus_benchmark/.gitignore b/tests/milvus_benchmark/.gitignore new file mode 100644 index 0000000000..70af07bba8 --- /dev/null +++ b/tests/milvus_benchmark/.gitignore @@ -0,0 +1,8 @@ +random_data +benchmark_logs/ +db/ +logs/ +*idmap*.txt +__pycache__/ +venv +.idea \ No newline at end of file diff --git a/tests/milvus_benchmark/README.md b/tests/milvus_benchmark/README.md new file mode 100644 index 0000000000..05268057a4 --- /dev/null +++ b/tests/milvus_benchmark/README.md @@ -0,0 +1,23 @@ +# Requirements + +- python 3.6+ +- pip install -r requirements.txt + +# How to use this Test Project + +This project is used to test performance / accuracy / stability of milvus server + +1. update your test configuration in suites_*.yaml +2. run command + +```shell +### docker mode: +python main.py --image=milvusdb/milvus:latest --run-count=2 --run-type=performance + +### local mode: +python main.py --local --run-count=2 --run-type=performance --ip=127.0.0.1 --port=19530 +``` + +# Contribution getting started + +- Follow PEP-8 for naming and black for formatting. \ No newline at end of file diff --git a/tests/milvus_benchmark/__init__.py b/tests/milvus_benchmark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/milvus_benchmark/client.py b/tests/milvus_benchmark/client.py new file mode 100644 index 0000000000..4744c10854 --- /dev/null +++ b/tests/milvus_benchmark/client.py @@ -0,0 +1,244 @@ +import pdb +import random +import logging +import json +import sys +import time, datetime +from multiprocessing import Process +from milvus import Milvus, IndexType, MetricType + +logger = logging.getLogger("milvus_benchmark.client") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 + + +def time_wrapper(func): + """ + This decorator prints the execution time for the decorated function. + """ + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) + return result + return wrapper + + +class MilvusClient(object): + def __init__(self, table_name=None, ip=None, port=None): + self._milvus = Milvus() + self._table_name = table_name + try: + if not ip: + self._milvus.connect( + host = SERVER_HOST_DEFAULT, + port = SERVER_PORT_DEFAULT) + else: + self._milvus.connect( + host = ip, + port = port) + except Exception as e: + raise e + + def __str__(self): + return 'Milvus table %s' % self._table_name + + def check_status(self, status): + if not status.OK(): + logger.error(status.message) + raise Exception("Status not ok") + + def create_table(self, table_name, dimension, index_file_size, metric_type): + if not self._table_name: + self._table_name = table_name + if metric_type == "l2": + metric_type = MetricType.L2 + elif metric_type == "ip": + metric_type = MetricType.IP + else: + logger.error("Not supported metric_type: %s" % metric_type) + create_param = {'table_name': table_name, + 'dimension': dimension, + 'index_file_size': index_file_size, + "metric_type": metric_type} + status = self._milvus.create_table(create_param) + self.check_status(status) + + @time_wrapper + def insert(self, X, ids=None): + status, result = self._milvus.add_vectors(self._table_name, X, ids) + self.check_status(status) + return status, result + + @time_wrapper + def create_index(self, index_type, nlist): + if index_type == "flat": + index_type = IndexType.FLAT + elif index_type == "ivf_flat": + index_type = IndexType.IVFLAT + elif index_type == "ivf_sq8": + index_type = IndexType.IVF_SQ8 + elif index_type == "mix_nsg": + index_type = IndexType.MIX_NSG + elif index_type == "ivf_sq8h": + index_type = IndexType.IVF_SQ8H + index_params = { + "index_type": index_type, + "nlist": nlist, + } + logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) + status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600) + self.check_status(status) + + def describe_index(self): + return self._milvus.describe_index(self._table_name) + + def drop_index(self): + logger.info("Drop index: %s" % self._table_name) + return self._milvus.drop_index(self._table_name) + + @time_wrapper + def query(self, X, top_k, nprobe): + status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X) + self.check_status(status) + return status, result + + def count(self): + return self._milvus.get_table_row_count(self._table_name)[1] + + def delete(self, timeout=60): + logger.info("Start delete table: %s" % self._table_name) + self._milvus.delete_table(self._table_name) + i = 0 + while i < timeout: + if self.count(): + time.sleep(1) + i = i + 1 + continue + else: + break + if i < timeout: + logger.error("Delete table timeout") + + def describe(self): + return self._milvus.describe_table(self._table_name) + + def exists_table(self): + return self._milvus.has_table(self._table_name) + + @time_wrapper + def preload_table(self): + return self._milvus.preload_table(self._table_name, timeout=3000) + + +def fit(table_name, X): + milvus = Milvus() + milvus.connect(host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT) + start = time.time() + status, ids = milvus.add_vectors(table_name, X) + end = time.time() + logger(status, round(end - start, 2)) + + +def fit_concurrent(table_name, process_num, vectors): + processes = [] + + for i in range(process_num): + p = Process(target=fit, args=(table_name, vectors, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + +if __name__ == "__main__": + + # table_name = "sift_2m_20_128_l2" + table_name = "test_tset1" + m = MilvusClient(table_name) + # m.create_table(table_name, 128, 50, "l2") + + print(m.describe()) + # print(m.count()) + # print(m.describe_index()) + insert_vectors = [[random.random() for _ in range(128)] for _ in range(10000)] + for i in range(5): + m.insert(insert_vectors) + print(m.create_index("ivf_sq8h", 16384)) + X = [insert_vectors[0]] + top_k = 10 + nprobe = 10 + print(m.query(X, top_k, nprobe)) + + # # # print(m.drop_index()) + # # print(m.describe_index()) + # # sys.exit() + # # # insert_vectors = [[random.random() for _ in range(128)] for _ in range(100000)] + # # # for i in range(100): + # # # m.insert(insert_vectors) + # # # time.sleep(5) + # # # print(m.describe_index()) + # # # print(m.drop_index()) + # # m.create_index("ivf_sq8h", 16384) + # print(m.count()) + # print(m.describe_index()) + + + + # sys.exit() + # print(m.create_index("ivf_sq8h", 16384)) + # print(m.count()) + # print(m.describe_index()) + import numpy as np + + def mmap_fvecs(fname): + x = np.memmap(fname, dtype='int32', mode='r') + d = x[0] + return x.view('float32').reshape(-1, d + 1)[:, 1:] + + print(mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs")) + # SIFT_SRC_QUERY_DATA_DIR = '/poc/yuncong/ann_1000m' + # file_name = SIFT_SRC_QUERY_DATA_DIR+'/'+'query.npy' + # data = numpy.load(file_name) + # query_vectors = data[0:2].tolist() + # print(len(query_vectors)) + # results = m.query(query_vectors, 10, 10) + # result_ids = [] + # for result in results[1]: + # tmp = [] + # for item in result: + # tmp.append(item.id) + # result_ids.append(tmp) + # print(result_ids[0][:10]) + # # gt + # file_name = SIFT_SRC_QUERY_DATA_DIR+"/gnd/"+"idx_1M.ivecs" + # a = numpy.fromfile(file_name, dtype='int32') + # d = a[0] + # true_ids = a.reshape(-1, d + 1)[:, 1:].copy() + # print(true_ids[:3, :2]) + + # print(len(true_ids[0])) + # import numpy as np + # import sklearn.preprocessing + + # def mmap_fvecs(fname): + # x = np.memmap(fname, dtype='int32', mode='r') + # d = x[0] + # return x.view('float32').reshape(-1, d + 1)[:, 1:] + + # data = mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs") + # print(data[0], len(data[0]), len(data)) + + # total_size = 10000 + # # total_size = 1000000000 + # file_size = 1000 + # # file_size = 100000 + # file_num = total_size // file_size + # for i in range(file_num): + # fname = "/test/milvus/raw_data/deep1b/binary_96_%05d" % i + # print(fname, i*file_size, (i+1)*file_size) + # single_data = data[i*file_size : (i+1)*file_size] + # single_data = sklearn.preprocessing.normalize(single_data, axis=1, norm='l2') + # np.save(fname, single_data) diff --git a/tests/milvus_benchmark/conf/log_config.conf b/tests/milvus_benchmark/conf/log_config.conf new file mode 100644 index 0000000000..c9c14d57f4 --- /dev/null +++ b/tests/milvus_benchmark/conf/log_config.conf @@ -0,0 +1,28 @@ +* GLOBAL: + FORMAT = "%datetime | %level | %logger | %msg" + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log" + ENABLED = true + TO_FILE = true + TO_STANDARD_OUTPUT = false + SUBSECOND_PRECISION = 3 + PERFORMANCE_TRACKING = false + MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB +* DEBUG: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log" + ENABLED = true +* WARNING: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log" +* TRACE: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log" +* VERBOSE: + FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg" + TO_FILE = false + TO_STANDARD_OUTPUT = false +## Error logs +* ERROR: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log" +* FATAL: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log" + diff --git a/tests/milvus_benchmark/conf/server_config.yaml b/tests/milvus_benchmark/conf/server_config.yaml new file mode 100644 index 0000000000..a5d2081b8a --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml @@ -0,0 +1,28 @@ +cache_config: + cache_insert_data: false + cpu_cache_capacity: 16 + gpu_cache_capacity: 6 + cpu_cache_threshold: 0.85 +db_config: + backend_url: sqlite://:@:/ + build_index_gpu: 0 + insert_buffer_size: 4 + preload_table: null + primary_path: /opt/milvus + secondary_path: null +engine_config: + use_blas_threshold: 20 +metric_config: + collector: prometheus + enable_monitor: true + prometheus_config: + port: 8080 +resource_config: + resource_pool: + - cpu + - gpu0 +server_config: + address: 0.0.0.0 + deploy_mode: single + port: 19530 + time_zone: UTC+8 diff --git a/tests/milvus_benchmark/conf/server_config.yaml.cpu b/tests/milvus_benchmark/conf/server_config.yaml.cpu new file mode 100644 index 0000000000..95ab5f5343 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.cpu @@ -0,0 +1,31 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu \ No newline at end of file diff --git a/tests/milvus_benchmark/conf/server_config.yaml.multi b/tests/milvus_benchmark/conf/server_config.yaml.multi new file mode 100644 index 0000000000..002d3bd2a6 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.multi @@ -0,0 +1,33 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu + - gpu0 + - gpu1 \ No newline at end of file diff --git a/tests/milvus_benchmark/conf/server_config.yaml.single b/tests/milvus_benchmark/conf/server_config.yaml.single new file mode 100644 index 0000000000..033d8868d1 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.single @@ -0,0 +1,32 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu + - gpu0 \ No newline at end of file diff --git a/tests/milvus_benchmark/demo.py b/tests/milvus_benchmark/demo.py new file mode 100644 index 0000000000..27152e0980 --- /dev/null +++ b/tests/milvus_benchmark/demo.py @@ -0,0 +1,51 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient + +nq = 100000 +dimension = 128 +run_count = 1 +table_name = "sift_10m_1024_128_ip" +insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] + +def do_query(milvus, table_name, top_ks, nqs, nprobe, run_count): + bi_res = [] + for index, nq in enumerate(nqs): + tmp_res = [] + for top_k in top_ks: + avg_query_time = 0.0 + total_query_time = 0.0 + vectors = insert_vectors[0:nq] + for i in range(run_count): + start_time = time.time() + status, query_res = milvus.query(vectors, top_k, nprobe) + total_query_time = total_query_time + (time.time() - start_time) + if status.code: + print(status.message) + avg_query_time = round(total_query_time / run_count, 2) + tmp_res.append(avg_query_time) + bi_res.append(tmp_res) + return bi_res + +while 1: + milvus_instance = MilvusClient(table_name, ip="192.168.1.197", port=19530) + top_ks = random.sample([x for x in range(1, 100)], 4) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count) + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status.message) + + # status = milvus_instance.drop_index() + if not status.OK(): + print(status.message) + index_type = "ivf_sq8" + status = milvus_instance.create_index(index_type, 16384) + if not status.OK(): + print(status.message) \ No newline at end of file diff --git a/tests/milvus_benchmark/docker_runner.py b/tests/milvus_benchmark/docker_runner.py new file mode 100644 index 0000000000..008c9866b4 --- /dev/null +++ b/tests/milvus_benchmark/docker_runner.py @@ -0,0 +1,261 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser +from runner import Runner + +logger = logging.getLogger("milvus_benchmark.docker") + + +class DockerRunner(Runner): + """run docker mode""" + def __init__(self, image): + super(DockerRunner, self).__init__() + self.image = image + + def run(self, definition, run_type=None): + if run_type == "performance": + for op_type, op_value in definition.items(): + # run docker mode + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + + if op_type == "insert": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["table_name"] + volume_name = param["db_path_prefix"] + print(table_name) + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + # Update server config + utils.modify_config(k, v, type="server", db_slave=None) + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if milvus.exists_table(): + milvus.delete() + time.sleep(10) + milvus.create_table(table_name, dimension, index_file_size, metric_type) + res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"]) + logger.info(res) + + # wait for file merge + time.sleep(6 * (table_size / 500000)) + # Clear up + utils.remove_container(container) + + elif op_type == "query": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + volume_name = param["db_path_prefix"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + logger.debug(milvus._milvus.show_tables()) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + for index_type in index_types: + for nlist in nlists: + result = milvus.describe_index() + logger.info(result) + milvus.create_index(index_type, nlist) + result = milvus.describe_index() + logger.info(result) + # preload index + milvus.preload_table() + logger.info("Start warm up query") + res = self.do_query(milvus, table_name, [1], [1], 1, 1) + logger.info("End warm up query") + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + headers = ["Nprobe/Top-k"] + headers.extend([str(top_k) for top_k in top_ks]) + utils.print_table(headers, nqs, res) + utils.remove_container(container) + + elif run_type == "accuracy": + """ + { + "dataset": "random_50m_1024_512", + "index.index_types": ["flat", ivf_flat", "ivf_sq8"], + "index.nlists": [16384], + "nprobes": [1, 32, 128], + "nqs": [100], + "top_ks": [1, 64], + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 256 + } + """ + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + sift_acc = False + if "sift_acc" in param: + sift_acc = param["sift_acc"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + volume_name = param["db_path_prefix"] + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + + if sift_acc is True: + # preload groundtruth data + true_ids_all = self.get_groundtruth_ids(table_size) + + acc_dict = {} + for index_type in index_types: + for nlist in nlists: + result = milvus.describe_index() + logger.info(result) + milvus.create_index(index_type, nlist) + # preload index + milvus.preload_table() + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + for top_k in top_ks: + for nq in nqs: + result_ids = [] + id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \ + (table_name, index_type, nlist, metric_type, nprobe, top_k, nq) + if sift_acc is False: + self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix) + if index_type != "flat": + # Compute accuracy + base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \ + (table_name, nlist, metric_type, nprobe, top_k, nq) + avg_acc = self.compute_accuracy(base_name, id_prefix) + logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc)) + else: + result_ids = self.do_query_ids(milvus, table_name, top_k, nq, nprobe) + acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids) + logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value)) + # # print accuracy table + # headers = [table_name] + # headers.extend([str(top_k) for top_k in top_ks]) + # utils.print_table(headers, nqs, res) + + # remove container, and run next definition + logger.info("remove container, and run next definition") + utils.remove_container(container) + + elif run_type == "stability": + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + volume_name = param["db_path_prefix"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + # set default test time + if "during_time" not in param: + during_time = 100 # seconds + else: + during_time = int(param["during_time"]) * 60 + # set default query process num + if "query_process_num" not in param: + query_process_num = 10 + else: + query_process_num = int(param["query_process_num"]) + + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + start_time = time.time() + insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)] + while time.time() < start_time + during_time: + processes = [] + # do query + # for i in range(query_process_num): + # milvus_instance = MilvusClient(table_name) + # top_k = random.choice([x for x in range(1, 100)]) + # nq = random.choice([x for x in range(1, 100)]) + # nprobe = random.choice([x for x in range(1, 1000)]) + # # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], [nprobe], run_count, )) + # processes.append(p) + # p.start() + # time.sleep(0.1) + # for p in processes: + # p.join() + milvus_instance = MilvusClient(table_name) + top_ks = random.sample([x for x in range(1, 100)], 3) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + if int(time.time() - start_time) % 120 == 0: + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status) + # status = milvus_instance.drop_index() + # if not status.OK(): + # logger.error(status) + # index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"]) + result = milvus.describe_index() + logger.info(result) + milvus_instance.create_index("ivf_sq8", 16384) + utils.remove_container(container) + + else: + logger.warning("Run type: %s not supported" % run_type) + diff --git a/tests/milvus_benchmark/local_runner.py b/tests/milvus_benchmark/local_runner.py new file mode 100644 index 0000000000..1067f500bc --- /dev/null +++ b/tests/milvus_benchmark/local_runner.py @@ -0,0 +1,132 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser +from runner import Runner + +logger = logging.getLogger("milvus_benchmark.local_runner") + + +class LocalRunner(Runner): + """run local mode""" + def __init__(self, ip, port): + super(LocalRunner, self).__init__() + self.ip = ip + self.port = port + + def run(self, definition, run_type=None): + if run_type == "performance": + for op_type, op_value in definition.items(): + run_count = op_value["run_count"] + run_params = op_value["params"] + + if op_type == "insert": + for index, param in enumerate(run_params): + table_name = param["table_name"] + # random_1m_100_512 + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + milvus = MilvusClient(table_name, ip=self.ip, port=self.port) + # Check has table or not + if milvus.exists_table(): + milvus.delete() + time.sleep(10) + milvus.create_table(table_name, dimension, index_file_size, metric_type) + res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"]) + logger.info(res) + + elif op_type == "query": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + milvus = MilvusClient(table_name, ip=self.ip, port=self.port) + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + + for index_type in index_types: + for nlist in nlists: + milvus.create_index(index_type, nlist) + # preload index + milvus.preload_table() + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + headers = [param["dataset"]] + headers.extend([str(top_k) for top_k in top_ks]) + utils.print_table(headers, nqs, res) + + elif run_type == "stability": + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + nq = 10000 + + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + # set default test time + if "during_time" not in param: + during_time = 100 # seconds + else: + during_time = int(param["during_time"]) * 60 + # set default query process num + if "query_process_num" not in param: + query_process_num = 10 + else: + query_process_num = int(param["query_process_num"]) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + start_time = time.time() + insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] + while time.time() < start_time + during_time: + processes = [] + # # do query + # for i in range(query_process_num): + # milvus_instance = MilvusClient(table_name) + # top_k = random.choice([x for x in range(1, 100)]) + # nq = random.choice([x for x in range(1, 1000)]) + # nprobe = random.choice([x for x in range(1, 500)]) + # logger.info(nprobe) + # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, )) + # processes.append(p) + # p.start() + # time.sleep(0.1) + # for p in processes: + # p.join() + milvus_instance = MilvusClient(table_name) + top_ks = random.sample([x for x in range(1, 100)], 4) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + # milvus_instance = MilvusClient(table_name) + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status.message) + if (time.time() - start_time) % 300 == 0: + status = milvus_instance.drop_index() + if not status.OK(): + logger.error(status.message) + index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"]) + status = milvus_instance.create_index(index_type, 16384) + if not status.OK(): + logger.error(status.message) diff --git a/tests/milvus_benchmark/main.py b/tests/milvus_benchmark/main.py new file mode 100644 index 0000000000..c11237c5e7 --- /dev/null +++ b/tests/milvus_benchmark/main.py @@ -0,0 +1,131 @@ +import os +import sys +import time +import pdb +import argparse +import logging +import utils +from yaml import load, dump +from logging import handlers +from parser import operations_parser +from local_runner import LocalRunner +from docker_runner import DockerRunner + +DEFAULT_IMAGE = "milvusdb/milvus:latest" +LOG_FOLDER = "benchmark_logs" +logger = logging.getLogger("milvus_benchmark") + +formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s') +if not os.path.exists(LOG_FOLDER): + os.system('mkdir -p %s' % LOG_FOLDER) +fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'milvus_benchmark'), "D", 1, 10) +fileTimeHandler.suffix = "%Y%m%d.log" +fileTimeHandler.setFormatter(formatter) +logging.basicConfig(level=logging.DEBUG) +fileTimeHandler.setFormatter(formatter) +logger.addHandler(fileTimeHandler) + + +def positive_int(s): + i = None + try: + i = int(s) + except ValueError: + pass + if not i or i < 1: + raise argparse.ArgumentTypeError("%r is not a positive integer" % s) + return i + + +# # link random_data if not exists +# def init_env(): +# if not os.path.islink(BINARY_DATA_FOLDER): +# try: +# os.symlink(SRC_BINARY_DATA_FOLDER, BINARY_DATA_FOLDER) +# except Exception as e: +# logger.error("Create link failed: %s" % str(e)) +# sys.exit() + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--image', + help='use the given image') + parser.add_argument( + '--local', + action='store_true', + help='use local milvus server') + parser.add_argument( + "--run-count", + default=1, + type=positive_int, + help="run each db operation times") + # performance / stability / accuracy test + parser.add_argument( + "--run-type", + default="performance", + help="run type, default performance") + parser.add_argument( + '--suites', + metavar='FILE', + help='load test suites from FILE', + default='suites.yaml') + parser.add_argument( + '--ip', + help='server ip param for local mode', + default='127.0.0.1') + parser.add_argument( + '--port', + help='server port param for local mode', + default='19530') + + args = parser.parse_args() + + operations = None + # Get all benchmark test suites + if args.suites: + with open(args.suites) as f: + suites_dict = load(f) + f.close() + # With definition order + operations = operations_parser(suites_dict, run_type=args.run_type) + + # init_env() + run_params = {"run_count": args.run_count} + + if args.image: + # for docker mode + if args.local: + logger.error("Local mode and docker mode are incompatible arguments") + sys.exit(-1) + # Docker pull image + if not utils.pull_image(args.image): + raise Exception('Image %s pull failed' % image) + + # TODO: Check milvus server port is available + logger.info("Init: remove all containers created with image: %s" % args.image) + utils.remove_all_containers(args.image) + runner = DockerRunner(args.image) + for operation_type in operations: + logger.info("Start run test, test type: %s" % operation_type) + run_params["params"] = operations[operation_type] + runner.run({operation_type: run_params}, run_type=args.run_type) + logger.info("Run params: %s" % str(run_params)) + + if args.local: + # for local mode + ip = args.ip + port = args.port + + runner = LocalRunner(ip, port) + for operation_type in operations: + logger.info("Start run local mode test, test type: %s" % operation_type) + run_params["params"] = operations[operation_type] + runner.run({operation_type: run_params}, run_type=args.run_type) + logger.info("Run params: %s" % str(run_params)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/milvus_benchmark/operation.py b/tests/milvus_benchmark/operation.py new file mode 100644 index 0000000000..348fa47f4a --- /dev/null +++ b/tests/milvus_benchmark/operation.py @@ -0,0 +1,10 @@ +from __future__ import absolute_import +import pdb +import time + +class Base(object): + pass + + +class Insert(Base): + pass \ No newline at end of file diff --git a/tests/milvus_benchmark/parser.py b/tests/milvus_benchmark/parser.py new file mode 100644 index 0000000000..1c1d2605f2 --- /dev/null +++ b/tests/milvus_benchmark/parser.py @@ -0,0 +1,66 @@ +import pdb +import logging + +logger = logging.getLogger("milvus_benchmark.parser") + + +def operations_parser(operations, run_type="performance"): + definitions = operations[run_type] + return definitions + + +def table_parser(table_name): + tmp = table_name.split("_") + # if len(tmp) != 5: + # return None + data_type = tmp[0] + table_size_unit = tmp[1][-1] + table_size = tmp[1][0:-1] + if table_size_unit == "m": + table_size = int(table_size) * 1000000 + elif table_size_unit == "b": + table_size = int(table_size) * 1000000000 + index_file_size = int(tmp[2]) + dimension = int(tmp[3]) + metric_type = str(tmp[4]) + return (data_type, table_size, index_file_size, dimension, metric_type) + + +def search_params_parser(param): + # parse top-k, set default value if top-k not in param + if "top_ks" not in param: + top_ks = [10] + else: + top_ks = param["top_ks"] + if isinstance(top_ks, int): + top_ks = [top_ks] + elif isinstance(top_ks, list): + top_ks = list(top_ks) + else: + logger.warning("Invalid format top-ks: %s" % str(top_ks)) + + # parse nqs, set default value if nq not in param + if "nqs" not in param: + nqs = [10] + else: + nqs = param["nqs"] + if isinstance(nqs, int): + nqs = [nqs] + elif isinstance(nqs, list): + nqs = list(nqs) + else: + logger.warning("Invalid format nqs: %s" % str(nqs)) + + # parse nprobes + if "nprobes" not in param: + nprobes = [1] + else: + nprobes = param["nprobes"] + if isinstance(nprobes, int): + nprobes = [nprobes] + elif isinstance(nprobes, list): + nprobes = list(nprobes) + else: + logger.warning("Invalid format nprobes: %s" % str(nprobes)) + + return top_ks, nqs, nprobes \ No newline at end of file diff --git a/tests/milvus_benchmark/report.py b/tests/milvus_benchmark/report.py new file mode 100644 index 0000000000..6041311513 --- /dev/null +++ b/tests/milvus_benchmark/report.py @@ -0,0 +1,10 @@ +# from tablereport import Table +# from tablereport.shortcut import write_to_excel + +# RESULT_FOLDER = "results" + +# def create_table(headers, bodys, table_name): +# table = Table(header=[headers], +# body=[bodys]) + +# write_to_excel('%s/%s.xlsx' % (RESULT_FOLDER, table_name), table) \ No newline at end of file diff --git a/tests/milvus_benchmark/requirements.txt b/tests/milvus_benchmark/requirements.txt new file mode 100644 index 0000000000..1285b4d2ba --- /dev/null +++ b/tests/milvus_benchmark/requirements.txt @@ -0,0 +1,6 @@ +numpy==1.16.3 +pymilvus>=0.1.18 +pyyaml==3.12 +docker==4.0.2 +tableprint==0.8.0 +ansicolors==1.1.8 \ No newline at end of file diff --git a/tests/milvus_benchmark/runner.py b/tests/milvus_benchmark/runner.py new file mode 100644 index 0000000000..d3ad345da9 --- /dev/null +++ b/tests/milvus_benchmark/runner.py @@ -0,0 +1,219 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser + +logger = logging.getLogger("milvus_benchmark.runner") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 +VECTORS_PER_FILE = 1000000 +SIFT_VECTORS_PER_FILE = 100000 +MAX_NQ = 10001 +FILE_PREFIX = "binary_" + +RANDOM_SRC_BINARY_DATA_DIR = '/tmp/random/binary_data' +SIFT_SRC_DATA_DIR = '/tmp/sift1b/query' +SIFT_SRC_BINARY_DATA_DIR = '/tmp/sift1b/binary_data' +SIFT_SRC_GROUNDTRUTH_DATA_DIR = '/tmp/sift1b/groundtruth' + +WARM_TOP_K = 1 +WARM_NQ = 1 +DEFAULT_DIM = 512 + +GROUNDTRUTH_MAP = { + "1000000": "idx_1M.ivecs", + "2000000": "idx_2M.ivecs", + "5000000": "idx_5M.ivecs", + "10000000": "idx_10M.ivecs", + "20000000": "idx_20M.ivecs", + "50000000": "idx_50M.ivecs", + "100000000": "idx_100M.ivecs", + "200000000": "idx_200M.ivecs", + "500000000": "idx_500M.ivecs", + "1000000000": "idx_1000M.ivecs", +} + + +def gen_file_name(idx, table_dimension, data_type): + s = "%05d" % idx + fname = FILE_PREFIX + str(table_dimension) + "d_" + s + ".npy" + if data_type == "random": + fname = RANDOM_SRC_BINARY_DATA_DIR+'/'+fname + elif data_type == "sift": + fname = SIFT_SRC_BINARY_DATA_DIR+'/'+fname + return fname + + +def get_vectors_from_binary(nq, dimension, data_type): + # use the first file, nq should be less than VECTORS_PER_FILE + if nq > MAX_NQ: + raise Exception("Over size nq") + if data_type == "random": + file_name = gen_file_name(0, dimension, data_type) + elif data_type == "sift": + file_name = SIFT_SRC_DATA_DIR+'/'+'query.npy' + data = np.load(file_name) + vectors = data[0:nq].tolist() + return vectors + + +class Runner(object): + def __init__(self): + pass + + def do_insert(self, milvus, table_name, data_type, dimension, size, ni): + ''' + @params: + mivlus: server connect instance + dimension: table dimensionn + # index_file_size: size trigger file merge + size: row count of vectors to be insert + ni: row count of vectors to be insert each time + # store_id: if store the ids returned by call add_vectors or not + @return: + total_time: total time for all insert operation + qps: vectors added per second + ni_time: avarage insert operation time + ''' + bi_res = {} + total_time = 0.0 + qps = 0.0 + ni_time = 0.0 + if data_type == "random": + vectors_per_file = VECTORS_PER_FILE + elif data_type == "sift": + vectors_per_file = SIFT_VECTORS_PER_FILE + if size % vectors_per_file or ni > vectors_per_file: + raise Exception("Not invalid table size or ni") + file_num = size // vectors_per_file + for i in range(file_num): + file_name = gen_file_name(i, dimension, data_type) + logger.info("Load npy file: %s start" % file_name) + data = np.load(file_name) + logger.info("Load npy file: %s end" % file_name) + loops = vectors_per_file // ni + for j in range(loops): + vectors = data[j*ni:(j+1)*ni].tolist() + ni_start_time = time.time() + # start insert vectors + start_id = i * vectors_per_file + j * ni + end_id = start_id + len(vectors) + logger.info("Start id: %s, end id: %s" % (start_id, end_id)) + ids = [k for k in range(start_id, end_id)] + status, ids = milvus.insert(vectors, ids=ids) + ni_end_time = time.time() + total_time = total_time + ni_end_time - ni_start_time + + qps = round(size / total_time, 2) + ni_time = round(total_time / (loops * file_num), 2) + bi_res["total_time"] = round(total_time, 2) + bi_res["qps"] = qps + bi_res["ni_time"] = ni_time + return bi_res + + def do_query(self, milvus, table_name, top_ks, nqs, nprobe, run_count): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + + bi_res = [] + for index, nq in enumerate(nqs): + tmp_res = [] + for top_k in top_ks: + avg_query_time = 0.0 + total_query_time = 0.0 + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + for i in range(run_count): + logger.info("Start run query, run %d of %s" % (i+1, run_count)) + start_time = time.time() + status, query_res = milvus.query(vectors, top_k, nprobe) + total_query_time = total_query_time + (time.time() - start_time) + if status.code: + logger.error("Query failed with message: %s" % status.message) + avg_query_time = round(total_query_time / run_count, 2) + logger.info("Avarage query time: %.2f" % avg_query_time) + tmp_res.append(avg_query_time) + bi_res.append(tmp_res) + return bi_res + + def do_query_ids(self, milvus, table_name, top_k, nq, nprobe): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + status, query_res = milvus.query(vectors, top_k, nprobe) + if not status.OK(): + msg = "Query failed with message: %s" % status.message + raise Exception(msg) + result_ids = [] + for result in query_res: + tmp = [] + for item in result: + tmp.append(item.id) + result_ids.append(tmp) + return result_ids + + def do_query_acc(self, milvus, table_name, top_k, nq, nprobe, id_store_name): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + status, query_res = milvus.query(vectors, top_k, nprobe) + if not status.OK(): + msg = "Query failed with message: %s" % status.message + raise Exception(msg) + # if file existed, cover it + if os.path.isfile(id_store_name): + os.remove(id_store_name) + with open(id_store_name, 'a+') as fd: + for nq_item in query_res: + for item in nq_item: + fd.write(str(item.id)+'\t') + fd.write('\n') + + # compute and print accuracy + def compute_accuracy(self, flat_file_name, index_file_name): + flat_id_list = []; index_id_list = [] + logger.info("Loading flat id file: %s" % flat_file_name) + with open(flat_file_name, 'r') as flat_id_fd: + for line in flat_id_fd: + tmp_list = line.strip("\n").strip().split("\t") + flat_id_list.append(tmp_list) + logger.info("Loading index id file: %s" % index_file_name) + with open(index_file_name) as index_id_fd: + for line in index_id_fd: + tmp_list = line.strip("\n").strip().split("\t") + index_id_list.append(tmp_list) + if len(flat_id_list) != len(index_id_list): + raise Exception("Flat index result length: not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list))) + # get the accuracy + return self.get_recall_value(flat_id_list, index_id_list) + + def get_recall_value(self, flat_id_list, index_id_list): + """ + Use the intersection length + """ + sum_radio = 0.0 + for index, item in enumerate(index_id_list): + tmp = set(item).intersection(set(flat_id_list[index])) + sum_radio = sum_radio + len(tmp) / len(item) + return round(sum_radio / len(index_id_list), 3) + + """ + Implementation based on: + https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py + """ + def get_groundtruth_ids(self, table_size): + fname = GROUNDTRUTH_MAP[str(table_size)] + fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname + a = np.fromfile(fname, dtype='int32') + d = a[0] + true_ids = a.reshape(-1, d + 1)[:, 1:].copy() + return true_ids diff --git a/tests/milvus_benchmark/suites.yaml b/tests/milvus_benchmark/suites.yaml new file mode 100644 index 0000000000..f30b963c03 --- /dev/null +++ b/tests/milvus_benchmark/suites.yaml @@ -0,0 +1,38 @@ +# data sets +datasets: + hf5: + gist-960,sift-128 + npy: + 50000000-512, 100000000-512 + +operations: + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including nprpbe/nlist/use_blas_threshold/.. + [ + # debug + # {"dataset": "ip_ivfsq8_1000", "top_ks": [16], "nqs": [1], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + + {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + # {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], "nqs": [1, 10, 100, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + ] + + # interface: add_vectors + insert: + # index_type: flat/ivf_flat/ivf_sq8 + [ + # debug + + + {"table_name": "ip_ivf_flat_20m_1024", "table.index_type": "ivf_flat", "server.index_building_threshold": 1024, "table.size": 20000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110}, + {"table_name": "ip_ivf_sq8_50m_1024", "table.index_type": "ivf_sq8", "server.index_building_threshold": 1024, "table.size": 50000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110}, + ] + + # TODO: interface: build_index + build: [] + diff --git a/tests/milvus_benchmark/suites_accuracy.yaml b/tests/milvus_benchmark/suites_accuracy.yaml new file mode 100644 index 0000000000..fd8a5904fa --- /dev/null +++ b/tests/milvus_benchmark/suites_accuracy.yaml @@ -0,0 +1,121 @@ + +accuracy: + # interface: search_vectors + query: + [ + { + "dataset": "random_20m_1024_512_ip", + # index info + "index.index_types": ["flat", "ivf_sq8"], + "index.nlists": [16384], + "index.metric_types": ["ip"], + "nprobes": [1, 16, 64], + "top_ks": [64], + "nqs": [100], + "server.cpu_cache_capacity": 100, + "server.resources": ["cpu", "gpu0"], + "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip", + }, + # { + # "dataset": "sift_50m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 160, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2", + # "sift_acc": true + # }, + # { + # "dataset": "sift_50m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 160, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2_sq8", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "sift_10m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 32, + # }, + # { + # "dataset": "sift_50m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64, + # } + + + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_performance.yaml b/tests/milvus_benchmark/suites_performance.yaml new file mode 100644 index 0000000000..52d457a400 --- /dev/null +++ b/tests/milvus_benchmark/suites_performance.yaml @@ -0,0 +1,258 @@ +performance: + + # interface: add_vectors + insert: + # index_type: flat/ivf_flat/ivf_sq8/mix_nsg + [ + # debug + # data_type / data_size / index_file_size / dimension + # data_type: random / ann_sift + # data_size: 10m / 1b + # { + # "table_name": "random_50m_1024_512_ip", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "table_name": "random_5m_1024_512_ip", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_5m_1024_512_ip" + # }, + # { + # "table_name": "sift_1m_50_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "table_name": "sift_1m_256_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # "db_path_prefix": "/test/milvus/db_data" + # } + # { + # "table_name": "sift_50m_1024_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # }, + # { + # "table_name": "sift_100m_1024_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # }, + # { + # "table_name": "sift_1b_2048_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # } + ] + + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity .. + [ + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + { + "dataset": "random_50m_1024_512_ip", + "index.index_types": ["ivf_sq8h"], + "index.nlists": [16384], + "nprobes": [8], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + "top_ks": [512], + # "nqs": [1, 10, 100, 500, 1000], + "nqs": [500], + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 150, + "server.gpu_cache_capacity": 6, + "server.resources": ["cpu", "gpu0", "gpu1"], + "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip" + }, + # { + # "dataset": "random_50m_1024_512_ip", + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 150, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip_sq8" + # }, + # { + # "dataset": "random_20m_1024_512_ip", + # "index.index_types": ["flat"], + # "index.nlists": [16384], + # "nprobes": [50], + # "top_ks": [64], + # "nqs": [10], + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 100, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip" + # }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 250, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 250, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64 + # }, + # { + # "dataset": "sift_500m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 8, 16, 64, 256, 512, 1000], + # "nqs": [1, 100, 500, 800, 1000, 1500], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 120, + # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [1], + # # "top_ks": [1], + # # "nqs": [1], + # "top_ks": [256], + # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 110, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_50m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 128 + # }, + # [ + # { + # "dataset": "sift_1m_50_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1], + # "nqs": [1], + # "db_path_prefix": "/test/milvus/db_data" + # # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # # "server.cpu_cache_capacity": 256 + # } + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_stability.yaml b/tests/milvus_benchmark/suites_stability.yaml new file mode 100644 index 0000000000..408221079e --- /dev/null +++ b/tests/milvus_benchmark/suites_stability.yaml @@ -0,0 +1,17 @@ + +stability: + # interface: search_vectors / add_vectors mix operation + query: + [ + { + "dataset": "random_20m_1024_512_ip", + # "nqs": [1, 10, 100, 1000, 10000], + # "pds": [0.1, 0.44, 0.44, 0.02], + "query_process_num": 10, + # each 10s, do an insertion + # "insert_interval": 1, + # minutes + "during_time": 360, + "server.cpu_cache_capacity": 100 + }, + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_yzb.yaml b/tests/milvus_benchmark/suites_yzb.yaml new file mode 100644 index 0000000000..59efcb37e4 --- /dev/null +++ b/tests/milvus_benchmark/suites_yzb.yaml @@ -0,0 +1,171 @@ +#"server.resources": ["gpu0", "gpu1"] + +performance: + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity .. + [ + # debug + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64 + # }, +# { +# "dataset": "sift_50m_1024_128_l2", +# # index info +# "index.index_types": ["ivf_sq8"], +# "index.nlists": [16384], +# "nprobes": [1, 32, 128], +# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], +# "nqs": [1, 10, 100, 500, 800], +# # "top_ks": [256], +# # "nqs": [800], +# "processes": 1, # multiprocessing +# "server.use_blas_threshold": 1100, +# "server.cpu_cache_capacity": 310, +# "server.resources": ["gpu0", "gpu1"] +# }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["cpu"] + }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["gpu0"] + }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["gpu0", "gpu1"] + }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 310 + # }, +# { +# "dataset": "random_50m_1024_512_l2", +# # index info +# "index.index_types": ["ivf_sq8"], +# "index.nlists": [16384], +# "nprobes": [1], +# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], +# "nqs": [1, 10, 100, 500, 800], +# # "top_ks": [256], +# # "nqs": [800], +# "processes": 1, # multiprocessing +# "server.use_blas_threshold": 1100, +# "server.cpu_cache_capacity": 128, +# "server.resources": ["gpu0", "gpu1"] +# }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 256 + # }, + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/utils.py b/tests/milvus_benchmark/utils.py new file mode 100644 index 0000000000..f6522578ad --- /dev/null +++ b/tests/milvus_benchmark/utils.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function + +__true_print = print # noqa + +import os +import sys +import pdb +import time +import datetime +import argparse +import threading +import logging +import docker +import multiprocessing +import numpy +# import psutil +from yaml import load, dump +import tableprint as tp + +logger = logging.getLogger("milvus_benchmark.utils") + +MULTI_DB_SLAVE_PATH = "/opt/milvus/data2;/opt/milvus/data3" + + +def get_current_time(): + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + + +def print_table(headers, columns, data): + bodys = [] + for index, value in enumerate(columns): + tmp = [value] + tmp.extend(data[index]) + bodys.append(tmp) + tp.table(bodys, headers) + + +def modify_config(k, v, type=None, file_path="conf/server_config.yaml", db_slave=None): + if not os.path.isfile(file_path): + raise Exception('File: %s not found' % file_path) + with open(file_path) as f: + config_dict = load(f) + f.close() + if config_dict: + if k.find("use_blas_threshold") != -1: + config_dict['engine_config']['use_blas_threshold'] = int(v) + elif k.find("cpu_cache_capacity") != -1: + config_dict['cache_config']['cpu_cache_capacity'] = int(v) + elif k.find("gpu_cache_capacity") != -1: + config_dict['cache_config']['gpu_cache_capacity'] = int(v) + elif k.find("resource_pool") != -1: + config_dict['resource_config']['resource_pool'] = v + + if db_slave: + config_dict['db_config']['db_slave_path'] = MULTI_DB_SLAVE_PATH + with open(file_path, 'w') as f: + dump(config_dict, f, default_flow_style=False) + f.close() + else: + raise Exception('Load file:%s error' % file_path) + + +def pull_image(image): + registry = image.split(":")[0] + image_tag = image.split(":")[1] + client = docker.APIClient(base_url='unix://var/run/docker.sock') + logger.info("Start pulling image: %s" % image) + return client.pull(registry, image_tag) + + +def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None): + import colors + + client = docker.from_env() + # if mem_limit is None: + # mem_limit = psutil.virtual_memory().available + # logger.info('Memory limit:', mem_limit) + # cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1) + # logger.info('Running on CPUs:', cpu_limit) + for dir_item in ['logs', 'db']: + try: + os.mkdir(os.path.abspath(dir_item)) + except Exception as e: + pass + + if test_type == "local": + volumes = { + os.path.abspath('conf'): + {'bind': '/opt/milvus/conf', 'mode': 'ro'}, + os.path.abspath('logs'): + {'bind': '/opt/milvus/logs', 'mode': 'rw'}, + os.path.abspath('db'): + {'bind': '/opt/milvus/db', 'mode': 'rw'}, + } + elif test_type == "remote": + if volume_name is None: + raise Exception("No volume name") + remote_log_dir = volume_name+'/logs' + remote_db_dir = volume_name+'/db' + + for dir_item in [remote_log_dir, remote_db_dir]: + if not os.path.isdir(dir_item): + os.makedirs(dir_item, exist_ok=True) + volumes = { + os.path.abspath('conf'): + {'bind': '/opt/milvus/conf', 'mode': 'ro'}, + remote_log_dir: + {'bind': '/opt/milvus/logs', 'mode': 'rw'}, + remote_db_dir: + {'bind': '/opt/milvus/db', 'mode': 'rw'} + } + # add volumes + if db_slave and isinstance(db_slave, int): + for i in range(2, db_slave+1): + remote_db_dir = volume_name+'/data'+str(i) + if not os.path.isdir(remote_db_dir): + os.makedirs(remote_db_dir, exist_ok=True) + volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'} + + container = client.containers.run( + image, + volumes=volumes, + runtime="nvidia", + ports={'19530/tcp': 19530, '8080/tcp': 8080}, + environment=["OMP_NUM_THREADS=48"], + # cpuset_cpus=cpu_limit, + # mem_limit=mem_limit, + # environment=[""], + detach=True) + + def stream_logs(): + for line in container.logs(stream=True): + logger.info(colors.color(line.decode().rstrip(), fg='blue')) + + if sys.version_info >= (3, 0): + t = threading.Thread(target=stream_logs, daemon=True) + else: + t = threading.Thread(target=stream_logs) + t.daemon = True + t.start() + + logger.info('Container: %s started' % container) + return container + # exit_code = container.wait(timeout=timeout) + # # Exit if exit code + # if exit_code == 0: + # return container + # elif exit_code is not None: + # print(colors.color(container.logs().decode(), fg='red')) + # raise Exception('Child process raised exception %s' % str(exit_code)) + +def restart_server(container): + client = docker.APIClient(base_url='unix://var/run/docker.sock') + + client.restart(container.name) + logger.info('Container: %s restarted' % container.name) + return container + + +def remove_container(container): + container.remove(force=True) + logger.info('Container: %s removed' % container) + + +def remove_all_containers(image): + client = docker.from_env() + try: + for container in client.containers.list(): + if image in container.image.tags: + container.stop(timeout=30) + container.remove(force=True) + except Exception as e: + logger.error("Containers removed failed") + + +def container_exists(image): + ''' + Check if container existed with the given image name + @params: image name + @return: container if exists + ''' + res = False + client = docker.from_env() + for container in client.containers.list(): + if image in container.image.tags: + # True + res = container + return res + + +if __name__ == '__main__': + # print(pull_image('branch-0.3.1-debug')) + stop_server() \ No newline at end of file diff --git a/tests/milvus_python_test/.dockerignore b/tests/milvus_python_test/.dockerignore new file mode 100644 index 0000000000..c97d9d043c --- /dev/null +++ b/tests/milvus_python_test/.dockerignore @@ -0,0 +1,14 @@ +node_modules +npm-debug.log +Dockerfile* +docker-compose* +.dockerignore +.git +.gitignore +.env +*/bin +*/obj +README.md +LICENSE +.vscode +__pycache__ \ No newline at end of file diff --git a/tests/milvus_python_test/.gitignore b/tests/milvus_python_test/.gitignore new file mode 100644 index 0000000000..9bd7345e51 --- /dev/null +++ b/tests/milvus_python_test/.gitignore @@ -0,0 +1,13 @@ +.python-version +.pytest_cache +__pycache__ +.vscode +.idea + +test_out/ +*.pyc + +db/ +logs/ + +.coverage diff --git a/tests/milvus_python_test/Dockerfile b/tests/milvus_python_test/Dockerfile new file mode 100644 index 0000000000..ec78b943dc --- /dev/null +++ b/tests/milvus_python_test/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.6.8-jessie + +LABEL Name=megasearch_engine_test Version=0.0.1 + +WORKDIR /app +ADD . /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libc-dev build-essential && \ + python3 -m pip install -r requirements.txt && \ + apt-get remove --purge -y + +ENTRYPOINT [ "/app/docker-entrypoint.sh" ] +CMD [ "start" ] \ No newline at end of file diff --git a/tests/milvus_python_test/MilvusCases.md b/tests/milvus_python_test/MilvusCases.md new file mode 100644 index 0000000000..ea6372c373 --- /dev/null +++ b/tests/milvus_python_test/MilvusCases.md @@ -0,0 +1,143 @@ +# Milvus test cases + +## * Interfaces test + +### 1. 连接测试 + +#### 1.1 连接 + +| cases | expected | +| ---------------- | -------------------------------------------- | +| 非法IP 123.0.0.2 | method: connect raise error in given timeout | +| 正常 uri | attr: connected assert true | +| 非法 uri | method: connect raise error in given timeout | +| 最大连接数 | all connection attrs: connected assert true | +| | | + +#### 1.2 断开连接 + +| cases | expected | +| ------------------------ | ------------------- | +| 正常连接下,断开连接 | connect raise error | +| 正常连接下,重复断开连接 | connect raise error | + +### 2. Table operation + +#### 2.1 表创建 + +##### 2.1.1 表名 + +| cases | expected | +| ------------------------- | ----------- | +| 基础功能,参数正常 | status pass | +| 表名已存在 | status fail | +| 表名:"中文" | status pass | +| 表名带特殊字符: "-39fsd-" | status pass | +| 表名带空格: "test1 2" | status pass | +| invalid dim: 0 | raise error | +| invalid dim: -1 | raise error | +| invalid dim: 100000000 | raise error | +| invalid dim: "string" | raise error | +| index_type: 0 | status pass | +| index_type: 1 | status pass | +| index_type: 2 | status pass | +| index_type: string | raise error | +| | | + +##### 2.1.2 维数支持 + +| cases | expected | +| --------------------- | ----------- | +| 维数: 0 | raise error | +| 维数负数: -1 | raise error | +| 维数最大值: 100000000 | raise error | +| 维数字符串: "string" | raise error | +| | | + +##### 2.1.3 索引类型支持 + +| cases | expected | +| ---------------- | ----------- | +| 索引类型: 0 | status pass | +| 索引类型: 1 | status pass | +| 索引类型: 2 | status pass | +| 索引类型: string | raise error | +| | | + +#### 2.2 表说明 + +| cases | expected | +| ---------------------- | -------------------------------- | +| 创建表后,执行describe | 返回结构体,元素与创建表参数一致 | +| | | + +#### 2.3 表删除 + +| cases | expected | +| -------------- | ---------------------- | +| 删除已存在表名 | has_table return False | +| 删除不存在表名 | status fail | +| | | + +#### 2.4 表是否存在 + +| cases | expected | +| ----------------------- | ------------ | +| 存在表,调用has_table | assert true | +| 不存在表,调用has_table | assert false | +| | | + +#### 2.5 查询表记录条数 + +| cases | expected | +| -------------------- | ------------------------ | +| 空表 | 0 | +| 空表插入数据(单条) | 1 | +| 空表插入数据(多条) | assert length of vectors | + +#### 2.6 查询表数量 + +| cases | expected | +| --------------------------------------------- | -------------------------------- | +| 两张表,一张空表,一张有数据:调用show tables | assert length of table list == 2 | +| | | + +### 3. Add vectors + +| interfaces | cases | expected | +| ----------- | --------------------------------------------------------- | ------------------------------------ | +| add_vectors | add basic | assert length of ids == nq | +| | add vectors into table not existed | status fail | +| | dim not match: single vector | status fail | +| | dim not match: vector list | status fail | +| | single vector element empty | status fail | +| | vector list element empty | status fail | +| | query immediately after adding | status pass | +| | query immediately after sleep 6s | status pass && length of result == 1 | +| | concurrent add with multi threads(share one connection) | status pass | +| | concurrent add with multi threads(independent connection) | status pass | +| | concurrent add with multi process(independent connection) | status pass | +| | index_type: 2 | status pass | +| | index_type: string | raise error | +| | | | + +### 4. Search vectors + +| interfaces | cases | expected | +| -------------- | ------------------------------------------------- | -------------------------------- | +| search_vectors | search basic(query vector in vectors, top-k nq | assert length of result == nq | +| | concurrent search | status pass | +| | query_range(get_current_day(), get_current_day()) | assert length of result == nq | +| | invalid query_range: "" | raise error | +| | query_range(get_last_day(2), get_last_day(1)) | assert length of result == 0 | +| | query_range(get_last_day(2), get_current_day()) | assert length of result == nq | +| | query_range((get_last_day(2), get_next_day(2)) | assert length of result == nq | +| | query_range((get_current_day(), get_next_day(2)) | assert length of result == nq | +| | query_range(get_next_day(1), get_next_day(2)) | assert length of result == 0 | +| | score: vector[i] = vector[i]+-0.01 | score > 99.9 | \ No newline at end of file diff --git a/tests/milvus_python_test/README.md b/tests/milvus_python_test/README.md new file mode 100644 index 0000000000..6a87bf1ff8 --- /dev/null +++ b/tests/milvus_python_test/README.md @@ -0,0 +1,23 @@ +# Requirements +* python 3.6.8+ +* pip install -r requirements.txt + +# How to use this Test Project +```shell +pytest . --level=1 +``` +or test connect function only + +```shell +pytest test_connect.py --level=1 +``` + +with allure test report + + ```shell +pytest --alluredir=test_out . -q -v +allure serve test_out + ``` +# Contribution getting started +* Follow PEP-8 for naming and black for formatting. + diff --git a/tests/milvus_python_test/conftest.py b/tests/milvus_python_test/conftest.py new file mode 100644 index 0000000000..8bab824606 --- /dev/null +++ b/tests/milvus_python_test/conftest.py @@ -0,0 +1,132 @@ +import socket +import pdb +import logging + +import pytest +from utils import gen_unique_str +from milvus import Milvus, IndexType, MetricType + +index_file_size = 10 + + +def pytest_addoption(parser): + parser.addoption("--ip", action="store", default="localhost") + parser.addoption("--port", action="store", default=19530) + parser.addoption("--internal", action="store", default=False) + + +def check_server_connection(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + connected = True + if ip and (ip not in ['localhost', '127.0.0.1']): + try: + socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP) + except Exception as e: + print("Socket connnet failed: %s" % str(e)) + connected = False + return connected + + +def get_args(request): + args = { + "ip": request.config.getoption("--ip"), + "port": request.config.getoption("--port") + } + return args + + +@pytest.fixture(scope="module") +def connect(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + milvus = Milvus() + try: + milvus.connect(host=ip, port=port) + except: + pytest.exit("Milvus server can not connected, exit pytest ...") + + def fin(): + try: + milvus.disconnect() + except: + pass + + request.addfinalizer(fin) + return milvus + + +@pytest.fixture(scope="module") +def dis_connect(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + milvus = Milvus() + milvus.connect(host=ip, port=port) + milvus.disconnect() + def fin(): + try: + milvus.disconnect() + except: + pass + + request.addfinalizer(fin) + return milvus + + +@pytest.fixture(scope="module") +def args(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + internal = request.config.getoption("--internal") + args = {"ip": ip, "port": port} + if internal: + args = {"ip": ip, "port": port, "internal": internal} + return args + + +@pytest.fixture(scope="function") +def table(request, connect): + ori_table_name = getattr(request.module, "table_id", "test") + table_name = gen_unique_str(ori_table_name) + dim = getattr(request.module, "dim", "128") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + # logging.getLogger().info(status) + if not status.OK(): + pytest.exit("Table can not be created, exit pytest ...") + + def teardown(): + status, table_names = connect.show_tables() + for table_name in table_names: + connect.delete_table(table_name) + + request.addfinalizer(teardown) + + return table_name + + +@pytest.fixture(scope="function") +def ip_table(request, connect): + ori_table_name = getattr(request.module, "table_id", "test") + table_name = gen_unique_str(ori_table_name) + dim = getattr(request.module, "dim", "128") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + # logging.getLogger().info(status) + if not status.OK(): + pytest.exit("Table can not be created, exit pytest ...") + + def teardown(): + status, table_names = connect.show_tables() + for table_name in table_names: + connect.delete_table(table_name) + + request.addfinalizer(teardown) + + return table_name \ No newline at end of file diff --git a/tests/milvus_python_test/docker-entrypoint.sh b/tests/milvus_python_test/docker-entrypoint.sh new file mode 100755 index 0000000000..af9ba0ba66 --- /dev/null +++ b/tests/milvus_python_test/docker-entrypoint.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e + +if [ "$1" = 'start' ]; then + tail -f /dev/null +fi + +exec "$@" \ No newline at end of file diff --git a/tests/milvus_python_test/pytest.ini b/tests/milvus_python_test/pytest.ini new file mode 100644 index 0000000000..3f95dc29b8 --- /dev/null +++ b/tests/milvus_python_test/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s) + +log_cli = true +log_level = 20 + +timeout = 300 + +level = 1 \ No newline at end of file diff --git a/tests/milvus_python_test/requirements.txt b/tests/milvus_python_test/requirements.txt new file mode 100644 index 0000000000..4bdecd6033 --- /dev/null +++ b/tests/milvus_python_test/requirements.txt @@ -0,0 +1,25 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 +pymilvus-test>=0.2.0 diff --git a/tests/milvus_python_test/requirements_cluster.txt b/tests/milvus_python_test/requirements_cluster.txt new file mode 100644 index 0000000000..a1f3b69be9 --- /dev/null +++ b/tests/milvus_python_test/requirements_cluster.txt @@ -0,0 +1,25 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 +pymilvus>=0.1.24 diff --git a/tests/milvus_python_test/requirements_no_pymilvus.txt b/tests/milvus_python_test/requirements_no_pymilvus.txt new file mode 100644 index 0000000000..45884c0c71 --- /dev/null +++ b/tests/milvus_python_test/requirements_no_pymilvus.txt @@ -0,0 +1,24 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 diff --git a/tests/milvus_python_test/run.sh b/tests/milvus_python_test/run.sh new file mode 100644 index 0000000000..cee5b061f5 --- /dev/null +++ b/tests/milvus_python_test/run.sh @@ -0,0 +1,4 @@ +#/bin/bash + + +pytest . $@ \ No newline at end of file diff --git a/tests/milvus_python_test/test.template b/tests/milvus_python_test/test.template new file mode 100644 index 0000000000..0403d72860 --- /dev/null +++ b/tests/milvus_python_test/test.template @@ -0,0 +1,41 @@ +''' +Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +Unauthorized copying of this file, via any medium is strictly prohibited. +Proprietary and confidential. +''' + +''' +Test Description: + +This document is only a template to show how to write a auto-test script + +本文档仅仅是个展示如何编写自动化测试脚本的模板 + +''' + +import pytest +from milvus import Milvus + + +class TestConnection: + def test_connect_localhost(self): + + """ + TestCase1.1 + Test target: This case is to check if the server can be connected. + Test method: Call API: milvus.connect to connect local milvus server, ip address: 127.0.0.1 and port: 19530, check the return status + Expectation: Return status is OK. + + 测试目的:本用例测试客户端是否可以与服务器建立连接 + 测试方法:调用SDK API: milvus.connect方法连接本地服务器,IP地址:127.0.0.1,端口:19530,检查调用返回状态 + 期望结果:返回状态是:OK + + """ + + milvus = Milvus() + milvus.connect(host='127.0.0.1', port='19530') + assert milvus.connected + + + + diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py new file mode 100644 index 0000000000..51c12dcd87 --- /dev/null +++ b/tests/milvus_python_test/test_add_vectors.py @@ -0,0 +1,1231 @@ +import time +import random +import pdb +import threading +import logging +from multiprocessing import Pool, Process +import pytest +from milvus import Milvus, IndexType, MetricType +from utils import * + + +dim = 128 +index_file_size = 10 +table_id = "test_add" +ADD_TIMEOUT = 60 +nprobe = 1 +epsilon = 0.0001 + +index_params = random.choice(gen_index_params()) +logging.getLogger().info(index_params) + + +class TestAddBase: + """ + ****************************************************************** + The following cases are used to test `add_vectors / index / search / delete` mixed function + ****************************************************************** + """ + + def test_add_vector_create_table(self, connect, table): + ''' + target: test add vector, then create table again + method: add vector and create table + expected: status not ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_add_vector_has_table(self, connect, table): + ''' + target: test add vector, then check table existence + method: add vector and call HasTable + expected: table exists, status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + assert assert_has_table(connect, table) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector(self, connect, table): + ''' + target: test add vector after table deleted + method: delete table and add vector + expected: status not ok + ''' + status = connect.delete_table(table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after table_2 deleted + method: delete table_2 and add vector to table_1 + expected: status ok + ''' + param = {'table_name': 'test_delete_table_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.delete_table(table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_table(self, connect, table): + ''' + target: test delete table after add vector + method: add vector and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.delete_table(table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_another_table(self, connect, table): + ''' + target: test delete table_1 table after add vector to table_2 + method: add vector and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_table(self, connect, table): + ''' + target: test delete table after add vector for a while + method: add vector, sleep, and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.delete_table(table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_another_table(self, connect, table): + ''' + target: test delete table_1 table after add vector to table_2 for a while + method: add vector , sleep, and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector(self, connect, table): + ''' + target: test add vector after build index + method: build index and add vector + expected: status ok + ''' + status = connect.create_index(table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_create_index_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_index(table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index(self, connect, table): + ''' + target: test build index add after vector + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index(self, connect, table): + ''' + target: test build index add after vector for a while + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 for a while + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector(self, connect, table): + ''' + target: test add vector after search table + method: search table and add vector + expected: status ok + ''' + vector = gen_single_vector(dim) + status, result = connect.search_vectors(table, 1, nprobe, vector) + status, ids = connect.add_vectors(table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_search_vector_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, result = connect.search_vectors(table, 1, nprobe, vector) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector(self, connect, table): + ''' + target: test search vector after add vector + method: add vector and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status, result = connect.search_vectors(table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector(self, connect, table): + ''' + target: test search vector after add vector after a while + method: add vector, sleep, and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status, result = connect.search_vectors(table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 a while + method: search table , sleep, and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + """ + ****************************************************************** + The following cases are used to test `add_vectors` function + ****************************************************************** + """ + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids(self, connect, table): + ''' + target: test add vectors in table, use customize ids + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors, ids) + time.sleep(2) + assert status.OK() + assert len(ids) == nq + # check search result + status, result = connect.search_vectors(table, top_k, nprobe, vectors) + logging.getLogger().info(result) + assert len(result) == nq + for i in range(nq): + assert result[i][0].id == i + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_ids_no_ids(self, connect, table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use customize ids first, and then use no ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors, ids) + assert status.OK() + status, ids = connect.add_vectors(table, vectors) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_not_ids_ids(self, connect, table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use not ids first, and then use customize ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + status, ids = connect.add_vectors(table, vectors, ids) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids_length_not_match(self, connect, table): + ''' + target: test add vectors in table, use customize ids, len(ids) != len(vectors) + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(1, nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(table, vectors, ids) + + @pytest.fixture( + scope="function", + params=gen_invalid_vector_ids() + ) + def get_vector_id(self, request): + yield request.param + + def test_add_vectors_ids_invalid(self, connect, table, get_vector_id): + ''' + target: test add vectors in table, use customize ids, which are not int64 + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + vector_id = get_vector_id + ids = [vector_id for i in range(nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(table, vectors, ids) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors(self, connect, table): + ''' + target: test add vectors in table created before + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == nq + + @pytest.mark.level(2) + def test_add_vectors_without_connect(self, dis_connect, table): + ''' + target: test add vectors without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + with pytest.raises(Exception) as e: + status, ids = dis_connect.add_vectors(table, vectors) + + def test_add_table_not_existed(self, connect): + ''' + target: test add vectors in table, which not existed before + method: add vectors table not existed, check the status + expected: status not ok + ''' + nq = 5 + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(gen_unique_str("not_exist_table"), vector) + assert not status.OK() + assert not ids + + def test_add_vector_dim_not_matched(self, connect, table): + ''' + target: test add vector, the vector dimension is not equal to the table dimension + method: the vector dimension is half of the table dimension, check the status + expected: status not ok + ''' + vector = gen_single_vector(int(dim)//2) + status, ids = connect.add_vectors(table, vector) + assert not status.OK() + + def test_add_vectors_dim_not_matched(self, connect, table): + ''' + target: test add vectors, the vector dimension is not equal to the table dimension + method: the vectors dimension is half of the table dimension, check the status + expected: status not ok + ''' + nq = 5 + vectors = gen_vectors(nq, int(dim)//2) + status, ids = connect.add_vectors(table, vectors) + assert not status.OK() + + def test_add_vector_query_after_sleep(self, connect, table): + ''' + target: test add vectors, and search it after sleep + method: set vector[0][1] as query vectors + expected: status ok and result length is 1 + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(table, vectors) + time.sleep(3) + status, result = connect.search_vectors(table, 1, nprobe, [vectors[0]]) + assert status.OK() + assert len(result) == 1 + + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_multi_threading(self, connect, table): + ''' + target: test add vectors, with multi threading + method: 10 thread add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + thread_num = 4 + loops = 100 + threads = [] + total_ids = [] + vector = gen_single_vector(dim) + def add(): + i = 0 + while i < loops: + status, ids = connect.add_vectors(table, vector) + total_ids.append(ids[0]) + i = i + 1 + for i in range(thread_num): + x = threading.Thread(target=add, args=()) + threads.append(x) + x.start() + time.sleep(0.2) + for th in threads: + th.join() + assert len(total_ids) == thread_num * loops + # make sure ids not the same + assert len(set(total_ids)) == thread_num * loops + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_with_multiprocessing(self, args): + ''' + target: test add vectors, with multi processes + method: 10 processed add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num + + def test_add_vector_multi_tables(self, connect): + ''' + target: test add vectors is correct or not with multiple tables of L2 + method: create 50 tables and add vectors into them in turn + expected: status ok + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_add_vector_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + time.sleep(2) + for j in range(10): + for i in range(50): + status, ids = connect.add_vectors(table_name=table_list[i], records=vectors) + assert status.OK() + +class TestAddIP: + """ + ****************************************************************** + The following cases are used to test `add_vectors / index / search / delete` mixed function + ****************************************************************** + """ + + def test_add_vector_create_table(self, connect, ip_table): + ''' + target: test add vector, then create table again + method: add vector and create table + expected: status not ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + param = {'table_name': ip_table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_add_vector_has_table(self, connect, ip_table): + ''' + target: test add vector, then check table existence + method: add vector and call HasTable + expected: table exists, status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + assert assert_has_table(connect, ip_table) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector(self, connect, ip_table): + ''' + target: test add vector after table deleted + method: delete table and add vector + expected: status not ok + ''' + status = connect.delete_table(ip_table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after table_2 deleted + method: delete table_2 and add vector to table_1 + expected: status ok + ''' + param = {'table_name': 'test_delete_table_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.delete_table(ip_table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_table(self, connect, ip_table): + ''' + target: test delete table after add vector + method: add vector and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.delete_table(ip_table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_another_table(self, connect, ip_table): + ''' + target: test delete table_1 table after add vector to table_2 + method: add vector and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_table(self, connect, ip_table): + ''' + target: test delete table after add vector for a while + method: add vector, sleep, and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.delete_table(ip_table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_another_table(self, connect, ip_table): + ''' + target: test delete table_1 table after add vector to table_2 for a while + method: add vector , sleep, and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector(self, connect, ip_table): + ''' + target: test add vector after build index + method: build index and add vector + expected: status ok + ''' + status = connect.create_index(ip_table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_create_index_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_index(ip_table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index(self, connect, ip_table): + ''' + target: test build index add after vector + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index(self, connect, ip_table): + ''' + target: test build index add after vector for a while + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 for a while + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector(self, connect, ip_table): + ''' + target: test add vector after search table + method: search table and add vector + expected: status ok + ''' + vector = gen_single_vector(dim) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + status, ids = connect.add_vectors(ip_table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_search_vector_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector(self, connect, ip_table): + ''' + target: test search vector after add vector + method: add vector and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector(self, connect, ip_table): + ''' + target: test search vector after add vector after a while + method: add vector, sleep, and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 a while + method: search table , sleep, and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + """ + ****************************************************************** + The following cases are used to test `add_vectors` function + ****************************************************************** + """ + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids(self, connect, ip_table): + ''' + target: test add vectors in table, use customize ids + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors, ids) + time.sleep(2) + assert status.OK() + assert len(ids) == nq + # check search result + status, result = connect.search_vectors(ip_table, top_k, nprobe, vectors) + logging.getLogger().info(result) + assert len(result) == nq + for i in range(nq): + assert result[i][0].id == i + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_ids_no_ids(self, connect, ip_table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use customize ids first, and then use no ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors, ids) + assert status.OK() + status, ids = connect.add_vectors(ip_table, vectors) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_not_ids_ids(self, connect, ip_table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use not ids first, and then use customize ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + status, ids = connect.add_vectors(ip_table, vectors, ids) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids_length_not_match(self, connect, ip_table): + ''' + target: test add vectors in table, use customize ids, len(ids) != len(vectors) + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(1, nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(ip_table, vectors, ids) + + @pytest.fixture( + scope="function", + params=gen_invalid_vector_ids() + ) + def get_vector_id(self, request): + yield request.param + + def test_add_vectors_ids_invalid(self, connect, ip_table, get_vector_id): + ''' + target: test add vectors in table, use customize ids, which are not int64 + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + vector_id = get_vector_id + ids = [vector_id for i in range(nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(ip_table, vectors, ids) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors(self, connect, ip_table): + ''' + target: test add vectors in table created before + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + assert len(ids) == nq + + @pytest.mark.level(2) + def test_add_vectors_without_connect(self, dis_connect, ip_table): + ''' + target: test add vectors without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + with pytest.raises(Exception) as e: + status, ids = dis_connect.add_vectors(ip_table, vectors) + + def test_add_vector_dim_not_matched(self, connect, ip_table): + ''' + target: test add vector, the vector dimension is not equal to the table dimension + method: the vector dimension is half of the table dimension, check the status + expected: status not ok + ''' + vector = gen_single_vector(int(dim)//2) + status, ids = connect.add_vectors(ip_table, vector) + assert not status.OK() + + def test_add_vectors_dim_not_matched(self, connect, ip_table): + ''' + target: test add vectors, the vector dimension is not equal to the table dimension + method: the vectors dimension is half of the table dimension, check the status + expected: status not ok + ''' + nq = 5 + vectors = gen_vectors(nq, int(dim)//2) + status, ids = connect.add_vectors(ip_table, vectors) + assert not status.OK() + + def test_add_vector_query_after_sleep(self, connect, ip_table): + ''' + target: test add vectors, and search it after sleep + method: set vector[0][1] as query vectors + expected: status ok and result length is 1 + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(ip_table, vectors) + time.sleep(3) + status, result = connect.search_vectors(ip_table, 1, nprobe, [vectors[0]]) + assert status.OK() + assert len(result) == 1 + + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_multi_threading(self, connect, ip_table): + ''' + target: test add vectors, with multi threading + method: 10 thread add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + thread_num = 4 + loops = 100 + threads = [] + total_ids = [] + vector = gen_single_vector(dim) + def add(): + i = 0 + while i < loops: + status, ids = connect.add_vectors(ip_table, vector) + total_ids.append(ids[0]) + i = i + 1 + for i in range(thread_num): + x = threading.Thread(target=add, args=()) + threads.append(x) + x.start() + time.sleep(0.2) + for th in threads: + th.join() + assert len(total_ids) == thread_num * loops + # make sure ids not the same + assert len(set(total_ids)) == thread_num * loops + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_with_multiprocessing(self, args): + ''' + target: test add vectors, with multi processes + method: 10 processed add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num + + def test_add_vector_multi_tables(self, connect): + ''' + target: test add vectors is correct or not with multiple tables of IP + method: create 50 tables and add vectors into them in turn + expected: status ok + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_add_vector_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + time.sleep(2) + for j in range(10): + for i in range(50): + status, ids = connect.add_vectors(table_name=table_list[i], records=vectors) + assert status.OK() + +class TestAddAdvance: + + @pytest.fixture( + scope="function", + params=[ + 1, + 10, + 100, + 1000, + pytest.param(5000 - 1, marks=pytest.mark.xfail), + pytest.param(5000, marks=pytest.mark.xfail), + pytest.param(5000 + 1, marks=pytest.mark.xfail), + ], + ) + def insert_count(self, request): + yield request.param + + def test_insert_much(self, connect, table, insert_count): + ''' + target: test add vectors with different length of vectors + method: set different vectors as add method params + expected: length of ids is equal to the length of vectors + ''' + nb = insert_count + insert_vec_list = gen_vectors(nb, dim) + status, ids = connect.add_vectors(table, insert_vec_list) + assert len(ids) == nb + assert status.OK() + + def test_insert_much_ip(self, connect, ip_table, insert_count): + ''' + target: test add vectors with different length of vectors + method: set different vectors as add method params + expected: length of ids is equal to the length of vectors + ''' + nb = insert_count + insert_vec_list = gen_vectors(nb, dim) + status, ids = connect.add_vectors(ip_table, insert_vec_list) + assert len(ids) == nb + assert status.OK() + +class TestAddTableNameInvalid(object): + """ + Test adding vectors with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + @pytest.mark.level(2) + def test_add_vectors_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + vectors = gen_vectors(1, dim) + status, result = connect.add_vectors(table_name, vectors) + assert not status.OK() + + +class TestAddTableVectorsInvalid(object): + single_vector = gen_single_vector(dim) + vectors = gen_vectors(2, dim) + + """ + Test adding vectors with invalid vectors + """ + @pytest.fixture( + scope="function", + params=gen_invalid_vectors() + ) + def gen_vector(self, request): + yield request.param + + @pytest.mark.level(2) + def test_add_vector_with_invalid_vectors(self, connect, table, gen_vector): + tmp_single_vector = copy.deepcopy(self.single_vector) + tmp_single_vector[0][1] = gen_vector + with pytest.raises(Exception) as e: + status, result = connect.add_vectors(table, tmp_single_vector) + + @pytest.mark.level(1) + def test_add_vectors_with_invalid_vectors(self, connect, table, gen_vector): + tmp_vectors = copy.deepcopy(self.vectors) + tmp_vectors[1][1] = gen_vector + with pytest.raises(Exception) as e: + status, result = connect.add_vectors(table, tmp_vectors) \ No newline at end of file diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py new file mode 100644 index 0000000000..5ec9539011 --- /dev/null +++ b/tests/milvus_python_test/test_connect.py @@ -0,0 +1,386 @@ +import pytest +from milvus import Milvus +import pdb +import threading +from multiprocessing import Process +from utils import * + +__version__ = '0.5.0' +CONNECT_TIMEOUT = 12 + + +class TestConnect: + + def local_ip(self, args): + ''' + check if ip is localhost or not + ''' + if not args["ip"] or args["ip"] == 'localhost' or args["ip"] == "127.0.0.1": + return True + else: + return False + + def test_disconnect(self, connect): + ''' + target: test disconnect + method: disconnect a connected client + expected: connect failed after disconnected + ''' + res = connect.disconnect() + assert res.OK() + with pytest.raises(Exception) as e: + res = connect.server_version() + + def test_disconnect_repeatedly(self, connect, args): + ''' + target: test disconnect repeatedly + method: disconnect a connected client, disconnect again + expected: raise an error after disconnected + ''' + if not connect.connected(): + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + res = milvus.disconnect() + with pytest.raises(Exception) as e: + res = milvus.disconnect() + else: + res = connect.disconnect() + with pytest.raises(Exception) as e: + res = connect.disconnect() + + def test_connect_correct_ip_port(self, args): + ''' + target: test connect with corrent ip and port value + method: set correct ip and port + expected: connected is True + ''' + milvus = Milvus() + milvus.connect(host=args["ip"], port=args["port"]) + assert milvus.connected() + + def test_connect_connected(self, args): + ''' + target: test connect and disconnect with corrent ip and port value, assert connected value + method: set correct ip and port + expected: connected is False + ''' + milvus = Milvus() + milvus.connect(host=args["ip"], port=args["port"]) + milvus.disconnect() + assert not milvus.connected() + + # TODO: Currently we test with remote IP, localhost testing need to add + def _test_connect_ip_localhost(self, args): + ''' + target: test connect with ip value: localhost + method: set host localhost + expected: connected is True + ''' + milvus = Milvus() + milvus.connect(host='localhost', port=args["port"]) + assert milvus.connected() + + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_ip_null(self, args): + ''' + target: test connect with wrong ip value + method: set host null + expected: not use default ip, connected is False + ''' + milvus = Milvus() + ip = "" + with pytest.raises(Exception) as e: + milvus.connect(host=ip, port=args["port"], timeout=1) + assert not milvus.connected() + + def test_connect_uri(self, args): + ''' + target: test connect with correct uri + method: uri format and value are both correct + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_uri_null(self, args): + ''' + target: test connect with null uri + method: uri set null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "" + + if self.local_ip(args): + milvus.connect(uri=uri_value, timeout=1) + assert milvus.connected() + else: + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_uri_wrong_port_null(self, args): + ''' + target: test uri connect with port value wouldn't connected + method: set uri port null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://%s:" % args["ip"] + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_uri_wrong_ip_null(self, args): + ''' + target: test uri connect with ip value wouldn't connected + method: set uri ip null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://:%s" % args["port"] + + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() + + # TODO: enable + def _test_connect_with_multiprocess(self, args): + ''' + target: test uri connect with multiprocess + method: set correct uri, test with multiprocessing connecting + expected: all connection is connected + ''' + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + process_num = 4 + processes = [] + + def connect(milvus): + milvus.connect(uri=uri_value) + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value) + assert milvus.connected() + + for i in range(process_num): + milvus = Milvus() + p = Process(target=connect, args=(milvus, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + def test_connect_repeatedly(self, args): + ''' + target: test connect repeatedly + method: connect again + expected: status.code is 0, and status.message shows have connected already + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_disconnect_repeatedly_once(self, args): + ''' + target: test connect and disconnect repeatedly + method: disconnect, and then connect, assert connect status + expected: status.code is 0 + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_disconnect_repeatedly_times(self, args): + ''' + target: test connect and disconnect for 10 times repeatedly + method: disconnect, and then connect, assert connect status + expected: status.code is 0 + ''' + times = 10 + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + for i in range(times): + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + # TODO: enable + def _test_connect_disconnect_with_multiprocess(self, args): + ''' + target: test uri connect and disconnect repeatly with multiprocess + method: set correct uri, test with multiprocessing connecting and disconnecting + expected: all connection is connected after 10 times operation + ''' + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + process_num = 4 + processes = [] + + def connect(milvus): + milvus.connect(uri=uri_value) + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + for i in range(process_num): + milvus = Milvus() + p = Process(target=connect, args=(milvus, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + def test_connect_param_priority_no_port(self, args): + ''' + target: both host_ip_port / uri are both given, if port is null, use the uri params + method: port set "", check if wrong uri connection is ok + expected: connect raise an exception and connected is false + ''' + milvus = Milvus() + uri_value = "tcp://%s:19540" % args["ip"] + milvus.connect(host=args["ip"], port="", uri=uri_value) + assert milvus.connected() + + def test_connect_param_priority_uri(self, args): + ''' + target: both host_ip_port / uri are both given, if host is null, use the uri params + method: host set "", check if correct uri connection is ok + expected: connected is False + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + with pytest.raises(Exception) as e: + milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1) + assert not milvus.connected() + + def test_connect_param_priority_both_hostip_uri(self, args): + ''' + target: both host_ip_port / uri are both given, and not null, use the uri params + method: check if wrong uri connection is ok + expected: connect raise an exception and connected is false + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + with pytest.raises(Exception) as e: + milvus.connect(host=args["ip"], port=19540, uri=uri_value, timeout=1) + assert not milvus.connected() + + def _test_add_vector_and_disconnect_concurrently(self): + ''' + Target: test disconnect in the middle of add vectors + Method: + a. use coroutine or multi-processing, to simulate network crashing + b. data_set not too large incase disconnection happens when data is underd-preparing + c. data_set not too small incase disconnection happens when data has already been transferred + d. make sure disconnection happens when data is in-transport + Expected: Failure, get_table_row_count == 0 + + ''' + pass + + def _test_search_vector_and_disconnect_concurrently(self): + ''' + Target: Test disconnect in the middle of search vectors(with large nq and topk)multiple times, and search/add vectors still work + Method: + a. coroutine or multi-processing, to simulate network crashing + b. connect, search and disconnect, repeating many times + c. connect and search, add vectors + Expected: Successfully searched back, successfully added + + ''' + pass + + def _test_thread_safe_with_one_connection_shared_in_multi_threads(self): + ''' + Target: test 1 connection thread safe + Method: 1 connection shared in multi-threads, all adding vectors, or other things + Expected: Functional as one thread + + ''' + pass + + +class TestConnectIPInvalid(object): + """ + Test connect server with invalid ip + """ + @pytest.fixture( + scope="function", + params=gen_invalid_ips() + ) + def get_invalid_ip(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_ip(self, args, get_invalid_ip): + milvus = Milvus() + ip = get_invalid_ip + with pytest.raises(Exception) as e: + milvus.connect(host=ip, port=args["port"], timeout=1) + assert not milvus.connected() + + +class TestConnectPortInvalid(object): + """ + Test connect server with invalid ip + """ + + @pytest.fixture( + scope="function", + params=gen_invalid_ports() + ) + def get_invalid_port(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_port(self, args, get_invalid_port): + ''' + target: test ip:port connect with invalid port value + method: set port in gen_invalid_ports + expected: connected is False + ''' + milvus = Milvus() + port = get_invalid_port + with pytest.raises(Exception) as e: + milvus.connect(host=args["ip"], port=port, timeout=1) + assert not milvus.connected() + + +class TestConnectURIInvalid(object): + """ + Test connect server with invalid uri + """ + @pytest.fixture( + scope="function", + params=gen_invalid_uris() + ) + def get_invalid_uri(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_uri(self, get_invalid_uri): + ''' + target: test uri connect with invalid uri value + method: set port in gen_invalid_uris + expected: connected is False + ''' + milvus = Milvus() + uri_value = get_invalid_uri + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() diff --git a/tests/milvus_python_test/test_delete_vectors.py b/tests/milvus_python_test/test_delete_vectors.py new file mode 100644 index 0000000000..57ab76770f --- /dev/null +++ b/tests/milvus_python_test/test_delete_vectors.py @@ -0,0 +1,425 @@ +# import time +# import random +# import pdb +# import logging +# import threading +# from builtins import Exception +# from multiprocessing import Pool, Process +# import pytest + +# from milvus import Milvus, IndexType +# from utils import * + + +# dim = 128 +# index_file_size = 10 +# table_id = "test_delete" +# DELETE_TIMEOUT = 60 +# vectors = gen_vectors(100, dim) + +# class TestDeleteVectorsBase: +# """ +# generate invalid query range params +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_current_day(), get_current_day()), +# (get_last_day(1), get_last_day(1)), +# (get_next_day(1), get_next_day(1)) +# ] +# ) +# def get_invalid_range(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_invalid_range(self, connect, table, get_invalid_range): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with invalid date params +# expected: return code 0 +# ''' +# start_date = get_invalid_range[0] +# end_date = get_invalid_range[1] +# status, ids = connect.add_vectors(table, vectors) +# status = connect.delete_vectors_by_range(table, start_date, end_date) +# assert not status.OK() + +# """ +# generate valid query range params, no search result +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_last_day(2), get_last_day(1)), +# (get_last_day(2), get_current_day()), +# (get_next_day(1), get_next_day(2)) +# ] +# ) +# def get_valid_range_no_result(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_no_result(self, connect, table, get_valid_range_no_result): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_valid_range_no_result[0] +# end_date = get_valid_range_no_result[1] +# status, ids = connect.add_vectors(table, vectors) +# time.sleep(2) +# status = connect.delete_vectors_by_range(table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(table) +# assert result == 100 + +# """ +# generate valid query range params, no search result +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_last_day(2), get_next_day(2)), +# (get_current_day(), get_next_day(2)), +# ] +# ) +# def get_valid_range(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range(self, connect, table, get_valid_range): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_valid_range[0] +# end_date = get_valid_range[1] +# status, ids = connect.add_vectors(table, vectors) +# time.sleep(2) +# status = connect.delete_vectors_by_range(table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(table) +# assert result == 0 + +# @pytest.fixture( +# scope="function", +# params=gen_index_params() +# ) +# def get_index_params(self, request, args): +# if "internal" not in args: +# if request.param["index_type"] == IndexType.IVF_SQ8H: +# pytest.skip("sq8h not support in open source") +# return request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_index_created(self, connect, table, get_index_params): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# index_params = get_index_params +# logging.getLogger().info(index_params) +# status, ids = connect.add_vectors(table, vectors) +# status = connect.create_index(table, index_params) +# logging.getLogger().info(status) +# logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date)) +# status = connect.delete_vectors_by_range(table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(table) +# assert result == 0 + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_no_data(self, connect, table): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params, and no data in db +# expected: return code 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# # status, ids = connect.add_vectors(table, vectors) +# status = connect.delete_vectors_by_range(table, start_date, end_date) +# assert status.OK() + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_table_not_existed(self, connect): +# ''' +# target: test delete vectors, table not existed in db +# method: call `delete_vectors_by_range`, with table not existed +# expected: return code not 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# table_name = gen_unique_str("not_existed_table") +# status = connect.delete_vectors_by_range(table_name, start_date, end_date) +# assert not status.OK() + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_table_None(self, connect, table): +# ''' +# target: test delete vectors, table set Nope +# method: call `delete_vectors_by_range`, with table value is None +# expected: return code not 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# table_name = None +# with pytest.raises(Exception) as e: +# status = connect.delete_vectors_by_range(table_name, start_date, end_date) + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range): +# ''' +# target: test delete vectors is correct or not with multiple tables of L2 +# method: create 50 tables and add vectors into them , then delete vectors +# in valid range +# expected: return code 0 +# ''' +# nq = 100 +# vectors = gen_vectors(nq, dim) +# table_list = [] +# for i in range(50): +# table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables') +# table_list.append(table_name) +# param = {'table_name': table_name, +# 'dimension': dim, +# 'index_file_size': index_file_size, +# 'metric_type': MetricType.L2} +# connect.create_table(param) +# status, ids = connect.add_vectors(table_name=table_name, records=vectors) +# time.sleep(2) +# start_date = get_valid_range[0] +# end_date = get_valid_range[1] +# for i in range(50): +# status = connect.delete_vectors_by_range(table_list[i], start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(table_list[i]) +# assert result == 0 + + +# class TestDeleteVectorsIP: +# """ +# generate invalid query range params +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_current_day(), get_current_day()), +# (get_last_day(1), get_last_day(1)), +# (get_next_day(1), get_next_day(1)) +# ] +# ) +# def get_invalid_range(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_invalid_range(self, connect, ip_table, get_invalid_range): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with invalid date params +# expected: return code 0 +# ''' +# start_date = get_invalid_range[0] +# end_date = get_invalid_range[1] +# status, ids = connect.add_vectors(ip_table, vectors) +# status = connect.delete_vectors_by_range(ip_table, start_date, end_date) +# assert not status.OK() + +# """ +# generate valid query range params, no search result +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_last_day(2), get_last_day(1)), +# (get_last_day(2), get_current_day()), +# (get_next_day(1), get_next_day(2)) +# ] +# ) +# def get_valid_range_no_result(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_no_result(self, connect, ip_table, get_valid_range_no_result): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_valid_range_no_result[0] +# end_date = get_valid_range_no_result[1] +# status, ids = connect.add_vectors(ip_table, vectors) +# time.sleep(2) +# status = connect.delete_vectors_by_range(ip_table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(ip_table) +# assert result == 100 + +# """ +# generate valid query range params, no search result +# """ +# @pytest.fixture( +# scope="function", +# params=[ +# (get_last_day(2), get_next_day(2)), +# (get_current_day(), get_next_day(2)), +# ] +# ) +# def get_valid_range(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range(self, connect, ip_table, get_valid_range): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_valid_range[0] +# end_date = get_valid_range[1] +# status, ids = connect.add_vectors(ip_table, vectors) +# time.sleep(2) +# status = connect.delete_vectors_by_range(ip_table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(ip_table) +# assert result == 0 + +# @pytest.fixture( +# scope="function", +# params=gen_index_params() +# ) +# def get_index_params(self, request, args): +# if "internal" not in args: +# if request.param["index_type"] == IndexType.IVF_SQ8H: +# pytest.skip("sq8h not support in open source") +# return request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_index_created(self, connect, ip_table, get_index_params): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params +# expected: return code 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# index_params = get_index_params +# logging.getLogger().info(index_params) +# status, ids = connect.add_vectors(ip_table, vectors) +# status = connect.create_index(ip_table, index_params) +# logging.getLogger().info(status) +# logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date)) +# status = connect.delete_vectors_by_range(ip_table, start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(ip_table) +# assert result == 0 + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_no_data(self, connect, ip_table): +# ''' +# target: test delete vectors, no index created +# method: call `delete_vectors_by_range`, with valid date params, and no data in db +# expected: return code 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# # status, ids = connect.add_vectors(table, vectors) +# status = connect.delete_vectors_by_range(ip_table, start_date, end_date) +# assert status.OK() + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_table_None(self, connect, ip_table): +# ''' +# target: test delete vectors, table set Nope +# method: call `delete_vectors_by_range`, with table value is None +# expected: return code not 0 +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# table_name = None +# with pytest.raises(Exception) as e: +# status = connect.delete_vectors_by_range(table_name, start_date, end_date) + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range): +# ''' +# target: test delete vectors is correct or not with multiple tables of IP +# method: create 50 tables and add vectors into them , then delete vectors +# in valid range +# expected: return code 0 +# ''' +# nq = 100 +# vectors = gen_vectors(nq, dim) +# table_list = [] +# for i in range(50): +# table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables') +# table_list.append(table_name) +# param = {'table_name': table_name, +# 'dimension': dim, +# 'index_file_size': index_file_size, +# 'metric_type': MetricType.IP} +# connect.create_table(param) +# status, ids = connect.add_vectors(table_name=table_name, records=vectors) +# time.sleep(2) +# start_date = get_valid_range[0] +# end_date = get_valid_range[1] +# for i in range(50): +# status = connect.delete_vectors_by_range(table_list[i], start_date, end_date) +# assert status.OK() +# status, result = connect.get_table_row_count(table_list[i]) +# assert result == 0 + +# class TestDeleteVectorsParamsInvalid: + +# """ +# Test search table with invalid table names +# """ +# @pytest.fixture( +# scope="function", +# params=gen_invalid_table_names() +# ) +# def get_table_name(self, request): +# yield request.param + +# @pytest.mark.level(2) +# def test_delete_vectors_table_invalid_name(self, connect, get_table_name): +# ''' +# ''' +# start_date = get_current_day() +# end_date = get_next_day(2) +# table_name = get_table_name +# logging.getLogger().info(table_name) +# top_k = 1 +# nprobe = 1 +# status = connect.delete_vectors_by_range(table_name, start_date, end_date) +# assert not status.OK() + +# """ +# Test search table with invalid query ranges +# """ +# @pytest.fixture( +# scope="function", +# params=gen_invalid_query_ranges() +# ) +# def get_query_ranges(self, request): +# yield request.param + +# @pytest.mark.timeout(DELETE_TIMEOUT) +# def test_delete_vectors_range_invalid(self, connect, table, get_query_ranges): +# ''' +# target: test search fuction, with the wrong query_range +# method: search with query_range +# expected: raise an error, and the connection is normal +# ''' +# start_date = get_query_ranges[0][0] +# end_date = get_query_ranges[0][1] +# status, ids = connect.add_vectors(table, vectors) +# logging.getLogger().info(get_query_ranges) +# with pytest.raises(Exception) as e: +# status = connect.delete_vectors_by_range(table, start_date, end_date) \ No newline at end of file diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py new file mode 100644 index 0000000000..435a547855 --- /dev/null +++ b/tests/milvus_python_test/test_index.py @@ -0,0 +1,972 @@ +""" + For testing index operations, including `create_index`, `describe_index` and `drop_index` interfaces +""" +import logging +import pytest +import time +import pdb +import threading +from multiprocessing import Pool, Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +nb = 100000 +dim = 128 +index_file_size = 10 +vectors = gen_vectors(nb, dim) +vectors /= numpy.linalg.norm(vectors) +vectors = vectors.tolist() +BUILD_TIMEOUT = 60 +nprobe = 1 + + +class TestIndexBase: + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index_params() + ) + def get_simple_index_params(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, table, get_index_params): + ''' + target: test create index interface + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.level(2) + def test_create_index_without_connect(self, dis_connect, table): + ''' + target: test create index without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.create_index(table, random.choice(gen_index_params())) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, table, get_index_params): + ''' + target: test create index interface, search with more query vectors + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + logging.getLogger().info(connect.describe_index(table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + logging.getLogger().info(result) + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def _test_create_index_multiprocessing(self, connect, table, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + status, ids = connect.add_vectors(table, vectors) + + def build(connect): + status = connect.create_index(table) + assert status.OK() + + process_num = 8 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + p = Process(target=build, args=(m,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + def _test_create_index_multiprocessing_multitable(self, connect, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + process_num = 8 + loop_num = 8 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_create_index_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_type': IndexType.FLAT, + 'store_raw_vector': False} + connect.create_table(param) + j = j + 1 + + def create_index(): + i = 0 + while i < loop_num: + # assert connect.has_table(table[ids*process_num+i]) + status, ids = connect.add_vectors(table[ids*process_num+i], vectors) + + status = connect.create_index(table[ids*process_num+i]) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + i = i + 1 + + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + ids = i + p = Process(target=create_index, args=(m,ids)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_create_index_table_not_existed(self, connect): + ''' + target: test create index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, create index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status = connect.create_index(table_name, random.choice(gen_index_params())) + assert not status.OK() + + def test_create_index_table_None(self, connect): + ''' + target: test create index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, create index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.create_index(table_name, random.choice(gen_index_params())) + + def test_create_index_no_vectors(self, connect, table): + ''' + target: test create index interface when there is no vectors in table + method: create table and add no vectors in it, and then create index + expected: return code equals to 0 + ''' + status = connect.create_index(table, random.choice(gen_index_params())) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_then_add_vectors(self, connect, table): + ''' + target: test create index interface when there is no vectors in table, and does not affect the subsequent process + method: create table and add no vectors in it, and then create index, add vectors in it + expected: return code equals to 0 + ''' + status = connect.create_index(table, random.choice(gen_index_params())) + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly(self, connect, table): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + status, ids = connect.add_vectors(table, vectors) + index_params = random.choice(gen_index_params()) + # index_params = get_index_params + status = connect.create_index(table, index_params) + status = connect.create_index(table, index_params) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly(self, connect, table): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + status, ids = connect.add_vectors(table, vectors) + index_params = random.sample(gen_index_params(), 2) + logging.getLogger().info(index_params) + status = connect.create_index(table, index_params[0]) + status = connect.create_index(table, index_params[1]) + assert status.OK() + status, result = connect.describe_index(table) + assert result._nlist == index_params[1]["nlist"] + assert result._table_name == table + assert result._index_type == index_params[1]["index_type"] + + """ + ****************************************************************** + The following cases are used to test `describe_index` function + ****************************************************************** + """ + + def test_describe_index(self, connect, table, get_index_params): + ''' + target: test describe index interface + method: create table and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table + assert result._index_type == index_params["index_type"] + + def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params): + ''' + target: test create, describe and drop index interface with multiple tables of L2 + method: create tables and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(10): + table_name = gen_unique_str('test_create_index_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + index_params = get_simple_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + status = connect.create_index(table_name, index_params) + assert status.OK() + + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table_list[i] + assert result._index_type == index_params["index_type"] + + for i in range(10): + status = connect.drop_index(table_list[i]) + assert status.OK() + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_describe_index_without_connect(self, dis_connect, table): + ''' + target: test describe index without connection + method: describe index, and check if describe successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_index(table) + + def test_describe_index_table_not_existed(self, connect): + ''' + target: test describe index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status, result = connect.describe_index(table_name) + assert not status.OK() + + def test_describe_index_table_None(self, connect): + ''' + target: test describe index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, describe index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.describe_index(table_name) + + def test_describe_index_not_create(self, connect, table): + ''' + target: test describe index interface when index not created + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + status, ids = connect.add_vectors(table, vectors) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert status.OK() + # assert result._nlist == index_params["nlist"] + # assert result._table_name == table + # assert result._index_type == index_params["index_type"] + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, table, get_index_params): + ''' + target: test drop index interface + method: create table and add vectors in it, create index, call drop index + expected: return code 0, and default index param + ''' + index_params = get_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + def test_drop_index_repeatly(self, connect, table, get_simple_index_params): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_drop_index_without_connect(self, dis_connect, table): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.drop_index(table) + + def test_drop_index_table_not_existed(self, connect): + ''' + target: test drop index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index, and then drop it + expected: return code not equals to 0, drop index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status = connect.drop_index(table_name) + assert not status.OK() + + def test_drop_index_table_None(self, connect): + ''' + target: test drop index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, drop index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.drop_index(table_name) + + def test_drop_index_table_not_create(self, connect, table): + ''' + target: test drop index interface when index not created + method: create table and add vectors in it, create index + expected: return code not equals to 0, drop index failed + ''' + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + # no create index + status = connect.drop_index(table) + logging.getLogger().info(status) + assert status.OK() + + def test_create_drop_index_repeatly(self, connect, table, get_simple_index_params): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(table, vectors) + for i in range(2): + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + def test_create_drop_index_repeatly_different_index_params(self, connect, table): + ''' + target: test create / drop index repeatly, use the different index params + method: create index, drop index, four times, each tme use different index_params to create index + expected: return code 0 + ''' + index_params = random.sample(gen_index_params(), 2) + status, ids = connect.add_vectors(table, vectors) + for i in range(2): + status = connect.create_index(table, index_params[i]) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + +class TestIndexIP: + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index_params() + ) + def get_simple_index_params(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, ip_table, get_index_params): + ''' + target: test create index interface + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.level(2) + def test_create_index_without_connect(self, dis_connect, ip_table): + ''' + target: test create index without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.create_index(ip_table, random.choice(gen_index_params())) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, ip_table, get_index_params): + ''' + target: test create index interface, search with more query vectors + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + # logging.getLogger().info(result) + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def _test_create_index_multiprocessing(self, connect, ip_table, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + status, ids = connect.add_vectors(ip_table, vectors) + + def build(connect): + status = connect.create_index(ip_table) + assert status.OK() + + process_num = 8 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + p = Process(target=build, args=(m,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + def _test_create_index_multiprocessing_multitable(self, connect, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + process_num = 8 + loop_num = 8 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_create_index_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim} + connect.create_table(param) + j = j + 1 + + def create_index(): + i = 0 + while i < loop_num: + # assert connect.has_table(table[ids*process_num+i]) + status, ids = connect.add_vectors(table[ids*process_num+i], vectors) + + status = connect.create_index(table[ids*process_num+i]) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + i = i + 1 + + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + ids = i + p = Process(target=create_index, args=(m,ids)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_create_index_no_vectors(self, connect, ip_table): + ''' + target: test create index interface when there is no vectors in table + method: create table and add no vectors in it, and then create index + expected: return code equals to 0 + ''' + status = connect.create_index(ip_table, random.choice(gen_index_params())) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_then_add_vectors(self, connect, ip_table): + ''' + target: test create index interface when there is no vectors in table, and does not affect the subsequent process + method: create table and add no vectors in it, and then create index, add vectors in it + expected: return code equals to 0 + ''' + status = connect.create_index(ip_table, random.choice(gen_index_params())) + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly(self, connect, ip_table): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + status, ids = connect.add_vectors(ip_table, vectors) + index_params = random.choice(gen_index_params()) + # index_params = get_index_params + status = connect.create_index(ip_table, index_params) + status = connect.create_index(ip_table, index_params) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly(self, connect, ip_table): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + status, ids = connect.add_vectors(ip_table, vectors) + index_params = random.sample(gen_index_params(), 2) + logging.getLogger().info(index_params) + status = connect.create_index(ip_table, index_params[0]) + status = connect.create_index(ip_table, index_params[1]) + assert status.OK() + status, result = connect.describe_index(ip_table) + assert result._nlist == index_params[1]["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params[1]["index_type"] + + """ + ****************************************************************** + The following cases are used to test `describe_index` function + ****************************************************************** + """ + + def test_describe_index(self, connect, ip_table, get_index_params): + ''' + target: test describe index interface + method: create table and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params["index_type"] + + def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params): + ''' + target: test create, describe and drop index interface with multiple tables of IP + method: create tables and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(10): + table_name = gen_unique_str('test_create_index_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + index_params = get_simple_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + status = connect.create_index(table_name, index_params) + assert status.OK() + + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table_list[i] + assert result._index_type == index_params["index_type"] + + for i in range(10): + status = connect.drop_index(table_list[i]) + assert status.OK() + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_describe_index_without_connect(self, dis_connect, ip_table): + ''' + target: test describe index without connection + method: describe index, and check if describe successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_index(ip_table) + + def test_describe_index_not_create(self, connect, ip_table): + ''' + target: test describe index interface when index not created + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + status, ids = connect.add_vectors(ip_table, vectors) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert status.OK() + # assert result._nlist == index_params["nlist"] + # assert result._table_name == table + # assert result._index_type == index_params["index_type"] + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, ip_table, get_index_params): + ''' + target: test drop index interface + method: create table and add vectors in it, create index, call drop index + expected: return code 0, and default index param + ''' + index_params = get_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_drop_index_without_connect(self, dis_connect, ip_table): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.drop_index(ip_table, random.choice(gen_index_params())) + + def test_drop_index_table_not_create(self, connect, ip_table): + ''' + target: test drop index interface when index not created + method: create table and add vectors in it, create index + expected: return code not equals to 0, drop index failed + ''' + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + # no create index + status = connect.drop_index(ip_table) + logging.getLogger().info(status) + assert status.OK() + + def test_create_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(ip_table, vectors) + for i in range(2): + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table): + ''' + target: test create / drop index repeatly, use the different index params + method: create index, drop index, four times, each tme use different index_params to create index + expected: return code 0 + ''' + index_params = random.sample(gen_index_params(), 2) + status, ids = connect.add_vectors(ip_table, vectors) + for i in range(2): + status = connect.create_index(ip_table, index_params[i]) + assert status.OK() + status, result = connect.describe_index(ip_table) + assert result._nlist == index_params[i]["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params[i]["index_type"] + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + +class TestIndexTableInvalid(object): + """ + Test create / describe / drop index interfaces with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + # @pytest.mark.level(1) + def test_create_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status = connect.create_index(table_name, random.choice(gen_index_params())) + assert not status.OK() + + # @pytest.mark.level(1) + def test_describe_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status, result = connect.describe_index(table_name) + assert not status.OK() + + # @pytest.mark.level(1) + def test_drop_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status = connect.drop_index(table_name) + assert not status.OK() + + +class TestCreateIndexParamsInvalid(object): + """ + Test Building index with invalid table names, table names not in db + """ + @pytest.fixture( + scope="function", + params=gen_invalid_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_index_with_invalid_index_params(self, connect, table, get_index_params): + index_params = get_index_params + index_type = index_params["index_type"] + nlist = index_params["nlist"] + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + # if not isinstance(index_type, int) or not isinstance(nlist, int): + with pytest.raises(Exception) as e: + status = connect.create_index(table, index_params) + # else: + # status = connect.create_index(table, index_params) + # assert not status.OK() diff --git a/tests/milvus_python_test/test_mix.py b/tests/milvus_python_test/test_mix.py new file mode 100644 index 0000000000..4578e330b3 --- /dev/null +++ b/tests/milvus_python_test/test_mix.py @@ -0,0 +1,180 @@ +import pdb +import copy +import pytest +import threading +import datetime +import logging +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +dim = 128 +index_file_size = 10 +table_id = "test_mix" +add_interval_time = 2 +vectors = gen_vectors(100000, dim) +vectors /= numpy.linalg.norm(vectors) +vectors = vectors.tolist() +top_k = 1 +nprobe = 1 +epsilon = 0.0001 +index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384} + + +class TestMixBase: + + # TODO: enable + def _test_search_during_createIndex(self, args): + loops = 100000 + table = "test_search_during_createIndex" + query_vecs = [vectors[0], vectors[1]] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + id_0 = 0; id_1 = 0 + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + milvus_instance.create_table({'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2}) + for i in range(10): + status, ids = milvus_instance.add_vectors(table, vectors) + # logging.getLogger().info(ids) + if i == 0: + id_0 = ids[0]; id_1 = ids[1] + def create_index(milvus_instance): + logging.getLogger().info("In create index") + status = milvus_instance.create_index(table, index_params) + logging.getLogger().info(status) + status, result = milvus_instance.describe_index(table) + logging.getLogger().info(result) + def add_vectors(milvus_instance): + logging.getLogger().info("In add vectors") + status, ids = milvus_instance.add_vectors(table, vectors) + logging.getLogger().info(status) + def search(milvus_instance): + for i in range(loops): + status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs) + logging.getLogger().info(status) + assert result[0][0].id == id_0 + assert result[1][0].id == id_1 + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + p_search = Process(target=search, args=(milvus_instance, )) + p_search.start() + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + p_create = Process(target=add_vectors, args=(milvus_instance, )) + p_create.start() + p_create.join() + + def test_mix_multi_tables(self, connect): + ''' + target: test functions with multiple tables of different metric_types and index_types + method: create 60 tables which 30 are L2 and the other are IP, add vectors into them + and test describe index and search + expected: status ok + ''' + nq = 10000 + vectors = gen_vectors(nq, dim) + table_list = [] + idx = [] + + #create table and add vectors + for i in range(30): + table_name = gen_unique_str('test_mix_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + assert status.OK() + for i in range(30): + table_name = gen_unique_str('test_mix_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + assert status.OK() + time.sleep(2) + + #create index + for i in range(10): + index_params = {'index_type': IndexType.FLAT, 'nlist': 16384} + status = connect.create_index(table_list[i], index_params) + assert status.OK() + status = connect.create_index(table_list[30 + i], index_params) + assert status.OK() + index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384} + status = connect.create_index(table_list[10 + i], index_params) + assert status.OK() + status = connect.create_index(table_list[40 + i], index_params) + assert status.OK() + index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384} + status = connect.create_index(table_list[20 + i], index_params) + assert status.OK() + status = connect.create_index(table_list[50 + i], index_params) + assert status.OK() + + #describe index + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(table_list[10 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[10 + i] + assert result._index_type == IndexType.IVFLAT + status, result = connect.describe_index(table_list[20 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[20 + i] + assert result._index_type == IndexType.IVF_SQ8 + status, result = connect.describe_index(table_list[30 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[30 + i] + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(table_list[40 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[40 + i] + assert result._index_type == IndexType.IVFLAT + status, result = connect.describe_index(table_list[50 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[50 + i] + assert result._index_type == IndexType.IVF_SQ8 + + #search + query_vecs = [vectors[0], vectors[10], vectors[20]] + for i in range(60): + table = table_list[i] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) + +def check_result(result, id): + if len(result) >= 5: + return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id] + else: + return id in (i.id for i in result) \ No newline at end of file diff --git a/tests/milvus_python_test/test_ping.py b/tests/milvus_python_test/test_ping.py new file mode 100644 index 0000000000..a55559bc63 --- /dev/null +++ b/tests/milvus_python_test/test_ping.py @@ -0,0 +1,77 @@ +import logging +import pytest + +__version__ = '0.5.0' + + +class TestPing: + def test_server_version(self, connect): + ''' + target: test get the server version + method: call the server_version method after connected + expected: version should be the pymilvus version + ''' + status, res = connect.server_version() + assert res == __version__ + + def test_server_status(self, connect): + ''' + target: test get the server status + method: call the server_status method after connected + expected: status returned should be ok + ''' + status, msg = connect.server_status() + assert status.OK() + + def _test_server_cmd_with_params_version(self, connect): + ''' + target: test cmd: version + method: cmd = "version" ... + expected: when cmd = 'version', return version of server; + ''' + cmd = "version" + status, msg = connect.cmd(cmd) + logging.getLogger().info(status) + logging.getLogger().info(msg) + assert status.OK() + assert msg == __version__ + + def _test_server_cmd_with_params_others(self, connect): + ''' + target: test cmd: lalala + method: cmd = "lalala" ... + expected: when cmd = 'version', return version of server; + ''' + cmd = "rm -rf test" + status, msg = connect.cmd(cmd) + logging.getLogger().info(status) + logging.getLogger().info(msg) + assert status.OK() + # assert msg == __version__ + + def test_connected(self, connect): + assert connect.connected() + + +class TestPingDisconnect: + def test_server_version(self, dis_connect): + ''' + target: test get the server version, after disconnect + method: call the server_version method after connected + expected: version should not be the pymilvus version + ''' + res = None + with pytest.raises(Exception) as e: + status, res = connect.server_version() + assert res is None + + def test_server_status(self, dis_connect): + ''' + target: test get the server status, after disconnect + method: call the server_status method after connected + expected: status returned should be not ok + ''' + status = None + with pytest.raises(Exception) as e: + status, msg = connect.server_status() + assert status is None diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py new file mode 100644 index 0000000000..e52e0d2d08 --- /dev/null +++ b/tests/milvus_python_test/test_search_vectors.py @@ -0,0 +1,653 @@ +import pdb +import copy +import pytest +import threading +import datetime +import logging +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +dim = 128 +table_id = "test_search" +add_interval_time = 2 +vectors = gen_vectors(100, dim) +# vectors /= numpy.linalg.norm(vectors) +# vectors = vectors.tolist() +nrpobe = 1 +epsilon = 0.001 + + +class TestSearchBase: + def init_data(self, connect, table, nb=100): + ''' + Generate vectors and add it in table, before search vectors + ''' + global vectors + if nb == 100: + add_vectors = vectors + else: + add_vectors = gen_vectors(nb, dim) + # add_vectors /= numpy.linalg.norm(add_vectors) + # add_vectors = add_vectors.tolist() + status, ids = connect.add_vectors(table, add_vectors) + sleep(add_interval_time) + return add_vectors, ids + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + """ + generate top-k params + """ + @pytest.fixture( + scope="function", + params=[1, 99, 101, 1024, 2048, 2049] + ) + def get_top_k(self, request): + yield request.param + + + def test_search_top_k_flat_index(self, connect, table, get_top_k): + ''' + target: test basic search fuction, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + vectors, ids = self.init_data(connect, table) + query_vec = [vectors[0]] + top_k = get_top_k + nprobe = 1 + status, result = connect.search_vectors(table, top_k, nrpobe, query_vec) + if top_k <= 2048: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert result[0][0].distance <= epsilon + assert check_result(result[0], ids[0]) + else: + assert not status.OK() + + def test_search_l2_index_params(self, connect, table, get_index_params): + ''' + target: test basic search fuction, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + + index_params = get_index_params + logging.getLogger().info(index_params) + vectors, ids = self.init_data(connect, table) + status = connect.create_index(table, index_params) + query_vec = [vectors[0]] + top_k = 10 + nprobe = 1 + status, result = connect.search_vectors(table, top_k, nrpobe, query_vec) + logging.getLogger().info(result) + if top_k <= 1024: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon + else: + assert not status.OK() + + def test_search_ip_index_params(self, connect, ip_table, get_index_params): + ''' + target: test basic search fuction, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + + index_params = get_index_params + logging.getLogger().info(index_params) + vectors, ids = self.init_data(connect, ip_table) + status = connect.create_index(ip_table, index_params) + query_vec = [vectors[0]] + top_k = 10 + nprobe = 1 + status, result = connect.search_vectors(ip_table, top_k, nrpobe, query_vec) + logging.getLogger().info(result) + + if top_k <= 1024: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance) + else: + assert not status.OK() + + @pytest.mark.level(2) + def test_search_vectors_without_connect(self, dis_connect, table): + ''' + target: test search vectors without connection + method: use dis connected instance, call search method and check if search successfully + expected: raise exception + ''' + query_vectors = [vectors[0]] + top_k = 1 + nprobe = 1 + with pytest.raises(Exception) as e: + status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors) + + def test_search_table_name_not_existed(self, connect, table): + ''' + target: search table not existed + method: search with the random table_name, which is not in db + expected: status not ok + ''' + table_name = gen_unique_str("not_existed_table") + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + assert not status.OK() + + def test_search_table_name_None(self, connect, table): + ''' + target: search table that table name is None + method: search with the table_name: None + expected: status not ok + ''' + table_name = None + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + + def test_search_top_k_query_records(self, connect, table): + ''' + target: test search fuction, with search params: query_records + method: search with the given query_records, which are subarrays of the inserted vectors + expected: status ok and the returned vectors should be query_records + ''' + top_k = 10 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0],vectors[55],vectors[99]] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert len(result[i]) == top_k + assert result[i][0].distance <= epsilon + + """ + generate invalid query range params + """ + @pytest.fixture( + scope="function", + params=[ + (get_current_day(), get_current_day()), + (get_last_day(1), get_last_day(1)), + (get_next_day(1), get_next_day(1)) + ] + ) + def get_invalid_range(self, request): + yield request.param + + def test_search_invalid_query_ranges(self, connect, table, get_invalid_range): + ''' + target: search table with query ranges + method: search with the same query ranges + expected: status not ok + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_invalid_range] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert not status.OK() + assert len(result) == 0 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_last_day(1)), + (get_last_day(2), get_current_day()), + (get_next_day(1), get_next_day(2)) + ] + ) + def get_valid_range_no_result(self, request): + yield request.param + + def test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result): + ''' + target: search table with normal query ranges, but no data in db + method: search with query ranges (low, low) + expected: length of result is 0 + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_valid_range_no_result] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert status.OK() + assert len(result) == 0 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_next_day(2)), + (get_current_day(), get_next_day(2)), + ] + ) + def get_valid_range(self, request): + yield request.param + + def test_search_valid_query_ranges(self, connect, table, get_valid_range): + ''' + target: search table with normal query ranges, but no data in db + method: search with query ranges (low, normal) + expected: length of result is 0 + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_valid_range] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert status.OK() + assert len(result) == 1 + assert result[0][0].distance <= epsilon + + def test_search_distance_l2_flat_index(self, connect, table): + ''' + target: search table, and check the result: distance + method: compare the return distance value with value computed with Euclidean + expected: the return distance equals to the computed value + ''' + nb = 2 + top_k = 1 + nprobe = 1 + vectors, ids = self.init_data(connect, table, nb=nb) + query_vecs = [[0.50 for i in range(dim)]] + distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0])) + distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1])) + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + def test_search_distance_ip_flat_index(self, connect, ip_table): + ''' + target: search ip_table, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nb = 2 + top_k = 1 + nprobe = 1 + vectors, ids = self.init_data(connect, ip_table, nb=nb) + index_params = { + "index_type": IndexType.FLAT, + "nlist": 16384 + } + connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [[0.50 for i in range(dim)]] + distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) + distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params): + ''' + target: search table, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, ip_table, nb=2) + index_params = get_index_params + connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [[0.50 for i in range(dim)]] + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) + distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) + assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(30) + def _test_search_concurrent(self, connect, table): + vectors, ids = self.init_data(connect, table) + thread_num = 10 + nb = 100 + top_k = 10 + threads = [] + query_vecs = vectors[nb//2:nb] + def search(): + status, result = connect.search_vectors(table, top_k, query_vecs) + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert result[i][0].id in ids + assert result[i][0].distance == 0.0 + for i in range(thread_num): + x = threading.Thread(target=search, args=()) + threads.append(x) + x.start() + for th in threads: + th.join() + + # TODO: enable + @pytest.mark.timeout(30) + def _test_search_concurrent_multiprocessing(self, args): + ''' + target: test concurrent search with multiprocessess + method: search with 10 processes, each process uses dependent connection + expected: status ok and the returned vectors should be query_records + ''' + nb = 100 + top_k = 10 + process_num = 4 + processes = [] + table = gen_unique_str("test_search_concurrent_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_type': IndexType.FLAT, + 'store_raw_vector': False} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vectors, ids = self.init_data(milvus, table, nb=nb) + query_vecs = vectors[nb//2:nb] + def search(milvus): + status, result = milvus.search_vectors(table, top_k, query_vecs) + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert result[i][0].id in ids + assert result[i][0].distance == 0.0 + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=search, args=(milvus, )) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_search_multi_table_L2(search, args): + ''' + target: test search multi tables of L2 + method: add vectors into 10 tables, and search + expected: search status ok, the length of result + ''' + num = 10 + top_k = 10 + nprobe = 1 + tables = [] + idx = [] + for i in range(num): + table = gen_unique_str("test_add_multitable_%d" % i) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': MetricType.L2} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + status, ids = milvus.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == len(vectors) + tables.append(table) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + time.sleep(6) + query_vecs = [vectors[0], vectors[10], vectors[20]] + # start query from random table + for i in range(num): + table = tables[i] + status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) + + def test_search_multi_table_IP(search, args): + ''' + target: test search multi tables of IP + method: add vectors into 10 tables, and search + expected: search status ok, the length of result + ''' + num = 10 + top_k = 10 + nprobe = 1 + tables = [] + idx = [] + for i in range(num): + table = gen_unique_str("test_add_multitable_%d" % i) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': MetricType.L2} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + status, ids = milvus.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == len(vectors) + tables.append(table) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + time.sleep(6) + query_vecs = [vectors[0], vectors[10], vectors[20]] + # start query from random table + for i in range(num): + table = tables[i] + status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) +""" +****************************************************************** +# The following cases are used to test `search_vectors` function +# with invalid table_name top-k / nprobe / query_range +****************************************************************** +""" + +class TestSearchParamsInvalid(object): + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + + def init_data(self, connect, table, nb=100): + ''' + Generate vectors and add it in table, before search vectors + ''' + global vectors + if nb == 100: + add_vectors = vectors + else: + add_vectors = gen_vectors(nb, dim) + status, ids = connect.add_vectors(table, add_vectors) + sleep(add_interval_time) + return add_vectors, ids + + """ + Test search table with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + logging.getLogger().info(table_name) + top_k = 1 + nprobe = 1 + query_vecs = gen_vectors(1, dim) + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + assert not status.OK() + + """ + Test search table with invalid top-k + """ + @pytest.fixture( + scope="function", + params=gen_invalid_top_ks() + ) + def get_top_k(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_top_k(self, connect, table, get_top_k): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = get_top_k + logging.getLogger().info(top_k) + nprobe = 1 + query_vecs = gen_vectors(1, dim) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + res = connect.server_version() + + @pytest.mark.level(2) + def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = get_top_k + logging.getLogger().info(top_k) + nprobe = 1 + query_vecs = gen_vectors(1, dim) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + res = connect.server_version() + + """ + Test search table with invalid nprobe + """ + @pytest.fixture( + scope="function", + params=gen_invalid_nprobes() + ) + def get_nprobes(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_nrpobe(self, connect, table, get_nprobes): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = get_nprobes + logging.getLogger().info(nprobe) + query_vecs = gen_vectors(1, dim) + if isinstance(nprobe, int) and nprobe > 0: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + + @pytest.mark.level(2) + def test_search_with_invalid_nrpobe_ip(self, connect, ip_table, get_nprobes): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = get_nprobes + logging.getLogger().info(nprobe) + query_vecs = gen_vectors(1, dim) + if isinstance(nprobe, int) and nprobe > 0: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + + """ + Test search table with invalid query ranges + """ + @pytest.fixture( + scope="function", + params=gen_invalid_query_ranges() + ) + def get_query_ranges(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges): + ''' + target: test search fuction, with the wrong query_range + method: search with query_range + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + query_ranges = get_query_ranges + logging.getLogger().info(query_ranges) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges) + + + @pytest.mark.level(2) + def test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges): + ''' + target: test search fuction, with the wrong query_range + method: search with query_range + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + query_ranges = get_query_ranges + logging.getLogger().info(query_ranges) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges) + + +def check_result(result, id): + if len(result) >= 5: + return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id] + else: + return id in (i.id for i in result) \ No newline at end of file diff --git a/tests/milvus_python_test/test_table.py b/tests/milvus_python_test/test_table.py new file mode 100644 index 0000000000..eb538281ed --- /dev/null +++ b/tests/milvus_python_test/test_table.py @@ -0,0 +1,885 @@ +import random +import pdb +import pytest +import logging +import itertools + +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus +from milvus import IndexType, MetricType +from utils import * + +dim = 128 +delete_table_interval_time = 3 +index_file_size = 10 +vectors = gen_vectors(100, dim) + + +class TestTable: + + """ + ****************************************************************** + The following cases are used to test `create_table` function + ****************************************************************** + """ + + def test_create_table(self, connect): + ''' + target: test create normal table + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert status.OK() + + def test_create_table_ip(self, connect): + ''' + target: test create normal table + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + assert status.OK() + + @pytest.mark.level(2) + def test_create_table_without_connection(self, dis_connect): + ''' + target: test create table, without connection + method: create table with correct params, with a disconnected instance + expected: create raise exception + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = dis_connect.create_table(param) + + def test_create_table_existed(self, connect): + ''' + target: test create table but the table name have already existed + method: create table with the same table_name + expected: create status return not ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_table(param) + assert not status.OK() + + @pytest.mark.level(2) + def test_create_table_existed_ip(self, connect): + ''' + target: test create table but the table name have already existed + method: create table with the same table_name + expected: create status return not ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + status = connect.create_table(param) + assert not status.OK() + + def test_create_table_None(self, connect): + ''' + target: test create table but the table name is None + method: create table, param table_name is None + expected: create raise error + ''' + param = {'table_name': None, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_create_table_no_dimension(self, connect): + ''' + target: test create table with no dimension params + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_create_table_no_file_size(self, connect): + ''' + target: test create table with no index_file_size params + method: create table with corrent params + expected: create status return ok, use default 1024 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + logging.getLogger().info(status) + status, result = connect.describe_table(table_name) + logging.getLogger().info(result) + assert result.index_file_size == 1024 + + def test_create_table_no_metric_type(self, connect): + ''' + target: test create table with no metric_type params + method: create table with corrent params + expected: create status return ok, use default L2 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + status = connect.create_table(param) + status, result = connect.describe_table(table_name) + logging.getLogger().info(result) + assert result.metric_type == MetricType.L2 + + """ + ****************************************************************** + The following cases are used to test `describe_table` function + ****************************************************************** + """ + + def test_table_describe_result(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.table_name == table_name + assert res.metric_type == MetricType.L2 + + def test_table_describe_table_name_ip(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.table_name == table_name + assert res.metric_type == MetricType.IP + + # TODO: enable + @pytest.mark.level(2) + def _test_table_describe_table_name_multiprocessing(self, connect, args): + ''' + target: test describe table created with multiprocess + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + + def describetable(milvus): + status, res = milvus.describe_table(table_name) + assert res.table_name == table_name + + process_num = 4 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=describetable, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + @pytest.mark.level(2) + def test_table_describe_without_connection(self, table, dis_connect): + ''' + target: test describe table, without connection + method: describe table with correct params, with a disconnected instance + expected: describe raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_table(table) + + def test_table_describe_dimension(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the dimention value returned by describe method + expected: dimention equals with dimention when created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim+1, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.dimension == dim+1 + + """ + ****************************************************************** + The following cases are used to test `delete_table` function + ****************************************************************** + """ + + def test_delete_table(self, connect, table): + ''' + target: test delete table created with correct params + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + status = connect.delete_table(table) + assert not assert_has_table(connect, table) + + def test_delete_table_ip(self, connect, ip_table): + ''' + target: test delete table created with correct params + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + status = connect.delete_table(ip_table) + assert not assert_has_table(connect, ip_table) + + @pytest.mark.level(2) + def test_table_delete_without_connection(self, table, dis_connect): + ''' + target: test describe table, without connection + method: describe table with correct params, with a disconnected instance + expected: describe raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.delete_table(table) + + def test_delete_table_not_existed(self, connect): + ''' + target: test delete table not in index + method: delete all tables, and delete table again, + assert the value returned by delete method + expected: status not ok + ''' + table_name = gen_unique_str("test_table") + status = connect.delete_table(table_name) + assert not status.code==0 + + def test_delete_table_repeatedly(self, connect): + ''' + target: test delete table created with correct params + method: create table and delete new table repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 1 + for i in range(loops): + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(1) + assert not assert_has_table(connect, table_name) + + def test_delete_create_table_repeatedly(self, connect): + ''' + target: test delete and create the same table repeatedly + method: try to create the same table and delete repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 5 + for i in range(loops): + table_name = "test_table" + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(2) + assert status.OK() + + @pytest.mark.level(2) + def test_delete_create_table_repeatedly_ip(self, connect): + ''' + target: test delete and create the same table repeatedly + method: try to create the same table and delete repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 5 + for i in range(loops): + table_name = "test_table" + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(2) + assert status.OK() + + # TODO: enable + @pytest.mark.level(2) + def _test_delete_table_multiprocessing(self, args): + ''' + target: test delete table with multiprocess + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + process_num = 6 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + def deletetable(milvus): + status = milvus.delete_table(table) + # assert not status.code==0 + assert assert_has_table(milvus, table) + assert status.OK() + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=deletetable, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + # TODO: enable + @pytest.mark.level(2) + def _test_delete_table_multiprocessing_multitable(self, connect): + ''' + target: test delete table with multiprocess + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + process_num = 5 + loop_num = 2 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_delete_table_with_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + j = j + 1 + + def delete(connect,ids): + i = 0 + while i < loop_num: + status = connect.delete_table(table[ids*process_num+i]) + time.sleep(2) + assert status.OK() + assert not assert_has_table(connect, table[ids*process_num+i]) + i = i + 1 + + for i in range(process_num): + ids = i + p = Process(target=delete, args=(connect,ids)) + processes.append(p) + p.start() + for p in processes: + p.join() + + """ + ****************************************************************** + The following cases are used to test `has_table` function + ****************************************************************** + """ + + def test_has_table(self, connect): + ''' + target: test if the created table existed + method: create table, assert the value returned by has_table method + expected: True + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + assert assert_has_table(connect, table_name) + + def test_has_table_ip(self, connect): + ''' + target: test if the created table existed + method: create table, assert the value returned by has_table method + expected: True + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + assert assert_has_table(connect, table_name) + + @pytest.mark.level(2) + def test_has_table_without_connection(self, table, dis_connect): + ''' + target: test has table, without connection + method: calling has table with correct params, with a disconnected instance + expected: has table raise exception + ''' + with pytest.raises(Exception) as e: + assert_has_table(dis_connect, table) + + def test_has_table_not_existed(self, connect): + ''' + target: test if table not created + method: random a table name, which not existed in db, + assert the value returned by has_table method + expected: False + ''' + table_name = gen_unique_str("test_table") + assert not assert_has_table(connect, table_name) + + """ + ****************************************************************** + The following cases are used to test `show_tables` function + ****************************************************************** + """ + + def test_show_tables(self, connect): + ''' + target: test show tables is correct or not, if table created + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, result = connect.show_tables() + assert status.OK() + assert table_name in result + + def test_show_tables_ip(self, connect): + ''' + target: test show tables is correct or not, if table created + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, result = connect.show_tables() + assert status.OK() + assert table_name in result + + @pytest.mark.level(2) + def test_show_tables_without_connection(self, dis_connect): + ''' + target: test show_tables, without connection + method: calling show_tables with correct params, with a disconnected instance + expected: show_tables raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.show_tables() + + def test_show_tables_no_table(self, connect): + ''' + target: test show tables is correct or not, if no table in db + method: delete all tables, + assert the value returned by show_tables method is equal to [] + expected: the status is ok, and the result is equal to [] + ''' + status, result = connect.show_tables() + if result: + for table_name in result: + connect.delete_table(table_name) + time.sleep(delete_table_interval_time) + status, result = connect.show_tables() + assert status.OK() + assert len(result) == 0 + + # TODO: enable + @pytest.mark.level(2) + def _test_show_tables_multiprocessing(self, connect, args): + ''' + target: test show tables is correct or not with processes + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + + def showtables(milvus): + status, result = milvus.show_tables() + assert status.OK() + assert table_name in result + + process_num = 8 + processes = [] + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=showtables, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + """ + ****************************************************************** + The following cases are used to test `preload_table` function + ****************************************************************** + """ + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + @pytest.mark.level(1) + def test_preload_table(self, connect, table, get_index_params): + index_params = get_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status = connect.preload_table(table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_ip(self, connect, ip_table, get_index_params): + index_params = get_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status = connect.preload_table(ip_table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_not_existed(self, connect, table): + table_name = gen_unique_str("test_preload_table_not_existed") + index_params = random.choice(gen_index_params()) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status = connect.preload_table(table_name) + assert not status.OK() + + @pytest.mark.level(1) + def test_preload_table_not_existed_ip(self, connect, ip_table): + table_name = gen_unique_str("test_preload_table_not_existed") + index_params = random.choice(gen_index_params()) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status = connect.preload_table(table_name) + assert not status.OK() + + @pytest.mark.level(1) + def test_preload_table_no_vectors(self, connect, table): + status = connect.preload_table(table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_no_vectors_ip(self, connect, ip_table): + status = connect.preload_table(ip_table) + assert status.OK() + + # TODO: psutils get memory usage + @pytest.mark.level(1) + def test_preload_table_memory_usage(self, connect, table): + pass + + +class TestTableInvalid(object): + """ + Test creating table with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + def test_create_table_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_create_table_with_empty_tablename(self, connect): + table_name = '' + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_preload_table_with_invalid_tablename(self, connect): + table_name = '' + with pytest.raises(Exception) as e: + status = connect.preload_table(table_name) + + +class TestCreateTableDimInvalid(object): + """ + Test creating table with invalid dimension + """ + @pytest.fixture( + scope="function", + params=gen_invalid_dims() + ) + def get_dim(self, request): + yield request.param + + @pytest.mark.timeout(5) + def test_create_table_with_invalid_dimension(self, connect, get_dim): + dimension = get_dim + table = gen_unique_str("test_create_table_with_invalid_dimension") + param = {'table_name': table, + 'dimension': dimension, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + if isinstance(dimension, int): + status = connect.create_table(param) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +# TODO: max / min index file size +class TestCreateTableIndexSizeInvalid(object): + """ + Test creating tables with invalid index_file_size + """ + @pytest.fixture( + scope="function", + params=gen_invalid_file_sizes() + ) + def get_file_size(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_table_with_invalid_file_size(self, connect, table, get_file_size): + file_size = get_file_size + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': file_size, + 'metric_type': MetricType.L2} + if isinstance(file_size, int) and file_size > 0: + status = connect.create_table(param) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +class TestCreateMetricTypeInvalid(object): + """ + Test creating tables with invalid metric_type + """ + @pytest.fixture( + scope="function", + params=gen_invalid_metric_types() + ) + def get_metric_type(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_table_with_invalid_file_size(self, connect, table, get_metric_type): + metric_type = get_metric_type + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': metric_type} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +def create_table(connect, **params): + param = {'table_name': params["table_name"], + 'dimension': params["dimension"], + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + return status + +def search_table(connect, **params): + status, result = connect.search_vectors( + params["table_name"], + params["top_k"], + params["nprobe"], + params["query_vectors"]) + return status + +def preload_table(connect, **params): + status = connect.preload_table(params["table_name"]) + return status + +def has(connect, **params): + status = assert_has_table(connect, params["table_name"]) + return status + +def show(connect, **params): + status, result = connect.show_tables() + return status + +def delete(connect, **params): + status = connect.delete_table(params["table_name"]) + return status + +def describe(connect, **params): + status, result = connect.describe_table(params["table_name"]) + return status + +def rowcount(connect, **params): + status, result = connect.get_table_row_count(params["table_name"]) + return status + +def create_index(connect, **params): + status = connect.create_index(params["table_name"], params["index_params"]) + return status + +func_map = { + # 0:has, + 1:show, + 10:create_table, + 11:describe, + 12:rowcount, + 13:search_table, + 14:preload_table, + 15:create_index, + 30:delete +} + +def gen_sequence(): + raw_seq = func_map.keys() + result = itertools.permutations(raw_seq) + for x in result: + yield x + +class TestTableLogic(object): + + @pytest.mark.parametrize("logic_seq", gen_sequence()) + @pytest.mark.level(2) + def test_logic(self, connect, logic_seq): + if self.is_right(logic_seq): + self.execute(logic_seq, connect) + else: + self.execute_with_error(logic_seq, connect) + + def is_right(self, seq): + if sorted(seq) == seq: + return True + + not_created = True + has_deleted = False + for i in range(len(seq)): + if seq[i] > 10 and not_created: + return False + elif seq [i] > 10 and has_deleted: + return False + elif seq[i] == 10: + not_created = False + elif seq[i] == 30: + has_deleted = True + + return True + + def execute(self, logic_seq, connect): + basic_params = self.gen_params() + for i in range(len(logic_seq)): + # logging.getLogger().info(logic_seq[i]) + f = func_map[logic_seq[i]] + status = f(connect, **basic_params) + assert status.OK() + + def execute_with_error(self, logic_seq, connect): + basic_params = self.gen_params() + + error_flag = False + for i in range(len(logic_seq)): + f = func_map[logic_seq[i]] + status = f(connect, **basic_params) + if not status.OK(): + # logging.getLogger().info(logic_seq[i]) + error_flag = True + break + assert error_flag == True + + def gen_params(self): + table_name = gen_unique_str("test_table") + top_k = 1 + vectors = gen_vectors(2, dim) + param = {'table_name': table_name, + 'dimension': dim, + 'index_type': IndexType.IVFLAT, + 'metric_type': MetricType.L2, + 'nprobe': 1, + 'top_k': top_k, + 'index_params': { + 'index_type': IndexType.IVF_SQ8, + 'nlist': 16384 + }, + 'query_vectors': vectors} + return param diff --git a/tests/milvus_python_test/test_table_count.py b/tests/milvus_python_test/test_table_count.py new file mode 100644 index 0000000000..820fb9d546 --- /dev/null +++ b/tests/milvus_python_test/test_table_count.py @@ -0,0 +1,302 @@ +import random +import pdb + +import pytest +import logging +import itertools + +from time import sleep +from multiprocessing import Process +from milvus import Milvus +from utils import * +from milvus import IndexType, MetricType + +dim = 128 +index_file_size = 10 +add_time_interval = 5 + + +class TestTableCount: + """ + params means different nb, the nb value may trigger merge, or not + """ + @pytest.fixture( + scope="function", + params=[ + 100, + 5000, + 100000, + ], + ) + def add_vectors_nb(self, request): + yield request.param + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + def test_table_rows_count(self, connect, table, add_vectors_nb): + ''' + target: test table rows_count is correct or not + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nb = add_vectors_nb + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + status, res = connect.get_table_row_count(table) + assert res == nb + + def test_table_rows_count_after_index_created(self, connect, table, get_index_params): + ''' + target: test get_table_row_count, after index have been created + method: add vectors in db, and create index, then calling get_table_row_count with correct params + expected: get_table_row_count raise exception + ''' + nb = 100 + index_params = get_index_params + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + # logging.getLogger().info(index_params) + connect.create_index(table, index_params) + status, res = connect.get_table_row_count(table) + assert res == nb + + @pytest.mark.level(2) + def test_count_without_connection(self, table, dis_connect): + ''' + target: test get_table_row_count, without connection + method: calling get_table_row_count with correct params, with a disconnected instance + expected: get_table_row_count raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.get_table_row_count(table) + + def test_table_rows_count_no_vectors(self, connect, table): + ''' + target: test table rows_count is correct or not, if table is empty + method: create table and no vectors in it, + assert the value returned by get_table_row_count method is equal to 0 + expected: the count is equal to 0 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + connect.create_table(param) + status, res = connect.get_table_row_count(table) + assert res == 0 + + # TODO: enable + @pytest.mark.level(2) + @pytest.mark.timeout(20) + def _test_table_rows_count_multiprocessing(self, connect, table, args): + ''' + target: test table rows_count is correct or not with multiprocess + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 2 + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + vectors = gen_vectors(nq, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + + def rows_count(milvus): + status, res = milvus.get_table_row_count(table) + logging.getLogger().info(status) + assert res == nq + + process_num = 8 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=rows_count, args=(milvus, )) + processes.append(p) + p.start() + logging.getLogger().info(p) + for p in processes: + p.join() + + def test_table_rows_count_multi_tables(self, connect): + ''' + target: test table rows_count is correct or not with multiple tables of L2 + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_table_rows_count_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + res = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + for i in range(50): + status, res = connect.get_table_row_count(table_list[i]) + assert status.OK() + assert res == nq + + +class TestTableCountIP: + """ + params means different nb, the nb value may trigger merge, or not + """ + + @pytest.fixture( + scope="function", + params=[ + 100, + 5000, + 100000, + ], + ) + def add_vectors_nb(self, request): + yield request.param + + """ + generate valid create_index params + """ + + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request, args): + if "internal" not in args: + if request.param["index_type"] == IndexType.IVF_SQ8H: + pytest.skip("sq8h not support in open source") + return request.param + + def test_table_rows_count(self, connect, ip_table, add_vectors_nb): + ''' + target: test table rows_count is correct or not + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nb = add_vectors_nb + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + status, res = connect.get_table_row_count(ip_table) + assert res == nb + + def test_table_rows_count_after_index_created(self, connect, ip_table, get_index_params): + ''' + target: test get_table_row_count, after index have been created + method: add vectors in db, and create index, then calling get_table_row_count with correct params + expected: get_table_row_count raise exception + ''' + nb = 100 + index_params = get_index_params + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + # logging.getLogger().info(index_params) + connect.create_index(ip_table, index_params) + status, res = connect.get_table_row_count(ip_table) + assert res == nb + + @pytest.mark.level(2) + def test_count_without_connection(self, ip_table, dis_connect): + ''' + target: test get_table_row_count, without connection + method: calling get_table_row_count with correct params, with a disconnected instance + expected: get_table_row_count raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.get_table_row_count(ip_table) + + def test_table_rows_count_no_vectors(self, connect, ip_table): + ''' + target: test table rows_count is correct or not, if table is empty + method: create table and no vectors in it, + assert the value returned by get_table_row_count method is equal to 0 + expected: the count is equal to 0 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + connect.create_table(param) + status, res = connect.get_table_row_count(ip_table) + assert res == 0 + + # TODO: enable + @pytest.mark.level(2) + @pytest.mark.timeout(20) + def _test_table_rows_count_multiprocessing(self, connect, ip_table, args): + ''' + target: test table rows_count is correct or not with multiprocess + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 2 + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + vectors = gen_vectors(nq, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + + def rows_count(milvus): + status, res = milvus.get_table_row_count(ip_table) + logging.getLogger().info(status) + assert res == nq + + process_num = 8 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=rows_count, args=(milvus,)) + processes.append(p) + p.start() + logging.getLogger().info(p) + for p in processes: + p.join() + + def test_table_rows_count_multi_tables(self, connect): + ''' + target: test table rows_count is correct or not with multiple tables of IP + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_table_rows_count_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + res = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + for i in range(50): + status, res = connect.get_table_row_count(table_list[i]) + assert status.OK() + assert res == nq \ No newline at end of file diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py new file mode 100644 index 0000000000..806af62f57 --- /dev/null +++ b/tests/milvus_python_test/utils.py @@ -0,0 +1,551 @@ +# STL imports +import random +import string +import struct +import sys +import time, datetime +import copy +import numpy as np +from utils import * +from milvus import Milvus, IndexType, MetricType + + +def gen_inaccuracy(num): + return num/255.0 + +def gen_vectors(num, dim): + return [[random.random() for _ in range(dim)] for _ in range(num)] + + +def gen_single_vector(dim): + return [[random.random() for _ in range(dim)]] + + +def gen_vector(nb, d, seed=np.random.RandomState(1234)): + xb = seed.rand(nb, d).astype("float32") + return xb.tolist() + + +def gen_unique_str(str=None): + prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) + return prefix if str is None else str + "_" + prefix + + +def get_current_day(): + return time.strftime('%Y-%m-%d', time.localtime()) + + +def get_last_day(day): + tmp = datetime.datetime.now()-datetime.timedelta(days=day) + return tmp.strftime('%Y-%m-%d') + + +def get_next_day(day): + tmp = datetime.datetime.now()+datetime.timedelta(days=day) + return tmp.strftime('%Y-%m-%d') + + +def gen_long_str(num): + string = '' + for _ in range(num): + char = random.choice('tomorrow') + string += char + + +def gen_invalid_ips(): + ips = [ + "255.0.0.0", + "255.255.0.0", + "255.255.255.0", + "255.255.255.255", + "127.0.0", + "123.0.0.2", + "12-s", + " ", + "12 s", + "BB。A", + " siede ", + "(mn)", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return ips + + +def gen_invalid_ports(): + ports = [ + # empty + " ", + -1, + # too big port + 100000, + # not correct port + 39540, + "BB。A", + " siede ", + "(mn)", + "\n", + "\t", + "中文" + ] + return ports + + +def gen_invalid_uris(): + ip = None + port = 19530 + + uris = [ + " ", + "中文", + + # invalid protocol + # "tc://%s:%s" % (ip, port), + # "tcp%s:%s" % (ip, port), + + # # invalid port + # "tcp://%s:100000" % ip, + # "tcp://%s: " % ip, + # "tcp://%s:19540" % ip, + # "tcp://%s:-1" % ip, + # "tcp://%s:string" % ip, + + # invalid ip + "tcp:// :%s" % port, + "tcp://123.0.0.1:%s" % port, + "tcp://127.0.0:%s" % port, + "tcp://255.0.0.0:%s" % port, + "tcp://255.255.0.0:%s" % port, + "tcp://255.255.255.0:%s" % port, + "tcp://255.255.255.255:%s" % port, + "tcp://\n:%s" % port, + + ] + return uris + + +def gen_invalid_table_names(): + table_names = [ + "12-s", + "12/s", + " ", + # "", + # None, + "12 s", + "BB。A", + "c|c", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return table_names + + +def gen_invalid_top_ks(): + top_ks = [ + 0, + -1, + None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return top_ks + + +def gen_invalid_dims(): + dims = [ + 0, + -1, + 100001, + 1000000000000001, + None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return dims + + +def gen_invalid_file_sizes(): + file_sizes = [ + 0, + -1, + 1000000000000001, + None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return file_sizes + + +def gen_invalid_index_types(): + invalid_types = [ + 0, + -1, + 100, + 1000000000000001, + # None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return invalid_types + + +def gen_invalid_nlists(): + nlists = [ + 0, + -1, + 1000000000000001, + # None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return nlists + + +def gen_invalid_nprobes(): + nprobes = [ + 0, + -1, + 1000000000000001, + None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return nprobes + + +def gen_invalid_metric_types(): + metric_types = [ + 0, + -1, + 1000000000000001, + # None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return metric_types + + +def gen_invalid_vectors(): + invalid_vectors = [ + "1*2", + [], + [1], + [1,2], + [" "], + ['a'], + [None], + None, + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return invalid_vectors + + +def gen_invalid_vector_ids(): + invalid_vector_ids = [ + 1.0, + -1.0, + None, + # int 64 + 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, + " ", + "", + "String", + "BB。A", + " siede ", + "(mn)", + "#12s", + "=c", + "\n", + "中文", + ] + return invalid_vector_ids + + + +def gen_invalid_query_ranges(): + query_ranges = [ + [(get_last_day(1), "")], + [(get_current_day(), "")], + [(get_next_day(1), "")], + [(get_current_day(), get_last_day(1))], + [(get_next_day(1), get_last_day(1))], + [(get_next_day(1), get_current_day())], + [(0, get_next_day(1))], + [(-1, get_next_day(1))], + [(1, get_next_day(1))], + [(100001, get_next_day(1))], + [(1000000000000001, get_next_day(1))], + [(None, get_next_day(1))], + [([1,2,3], get_next_day(1))], + [((1,2), get_next_day(1))], + [({"a": 1}, get_next_day(1))], + [(" ", get_next_day(1))], + [("", get_next_day(1))], + [("String", get_next_day(1))], + [("12-s", get_next_day(1))], + [("BB。A", get_next_day(1))], + [(" siede ", get_next_day(1))], + [("(mn)", get_next_day(1))], + [("#12s", get_next_day(1))], + [("pip+", get_next_day(1))], + [("=c", get_next_day(1))], + [("\n", get_next_day(1))], + [("\t", get_next_day(1))], + [("中文", get_next_day(1))], + [("a".join("a" for i in range(256)), get_next_day(1))] + ] + return query_ranges + + +def gen_invalid_index_params(): + index_params = [] + for index_type in gen_invalid_index_types(): + index_param = {"index_type": index_type, "nlist": 16384} + index_params.append(index_param) + for nlist in gen_invalid_nlists(): + index_param = {"index_type": IndexType.IVFLAT, "nlist": nlist} + index_params.append(index_param) + return index_params + + +def gen_index_params(): + index_params = [] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + nlists = [1, 16384, 50000] + + def gen_params(index_types, nlists): + return [ {"index_type": index_type, "nlist": nlist} \ + for index_type in index_types \ + for nlist in nlists] + + return gen_params(index_types, nlists) + +def gen_simple_index_params(): + index_params = [] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + nlists = [16384] + + def gen_params(index_types, nlists): + return [ {"index_type": index_type, "nlist": nlist} \ + for index_type in index_types \ + for nlist in nlists] + + return gen_params(index_types, nlists) + + +def assert_has_table(conn, table_name): + status, ok = conn.has_table(table_name) + return status.OK() and ok + + +if __name__ == "__main__": + import numpy + + dim = 128 + nq = 10000 + table = "test" + + + file_name = '/poc/yuncong/ann_1000m/query.npy' + data = np.load(file_name) + vectors = data[0:nq].tolist() + # print(vectors) + + connect = Milvus() + # connect.connect(host="192.168.1.27") + # print(connect.show_tables()) + # print(connect.get_table_row_count(table)) + # sys.exit() + connect.connect(host="127.0.0.1") + connect.delete_table(table) + # sys.exit() + # time.sleep(2) + print(connect.get_table_row_count(table)) + param = {'table_name': table, + 'dimension': dim, + 'metric_type': MetricType.L2, + 'index_file_size': 10} + status = connect.create_table(param) + print(status) + print(connect.get_table_row_count(table)) + # add vectors + for i in range(10): + status, ids = connect.add_vectors(table, vectors) + print(status) + print(ids[0]) + # print(ids[0]) + index_params = {"index_type": IndexType.IVFLAT, "nlist": 16384} + status = connect.create_index(table, index_params) + print(status) + # sys.exit() + query_vec = [vectors[0]] + # print(numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) + top_k = 12 + nprobe = 1 + for i in range(2): + result = connect.search_vectors(table, top_k, nprobe, query_vec) + print(result) + sys.exit() + + + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num \ No newline at end of file