From 3b0ca7160283b23dd4b855a603e1182b26c6adc8 Mon Sep 17 00:00:00 2001 From: JinHai-CN Date: Wed, 16 Oct 2019 18:40:31 +0800 Subject: [PATCH] #18 Add all test cases Former-commit-id: ac930b6af9c664da4382e97722fed11a70bb2c99 --- CHANGELOG.md | 1 + tests/milvus-java-test/.gitignore | 4 + tests/milvus-java-test/bin/run.sh | 0 .../ci/function/file_transfer.groovy | 10 + .../ci/jenkinsfile/cleanup.groovy | 13 + .../ci/jenkinsfile/deploy_server.groovy | 16 + .../ci/jenkinsfile/integration_test.groovy | 13 + .../ci/jenkinsfile/notify.groovy | 15 + .../jenkinsfile/upload_unit_test_out.groovy | 13 + tests/milvus-java-test/ci/main_jenkinsfile | 110 ++ .../pod_containers/milvus-testframework.yaml | 13 + tests/milvus-java-test/milvus-java-test.iml | 2 + tests/milvus-java-test/pom.xml | 137 ++ .../src/main/java/com/MainClass.java | 147 ++ .../src/main/java/com/TestAddVectors.java | 154 ++ .../src/main/java/com/TestConnect.java | 80 ++ .../src/main/java/com/TestDeleteVectors.java | 122 ++ .../src/main/java/com/TestIndex.java | 340 +++++ .../src/main/java/com/TestMix.java | 221 +++ .../src/main/java/com/TestPing.java | 28 + .../src/main/java/com/TestSearchVectors.java | 480 +++++++ .../src/main/java/com/TestTable.java | 155 +++ .../src/main/java/com/TestTableCount.java | 89 ++ tests/milvus-java-test/testng.xml | 8 + tests/milvus_ann_acc/.gitignore | 2 + tests/milvus_ann_acc/client.py | 149 ++ tests/milvus_ann_acc/config.yaml | 17 + tests/milvus_ann_acc/main.py | 26 + tests/milvus_ann_acc/test.py | 132 ++ tests/milvus_benchmark/.gitignore | 8 + tests/milvus_benchmark/README.md | 57 + tests/milvus_benchmark/__init__.py | 0 tests/milvus_benchmark/client.py | 244 ++++ tests/milvus_benchmark/conf/log_config.conf | 28 + .../milvus_benchmark/conf/server_config.yaml | 28 + .../conf/server_config.yaml.cpu | 31 + .../conf/server_config.yaml.multi | 33 + .../conf/server_config.yaml.single | 32 + tests/milvus_benchmark/demo.py | 51 + tests/milvus_benchmark/docker_runner.py | 261 ++++ tests/milvus_benchmark/local_runner.py | 132 ++ tests/milvus_benchmark/main.py | 131 ++ tests/milvus_benchmark/operation.py | 10 + tests/milvus_benchmark/parser.py | 66 + tests/milvus_benchmark/report.py | 10 + tests/milvus_benchmark/requirements.txt | 6 + tests/milvus_benchmark/runner.py | 219 +++ tests/milvus_benchmark/suites.yaml | 38 + tests/milvus_benchmark/suites_accuracy.yaml | 121 ++ .../milvus_benchmark/suites_performance.yaml | 258 ++++ tests/milvus_benchmark/suites_stability.yaml | 17 + tests/milvus_benchmark/suites_yzb.yaml | 171 +++ tests/milvus_benchmark/utils.py | 194 +++ tests/milvus_python_test/.dockerignore | 14 + tests/milvus_python_test/.gitignore | 13 + tests/milvus_python_test/Dockerfile | 14 + tests/milvus_python_test/MilvusCases.md | 143 ++ tests/milvus_python_test/README.md | 14 + tests/milvus_python_test/conf/log_config.conf | 27 + .../conf/server_config.yaml | 32 + tests/milvus_python_test/conftest.py | 128 ++ tests/milvus_python_test/docker-entrypoint.sh | 9 + tests/milvus_python_test/pytest.ini | 9 + tests/milvus_python_test/requirements.txt | 25 + .../requirements_cluster.txt | 25 + .../requirements_no_pymilvus.txt | 24 + tests/milvus_python_test/run.sh | 4 + tests/milvus_python_test/test.template | 41 + tests/milvus_python_test/test_add_vectors.py | 1233 +++++++++++++++++ tests/milvus_python_test/test_connect.py | 386 ++++++ .../milvus_python_test/test_delete_vectors.py | 419 ++++++ tests/milvus_python_test/test_index.py | 966 +++++++++++++ tests/milvus_python_test/test_mix.py | 180 +++ tests/milvus_python_test/test_ping.py | 77 + .../milvus_python_test/test_search_vectors.py | 650 +++++++++ tests/milvus_python_test/test_table.py | 883 ++++++++++++ tests/milvus_python_test/test_table_count.py | 296 ++++ tests/milvus_python_test/utils.py | 545 ++++++++ 78 files changed, 10800 insertions(+) create mode 100644 tests/milvus-java-test/.gitignore create mode 100644 tests/milvus-java-test/bin/run.sh create mode 100644 tests/milvus-java-test/ci/function/file_transfer.groovy create mode 100644 tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy create mode 100644 tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy create mode 100644 tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy create mode 100644 tests/milvus-java-test/ci/jenkinsfile/notify.groovy create mode 100644 tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy create mode 100644 tests/milvus-java-test/ci/main_jenkinsfile create mode 100644 tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml create mode 100644 tests/milvus-java-test/milvus-java-test.iml create mode 100644 tests/milvus-java-test/pom.xml create mode 100644 tests/milvus-java-test/src/main/java/com/MainClass.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestAddVectors.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestConnect.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestIndex.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestMix.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestPing.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestSearchVectors.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestTable.java create mode 100644 tests/milvus-java-test/src/main/java/com/TestTableCount.java create mode 100644 tests/milvus-java-test/testng.xml create mode 100644 tests/milvus_ann_acc/.gitignore create mode 100644 tests/milvus_ann_acc/client.py create mode 100644 tests/milvus_ann_acc/config.yaml create mode 100644 tests/milvus_ann_acc/main.py create mode 100644 tests/milvus_ann_acc/test.py create mode 100644 tests/milvus_benchmark/.gitignore create mode 100644 tests/milvus_benchmark/README.md create mode 100644 tests/milvus_benchmark/__init__.py create mode 100644 tests/milvus_benchmark/client.py create mode 100644 tests/milvus_benchmark/conf/log_config.conf create mode 100644 tests/milvus_benchmark/conf/server_config.yaml create mode 100644 tests/milvus_benchmark/conf/server_config.yaml.cpu create mode 100644 tests/milvus_benchmark/conf/server_config.yaml.multi create mode 100644 tests/milvus_benchmark/conf/server_config.yaml.single create mode 100644 tests/milvus_benchmark/demo.py create mode 100644 tests/milvus_benchmark/docker_runner.py create mode 100644 tests/milvus_benchmark/local_runner.py create mode 100644 tests/milvus_benchmark/main.py create mode 100644 tests/milvus_benchmark/operation.py create mode 100644 tests/milvus_benchmark/parser.py create mode 100644 tests/milvus_benchmark/report.py create mode 100644 tests/milvus_benchmark/requirements.txt create mode 100644 tests/milvus_benchmark/runner.py create mode 100644 tests/milvus_benchmark/suites.yaml create mode 100644 tests/milvus_benchmark/suites_accuracy.yaml create mode 100644 tests/milvus_benchmark/suites_performance.yaml create mode 100644 tests/milvus_benchmark/suites_stability.yaml create mode 100644 tests/milvus_benchmark/suites_yzb.yaml create mode 100644 tests/milvus_benchmark/utils.py create mode 100644 tests/milvus_python_test/.dockerignore create mode 100644 tests/milvus_python_test/.gitignore create mode 100644 tests/milvus_python_test/Dockerfile create mode 100644 tests/milvus_python_test/MilvusCases.md create mode 100644 tests/milvus_python_test/README.md create mode 100644 tests/milvus_python_test/conf/log_config.conf create mode 100644 tests/milvus_python_test/conf/server_config.yaml create mode 100644 tests/milvus_python_test/conftest.py create mode 100755 tests/milvus_python_test/docker-entrypoint.sh create mode 100644 tests/milvus_python_test/pytest.ini create mode 100644 tests/milvus_python_test/requirements.txt create mode 100644 tests/milvus_python_test/requirements_cluster.txt create mode 100644 tests/milvus_python_test/requirements_no_pymilvus.txt create mode 100644 tests/milvus_python_test/run.sh create mode 100644 tests/milvus_python_test/test.template create mode 100644 tests/milvus_python_test/test_add_vectors.py create mode 100644 tests/milvus_python_test/test_connect.py create mode 100644 tests/milvus_python_test/test_delete_vectors.py create mode 100644 tests/milvus_python_test/test_index.py create mode 100644 tests/milvus_python_test/test_mix.py create mode 100644 tests/milvus_python_test/test_ping.py create mode 100644 tests/milvus_python_test/test_search_vectors.py create mode 100644 tests/milvus_python_test/test_table.py create mode 100644 tests/milvus_python_test/test_table_count.py create mode 100644 tests/milvus_python_test/utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index eda2e7fda2..7be9f8a436 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-624 - Re-organize project directory for open-source - MS-635 - Add compile option to support customized faiss - MS-660 - add ubuntu_build_deps.sh +- \#18 - Add all test cases # Milvus 0.4.0 (2019-09-12) diff --git a/tests/milvus-java-test/.gitignore b/tests/milvus-java-test/.gitignore new file mode 100644 index 0000000000..3a553813a8 --- /dev/null +++ b/tests/milvus-java-test/.gitignore @@ -0,0 +1,4 @@ +target/ +.idea/ +test-output/ +lib/* diff --git a/tests/milvus-java-test/bin/run.sh b/tests/milvus-java-test/bin/run.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/milvus-java-test/ci/function/file_transfer.groovy b/tests/milvus-java-test/ci/function/file_transfer.groovy new file mode 100644 index 0000000000..bebae14832 --- /dev/null +++ b/tests/milvus-java-test/ci/function/file_transfer.groovy @@ -0,0 +1,10 @@ +def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) { + if (protocol == "ftp") { + ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [ + [configName: "${remoteIP}", transfers: [ + [asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true + ] + ] + } +} +return this diff --git a/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy new file mode 100644 index 0000000000..2e9332fa6e --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy @@ -0,0 +1,13 @@ +try { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } +} catch (exc) { + def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true + if (!result) { + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + } + throw exc +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy new file mode 100644 index 0000000000..6650b7c219 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy @@ -0,0 +1,16 @@ +try { + sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts' + sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus' + sh 'helm repo update' + dir ("milvus-helm") { + checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]]) + dir ("milvus/milvus-gpu") { + sh "helm install --wait --timeout 300 --set engine.image.tag=${IMAGE_TAG} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-sdk-test --version 0.3.1 ." + } + } +} catch (exc) { + echo 'Helm running failed!' + sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}" + throw exc +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy new file mode 100644 index 0000000000..662c93bc77 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy @@ -0,0 +1,13 @@ +timeout(time: 30, unit: 'MINUTES') { + try { + dir ("milvus-java-test") { + sh "mvn clean install" + sh "java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-sdk-test.svc.cluster.local" + } + + } catch (exc) { + echo 'Milvus-SDK-Java Integration Test Failed !' + throw exc + } +} + diff --git a/tests/milvus-java-test/ci/jenkinsfile/notify.groovy b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy new file mode 100644 index 0000000000..0a257b8cd8 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy @@ -0,0 +1,15 @@ +def notify() { + if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + // Send an email only if the build status has changed from green/unstable to red + emailext subject: '$DEFAULT_SUBJECT', + body: '$DEFAULT_CONTENT', + recipientProviders: [ + [$class: 'DevelopersRecipientProvider'], + [$class: 'RequesterRecipientProvider'] + ], + replyTo: '$DEFAULT_REPLYTO', + to: '$DEFAULT_RECIPIENTS' + } +} +return this + diff --git a/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy new file mode 100644 index 0000000000..7e106c0296 --- /dev/null +++ b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy @@ -0,0 +1,13 @@ +timeout(time: 5, unit: 'MINUTES') { + dir ("${PROJECT_NAME}_test") { + if (fileExists('test_out')) { + def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy" + fileTransfer.FileTransfer("test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage') + if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) { + echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\"" + } + } else { + error("Milvus Dev Test Out directory don't exists!") + } + } +} diff --git a/tests/milvus-java-test/ci/main_jenkinsfile b/tests/milvus-java-test/ci/main_jenkinsfile new file mode 100644 index 0000000000..5df9d61ccb --- /dev/null +++ b/tests/milvus-java-test/ci/main_jenkinsfile @@ -0,0 +1,110 @@ +pipeline { + agent none + + options { + timestamps() + } + + environment { + SRC_BRANCH = "master" + IMAGE_TAG = "${params.IMAGE_TAG}-release" + HELM_BRANCH = "${params.IMAGE_TAG}" + TEST_URL = "git@192.168.1.105:Test/milvus-java-test.git" + TEST_BRANCH = "${params.IMAGE_TAG}" + } + + stages { + stage("Setup env") { + agent { + kubernetes { + label 'dev-test' + defaultContainer 'jnlp' + yaml """ + apiVersion: v1 + kind: Pod + metadata: + labels: + app: milvus + componet: test + spec: + containers: + - name: milvus-testframework-java + image: registry.zilliz.com/milvus/milvus-java-test:v0.1 + command: + - cat + tty: true + volumeMounts: + - name: kubeconf + mountPath: /root/.kube/ + readOnly: true + volumes: + - name: kubeconf + secret: + secretName: test-cluster-config + """ + } + } + + stages { + stage("Deploy Server") { + steps { + gitlabCommitStatus(name: 'Deloy Server') { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/deploy_server.groovy" + } + } + } + } + } + stage("Integration Test") { + steps { + gitlabCommitStatus(name: 'Integration Test') { + container('milvus-testframework-java') { + script { + print "In integration test stage" + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/integration_test.groovy" + } + } + } + } + } + stage ("Cleanup Env") { + steps { + gitlabCommitStatus(name: 'Cleanup Env') { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy" + } + } + } + } + } + } + post { + always { + container('milvus-testframework-java') { + script { + load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy" + } + } + } + success { + script { + echo "Milvus java-sdk test success !" + } + } + aborted { + script { + echo "Milvus java-sdk test aborted !" + } + } + failure { + script { + echo "Milvus java-sdk test failed !" + } + } + } + } + } +} diff --git a/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml new file mode 100644 index 0000000000..1381e1454f --- /dev/null +++ b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Pod +metadata: + labels: + app: milvus + componet: testframework-java +spec: + containers: + - name: milvus-testframework-java + image: maven:3.6.2-jdk-8 + command: + - cat + tty: true diff --git a/tests/milvus-java-test/milvus-java-test.iml b/tests/milvus-java-test/milvus-java-test.iml new file mode 100644 index 0000000000..78b2cc53b2 --- /dev/null +++ b/tests/milvus-java-test/milvus-java-test.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/tests/milvus-java-test/pom.xml b/tests/milvus-java-test/pom.xml new file mode 100644 index 0000000000..db02ff2c00 --- /dev/null +++ b/tests/milvus-java-test/pom.xml @@ -0,0 +1,137 @@ + + + 4.0.0 + + milvus + MilvusSDkJavaTest + 1.0-SNAPSHOT + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + package + + copy-dependencies + + + lib + false + true + + + + + + + + + UTF-8 + 1.23.0 + 3.9.0 + 3.9.0 + 1.8 + 1.8 + + + + + + + + + + + + + + + + + oss.sonatype.org-snapshot + http://oss.sonatype.org/content/repositories/snapshots + + false + + + true + + + + + + + org.apache.commons + commons-lang3 + 3.4 + + + + commons-cli + commons-cli + 1.3 + + + + org.testng + testng + 6.10 + + + + junit + junit + 4.9 + + + + + + + + + + io.milvus + milvus-sdk-java + 0.1.1-SNAPSHOT + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/milvus-java-test/src/main/java/com/MainClass.java b/tests/milvus-java-test/src/main/java/com/MainClass.java new file mode 100644 index 0000000000..0e61e5f3d0 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/MainClass.java @@ -0,0 +1,147 @@ +package com; + +import io.milvus.client.*; +import org.apache.commons.cli.*; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.SkipException; +import org.testng.TestNG; +import org.testng.annotations.DataProvider; +import org.testng.xml.XmlClass; +import org.testng.xml.XmlSuite; +import org.testng.xml.XmlTest; + +import java.util.ArrayList; +import java.util.List; + +public class MainClass { + private static String host = "127.0.0.1"; + private static String port = "19530"; + public Integer index_file_size = 50; + public Integer dimension = 128; + + public static void setHost(String host) { + MainClass.host = host; + } + + public static void setPort(String port) { + MainClass.port = port; + } + + @DataProvider(name="DefaultConnectArgs") + public static Object[][] defaultConnectArgs(){ + return new Object[][]{{host, port}}; + } + + @DataProvider(name="ConnectInstance") + public Object[][] connectInstance(){ + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + String tableName = RandomStringUtils.randomAlphabetic(10); + return new Object[][]{{client, tableName}}; + } + + @DataProvider(name="DisConnectInstance") + public Object[][] disConnectInstance(){ + // Generate connection instance + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + String tableName = RandomStringUtils.randomAlphabetic(10); + return new Object[][]{{client, tableName}}; + } + + @DataProvider(name="Table") + public Object[][] provideTable(){ + Object[][] tables = new Object[2][2]; + MetricType metricTypes[] = { MetricType.L2, MetricType.IP }; + for (Integer i = 0; i < metricTypes.length; ++i) { + String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10); + // Generate connection instance + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(metricTypes[i]) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + Response res = client.createTable(tableSchemaParam); + if (!res.ok()) { + System.out.println(res.getMessage()); + throw new SkipException("Table created failed"); + } + tables[i] = new Object[]{client, tableName}; + } + return tables; + } + + public static void main(String[] args) { + CommandLineParser parser = new DefaultParser(); + Options options = new Options(); + options.addOption("h", "host", true, "milvus-server hostname/ip"); + options.addOption("p", "port", true, "milvus-server port"); + try { + CommandLine cmd = parser.parse(options, args); + String host = cmd.getOptionValue("host"); + if (host != null) { + setHost(host); + } + String port = cmd.getOptionValue("port"); + if (port != null) { + setPort(port); + } + System.out.println("Host: "+host+", Port: "+port); + } + catch(ParseException exp) { + System.err.println("Parsing failed. Reason: " + exp.getMessage() ); + } + +// TestListenerAdapter tla = new TestListenerAdapter(); +// TestNG testng = new TestNG(); +// testng.setTestClasses(new Class[] { TestPing.class }); +// testng.setTestClasses(new Class[] { TestConnect.class }); +// testng.addListener(tla); +// testng.run(); + + XmlSuite suite = new XmlSuite(); + suite.setName("TmpSuite"); + + XmlTest test = new XmlTest(suite); + test.setName("TmpTest"); + List classes = new ArrayList(); + + classes.add(new XmlClass("com.TestPing")); + classes.add(new XmlClass("com.TestAddVectors")); + classes.add(new XmlClass("com.TestConnect")); + classes.add(new XmlClass("com.TestDeleteVectors")); + classes.add(new XmlClass("com.TestIndex")); + classes.add(new XmlClass("com.TestSearchVectors")); + classes.add(new XmlClass("com.TestTable")); + classes.add(new XmlClass("com.TestTableCount")); + + test.setXmlClasses(classes) ; + + List suites = new ArrayList(); + suites.add(suite); + TestNG tng = new TestNG(); + tng.setXmlSuites(suites); + tng.run(); + + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestAddVectors.java b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java new file mode 100644 index 0000000000..038b8d2a8d --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java @@ -0,0 +1,154 @@ +package com; + +import io.milvus.client.InsertParam; +import io.milvus.client.InsertResponse; +import io.milvus.client.MilvusClient; +import io.milvus.client.TableParam; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TestAddVectors { + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + String tableNameNew = tableName + "_"; + InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.currentThread().sleep(1000); + // Assert table row count + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException { + int nb = 200000; + List> vectors = gen_vectors(nb); + System.out.println(new Date()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException { + int nb = 500000; + List> vectors = gen_vectors(nb); + System.out.println(new Date()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + // Add vectors with ids + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.currentThread().sleep(1000); + // Assert table row count + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb); + } + + // TODO: MS-628 + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) { + int nb = 10; + List> vectors = gen_vectors(nb); + // Add vectors with ids + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb+1) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) { + int nb = 10000; + List> vectors = gen_vectors(nb); + vectors.get(0).add((float) 0); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) { + int nb = 10000; + List> vectors = gen_vectors(nb); + vectors.set(0, new ArrayList<>()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100000; + int loops = 10; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = null; + for (int i = 0; i < loops; ++i ) { + long startTime = System.currentTimeMillis(); + res = client.insert(insertParam); + long endTime = System.currentTimeMillis(); + System.out.println("Total execution time: " + (endTime-startTime) + "ms"); + } + Thread.currentThread().sleep(1000); + // Assert table row count + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb * loops); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestConnect.java b/tests/milvus-java-test/src/main/java/com/TestConnect.java new file mode 100644 index 0000000000..77b6fe6a33 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestConnect.java @@ -0,0 +1,80 @@ +package com; + +import io.milvus.client.ConnectParam; +import io.milvus.client.MilvusClient; +import io.milvus.client.MilvusGrpcClient; +import io.milvus.client.Response; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +public class TestConnect { + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect(String host, String port){ + System.out.println("Host: "+host+", Port: "+port); + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + Response res = client.connect(connectParam); + assert(res.ok()); + assert(client.connected()); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect_repeat(String host, String port){ + MilvusGrpcClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + Response res = client.connect(connectParam); + assert(!res.ok()); + assert(client.connected()); + } + + @Test(dataProvider="InvalidConnectArgs") + public void test_connect_invalid_connect_args(String ip, String port) throws InterruptedException { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(ip) + .withPort(port) + .build(); + client.connect(connectParam); + assert(!client.connected()); + } + + // TODO: MS-615 + @DataProvider(name="InvalidConnectArgs") + public Object[][] generate_invalid_connect_args() { + String port = "19530"; + String ip = ""; + return new Object[][]{ + {"1.1.1.1", port}, + {"255.255.0.0", port}, + {"1.2.2", port}, + {"中文", port}, + {"www.baidu.com", "100000"}, + {"127.0.0.1", "100000"}, + {"www.baidu.com", "80"}, + }; + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_disconnect(MilvusClient client, String tableName){ + assert(!client.connected()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_disconnect_repeatably(MilvusClient client, String tableNam){ + Response res = null; + try { + res = client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + assert(res.ok()); + assert(!client.connected()); + } +} diff --git a/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java new file mode 100644 index 0000000000..69b5e41434 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java @@ -0,0 +1,122 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; + +public class TestDeleteVectors { + int index_file_size = 50; + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + public static Date getDeltaDate(int delta) { + Date today = new Date(); + Calendar c = Calendar.getInstance(); + c.setTime(today); + c.add(Calendar.DAY_OF_MONTH, delta); + return c.getTime(); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.sleep(1000); + DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(res_delete.ok()); + Thread.sleep(1000); + // Assert table row count + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build(); + Response res_delete = client.deleteByRange(param); + assert(!res_delete.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException { + DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(!res_delete.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException { + DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(res_delete.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100; + List> vectors = gen_vectors(nb); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.sleep(1000); + DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(!res_delete.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(!res_delete.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException { + int nb = 100; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + InsertResponse res = client.insert(insertParam); + assert(res.getResponse().ok()); + Thread.sleep(1000); + DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2)); + DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build(); + Response res_delete = client.deleteByRange(param); + assert(res_delete.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestIndex.java b/tests/milvus-java-test/src/main/java/com/TestIndex.java new file mode 100644 index 0000000000..d003771b0b --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestIndex.java @@ -0,0 +1,340 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Date; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; + +public class TestIndex { + int index_file_size = 10; + int dimension = 128; + int n_list = 1024; + int default_n_list = 16384; + int nb = 100000; + IndexType indexType = IndexType.IVF_SQ8; + IndexType defaultIndexType = IndexType.FLAT; + + public List> gen_vectors(Integer nb) { + List> xb = new LinkedList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + LinkedList vector = new LinkedList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_repeatably(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException { + int nb = 500000; + IndexType indexType = IndexType.IVF_SQ8; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + System.out.println(new Date()); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(1).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_IVFSQ8H(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8_H; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_with_no_vector(MilvusClient client, String tableName) { + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableNameNew).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_create_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_create_index_invalid_n_list(MilvusClient client, String tableName) throws InterruptedException { + int n_list = 0; + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(!res_create.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list); + Assert.assertEquals(index1.getIndexType(), indexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_alter_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + // Create another index + IndexType indexTypeNew = IndexType.IVFLAT; + int n_list_new = n_list + 1; + Index index_new = new Index.Builder().withIndexType(indexTypeNew) + .withNList(n_list_new) + .build(); + CreateIndexParam createIndexParamNew = new CreateIndexParam.Builder(tableName).withIndex(index_new).build(); + Response res_create_new = client.createIndex(createIndexParamNew); + assert(res_create_new.ok()); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res_create.ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), n_list_new); + Assert.assertEquals(index1.getIndexType(), indexTypeNew); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + TableParam tableParam = new TableParam.Builder(tableNameNew).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_describe_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(defaultIndexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res_drop = client.dropIndex(tableParam); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_repeatably(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(defaultIndexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + Response res_create = client.createIndex(createIndexParam); + assert(res_create.ok()); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res_drop = client.dropIndex(tableParam); + res_drop = client.dropIndex(tableParam); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + TableParam tableParam = new TableParam.Builder(tableNameNew).build(); + Response res_drop = client.dropIndex(tableParam); + assert(!res_drop.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_drop_index_without_connect(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res_drop = client.dropIndex(tableParam); + assert(!res_drop.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_index_no_index_created(MilvusClient client, String tableName) throws InterruptedException { + List> vectors = gen_vectors(nb); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res_drop = client.dropIndex(tableParam); + assert(res_drop.ok()); + DescribeIndexResponse res = client.describeIndex(tableParam); + assert(res.getResponse().ok()); + Index index1 = res.getIndex().get(); + Assert.assertEquals(index1.getNList(), default_n_list); + Assert.assertEquals(index1.getIndexType(), defaultIndexType); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestMix.java b/tests/milvus-java-test/src/main/java/com/TestMix.java new file mode 100644 index 0000000000..795ab8630e --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestMix.java @@ -0,0 +1,221 @@ +package com; + +import io.milvus.client.*; +import org.apache.commons.lang3.RandomStringUtils; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +public class TestMix { + private int dimension = 128; + private int nb = 100000; + int n_list = 8192; + int n_probe = 20; + int top_k = 10; + double epsilon = 0.001; + int index_file_size = 20; + + public List normalize(List w2v){ + float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum); + final float norm = (float) Math.sqrt(squareSum); + w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList()); + return w2v; + } + + public List> gen_vectors(int nb, boolean norm) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + List vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + if (norm == true) { + vector = normalize(vector); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 10; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_connect_threads(String host, String port) throws InterruptedException { + int thread_num = 100; + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + assert(client.connected()); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 10; + List> vectors = gen_vectors(nb,false); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + assert (res_insert.getResponse().ok()); + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + + Thread.sleep(2000); + TableParam tableParam = new TableParam.Builder(tableName).build(); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 50; + List> vectors = gen_vectors(nb,false); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + assert (res_insert.getResponse().ok()); + }); + } + executor.awaitQuiescence(300, TimeUnit.SECONDS); + executor.shutdown(); + Thread.sleep(2000); + TableParam tableParam = new TableParam.Builder(tableName).build(); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { + int thread_num = 50; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + InsertResponse res_insert = client.insert(insertParam); + assert (res_insert.getResponse().ok()); + try { + TimeUnit.SECONDS.sleep(1); + } catch (InterruptedException e) { + e.printStackTrace(); + } + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + double distance = res.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + }); + } + executor.awaitQuiescence(300, TimeUnit.SECONDS); + executor.shutdown(); + Thread.sleep(2000); + TableParam tableParam = new TableParam.Builder(tableName).build(); + GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); + Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); + } + + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_create_insert_delete_threads(String host, String port) throws InterruptedException { + int thread_num = 100; + List> vectors = gen_vectors(nb,false); + ForkJoinPool executor = new ForkJoinPool(); + for (int i = 0; i < thread_num; i++) { + executor.execute( + () -> { + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + String tableName = RandomStringUtils.randomAlphabetic(10); + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.IP) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + client.createTable(tableSchemaParam); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response response = client.dropTable(tableParam); + Assert.assertTrue(response.ok()); + try { + client.disconnect(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }); + } + executor.awaitQuiescence(100, TimeUnit.SECONDS); + executor.shutdown(); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestPing.java b/tests/milvus-java-test/src/main/java/com/TestPing.java new file mode 100644 index 0000000000..46850f4a17 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestPing.java @@ -0,0 +1,28 @@ +package com; + +import io.milvus.client.ConnectParam; +import io.milvus.client.MilvusClient; +import io.milvus.client.MilvusGrpcClient; +import io.milvus.client.Response; +import org.testng.annotations.Test; + +public class TestPing { + @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) + public void test_server_status(String host, String port){ + System.out.println("Host: "+host+", Port: "+port); + MilvusClient client = new MilvusGrpcClient(); + ConnectParam connectParam = new ConnectParam.Builder() + .withHost(host) + .withPort(port) + .build(); + client.connect(connectParam); + Response res = client.serverStatus(); + assert (res.ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){ + Response res = client.serverStatus(); + assert (!res.ok()); + } +} \ No newline at end of file diff --git a/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java new file mode 100644 index 0000000000..c574298652 --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java @@ -0,0 +1,480 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TestSearchVectors { + int index_file_size = 10; + int dimension = 128; + int n_list = 1024; + int default_n_list = 16384; + int nb = 100000; + int n_probe = 20; + int top_k = 10; + double epsilon = 0.001; + IndexType indexType = IndexType.IVF_SQ8; + IndexType defaultIndexType = IndexType.FLAT; + + + public List normalize(List w2v){ + float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum); + final float norm = (float) Math.sqrt(squareSum); + w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList()); + return w2v; + } + + public List> gen_vectors(int nb, boolean norm) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + List vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + if (norm == true) { + vector = normalize(vector); + } + xb.add(vector); + } + return xb; + } + + public static Date getDeltaDate(int delta) { + Date today = new Date(); + Calendar c = Calendar.getInstance(); + c.setTime(today); + c.add(Calendar.DAY_OF_MONTH, delta); + return c.getTime(); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + String tableNameNew = tableName + "_"; + int nq = 5; + int nb = 100; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + SearchParam searchParam = new SearchParam.Builder(tableNameNew, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_ids_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + List vectorIds; + vectorIds = Stream.iterate(0L, n -> n) + .limit(nb) + .collect(Collectors.toList()); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_IVFLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVFLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + double distance = res_search.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int nb = 1000; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(default_n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getResultDistancesList(); + for (int i = 0; i < nq; i++) { + double distance = res_search.get(i).get(0); + System.out.println(distance); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res_search.size(), nq); + Assert.assertEquals(res_search.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nb = 100000; + int nq = 1000; + int top_k = 2048; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withTimeout(1).build(); + System.out.println(new Date()); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_FLAT_big_data_size(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nb = 100000; + int nq = 2000; + int top_k = 2048; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + System.out.println(new Date()); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_distance_FLAT(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.FLAT; + int nq = 5; + List> vectors = gen_vectors(nb, true); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); + List> res_search = client.search(searchParam).getQueryResultsList(); + double distance = res_search.get(0).get(0).getDistance(); + if (tableName.startsWith("L2")) { + Assert.assertEquals(distance, 0.0, epsilon); + }else if (tableName.startsWith("IP")) { + Assert.assertEquals(distance, 1.0, epsilon); + } + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_invalid_n_probe(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int n_probe_new = 0; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe_new).withTopK(top_k).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_invalid_top_k(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + int top_k_new = 0; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k_new).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + +// @Test(dataProvider = "Table", dataProviderClass = MainClass.class) +// public void test_search_invalid_query_vectors(MilvusClient client, String tableName) throws InterruptedException { +// IndexType indexType = IndexType.IVF_SQ8; +// int nq = 5; +// List> vectors = gen_vectors(nb, false); +// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); +// client.insert(insertParam); +// Index index = new Index.Builder().withIndexType(indexType) +// .withNList(n_list) +// .build(); +// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); +// client.createIndex(createIndexParam); +// TableParam tableParam = new TableParam.Builder(tableName).build(); +// SearchParam searchParam = new SearchParam.Builder(tableName, null).withNProbe(n_probe).withTopK(top_k).build(); +// SearchResponse res_search = client.search(searchParam); +// assert (!res_search.getResponse().ok()); +// } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), nq); + Assert.assertEquals(res.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), nq); + Assert.assertEquals(res.get(0).size(), top_k); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range_no_result(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range_no_result(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (res_search.getResponse().ok()); + List> res = client.search(searchParam).getQueryResultsList(); + Assert.assertEquals(res.size(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_index_range_invalid(MilvusClient client, String tableName) throws InterruptedException { + IndexType indexType = IndexType.IVF_SQ8; + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Index index = new Index.Builder().withIndexType(indexType) + .withNList(n_list) + .build(); + CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); + client.createIndex(createIndexParam); + TableParam tableParam = new TableParam.Builder(tableName).build(); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_search_range_invalid(MilvusClient client, String tableName) throws InterruptedException { + int nq = 5; + List> vectors = gen_vectors(nb, false); + List> queryVectors = vectors.subList(0,nq); + List dateRange = new ArrayList<>(); + dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1))); + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); + client.insert(insertParam); + Thread.sleep(1000); + SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build(); + SearchResponse res_search = client.search(searchParam); + assert (!res_search.getResponse().ok()); + } + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestTable.java b/tests/milvus-java-test/src/main/java/com/TestTable.java new file mode 100644 index 0000000000..12b80c654b --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestTable.java @@ -0,0 +1,155 @@ +package com; + + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.List; + +public class TestTable { + int index_file_size = 50; + int dimension = 128; + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + Response res = client.createTable(tableSchemaParam); + assert(res.ok()); + Assert.assertEquals(res.ok(), true); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_disconnect(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + Response res = client.createTable(tableSchemaParam); + assert(!res.ok()); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_repeatably(MilvusClient client, String tableName){ + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + Response res = client.createTable(tableSchemaParam); + Assert.assertEquals(res.ok(), true); + Response res_new = client.createTable(tableSchemaParam); + Assert.assertEquals(res_new.ok(), false); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_create_table_wrong_params(MilvusClient client, String tableName){ + Integer dimension = 0; + TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + Response res = client.createTable(tableSchemaParam); + System.out.println(res.toString()); + Assert.assertEquals(res.ok(), false); + } + + @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class) + public void test_show_tables(MilvusClient client, String tableName){ + Integer tableNum = 10; + ShowTablesResponse res = null; + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName+"_"+Integer.toString(i); + TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + client.createTable(tableSchemaParam); + List tableNames = client.showTables().getTableNames(); + Assert.assertTrue(tableNames.contains(tableNameNew)); + } + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_show_tables_without_connect(MilvusClient client, String tableName){ + ShowTablesResponse res = client.showTables(); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_table(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res = client.dropTable(tableParam); + assert(res.ok()); + Thread.currentThread().sleep(1000); + List tableNames = client.showTables().getTableNames(); + Assert.assertFalse(tableNames.contains(tableName)); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_drop_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName+"_").build(); + Response res = client.dropTable(tableParam); + assert(!res.ok()); + List tableNames = client.showTables().getTableNames(); + Assert.assertTrue(tableNames.contains(tableName)); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_drop_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + Response res = client.dropTable(tableParam); + assert(!res.ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_describe_table(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeTableResponse res = client.describeTable(tableParam); + assert(res.getResponse().ok()); + TableSchema tableSchema = res.getTableSchema().get(); + Assert.assertEquals(tableSchema.getDimension(), dimension); + Assert.assertEquals(tableSchema.getTableName(), tableName); + Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size); + Assert.assertEquals(tableSchema.getMetricType().name(), tableName.substring(0,2)); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_describe_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + DescribeTableResponse res = client.describeTable(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_has_table_not_existed(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName+"_").build(); + HasTableResponse res = client.hasTable(tableParam); + assert(res.getResponse().ok()); + Assert.assertFalse(res.hasTable()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_has_table_without_connect(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + HasTableResponse res = client.hasTable(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_has_table(MilvusClient client, String tableName) throws InterruptedException { + TableParam tableParam = new TableParam.Builder(tableName).build(); + HasTableResponse res = client.hasTable(tableParam); + assert(res.getResponse().ok()); + Assert.assertTrue(res.hasTable()); + } + + +} diff --git a/tests/milvus-java-test/src/main/java/com/TestTableCount.java b/tests/milvus-java-test/src/main/java/com/TestTableCount.java new file mode 100644 index 0000000000..afb16c471a --- /dev/null +++ b/tests/milvus-java-test/src/main/java/com/TestTableCount.java @@ -0,0 +1,89 @@ +package com; + +import io.milvus.client.*; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +public class TestTableCount { + int index_file_size = 50; + int dimension = 128; + + public List> gen_vectors(Integer nb) { + List> xb = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < nb; ++i) { + ArrayList vector = new ArrayList<>(); + for (int j = 0; j < dimension; j++) { + vector.add(random.nextFloat()); + } + xb.add(vector); + } + return xb; + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_no_vectors(MilvusClient client, String tableName) { + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_table_not_existed(MilvusClient client, String tableName) { + TableParam tableParam = new TableParam.Builder(tableName+"_").build(); + GetTableRowCountResponse res = client.getTableRowCount(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class) + public void test_table_count_without_connect(MilvusClient client, String tableName) { + TableParam tableParam = new TableParam.Builder(tableName+"_").build(); + GetTableRowCountResponse res = client.getTableRowCount(tableParam); + assert(!res.getResponse().ok()); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();; + client.insert(insertParam); + Thread.currentThread().sleep(1000); + TableParam tableParam = new TableParam.Builder(tableName).build(); + Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb); + } + + @Test(dataProvider = "Table", dataProviderClass = MainClass.class) + public void test_table_count_multi_tables(MilvusClient client, String tableName) throws InterruptedException { + int nb = 10000; + List> vectors = gen_vectors(nb); + Integer tableNum = 10; + GetTableRowCountResponse res = null; + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName + "_" + Integer.toString(i); + TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension) + .withIndexFileSize(index_file_size) + .withMetricType(MetricType.L2) + .build(); + TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); + client.createTable(tableSchemaParam); + // Add vectors + InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build(); + client.insert(insertParam); + } + Thread.currentThread().sleep(1000); + for (int i = 0; i < tableNum; ++i) { + String tableNameNew = tableName + "_" + Integer.toString(i); + TableParam tableParam = new TableParam.Builder(tableNameNew).build(); + res = client.getTableRowCount(tableParam); + Assert.assertEquals(res.getTableRowCount(), nb); + } + } + +} + + diff --git a/tests/milvus-java-test/testng.xml b/tests/milvus-java-test/testng.xml new file mode 100644 index 0000000000..19520f7eab --- /dev/null +++ b/tests/milvus-java-test/testng.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/tests/milvus_ann_acc/.gitignore b/tests/milvus_ann_acc/.gitignore new file mode 100644 index 0000000000..f250cab9fe --- /dev/null +++ b/tests/milvus_ann_acc/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +logs/ diff --git a/tests/milvus_ann_acc/client.py b/tests/milvus_ann_acc/client.py new file mode 100644 index 0000000000..de4ef17cb6 --- /dev/null +++ b/tests/milvus_ann_acc/client.py @@ -0,0 +1,149 @@ +import pdb +import random +import logging +import json +import time, datetime +from multiprocessing import Process +import numpy +import sklearn.preprocessing +from milvus import Milvus, IndexType, MetricType + +logger = logging.getLogger("milvus_ann_acc.client") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 + + +def time_wrapper(func): + """ + This decorator prints the execution time for the decorated function. + """ + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) + return result + return wrapper + + +class MilvusClient(object): + def __init__(self, table_name=None, ip=None, port=None): + self._milvus = Milvus() + self._table_name = table_name + try: + if not ip: + self._milvus.connect( + host = SERVER_HOST_DEFAULT, + port = SERVER_PORT_DEFAULT) + else: + self._milvus.connect( + host = ip, + port = port) + except Exception as e: + raise e + + def __str__(self): + return 'Milvus table %s' % self._table_name + + def check_status(self, status): + if not status.OK(): + logger.error(status.message) + raise Exception("Status not ok") + + def create_table(self, table_name, dimension, index_file_size, metric_type): + if not self._table_name: + self._table_name = table_name + if metric_type == "l2": + metric_type = MetricType.L2 + elif metric_type == "ip": + metric_type = MetricType.IP + else: + logger.error("Not supported metric_type: %s" % metric_type) + self._metric_type = metric_type + create_param = {'table_name': table_name, + 'dimension': dimension, + 'index_file_size': index_file_size, + "metric_type": metric_type} + status = self._milvus.create_table(create_param) + self.check_status(status) + + @time_wrapper + def insert(self, X, ids): + if self._metric_type == MetricType.IP: + logger.info("Set normalize for metric_type: Inner Product") + X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') + X = X.astype(numpy.float32) + status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids) + self.check_status(status) + return status, result + + @time_wrapper + def create_index(self, index_type, nlist): + if index_type == "flat": + index_type = IndexType.FLAT + elif index_type == "ivf_flat": + index_type = IndexType.IVFLAT + elif index_type == "ivf_sq8": + index_type = IndexType.IVF_SQ8 + elif index_type == "ivf_sq8h": + index_type = IndexType.IVF_SQ8H + elif index_type == "mix_nsg": + index_type = IndexType.MIX_NSG + index_params = { + "index_type": index_type, + "nlist": nlist, + } + logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) + status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600) + self.check_status(status) + + def describe_index(self): + return self._milvus.describe_index(self._table_name) + + def drop_index(self): + logger.info("Drop index: %s" % self._table_name) + return self._milvus.drop_index(self._table_name) + + @time_wrapper + def query(self, X, top_k, nprobe): + if self._metric_type == MetricType.IP: + logger.info("Set normalize for metric_type: Inner Product") + X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') + X = X.astype(numpy.float32) + status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist()) + self.check_status(status) + # logger.info(results[0]) + ids = [] + for result in results: + tmp_ids = [] + for item in result: + tmp_ids.append(item.id) + ids.append(tmp_ids) + return ids + + def count(self): + return self._milvus.get_table_row_count(self._table_name)[1] + + def delete(self, timeout=60): + logger.info("Start delete table: %s" % self._table_name) + self._milvus.delete_table(self._table_name) + i = 0 + while i < timeout: + if self.count(): + time.sleep(1) + i = i + 1 + else: + break + if i >= timeout: + logger.error("Delete table timeout") + + def describe(self): + return self._milvus.describe_table(self._table_name) + + def exists_table(self): + return self._milvus.has_table(self._table_name) + + @time_wrapper + def preload_table(self): + return self._milvus.preload_table(self._table_name) diff --git a/tests/milvus_ann_acc/config.yaml b/tests/milvus_ann_acc/config.yaml new file mode 100644 index 0000000000..e2ac2c1bfb --- /dev/null +++ b/tests/milvus_ann_acc/config.yaml @@ -0,0 +1,17 @@ +datasets: + sift-128-euclidean: + cpu_cache_size: 16 + gpu_cache_size: 5 + index_file_size: [1024] + nytimes-16-angular: + cpu_cache_size: 16 + gpu_cache_size: 5 + index_file_size: [1024] + +index: + index_types: ['flat', 'ivf_flat', 'ivf_sq8'] + nlists: [8092, 16384] + +search: + nprobes: [1, 8, 32] + top_ks: [10] diff --git a/tests/milvus_ann_acc/main.py b/tests/milvus_ann_acc/main.py new file mode 100644 index 0000000000..308e8246c7 --- /dev/null +++ b/tests/milvus_ann_acc/main.py @@ -0,0 +1,26 @@ + +import argparse + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--dataset', + metavar='NAME', + help='the dataset to load training points from', + default='glove-100-angular', + choices=DATASETS.keys()) + parser.add_argument( + "-k", "--count", + default=10, + type=positive_int, + help="the number of near neighbours to search for") + parser.add_argument( + '--definitions', + metavar='FILE', + help='load algorithm definitions from FILE', + default='algos.yaml') + parser.add_argument( + '--image-tag', + default=None, + help='pull image first') \ No newline at end of file diff --git a/tests/milvus_ann_acc/test.py b/tests/milvus_ann_acc/test.py new file mode 100644 index 0000000000..c4fbc33195 --- /dev/null +++ b/tests/milvus_ann_acc/test.py @@ -0,0 +1,132 @@ +import os +import pdb +import time +import random +import sys +import h5py +import numpy +import logging +from logging import handlers + +from client import MilvusClient + +LOG_FOLDER = "logs" +logger = logging.getLogger("milvus_ann_acc") + +formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s') +if not os.path.exists(LOG_FOLDER): + os.system('mkdir -p %s' % LOG_FOLDER) +fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10) +fileTimeHandler.suffix = "%Y%m%d.log" +fileTimeHandler.setFormatter(formatter) +logging.basicConfig(level=logging.DEBUG) +fileTimeHandler.setFormatter(formatter) +logger.addHandler(fileTimeHandler) + + +def get_dataset_fn(dataset_name): + file_path = "/test/milvus/ann_hdf5/" + if not os.path.exists(file_path): + raise Exception("%s not exists" % file_path) + return os.path.join(file_path, '%s.hdf5' % dataset_name) + + +def get_dataset(dataset_name): + hdf5_fn = get_dataset_fn(dataset_name) + hdf5_f = h5py.File(hdf5_fn) + return hdf5_f + + +def parse_dataset_name(dataset_name): + data_type = dataset_name.split("-")[0] + dimension = int(dataset_name.split("-")[1]) + metric = dataset_name.split("-")[-1] + # metric = dataset.attrs['distance'] + # dimension = len(dataset["train"][0]) + if metric == "euclidean": + metric_type = "l2" + elif metric == "angular": + metric_type = "ip" + return ("ann"+data_type, dimension, metric_type) + + +def get_table_name(dataset_name, index_file_size): + data_type, dimension, metric_type = parse_dataset_name(dataset_name) + dataset = get_dataset(dataset_name) + table_size = len(dataset["train"]) + table_size = str(table_size // 1000000)+"m" + table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type + return table_name + + +def main(dataset_name, index_file_size, nlist=16384, force=False): + top_k = 10 + nprobes = [32, 128] + + dataset = get_dataset(dataset_name) + table_name = get_table_name(dataset_name, index_file_size) + m = MilvusClient(table_name) + if m.exists_table(): + if force is True: + logger.info("Re-create table: %s" % table_name) + m.delete() + time.sleep(10) + else: + logger.info("Table name: %s existed" % table_name) + return + data_type, dimension, metric_type = parse_dataset_name(dataset_name) + m.create_table(table_name, dimension, index_file_size, metric_type) + print(m.describe()) + vectors = numpy.array(dataset["train"]) + query_vectors = numpy.array(dataset["test"]) + # m.insert(vectors) + + interval = 100000 + loops = len(vectors) // interval + 1 + + for i in range(loops): + start = i*interval + end = min((i+1)*interval, len(vectors)) + tmp_vectors = vectors[start:end] + if start < end: + m.insert(tmp_vectors, ids=[i for i in range(start, end)]) + + time.sleep(60) + print(m.count()) + + for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]: + m.create_index(index_type, nlist) + print(m.describe_index()) + if m.count() != len(vectors): + return + m.preload_table() + true_ids = numpy.array(dataset["neighbors"]) + for nprobe in nprobes: + print("nprobe: %s" % nprobe) + sum_radio = 0.0; avg_radio = 0.0 + result_ids = m.query(query_vectors, top_k, nprobe) + # print(result_ids[:10]) + for index, result_item in enumerate(result_ids): + if len(set(true_ids[index][:top_k])) != len(set(result_item)): + logger.info("Error happened") + # logger.info(query_vectors[index]) + # logger.info(true_ids[index][:top_k], result_item) + tmp = set(true_ids[index][:top_k]).intersection(set(result_item)) + sum_radio = sum_radio + (len(tmp) / top_k) + avg_radio = round(sum_radio / len(result_ids), 4) + logger.info(avg_radio) + m.drop_index() + + +if __name__ == "__main__": + print("glove-25-angular") + # main("sift-128-euclidean", 1024, force=True) + for index_file_size in [50, 1024]: + print("Index file size: %d" % index_file_size) + main("glove-25-angular", index_file_size, force=True) + + print("sift-128-euclidean") + for index_file_size in [50, 1024]: + print("Index file size: %d" % index_file_size) + main("sift-128-euclidean", index_file_size, force=True) + # m = MilvusClient() \ No newline at end of file diff --git a/tests/milvus_benchmark/.gitignore b/tests/milvus_benchmark/.gitignore new file mode 100644 index 0000000000..70af07bba8 --- /dev/null +++ b/tests/milvus_benchmark/.gitignore @@ -0,0 +1,8 @@ +random_data +benchmark_logs/ +db/ +logs/ +*idmap*.txt +__pycache__/ +venv +.idea \ No newline at end of file diff --git a/tests/milvus_benchmark/README.md b/tests/milvus_benchmark/README.md new file mode 100644 index 0000000000..72abd1264a --- /dev/null +++ b/tests/milvus_benchmark/README.md @@ -0,0 +1,57 @@ +# Quick start + +## 运行 + +### 运行示例: + +`python3 main.py --image=registry.zilliz.com/milvus/engine:branch-0.3.1-release --run-count=2 --run-type=performance` + +### 运行参数: + +--image: 容器模式,传入镜像名称,如传入,则运行测试时,会先进行pull image,基于image生成milvus server容器 + +--local: 与image参数互斥,本地模式,连接使用本地启动的milvus server进行测试 + +--run-count: 重复运行次数 + +--suites: 测试集配置文件,默认使用suites.yaml + +--run-type: 测试类型,包括性能--performance、准确性测试--accuracy以及稳定性--stability + +### 测试集配置文件: + +`operations: + + insert: + +​ [ +​ {"table.index_type": "ivf_flat", "server.index_building_threshold": 300, "table.size": 2000000, "table.ni": 100000, "table.dim": 512}, +​ ] + + build: [] + + query: + +​ [ +​ {"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 1, "server.use_blas_threshold": 800}, +​ {"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 10, "server.use_blas_threshold": 20}, +​ ]` + +## 测试结果: + +性能: + +`INFO:milvus_benchmark.runner:Start warm query, query params: top-k: 1, nq: 1 + +INFO:milvus_benchmark.client:query run in 19.19s +INFO:milvus_benchmark.runner:Start query, query params: top-k: 64, nq: 10, actually length of vectors: 10 +INFO:milvus_benchmark.runner:Start run query, run 1 of 1 +INFO:milvus_benchmark.client:query run in 0.2s +INFO:milvus_benchmark.runner:Avarage query time: 0.20 +INFO:milvus_benchmark.runner:[[0.2]]` + +**│ 10 │ 0.2 │** + +准确率: + +`INFO:milvus_benchmark.runner:Avarage accuracy: 1.0` \ No newline at end of file diff --git a/tests/milvus_benchmark/__init__.py b/tests/milvus_benchmark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/milvus_benchmark/client.py b/tests/milvus_benchmark/client.py new file mode 100644 index 0000000000..4744c10854 --- /dev/null +++ b/tests/milvus_benchmark/client.py @@ -0,0 +1,244 @@ +import pdb +import random +import logging +import json +import sys +import time, datetime +from multiprocessing import Process +from milvus import Milvus, IndexType, MetricType + +logger = logging.getLogger("milvus_benchmark.client") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 + + +def time_wrapper(func): + """ + This decorator prints the execution time for the decorated function. + """ + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2))) + return result + return wrapper + + +class MilvusClient(object): + def __init__(self, table_name=None, ip=None, port=None): + self._milvus = Milvus() + self._table_name = table_name + try: + if not ip: + self._milvus.connect( + host = SERVER_HOST_DEFAULT, + port = SERVER_PORT_DEFAULT) + else: + self._milvus.connect( + host = ip, + port = port) + except Exception as e: + raise e + + def __str__(self): + return 'Milvus table %s' % self._table_name + + def check_status(self, status): + if not status.OK(): + logger.error(status.message) + raise Exception("Status not ok") + + def create_table(self, table_name, dimension, index_file_size, metric_type): + if not self._table_name: + self._table_name = table_name + if metric_type == "l2": + metric_type = MetricType.L2 + elif metric_type == "ip": + metric_type = MetricType.IP + else: + logger.error("Not supported metric_type: %s" % metric_type) + create_param = {'table_name': table_name, + 'dimension': dimension, + 'index_file_size': index_file_size, + "metric_type": metric_type} + status = self._milvus.create_table(create_param) + self.check_status(status) + + @time_wrapper + def insert(self, X, ids=None): + status, result = self._milvus.add_vectors(self._table_name, X, ids) + self.check_status(status) + return status, result + + @time_wrapper + def create_index(self, index_type, nlist): + if index_type == "flat": + index_type = IndexType.FLAT + elif index_type == "ivf_flat": + index_type = IndexType.IVFLAT + elif index_type == "ivf_sq8": + index_type = IndexType.IVF_SQ8 + elif index_type == "mix_nsg": + index_type = IndexType.MIX_NSG + elif index_type == "ivf_sq8h": + index_type = IndexType.IVF_SQ8H + index_params = { + "index_type": index_type, + "nlist": nlist, + } + logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) + status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600) + self.check_status(status) + + def describe_index(self): + return self._milvus.describe_index(self._table_name) + + def drop_index(self): + logger.info("Drop index: %s" % self._table_name) + return self._milvus.drop_index(self._table_name) + + @time_wrapper + def query(self, X, top_k, nprobe): + status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X) + self.check_status(status) + return status, result + + def count(self): + return self._milvus.get_table_row_count(self._table_name)[1] + + def delete(self, timeout=60): + logger.info("Start delete table: %s" % self._table_name) + self._milvus.delete_table(self._table_name) + i = 0 + while i < timeout: + if self.count(): + time.sleep(1) + i = i + 1 + continue + else: + break + if i < timeout: + logger.error("Delete table timeout") + + def describe(self): + return self._milvus.describe_table(self._table_name) + + def exists_table(self): + return self._milvus.has_table(self._table_name) + + @time_wrapper + def preload_table(self): + return self._milvus.preload_table(self._table_name, timeout=3000) + + +def fit(table_name, X): + milvus = Milvus() + milvus.connect(host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT) + start = time.time() + status, ids = milvus.add_vectors(table_name, X) + end = time.time() + logger(status, round(end - start, 2)) + + +def fit_concurrent(table_name, process_num, vectors): + processes = [] + + for i in range(process_num): + p = Process(target=fit, args=(table_name, vectors, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + +if __name__ == "__main__": + + # table_name = "sift_2m_20_128_l2" + table_name = "test_tset1" + m = MilvusClient(table_name) + # m.create_table(table_name, 128, 50, "l2") + + print(m.describe()) + # print(m.count()) + # print(m.describe_index()) + insert_vectors = [[random.random() for _ in range(128)] for _ in range(10000)] + for i in range(5): + m.insert(insert_vectors) + print(m.create_index("ivf_sq8h", 16384)) + X = [insert_vectors[0]] + top_k = 10 + nprobe = 10 + print(m.query(X, top_k, nprobe)) + + # # # print(m.drop_index()) + # # print(m.describe_index()) + # # sys.exit() + # # # insert_vectors = [[random.random() for _ in range(128)] for _ in range(100000)] + # # # for i in range(100): + # # # m.insert(insert_vectors) + # # # time.sleep(5) + # # # print(m.describe_index()) + # # # print(m.drop_index()) + # # m.create_index("ivf_sq8h", 16384) + # print(m.count()) + # print(m.describe_index()) + + + + # sys.exit() + # print(m.create_index("ivf_sq8h", 16384)) + # print(m.count()) + # print(m.describe_index()) + import numpy as np + + def mmap_fvecs(fname): + x = np.memmap(fname, dtype='int32', mode='r') + d = x[0] + return x.view('float32').reshape(-1, d + 1)[:, 1:] + + print(mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs")) + # SIFT_SRC_QUERY_DATA_DIR = '/poc/yuncong/ann_1000m' + # file_name = SIFT_SRC_QUERY_DATA_DIR+'/'+'query.npy' + # data = numpy.load(file_name) + # query_vectors = data[0:2].tolist() + # print(len(query_vectors)) + # results = m.query(query_vectors, 10, 10) + # result_ids = [] + # for result in results[1]: + # tmp = [] + # for item in result: + # tmp.append(item.id) + # result_ids.append(tmp) + # print(result_ids[0][:10]) + # # gt + # file_name = SIFT_SRC_QUERY_DATA_DIR+"/gnd/"+"idx_1M.ivecs" + # a = numpy.fromfile(file_name, dtype='int32') + # d = a[0] + # true_ids = a.reshape(-1, d + 1)[:, 1:].copy() + # print(true_ids[:3, :2]) + + # print(len(true_ids[0])) + # import numpy as np + # import sklearn.preprocessing + + # def mmap_fvecs(fname): + # x = np.memmap(fname, dtype='int32', mode='r') + # d = x[0] + # return x.view('float32').reshape(-1, d + 1)[:, 1:] + + # data = mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs") + # print(data[0], len(data[0]), len(data)) + + # total_size = 10000 + # # total_size = 1000000000 + # file_size = 1000 + # # file_size = 100000 + # file_num = total_size // file_size + # for i in range(file_num): + # fname = "/test/milvus/raw_data/deep1b/binary_96_%05d" % i + # print(fname, i*file_size, (i+1)*file_size) + # single_data = data[i*file_size : (i+1)*file_size] + # single_data = sklearn.preprocessing.normalize(single_data, axis=1, norm='l2') + # np.save(fname, single_data) diff --git a/tests/milvus_benchmark/conf/log_config.conf b/tests/milvus_benchmark/conf/log_config.conf new file mode 100644 index 0000000000..c9c14d57f4 --- /dev/null +++ b/tests/milvus_benchmark/conf/log_config.conf @@ -0,0 +1,28 @@ +* GLOBAL: + FORMAT = "%datetime | %level | %logger | %msg" + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log" + ENABLED = true + TO_FILE = true + TO_STANDARD_OUTPUT = false + SUBSECOND_PRECISION = 3 + PERFORMANCE_TRACKING = false + MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB +* DEBUG: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log" + ENABLED = true +* WARNING: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log" +* TRACE: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log" +* VERBOSE: + FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg" + TO_FILE = false + TO_STANDARD_OUTPUT = false +## Error logs +* ERROR: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log" +* FATAL: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log" + diff --git a/tests/milvus_benchmark/conf/server_config.yaml b/tests/milvus_benchmark/conf/server_config.yaml new file mode 100644 index 0000000000..a5d2081b8a --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml @@ -0,0 +1,28 @@ +cache_config: + cache_insert_data: false + cpu_cache_capacity: 16 + gpu_cache_capacity: 6 + cpu_cache_threshold: 0.85 +db_config: + backend_url: sqlite://:@:/ + build_index_gpu: 0 + insert_buffer_size: 4 + preload_table: null + primary_path: /opt/milvus + secondary_path: null +engine_config: + use_blas_threshold: 20 +metric_config: + collector: prometheus + enable_monitor: true + prometheus_config: + port: 8080 +resource_config: + resource_pool: + - cpu + - gpu0 +server_config: + address: 0.0.0.0 + deploy_mode: single + port: 19530 + time_zone: UTC+8 diff --git a/tests/milvus_benchmark/conf/server_config.yaml.cpu b/tests/milvus_benchmark/conf/server_config.yaml.cpu new file mode 100644 index 0000000000..95ab5f5343 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.cpu @@ -0,0 +1,31 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu \ No newline at end of file diff --git a/tests/milvus_benchmark/conf/server_config.yaml.multi b/tests/milvus_benchmark/conf/server_config.yaml.multi new file mode 100644 index 0000000000..002d3bd2a6 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.multi @@ -0,0 +1,33 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu + - gpu0 + - gpu1 \ No newline at end of file diff --git a/tests/milvus_benchmark/conf/server_config.yaml.single b/tests/milvus_benchmark/conf/server_config.yaml.single new file mode 100644 index 0000000000..033d8868d1 --- /dev/null +++ b/tests/milvus_benchmark/conf/server_config.yaml.single @@ -0,0 +1,32 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: false + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 16 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu + - gpu0 \ No newline at end of file diff --git a/tests/milvus_benchmark/demo.py b/tests/milvus_benchmark/demo.py new file mode 100644 index 0000000000..27152e0980 --- /dev/null +++ b/tests/milvus_benchmark/demo.py @@ -0,0 +1,51 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient + +nq = 100000 +dimension = 128 +run_count = 1 +table_name = "sift_10m_1024_128_ip" +insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] + +def do_query(milvus, table_name, top_ks, nqs, nprobe, run_count): + bi_res = [] + for index, nq in enumerate(nqs): + tmp_res = [] + for top_k in top_ks: + avg_query_time = 0.0 + total_query_time = 0.0 + vectors = insert_vectors[0:nq] + for i in range(run_count): + start_time = time.time() + status, query_res = milvus.query(vectors, top_k, nprobe) + total_query_time = total_query_time + (time.time() - start_time) + if status.code: + print(status.message) + avg_query_time = round(total_query_time / run_count, 2) + tmp_res.append(avg_query_time) + bi_res.append(tmp_res) + return bi_res + +while 1: + milvus_instance = MilvusClient(table_name, ip="192.168.1.197", port=19530) + top_ks = random.sample([x for x in range(1, 100)], 4) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count) + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status.message) + + # status = milvus_instance.drop_index() + if not status.OK(): + print(status.message) + index_type = "ivf_sq8" + status = milvus_instance.create_index(index_type, 16384) + if not status.OK(): + print(status.message) \ No newline at end of file diff --git a/tests/milvus_benchmark/docker_runner.py b/tests/milvus_benchmark/docker_runner.py new file mode 100644 index 0000000000..008c9866b4 --- /dev/null +++ b/tests/milvus_benchmark/docker_runner.py @@ -0,0 +1,261 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser +from runner import Runner + +logger = logging.getLogger("milvus_benchmark.docker") + + +class DockerRunner(Runner): + """run docker mode""" + def __init__(self, image): + super(DockerRunner, self).__init__() + self.image = image + + def run(self, definition, run_type=None): + if run_type == "performance": + for op_type, op_value in definition.items(): + # run docker mode + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + + if op_type == "insert": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["table_name"] + volume_name = param["db_path_prefix"] + print(table_name) + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + # Update server config + utils.modify_config(k, v, type="server", db_slave=None) + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if milvus.exists_table(): + milvus.delete() + time.sleep(10) + milvus.create_table(table_name, dimension, index_file_size, metric_type) + res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"]) + logger.info(res) + + # wait for file merge + time.sleep(6 * (table_size / 500000)) + # Clear up + utils.remove_container(container) + + elif op_type == "query": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + volume_name = param["db_path_prefix"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + logger.debug(milvus._milvus.show_tables()) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + for index_type in index_types: + for nlist in nlists: + result = milvus.describe_index() + logger.info(result) + milvus.create_index(index_type, nlist) + result = milvus.describe_index() + logger.info(result) + # preload index + milvus.preload_table() + logger.info("Start warm up query") + res = self.do_query(milvus, table_name, [1], [1], 1, 1) + logger.info("End warm up query") + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + headers = ["Nprobe/Top-k"] + headers.extend([str(top_k) for top_k in top_ks]) + utils.print_table(headers, nqs, res) + utils.remove_container(container) + + elif run_type == "accuracy": + """ + { + "dataset": "random_50m_1024_512", + "index.index_types": ["flat", ivf_flat", "ivf_sq8"], + "index.nlists": [16384], + "nprobes": [1, 32, 128], + "nqs": [100], + "top_ks": [1, 64], + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 256 + } + """ + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + sift_acc = False + if "sift_acc" in param: + sift_acc = param["sift_acc"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + volume_name = param["db_path_prefix"] + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + + if sift_acc is True: + # preload groundtruth data + true_ids_all = self.get_groundtruth_ids(table_size) + + acc_dict = {} + for index_type in index_types: + for nlist in nlists: + result = milvus.describe_index() + logger.info(result) + milvus.create_index(index_type, nlist) + # preload index + milvus.preload_table() + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + for top_k in top_ks: + for nq in nqs: + result_ids = [] + id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \ + (table_name, index_type, nlist, metric_type, nprobe, top_k, nq) + if sift_acc is False: + self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix) + if index_type != "flat": + # Compute accuracy + base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \ + (table_name, nlist, metric_type, nprobe, top_k, nq) + avg_acc = self.compute_accuracy(base_name, id_prefix) + logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc)) + else: + result_ids = self.do_query_ids(milvus, table_name, top_k, nq, nprobe) + acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids) + logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value)) + # # print accuracy table + # headers = [table_name] + # headers.extend([str(top_k) for top_k in top_ks]) + # utils.print_table(headers, nqs, res) + + # remove container, and run next definition + logger.info("remove container, and run next definition") + utils.remove_container(container) + + elif run_type == "stability": + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + container = None + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + volume_name = param["db_path_prefix"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + # set default test time + if "during_time" not in param: + during_time = 100 # seconds + else: + during_time = int(param["during_time"]) * 60 + # set default query process num + if "query_process_num" not in param: + query_process_num = 10 + else: + query_process_num = int(param["query_process_num"]) + + for k, v in param.items(): + if k.startswith("server."): + utils.modify_config(k, v, type="server") + + container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None) + time.sleep(2) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + start_time = time.time() + insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)] + while time.time() < start_time + during_time: + processes = [] + # do query + # for i in range(query_process_num): + # milvus_instance = MilvusClient(table_name) + # top_k = random.choice([x for x in range(1, 100)]) + # nq = random.choice([x for x in range(1, 100)]) + # nprobe = random.choice([x for x in range(1, 1000)]) + # # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], [nprobe], run_count, )) + # processes.append(p) + # p.start() + # time.sleep(0.1) + # for p in processes: + # p.join() + milvus_instance = MilvusClient(table_name) + top_ks = random.sample([x for x in range(1, 100)], 3) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + if int(time.time() - start_time) % 120 == 0: + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status) + # status = milvus_instance.drop_index() + # if not status.OK(): + # logger.error(status) + # index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"]) + result = milvus.describe_index() + logger.info(result) + milvus_instance.create_index("ivf_sq8", 16384) + utils.remove_container(container) + + else: + logger.warning("Run type: %s not supported" % run_type) + diff --git a/tests/milvus_benchmark/local_runner.py b/tests/milvus_benchmark/local_runner.py new file mode 100644 index 0000000000..1067f500bc --- /dev/null +++ b/tests/milvus_benchmark/local_runner.py @@ -0,0 +1,132 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser +from runner import Runner + +logger = logging.getLogger("milvus_benchmark.local_runner") + + +class LocalRunner(Runner): + """run local mode""" + def __init__(self, ip, port): + super(LocalRunner, self).__init__() + self.ip = ip + self.port = port + + def run(self, definition, run_type=None): + if run_type == "performance": + for op_type, op_value in definition.items(): + run_count = op_value["run_count"] + run_params = op_value["params"] + + if op_type == "insert": + for index, param in enumerate(run_params): + table_name = param["table_name"] + # random_1m_100_512 + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + milvus = MilvusClient(table_name, ip=self.ip, port=self.port) + # Check has table or not + if milvus.exists_table(): + milvus.delete() + time.sleep(10) + milvus.create_table(table_name, dimension, index_file_size, metric_type) + res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"]) + logger.info(res) + + elif op_type == "query": + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + milvus = MilvusClient(table_name, ip=self.ip, port=self.port) + # parse index info + index_types = param["index.index_types"] + nlists = param["index.nlists"] + # parse top-k, nq, nprobe + top_ks, nqs, nprobes = parser.search_params_parser(param) + + for index_type in index_types: + for nlist in nlists: + milvus.create_index(index_type, nlist) + # preload index + milvus.preload_table() + # Run query test + for nprobe in nprobes: + logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe)) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + headers = [param["dataset"]] + headers.extend([str(top_k) for top_k in top_ks]) + utils.print_table(headers, nqs, res) + + elif run_type == "stability": + for op_type, op_value in definition.items(): + if op_type != "query": + logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type) + break + run_count = op_value["run_count"] + run_params = op_value["params"] + nq = 10000 + + for index, param in enumerate(run_params): + logger.info("Definition param: %s" % str(param)) + table_name = param["dataset"] + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + + # set default test time + if "during_time" not in param: + during_time = 100 # seconds + else: + during_time = int(param["during_time"]) * 60 + # set default query process num + if "query_process_num" not in param: + query_process_num = 10 + else: + query_process_num = int(param["query_process_num"]) + milvus = MilvusClient(table_name) + # Check has table or not + if not milvus.exists_table(): + logger.warning("Table %s not existed, continue exec next params ..." % table_name) + continue + + start_time = time.time() + insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] + while time.time() < start_time + during_time: + processes = [] + # # do query + # for i in range(query_process_num): + # milvus_instance = MilvusClient(table_name) + # top_k = random.choice([x for x in range(1, 100)]) + # nq = random.choice([x for x in range(1, 1000)]) + # nprobe = random.choice([x for x in range(1, 500)]) + # logger.info(nprobe) + # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, )) + # processes.append(p) + # p.start() + # time.sleep(0.1) + # for p in processes: + # p.join() + milvus_instance = MilvusClient(table_name) + top_ks = random.sample([x for x in range(1, 100)], 4) + nqs = random.sample([x for x in range(1, 1000)], 3) + nprobe = random.choice([x for x in range(1, 500)]) + res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count) + # milvus_instance = MilvusClient(table_name) + status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))]) + if not status.OK(): + logger.error(status.message) + if (time.time() - start_time) % 300 == 0: + status = milvus_instance.drop_index() + if not status.OK(): + logger.error(status.message) + index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"]) + status = milvus_instance.create_index(index_type, 16384) + if not status.OK(): + logger.error(status.message) diff --git a/tests/milvus_benchmark/main.py b/tests/milvus_benchmark/main.py new file mode 100644 index 0000000000..c11237c5e7 --- /dev/null +++ b/tests/milvus_benchmark/main.py @@ -0,0 +1,131 @@ +import os +import sys +import time +import pdb +import argparse +import logging +import utils +from yaml import load, dump +from logging import handlers +from parser import operations_parser +from local_runner import LocalRunner +from docker_runner import DockerRunner + +DEFAULT_IMAGE = "milvusdb/milvus:latest" +LOG_FOLDER = "benchmark_logs" +logger = logging.getLogger("milvus_benchmark") + +formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s') +if not os.path.exists(LOG_FOLDER): + os.system('mkdir -p %s' % LOG_FOLDER) +fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'milvus_benchmark'), "D", 1, 10) +fileTimeHandler.suffix = "%Y%m%d.log" +fileTimeHandler.setFormatter(formatter) +logging.basicConfig(level=logging.DEBUG) +fileTimeHandler.setFormatter(formatter) +logger.addHandler(fileTimeHandler) + + +def positive_int(s): + i = None + try: + i = int(s) + except ValueError: + pass + if not i or i < 1: + raise argparse.ArgumentTypeError("%r is not a positive integer" % s) + return i + + +# # link random_data if not exists +# def init_env(): +# if not os.path.islink(BINARY_DATA_FOLDER): +# try: +# os.symlink(SRC_BINARY_DATA_FOLDER, BINARY_DATA_FOLDER) +# except Exception as e: +# logger.error("Create link failed: %s" % str(e)) +# sys.exit() + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--image', + help='use the given image') + parser.add_argument( + '--local', + action='store_true', + help='use local milvus server') + parser.add_argument( + "--run-count", + default=1, + type=positive_int, + help="run each db operation times") + # performance / stability / accuracy test + parser.add_argument( + "--run-type", + default="performance", + help="run type, default performance") + parser.add_argument( + '--suites', + metavar='FILE', + help='load test suites from FILE', + default='suites.yaml') + parser.add_argument( + '--ip', + help='server ip param for local mode', + default='127.0.0.1') + parser.add_argument( + '--port', + help='server port param for local mode', + default='19530') + + args = parser.parse_args() + + operations = None + # Get all benchmark test suites + if args.suites: + with open(args.suites) as f: + suites_dict = load(f) + f.close() + # With definition order + operations = operations_parser(suites_dict, run_type=args.run_type) + + # init_env() + run_params = {"run_count": args.run_count} + + if args.image: + # for docker mode + if args.local: + logger.error("Local mode and docker mode are incompatible arguments") + sys.exit(-1) + # Docker pull image + if not utils.pull_image(args.image): + raise Exception('Image %s pull failed' % image) + + # TODO: Check milvus server port is available + logger.info("Init: remove all containers created with image: %s" % args.image) + utils.remove_all_containers(args.image) + runner = DockerRunner(args.image) + for operation_type in operations: + logger.info("Start run test, test type: %s" % operation_type) + run_params["params"] = operations[operation_type] + runner.run({operation_type: run_params}, run_type=args.run_type) + logger.info("Run params: %s" % str(run_params)) + + if args.local: + # for local mode + ip = args.ip + port = args.port + + runner = LocalRunner(ip, port) + for operation_type in operations: + logger.info("Start run local mode test, test type: %s" % operation_type) + run_params["params"] = operations[operation_type] + runner.run({operation_type: run_params}, run_type=args.run_type) + logger.info("Run params: %s" % str(run_params)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/milvus_benchmark/operation.py b/tests/milvus_benchmark/operation.py new file mode 100644 index 0000000000..348fa47f4a --- /dev/null +++ b/tests/milvus_benchmark/operation.py @@ -0,0 +1,10 @@ +from __future__ import absolute_import +import pdb +import time + +class Base(object): + pass + + +class Insert(Base): + pass \ No newline at end of file diff --git a/tests/milvus_benchmark/parser.py b/tests/milvus_benchmark/parser.py new file mode 100644 index 0000000000..1c1d2605f2 --- /dev/null +++ b/tests/milvus_benchmark/parser.py @@ -0,0 +1,66 @@ +import pdb +import logging + +logger = logging.getLogger("milvus_benchmark.parser") + + +def operations_parser(operations, run_type="performance"): + definitions = operations[run_type] + return definitions + + +def table_parser(table_name): + tmp = table_name.split("_") + # if len(tmp) != 5: + # return None + data_type = tmp[0] + table_size_unit = tmp[1][-1] + table_size = tmp[1][0:-1] + if table_size_unit == "m": + table_size = int(table_size) * 1000000 + elif table_size_unit == "b": + table_size = int(table_size) * 1000000000 + index_file_size = int(tmp[2]) + dimension = int(tmp[3]) + metric_type = str(tmp[4]) + return (data_type, table_size, index_file_size, dimension, metric_type) + + +def search_params_parser(param): + # parse top-k, set default value if top-k not in param + if "top_ks" not in param: + top_ks = [10] + else: + top_ks = param["top_ks"] + if isinstance(top_ks, int): + top_ks = [top_ks] + elif isinstance(top_ks, list): + top_ks = list(top_ks) + else: + logger.warning("Invalid format top-ks: %s" % str(top_ks)) + + # parse nqs, set default value if nq not in param + if "nqs" not in param: + nqs = [10] + else: + nqs = param["nqs"] + if isinstance(nqs, int): + nqs = [nqs] + elif isinstance(nqs, list): + nqs = list(nqs) + else: + logger.warning("Invalid format nqs: %s" % str(nqs)) + + # parse nprobes + if "nprobes" not in param: + nprobes = [1] + else: + nprobes = param["nprobes"] + if isinstance(nprobes, int): + nprobes = [nprobes] + elif isinstance(nprobes, list): + nprobes = list(nprobes) + else: + logger.warning("Invalid format nprobes: %s" % str(nprobes)) + + return top_ks, nqs, nprobes \ No newline at end of file diff --git a/tests/milvus_benchmark/report.py b/tests/milvus_benchmark/report.py new file mode 100644 index 0000000000..6041311513 --- /dev/null +++ b/tests/milvus_benchmark/report.py @@ -0,0 +1,10 @@ +# from tablereport import Table +# from tablereport.shortcut import write_to_excel + +# RESULT_FOLDER = "results" + +# def create_table(headers, bodys, table_name): +# table = Table(header=[headers], +# body=[bodys]) + +# write_to_excel('%s/%s.xlsx' % (RESULT_FOLDER, table_name), table) \ No newline at end of file diff --git a/tests/milvus_benchmark/requirements.txt b/tests/milvus_benchmark/requirements.txt new file mode 100644 index 0000000000..1285b4d2ba --- /dev/null +++ b/tests/milvus_benchmark/requirements.txt @@ -0,0 +1,6 @@ +numpy==1.16.3 +pymilvus>=0.1.18 +pyyaml==3.12 +docker==4.0.2 +tableprint==0.8.0 +ansicolors==1.1.8 \ No newline at end of file diff --git a/tests/milvus_benchmark/runner.py b/tests/milvus_benchmark/runner.py new file mode 100644 index 0000000000..d3ad345da9 --- /dev/null +++ b/tests/milvus_benchmark/runner.py @@ -0,0 +1,219 @@ +import os +import logging +import pdb +import time +import random +from multiprocessing import Process +import numpy as np +from client import MilvusClient +import utils +import parser + +logger = logging.getLogger("milvus_benchmark.runner") + +SERVER_HOST_DEFAULT = "127.0.0.1" +SERVER_PORT_DEFAULT = 19530 +VECTORS_PER_FILE = 1000000 +SIFT_VECTORS_PER_FILE = 100000 +MAX_NQ = 10001 +FILE_PREFIX = "binary_" + +RANDOM_SRC_BINARY_DATA_DIR = '/tmp/random/binary_data' +SIFT_SRC_DATA_DIR = '/tmp/sift1b/query' +SIFT_SRC_BINARY_DATA_DIR = '/tmp/sift1b/binary_data' +SIFT_SRC_GROUNDTRUTH_DATA_DIR = '/tmp/sift1b/groundtruth' + +WARM_TOP_K = 1 +WARM_NQ = 1 +DEFAULT_DIM = 512 + +GROUNDTRUTH_MAP = { + "1000000": "idx_1M.ivecs", + "2000000": "idx_2M.ivecs", + "5000000": "idx_5M.ivecs", + "10000000": "idx_10M.ivecs", + "20000000": "idx_20M.ivecs", + "50000000": "idx_50M.ivecs", + "100000000": "idx_100M.ivecs", + "200000000": "idx_200M.ivecs", + "500000000": "idx_500M.ivecs", + "1000000000": "idx_1000M.ivecs", +} + + +def gen_file_name(idx, table_dimension, data_type): + s = "%05d" % idx + fname = FILE_PREFIX + str(table_dimension) + "d_" + s + ".npy" + if data_type == "random": + fname = RANDOM_SRC_BINARY_DATA_DIR+'/'+fname + elif data_type == "sift": + fname = SIFT_SRC_BINARY_DATA_DIR+'/'+fname + return fname + + +def get_vectors_from_binary(nq, dimension, data_type): + # use the first file, nq should be less than VECTORS_PER_FILE + if nq > MAX_NQ: + raise Exception("Over size nq") + if data_type == "random": + file_name = gen_file_name(0, dimension, data_type) + elif data_type == "sift": + file_name = SIFT_SRC_DATA_DIR+'/'+'query.npy' + data = np.load(file_name) + vectors = data[0:nq].tolist() + return vectors + + +class Runner(object): + def __init__(self): + pass + + def do_insert(self, milvus, table_name, data_type, dimension, size, ni): + ''' + @params: + mivlus: server connect instance + dimension: table dimensionn + # index_file_size: size trigger file merge + size: row count of vectors to be insert + ni: row count of vectors to be insert each time + # store_id: if store the ids returned by call add_vectors or not + @return: + total_time: total time for all insert operation + qps: vectors added per second + ni_time: avarage insert operation time + ''' + bi_res = {} + total_time = 0.0 + qps = 0.0 + ni_time = 0.0 + if data_type == "random": + vectors_per_file = VECTORS_PER_FILE + elif data_type == "sift": + vectors_per_file = SIFT_VECTORS_PER_FILE + if size % vectors_per_file or ni > vectors_per_file: + raise Exception("Not invalid table size or ni") + file_num = size // vectors_per_file + for i in range(file_num): + file_name = gen_file_name(i, dimension, data_type) + logger.info("Load npy file: %s start" % file_name) + data = np.load(file_name) + logger.info("Load npy file: %s end" % file_name) + loops = vectors_per_file // ni + for j in range(loops): + vectors = data[j*ni:(j+1)*ni].tolist() + ni_start_time = time.time() + # start insert vectors + start_id = i * vectors_per_file + j * ni + end_id = start_id + len(vectors) + logger.info("Start id: %s, end id: %s" % (start_id, end_id)) + ids = [k for k in range(start_id, end_id)] + status, ids = milvus.insert(vectors, ids=ids) + ni_end_time = time.time() + total_time = total_time + ni_end_time - ni_start_time + + qps = round(size / total_time, 2) + ni_time = round(total_time / (loops * file_num), 2) + bi_res["total_time"] = round(total_time, 2) + bi_res["qps"] = qps + bi_res["ni_time"] = ni_time + return bi_res + + def do_query(self, milvus, table_name, top_ks, nqs, nprobe, run_count): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + + bi_res = [] + for index, nq in enumerate(nqs): + tmp_res = [] + for top_k in top_ks: + avg_query_time = 0.0 + total_query_time = 0.0 + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + for i in range(run_count): + logger.info("Start run query, run %d of %s" % (i+1, run_count)) + start_time = time.time() + status, query_res = milvus.query(vectors, top_k, nprobe) + total_query_time = total_query_time + (time.time() - start_time) + if status.code: + logger.error("Query failed with message: %s" % status.message) + avg_query_time = round(total_query_time / run_count, 2) + logger.info("Avarage query time: %.2f" % avg_query_time) + tmp_res.append(avg_query_time) + bi_res.append(tmp_res) + return bi_res + + def do_query_ids(self, milvus, table_name, top_k, nq, nprobe): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + status, query_res = milvus.query(vectors, top_k, nprobe) + if not status.OK(): + msg = "Query failed with message: %s" % status.message + raise Exception(msg) + result_ids = [] + for result in query_res: + tmp = [] + for item in result: + tmp.append(item.id) + result_ids.append(tmp) + return result_ids + + def do_query_acc(self, milvus, table_name, top_k, nq, nprobe, id_store_name): + (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name) + base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type) + vectors = base_query_vectors[0:nq] + logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors))) + status, query_res = milvus.query(vectors, top_k, nprobe) + if not status.OK(): + msg = "Query failed with message: %s" % status.message + raise Exception(msg) + # if file existed, cover it + if os.path.isfile(id_store_name): + os.remove(id_store_name) + with open(id_store_name, 'a+') as fd: + for nq_item in query_res: + for item in nq_item: + fd.write(str(item.id)+'\t') + fd.write('\n') + + # compute and print accuracy + def compute_accuracy(self, flat_file_name, index_file_name): + flat_id_list = []; index_id_list = [] + logger.info("Loading flat id file: %s" % flat_file_name) + with open(flat_file_name, 'r') as flat_id_fd: + for line in flat_id_fd: + tmp_list = line.strip("\n").strip().split("\t") + flat_id_list.append(tmp_list) + logger.info("Loading index id file: %s" % index_file_name) + with open(index_file_name) as index_id_fd: + for line in index_id_fd: + tmp_list = line.strip("\n").strip().split("\t") + index_id_list.append(tmp_list) + if len(flat_id_list) != len(index_id_list): + raise Exception("Flat index result length: not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list))) + # get the accuracy + return self.get_recall_value(flat_id_list, index_id_list) + + def get_recall_value(self, flat_id_list, index_id_list): + """ + Use the intersection length + """ + sum_radio = 0.0 + for index, item in enumerate(index_id_list): + tmp = set(item).intersection(set(flat_id_list[index])) + sum_radio = sum_radio + len(tmp) / len(item) + return round(sum_radio / len(index_id_list), 3) + + """ + Implementation based on: + https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py + """ + def get_groundtruth_ids(self, table_size): + fname = GROUNDTRUTH_MAP[str(table_size)] + fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname + a = np.fromfile(fname, dtype='int32') + d = a[0] + true_ids = a.reshape(-1, d + 1)[:, 1:].copy() + return true_ids diff --git a/tests/milvus_benchmark/suites.yaml b/tests/milvus_benchmark/suites.yaml new file mode 100644 index 0000000000..f30b963c03 --- /dev/null +++ b/tests/milvus_benchmark/suites.yaml @@ -0,0 +1,38 @@ +# data sets +datasets: + hf5: + gist-960,sift-128 + npy: + 50000000-512, 100000000-512 + +operations: + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including nprpbe/nlist/use_blas_threshold/.. + [ + # debug + # {"dataset": "ip_ivfsq8_1000", "top_ks": [16], "nqs": [1], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + + {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110}, + # {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], "nqs": [1, 10, 100, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110}, + ] + + # interface: add_vectors + insert: + # index_type: flat/ivf_flat/ivf_sq8 + [ + # debug + + + {"table_name": "ip_ivf_flat_20m_1024", "table.index_type": "ivf_flat", "server.index_building_threshold": 1024, "table.size": 20000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110}, + {"table_name": "ip_ivf_sq8_50m_1024", "table.index_type": "ivf_sq8", "server.index_building_threshold": 1024, "table.size": 50000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110}, + ] + + # TODO: interface: build_index + build: [] + diff --git a/tests/milvus_benchmark/suites_accuracy.yaml b/tests/milvus_benchmark/suites_accuracy.yaml new file mode 100644 index 0000000000..fd8a5904fa --- /dev/null +++ b/tests/milvus_benchmark/suites_accuracy.yaml @@ -0,0 +1,121 @@ + +accuracy: + # interface: search_vectors + query: + [ + { + "dataset": "random_20m_1024_512_ip", + # index info + "index.index_types": ["flat", "ivf_sq8"], + "index.nlists": [16384], + "index.metric_types": ["ip"], + "nprobes": [1, 16, 64], + "top_ks": [64], + "nqs": [100], + "server.cpu_cache_capacity": 100, + "server.resources": ["cpu", "gpu0"], + "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip", + }, + # { + # "dataset": "sift_50m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 160, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2", + # "sift_acc": true + # }, + # { + # "dataset": "sift_50m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 160, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2_sq8", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "index.metric_types": ["l2"], + # "nprobes": [1, 16, 64, 128], + # "top_ks": [64], + # "nqs": [100], + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h", + # "sift_acc": true + # }, + # { + # "dataset": "sift_1m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "sift_10m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 32, + # }, + # { + # "dataset": "sift_50m_1024_128_l2", + # "index.index_types": ["flat", "ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1, 32, 128, 256, 512], + # "nqs": 10, + # "top_ks": 10, + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64, + # } + + + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_performance.yaml b/tests/milvus_benchmark/suites_performance.yaml new file mode 100644 index 0000000000..52d457a400 --- /dev/null +++ b/tests/milvus_benchmark/suites_performance.yaml @@ -0,0 +1,258 @@ +performance: + + # interface: add_vectors + insert: + # index_type: flat/ivf_flat/ivf_sq8/mix_nsg + [ + # debug + # data_type / data_size / index_file_size / dimension + # data_type: random / ann_sift + # data_size: 10m / 1b + # { + # "table_name": "random_50m_1024_512_ip", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "table_name": "random_5m_1024_512_ip", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_5m_1024_512_ip" + # }, + # { + # "table_name": "sift_1m_50_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "table_name": "sift_1m_256_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # "db_path_prefix": "/test/milvus/db_data" + # } + # { + # "table_name": "sift_50m_1024_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # # "server.cpu_cache_capacity": 16, + # }, + # { + # "table_name": "sift_100m_1024_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # }, + # { + # "table_name": "sift_1b_2048_128_l2", + # "ni_per": 100000, + # "processes": 5, # multiprocessing + # "server.cpu_cache_capacity": 16, + # } + ] + + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity .. + [ + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 200, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + { + "dataset": "random_50m_1024_512_ip", + "index.index_types": ["ivf_sq8h"], + "index.nlists": [16384], + "nprobes": [8], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + "top_ks": [512], + # "nqs": [1, 10, 100, 500, 1000], + "nqs": [500], + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 150, + "server.gpu_cache_capacity": 6, + "server.resources": ["cpu", "gpu0", "gpu1"], + "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip" + }, + # { + # "dataset": "random_50m_1024_512_ip", + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 150, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip_sq8" + # }, + # { + # "dataset": "random_20m_1024_512_ip", + # "index.index_types": ["flat"], + # "index.nlists": [16384], + # "nprobes": [50], + # "top_ks": [64], + # "nqs": [10], + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 100, + # "server.resources": ["cpu", "gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip" + # }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 250, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [8, 32], + # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000], + # "nqs": [1, 10, 100, 500, 1000], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 250, + # "server.resources": ["cpu"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64 + # }, + # { + # "dataset": "sift_500m_1024_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 8, 16, 64, 256, 512, 1000], + # "nqs": [1, 100, 500, 800, 1000, 1500], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 120, + # "server.resources": ["gpu0", "gpu1"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8h"], + # "index.nlists": [16384], + # "nprobes": [1], + # # "top_ks": [1], + # # "nqs": [1], + # "top_ks": [256], + # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 110, + # "server.resources": ["cpu", "gpu0"], + # "db_path_prefix": "/test/milvus/db_data" + # }, + # { + # "dataset": "random_50m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 128 + # }, + # [ + # { + # "dataset": "sift_1m_50_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1], + # "nqs": [1], + # "db_path_prefix": "/test/milvus/db_data" + # # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # # "server.cpu_cache_capacity": 256 + # } + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_stability.yaml b/tests/milvus_benchmark/suites_stability.yaml new file mode 100644 index 0000000000..408221079e --- /dev/null +++ b/tests/milvus_benchmark/suites_stability.yaml @@ -0,0 +1,17 @@ + +stability: + # interface: search_vectors / add_vectors mix operation + query: + [ + { + "dataset": "random_20m_1024_512_ip", + # "nqs": [1, 10, 100, 1000, 10000], + # "pds": [0.1, 0.44, 0.44, 0.02], + "query_process_num": 10, + # each 10s, do an insertion + # "insert_interval": 1, + # minutes + "during_time": 360, + "server.cpu_cache_capacity": 100 + }, + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/suites_yzb.yaml b/tests/milvus_benchmark/suites_yzb.yaml new file mode 100644 index 0000000000..59efcb37e4 --- /dev/null +++ b/tests/milvus_benchmark/suites_yzb.yaml @@ -0,0 +1,171 @@ +#"server.resources": ["gpu0", "gpu1"] + +performance: + # interface: search_vectors + query: + # dataset: table name you have already created + # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity .. + [ + # debug + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 16, + # }, + # { + # "dataset": "random_10m_1024_512_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 64 + # }, +# { +# "dataset": "sift_50m_1024_128_l2", +# # index info +# "index.index_types": ["ivf_sq8"], +# "index.nlists": [16384], +# "nprobes": [1, 32, 128], +# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], +# "nqs": [1, 10, 100, 500, 800], +# # "top_ks": [256], +# # "nqs": [800], +# "processes": 1, # multiprocessing +# "server.use_blas_threshold": 1100, +# "server.cpu_cache_capacity": 310, +# "server.resources": ["gpu0", "gpu1"] +# }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["cpu"] + }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["gpu0"] + }, + { + "dataset": "sift_1m_1024_128_l2", + # index info + "index.index_types": ["ivf_sq8"], + "index.nlists": [16384], + "nprobes": [32], + "top_ks": [10], + "nqs": [100], + # "top_ks": [256], + # "nqs": [800], + "processes": 1, # multiprocessing + "server.use_blas_threshold": 1100, + "server.cpu_cache_capacity": 310, + "server.resources": ["gpu0", "gpu1"] + }, + # { + # "dataset": "sift_1b_2048_128_l2", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # # "top_ks": [256], + # # "nqs": [800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 310 + # }, +# { +# "dataset": "random_50m_1024_512_l2", +# # index info +# "index.index_types": ["ivf_sq8"], +# "index.nlists": [16384], +# "nprobes": [1], +# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], +# "nqs": [1, 10, 100, 500, 800], +# # "top_ks": [256], +# # "nqs": [800], +# "processes": 1, # multiprocessing +# "server.use_blas_threshold": 1100, +# "server.cpu_cache_capacity": 128, +# "server.resources": ["gpu0", "gpu1"] +# }, + # { + # "dataset": "random_100m_1024_512_ip", + # # index info + # "index.index_types": ["ivf_sq8"], + # "index.nlists": [16384], + # "nprobes": [1], + # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], + # "nqs": [1, 10, 100, 500, 800], + # "processes": 1, # multiprocessing + # "server.use_blas_threshold": 1100, + # "server.cpu_cache_capacity": 256 + # }, + ] \ No newline at end of file diff --git a/tests/milvus_benchmark/utils.py b/tests/milvus_benchmark/utils.py new file mode 100644 index 0000000000..f6522578ad --- /dev/null +++ b/tests/milvus_benchmark/utils.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function + +__true_print = print # noqa + +import os +import sys +import pdb +import time +import datetime +import argparse +import threading +import logging +import docker +import multiprocessing +import numpy +# import psutil +from yaml import load, dump +import tableprint as tp + +logger = logging.getLogger("milvus_benchmark.utils") + +MULTI_DB_SLAVE_PATH = "/opt/milvus/data2;/opt/milvus/data3" + + +def get_current_time(): + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + + +def print_table(headers, columns, data): + bodys = [] + for index, value in enumerate(columns): + tmp = [value] + tmp.extend(data[index]) + bodys.append(tmp) + tp.table(bodys, headers) + + +def modify_config(k, v, type=None, file_path="conf/server_config.yaml", db_slave=None): + if not os.path.isfile(file_path): + raise Exception('File: %s not found' % file_path) + with open(file_path) as f: + config_dict = load(f) + f.close() + if config_dict: + if k.find("use_blas_threshold") != -1: + config_dict['engine_config']['use_blas_threshold'] = int(v) + elif k.find("cpu_cache_capacity") != -1: + config_dict['cache_config']['cpu_cache_capacity'] = int(v) + elif k.find("gpu_cache_capacity") != -1: + config_dict['cache_config']['gpu_cache_capacity'] = int(v) + elif k.find("resource_pool") != -1: + config_dict['resource_config']['resource_pool'] = v + + if db_slave: + config_dict['db_config']['db_slave_path'] = MULTI_DB_SLAVE_PATH + with open(file_path, 'w') as f: + dump(config_dict, f, default_flow_style=False) + f.close() + else: + raise Exception('Load file:%s error' % file_path) + + +def pull_image(image): + registry = image.split(":")[0] + image_tag = image.split(":")[1] + client = docker.APIClient(base_url='unix://var/run/docker.sock') + logger.info("Start pulling image: %s" % image) + return client.pull(registry, image_tag) + + +def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None): + import colors + + client = docker.from_env() + # if mem_limit is None: + # mem_limit = psutil.virtual_memory().available + # logger.info('Memory limit:', mem_limit) + # cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1) + # logger.info('Running on CPUs:', cpu_limit) + for dir_item in ['logs', 'db']: + try: + os.mkdir(os.path.abspath(dir_item)) + except Exception as e: + pass + + if test_type == "local": + volumes = { + os.path.abspath('conf'): + {'bind': '/opt/milvus/conf', 'mode': 'ro'}, + os.path.abspath('logs'): + {'bind': '/opt/milvus/logs', 'mode': 'rw'}, + os.path.abspath('db'): + {'bind': '/opt/milvus/db', 'mode': 'rw'}, + } + elif test_type == "remote": + if volume_name is None: + raise Exception("No volume name") + remote_log_dir = volume_name+'/logs' + remote_db_dir = volume_name+'/db' + + for dir_item in [remote_log_dir, remote_db_dir]: + if not os.path.isdir(dir_item): + os.makedirs(dir_item, exist_ok=True) + volumes = { + os.path.abspath('conf'): + {'bind': '/opt/milvus/conf', 'mode': 'ro'}, + remote_log_dir: + {'bind': '/opt/milvus/logs', 'mode': 'rw'}, + remote_db_dir: + {'bind': '/opt/milvus/db', 'mode': 'rw'} + } + # add volumes + if db_slave and isinstance(db_slave, int): + for i in range(2, db_slave+1): + remote_db_dir = volume_name+'/data'+str(i) + if not os.path.isdir(remote_db_dir): + os.makedirs(remote_db_dir, exist_ok=True) + volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'} + + container = client.containers.run( + image, + volumes=volumes, + runtime="nvidia", + ports={'19530/tcp': 19530, '8080/tcp': 8080}, + environment=["OMP_NUM_THREADS=48"], + # cpuset_cpus=cpu_limit, + # mem_limit=mem_limit, + # environment=[""], + detach=True) + + def stream_logs(): + for line in container.logs(stream=True): + logger.info(colors.color(line.decode().rstrip(), fg='blue')) + + if sys.version_info >= (3, 0): + t = threading.Thread(target=stream_logs, daemon=True) + else: + t = threading.Thread(target=stream_logs) + t.daemon = True + t.start() + + logger.info('Container: %s started' % container) + return container + # exit_code = container.wait(timeout=timeout) + # # Exit if exit code + # if exit_code == 0: + # return container + # elif exit_code is not None: + # print(colors.color(container.logs().decode(), fg='red')) + # raise Exception('Child process raised exception %s' % str(exit_code)) + +def restart_server(container): + client = docker.APIClient(base_url='unix://var/run/docker.sock') + + client.restart(container.name) + logger.info('Container: %s restarted' % container.name) + return container + + +def remove_container(container): + container.remove(force=True) + logger.info('Container: %s removed' % container) + + +def remove_all_containers(image): + client = docker.from_env() + try: + for container in client.containers.list(): + if image in container.image.tags: + container.stop(timeout=30) + container.remove(force=True) + except Exception as e: + logger.error("Containers removed failed") + + +def container_exists(image): + ''' + Check if container existed with the given image name + @params: image name + @return: container if exists + ''' + res = False + client = docker.from_env() + for container in client.containers.list(): + if image in container.image.tags: + # True + res = container + return res + + +if __name__ == '__main__': + # print(pull_image('branch-0.3.1-debug')) + stop_server() \ No newline at end of file diff --git a/tests/milvus_python_test/.dockerignore b/tests/milvus_python_test/.dockerignore new file mode 100644 index 0000000000..c97d9d043c --- /dev/null +++ b/tests/milvus_python_test/.dockerignore @@ -0,0 +1,14 @@ +node_modules +npm-debug.log +Dockerfile* +docker-compose* +.dockerignore +.git +.gitignore +.env +*/bin +*/obj +README.md +LICENSE +.vscode +__pycache__ \ No newline at end of file diff --git a/tests/milvus_python_test/.gitignore b/tests/milvus_python_test/.gitignore new file mode 100644 index 0000000000..9bd7345e51 --- /dev/null +++ b/tests/milvus_python_test/.gitignore @@ -0,0 +1,13 @@ +.python-version +.pytest_cache +__pycache__ +.vscode +.idea + +test_out/ +*.pyc + +db/ +logs/ + +.coverage diff --git a/tests/milvus_python_test/Dockerfile b/tests/milvus_python_test/Dockerfile new file mode 100644 index 0000000000..ec78b943dc --- /dev/null +++ b/tests/milvus_python_test/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.6.8-jessie + +LABEL Name=megasearch_engine_test Version=0.0.1 + +WORKDIR /app +ADD . /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libc-dev build-essential && \ + python3 -m pip install -r requirements.txt && \ + apt-get remove --purge -y + +ENTRYPOINT [ "/app/docker-entrypoint.sh" ] +CMD [ "start" ] \ No newline at end of file diff --git a/tests/milvus_python_test/MilvusCases.md b/tests/milvus_python_test/MilvusCases.md new file mode 100644 index 0000000000..ea6372c373 --- /dev/null +++ b/tests/milvus_python_test/MilvusCases.md @@ -0,0 +1,143 @@ +# Milvus test cases + +## * Interfaces test + +### 1. 连接测试 + +#### 1.1 连接 + +| cases | expected | +| ---------------- | -------------------------------------------- | +| 非法IP 123.0.0.2 | method: connect raise error in given timeout | +| 正常 uri | attr: connected assert true | +| 非法 uri | method: connect raise error in given timeout | +| 最大连接数 | all connection attrs: connected assert true | +| | | + +#### 1.2 断开连接 + +| cases | expected | +| ------------------------ | ------------------- | +| 正常连接下,断开连接 | connect raise error | +| 正常连接下,重复断开连接 | connect raise error | + +### 2. Table operation + +#### 2.1 表创建 + +##### 2.1.1 表名 + +| cases | expected | +| ------------------------- | ----------- | +| 基础功能,参数正常 | status pass | +| 表名已存在 | status fail | +| 表名:"中文" | status pass | +| 表名带特殊字符: "-39fsd-" | status pass | +| 表名带空格: "test1 2" | status pass | +| invalid dim: 0 | raise error | +| invalid dim: -1 | raise error | +| invalid dim: 100000000 | raise error | +| invalid dim: "string" | raise error | +| index_type: 0 | status pass | +| index_type: 1 | status pass | +| index_type: 2 | status pass | +| index_type: string | raise error | +| | | + +##### 2.1.2 维数支持 + +| cases | expected | +| --------------------- | ----------- | +| 维数: 0 | raise error | +| 维数负数: -1 | raise error | +| 维数最大值: 100000000 | raise error | +| 维数字符串: "string" | raise error | +| | | + +##### 2.1.3 索引类型支持 + +| cases | expected | +| ---------------- | ----------- | +| 索引类型: 0 | status pass | +| 索引类型: 1 | status pass | +| 索引类型: 2 | status pass | +| 索引类型: string | raise error | +| | | + +#### 2.2 表说明 + +| cases | expected | +| ---------------------- | -------------------------------- | +| 创建表后,执行describe | 返回结构体,元素与创建表参数一致 | +| | | + +#### 2.3 表删除 + +| cases | expected | +| -------------- | ---------------------- | +| 删除已存在表名 | has_table return False | +| 删除不存在表名 | status fail | +| | | + +#### 2.4 表是否存在 + +| cases | expected | +| ----------------------- | ------------ | +| 存在表,调用has_table | assert true | +| 不存在表,调用has_table | assert false | +| | | + +#### 2.5 查询表记录条数 + +| cases | expected | +| -------------------- | ------------------------ | +| 空表 | 0 | +| 空表插入数据(单条) | 1 | +| 空表插入数据(多条) | assert length of vectors | + +#### 2.6 查询表数量 + +| cases | expected | +| --------------------------------------------- | -------------------------------- | +| 两张表,一张空表,一张有数据:调用show tables | assert length of table list == 2 | +| | | + +### 3. Add vectors + +| interfaces | cases | expected | +| ----------- | --------------------------------------------------------- | ------------------------------------ | +| add_vectors | add basic | assert length of ids == nq | +| | add vectors into table not existed | status fail | +| | dim not match: single vector | status fail | +| | dim not match: vector list | status fail | +| | single vector element empty | status fail | +| | vector list element empty | status fail | +| | query immediately after adding | status pass | +| | query immediately after sleep 6s | status pass && length of result == 1 | +| | concurrent add with multi threads(share one connection) | status pass | +| | concurrent add with multi threads(independent connection) | status pass | +| | concurrent add with multi process(independent connection) | status pass | +| | index_type: 2 | status pass | +| | index_type: string | raise error | +| | | | + +### 4. Search vectors + +| interfaces | cases | expected | +| -------------- | ------------------------------------------------- | -------------------------------- | +| search_vectors | search basic(query vector in vectors, top-k nq | assert length of result == nq | +| | concurrent search | status pass | +| | query_range(get_current_day(), get_current_day()) | assert length of result == nq | +| | invalid query_range: "" | raise error | +| | query_range(get_last_day(2), get_last_day(1)) | assert length of result == 0 | +| | query_range(get_last_day(2), get_current_day()) | assert length of result == nq | +| | query_range((get_last_day(2), get_next_day(2)) | assert length of result == nq | +| | query_range((get_current_day(), get_next_day(2)) | assert length of result == nq | +| | query_range(get_next_day(1), get_next_day(2)) | assert length of result == 0 | +| | score: vector[i] = vector[i]+-0.01 | score > 99.9 | \ No newline at end of file diff --git a/tests/milvus_python_test/README.md b/tests/milvus_python_test/README.md new file mode 100644 index 0000000000..69b4384d4c --- /dev/null +++ b/tests/milvus_python_test/README.md @@ -0,0 +1,14 @@ +# Requirements +* python 3.6.8 + +# How to use this Test Project +```shell +pytest . -q -v + ``` +with allure test report + ```shell +pytest --alluredir=test_out . -q -v +allure serve test_out + ``` +# Contribution getting started +* Follow PEP-8 for naming and black for formatting. \ No newline at end of file diff --git a/tests/milvus_python_test/conf/log_config.conf b/tests/milvus_python_test/conf/log_config.conf new file mode 100644 index 0000000000..c530fa4c60 --- /dev/null +++ b/tests/milvus_python_test/conf/log_config.conf @@ -0,0 +1,27 @@ +* GLOBAL: + FORMAT = "%datetime | %level | %logger | %msg" + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log" + ENABLED = true + TO_FILE = true + TO_STANDARD_OUTPUT = false + SUBSECOND_PRECISION = 3 + PERFORMANCE_TRACKING = false + MAX_LOG_FILE_SIZE = 209715200 ## Throw log files away after 200MB +* DEBUG: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log" + ENABLED = true +* WARNING: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log" +* TRACE: + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log" +* VERBOSE: + FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg" + TO_FILE = false + TO_STANDARD_OUTPUT = false +## Error logs +* ERROR: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log" +* FATAL: + ENABLED = true + FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log" diff --git a/tests/milvus_python_test/conf/server_config.yaml b/tests/milvus_python_test/conf/server_config.yaml new file mode 100644 index 0000000000..6fe7e05791 --- /dev/null +++ b/tests/milvus_python_test/conf/server_config.yaml @@ -0,0 +1,32 @@ +server_config: + address: 0.0.0.0 + port: 19530 + deploy_mode: single + time_zone: UTC+8 + +db_config: + primary_path: /opt/milvus + secondary_path: + backend_url: sqlite://:@:/ + insert_buffer_size: 4 + build_index_gpu: 0 + preload_table: + +metric_config: + enable_monitor: true + collector: prometheus + prometheus_config: + port: 8080 + +cache_config: + cpu_cache_capacity: 8 + cpu_cache_threshold: 0.85 + cache_insert_data: false + +engine_config: + use_blas_threshold: 20 + +resource_config: + resource_pool: + - cpu + - gpu0 diff --git a/tests/milvus_python_test/conftest.py b/tests/milvus_python_test/conftest.py new file mode 100644 index 0000000000..c6046ed56f --- /dev/null +++ b/tests/milvus_python_test/conftest.py @@ -0,0 +1,128 @@ +import socket +import pdb +import logging + +import pytest +from utils import gen_unique_str +from milvus import Milvus, IndexType, MetricType + +index_file_size = 10 + + +def pytest_addoption(parser): + parser.addoption("--ip", action="store", default="localhost") + parser.addoption("--port", action="store", default=19530) + + +def check_server_connection(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + connected = True + if ip and (ip not in ['localhost', '127.0.0.1']): + try: + socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP) + except Exception as e: + print("Socket connnet failed: %s" % str(e)) + connected = False + return connected + + +def get_args(request): + args = { + "ip": request.config.getoption("--ip"), + "port": request.config.getoption("--port") + } + return args + + +@pytest.fixture(scope="module") +def connect(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + milvus = Milvus() + try: + milvus.connect(host=ip, port=port) + except: + pytest.exit("Milvus server can not connected, exit pytest ...") + + def fin(): + try: + milvus.disconnect() + except: + pass + + request.addfinalizer(fin) + return milvus + + +@pytest.fixture(scope="module") +def dis_connect(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + milvus = Milvus() + milvus.connect(host=ip, port=port) + milvus.disconnect() + def fin(): + try: + milvus.disconnect() + except: + pass + + request.addfinalizer(fin) + return milvus + + +@pytest.fixture(scope="module") +def args(request): + ip = request.config.getoption("--ip") + port = request.config.getoption("--port") + args = {"ip": ip, "port": port} + return args + + +@pytest.fixture(scope="function") +def table(request, connect): + ori_table_name = getattr(request.module, "table_id", "test") + table_name = gen_unique_str(ori_table_name) + dim = getattr(request.module, "dim", "128") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + # logging.getLogger().info(status) + if not status.OK(): + pytest.exit("Table can not be created, exit pytest ...") + + def teardown(): + status, table_names = connect.show_tables() + for table_name in table_names: + connect.delete_table(table_name) + + request.addfinalizer(teardown) + + return table_name + + +@pytest.fixture(scope="function") +def ip_table(request, connect): + ori_table_name = getattr(request.module, "table_id", "test") + table_name = gen_unique_str(ori_table_name) + dim = getattr(request.module, "dim", "128") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + # logging.getLogger().info(status) + if not status.OK(): + pytest.exit("Table can not be created, exit pytest ...") + + def teardown(): + status, table_names = connect.show_tables() + for table_name in table_names: + connect.delete_table(table_name) + + request.addfinalizer(teardown) + + return table_name \ No newline at end of file diff --git a/tests/milvus_python_test/docker-entrypoint.sh b/tests/milvus_python_test/docker-entrypoint.sh new file mode 100755 index 0000000000..af9ba0ba66 --- /dev/null +++ b/tests/milvus_python_test/docker-entrypoint.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e + +if [ "$1" = 'start' ]; then + tail -f /dev/null +fi + +exec "$@" \ No newline at end of file diff --git a/tests/milvus_python_test/pytest.ini b/tests/milvus_python_test/pytest.ini new file mode 100644 index 0000000000..3f95dc29b8 --- /dev/null +++ b/tests/milvus_python_test/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s) + +log_cli = true +log_level = 20 + +timeout = 300 + +level = 1 \ No newline at end of file diff --git a/tests/milvus_python_test/requirements.txt b/tests/milvus_python_test/requirements.txt new file mode 100644 index 0000000000..4bdecd6033 --- /dev/null +++ b/tests/milvus_python_test/requirements.txt @@ -0,0 +1,25 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 +pymilvus-test>=0.2.0 diff --git a/tests/milvus_python_test/requirements_cluster.txt b/tests/milvus_python_test/requirements_cluster.txt new file mode 100644 index 0000000000..a1f3b69be9 --- /dev/null +++ b/tests/milvus_python_test/requirements_cluster.txt @@ -0,0 +1,25 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 +pymilvus>=0.1.24 diff --git a/tests/milvus_python_test/requirements_no_pymilvus.txt b/tests/milvus_python_test/requirements_no_pymilvus.txt new file mode 100644 index 0000000000..45884c0c71 --- /dev/null +++ b/tests/milvus_python_test/requirements_no_pymilvus.txt @@ -0,0 +1,24 @@ +astroid==2.2.5 +atomicwrites==1.3.0 +attrs==19.1.0 +importlib-metadata==0.15 +isort==4.3.20 +lazy-object-proxy==1.4.1 +mccabe==0.6.1 +more-itertools==7.0.0 +numpy==1.16.3 +pluggy==0.12.0 +py==1.8.0 +pylint==2.3.1 +pytest==4.5.0 +pytest-timeout==1.3.3 +pytest-repeat==0.8.0 +allure-pytest==2.7.0 +pytest-print==0.1.2 +pytest-level==0.1.1 +six==1.12.0 +thrift==0.11.0 +typed-ast==1.3.5 +wcwidth==0.1.7 +wrapt==1.11.1 +zipp==0.5.1 diff --git a/tests/milvus_python_test/run.sh b/tests/milvus_python_test/run.sh new file mode 100644 index 0000000000..cee5b061f5 --- /dev/null +++ b/tests/milvus_python_test/run.sh @@ -0,0 +1,4 @@ +#/bin/bash + + +pytest . $@ \ No newline at end of file diff --git a/tests/milvus_python_test/test.template b/tests/milvus_python_test/test.template new file mode 100644 index 0000000000..0403d72860 --- /dev/null +++ b/tests/milvus_python_test/test.template @@ -0,0 +1,41 @@ +''' +Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +Unauthorized copying of this file, via any medium is strictly prohibited. +Proprietary and confidential. +''' + +''' +Test Description: + +This document is only a template to show how to write a auto-test script + +本文档仅仅是个展示如何编写自动化测试脚本的模板 + +''' + +import pytest +from milvus import Milvus + + +class TestConnection: + def test_connect_localhost(self): + + """ + TestCase1.1 + Test target: This case is to check if the server can be connected. + Test method: Call API: milvus.connect to connect local milvus server, ip address: 127.0.0.1 and port: 19530, check the return status + Expectation: Return status is OK. + + 测试目的:本用例测试客户端是否可以与服务器建立连接 + 测试方法:调用SDK API: milvus.connect方法连接本地服务器,IP地址:127.0.0.1,端口:19530,检查调用返回状态 + 期望结果:返回状态是:OK + + """ + + milvus = Milvus() + milvus.connect(host='127.0.0.1', port='19530') + assert milvus.connected + + + + diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py new file mode 100644 index 0000000000..d9faeede33 --- /dev/null +++ b/tests/milvus_python_test/test_add_vectors.py @@ -0,0 +1,1233 @@ +import time +import random +import pdb +import threading +import logging +from multiprocessing import Pool, Process +import pytest +from milvus import Milvus, IndexType, MetricType +from utils import * + + +dim = 128 +index_file_size = 10 +table_id = "test_add" +ADD_TIMEOUT = 60 +nprobe = 1 +epsilon = 0.0001 + +index_params = random.choice(gen_index_params()) +logging.getLogger().info(index_params) + + +class TestAddBase: + """ + ****************************************************************** + The following cases are used to test `add_vectors / index / search / delete` mixed function + ****************************************************************** + """ + + def test_add_vector_create_table(self, connect, table): + ''' + target: test add vector, then create table again + method: add vector and create table + expected: status not ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_add_vector_has_table(self, connect, table): + ''' + target: test add vector, then check table existence + method: add vector and call HasTable + expected: table exists, status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + ret = connect.has_table(table) + assert ret == True + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector(self, connect, table): + ''' + target: test add vector after table deleted + method: delete table and add vector + expected: status not ok + ''' + status = connect.delete_table(table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after table_2 deleted + method: delete table_2 and add vector to table_1 + expected: status ok + ''' + param = {'table_name': 'test_delete_table_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.delete_table(table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_table(self, connect, table): + ''' + target: test delete table after add vector + method: add vector and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.delete_table(table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_another_table(self, connect, table): + ''' + target: test delete table_1 table after add vector to table_2 + method: add vector and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_table(self, connect, table): + ''' + target: test delete table after add vector for a while + method: add vector, sleep, and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.delete_table(table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_another_table(self, connect, table): + ''' + target: test delete table_1 table after add vector to table_2 for a while + method: add vector , sleep, and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector(self, connect, table): + ''' + target: test add vector after build index + method: build index and add vector + expected: status ok + ''' + status = connect.create_index(table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_create_index_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_index(table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index(self, connect, table): + ''' + target: test build index add after vector + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index(self, connect, table): + ''' + target: test build index add after vector for a while + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index_another(self, connect, table): + ''' + target: test add vector to table_2 after build index for table_1 for a while + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector(self, connect, table): + ''' + target: test add vector after search table + method: search table and add vector + expected: status ok + ''' + vector = gen_single_vector(dim) + status, result = connect.search_vectors(table, 1, nprobe, vector) + status, ids = connect.add_vectors(table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_search_vector_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, result = connect.search_vectors(table, 1, nprobe, vector) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector(self, connect, table): + ''' + target: test search vector after add vector + method: add vector and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status, result = connect.search_vectors(table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector(self, connect, table): + ''' + target: test search vector after add vector after a while + method: add vector, sleep, and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status, result = connect.search_vectors(table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector_another(self, connect, table): + ''' + target: test add vector to table_1 after search table_2 a while + method: search table , sleep, and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(table, vector) + time.sleep(1) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + """ + ****************************************************************** + The following cases are used to test `add_vectors` function + ****************************************************************** + """ + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids(self, connect, table): + ''' + target: test add vectors in table, use customize ids + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors, ids) + time.sleep(2) + assert status.OK() + assert len(ids) == nq + # check search result + status, result = connect.search_vectors(table, top_k, nprobe, vectors) + logging.getLogger().info(result) + assert len(result) == nq + for i in range(nq): + assert result[i][0].id == i + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_ids_no_ids(self, connect, table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use customize ids first, and then use no ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors, ids) + assert status.OK() + status, ids = connect.add_vectors(table, vectors) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_not_ids_ids(self, connect, table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use not ids first, and then use customize ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + status, ids = connect.add_vectors(table, vectors, ids) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids_length_not_match(self, connect, table): + ''' + target: test add vectors in table, use customize ids, len(ids) != len(vectors) + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(1, nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(table, vectors, ids) + + @pytest.fixture( + scope="function", + params=gen_invalid_vector_ids() + ) + def get_vector_id(self, request): + yield request.param + + def test_add_vectors_ids_invalid(self, connect, table, get_vector_id): + ''' + target: test add vectors in table, use customize ids, which are not int64 + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + vector_id = get_vector_id + ids = [vector_id for i in range(nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(table, vectors, ids) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors(self, connect, table): + ''' + target: test add vectors in table created before + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == nq + + @pytest.mark.level(2) + def test_add_vectors_without_connect(self, dis_connect, table): + ''' + target: test add vectors without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + with pytest.raises(Exception) as e: + status, ids = dis_connect.add_vectors(table, vectors) + + def test_add_table_not_existed(self, connect): + ''' + target: test add vectors in table, which not existed before + method: add vectors table not existed, check the status + expected: status not ok + ''' + nq = 5 + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(gen_unique_str("not_exist_table"), vector) + assert not status.OK() + assert not ids + + def test_add_vector_dim_not_matched(self, connect, table): + ''' + target: test add vector, the vector dimension is not equal to the table dimension + method: the vector dimension is half of the table dimension, check the status + expected: status not ok + ''' + vector = gen_single_vector(int(dim)//2) + status, ids = connect.add_vectors(table, vector) + assert not status.OK() + + def test_add_vectors_dim_not_matched(self, connect, table): + ''' + target: test add vectors, the vector dimension is not equal to the table dimension + method: the vectors dimension is half of the table dimension, check the status + expected: status not ok + ''' + nq = 5 + vectors = gen_vectors(nq, int(dim)//2) + status, ids = connect.add_vectors(table, vectors) + assert not status.OK() + + def test_add_vector_query_after_sleep(self, connect, table): + ''' + target: test add vectors, and search it after sleep + method: set vector[0][1] as query vectors + expected: status ok and result length is 1 + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(table, vectors) + time.sleep(3) + status, result = connect.search_vectors(table, 1, nprobe, [vectors[0]]) + assert status.OK() + assert len(result) == 1 + + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_multi_threading(self, connect, table): + ''' + target: test add vectors, with multi threading + method: 10 thread add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + thread_num = 4 + loops = 100 + threads = [] + total_ids = [] + vector = gen_single_vector(dim) + def add(): + i = 0 + while i < loops: + status, ids = connect.add_vectors(table, vector) + total_ids.append(ids[0]) + i = i + 1 + for i in range(thread_num): + x = threading.Thread(target=add, args=()) + threads.append(x) + x.start() + time.sleep(0.2) + for th in threads: + th.join() + assert len(total_ids) == thread_num * loops + # make sure ids not the same + assert len(set(total_ids)) == thread_num * loops + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_with_multiprocessing(self, args): + ''' + target: test add vectors, with multi processes + method: 10 processed add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num + + def test_add_vector_multi_tables(self, connect): + ''' + target: test add vectors is correct or not with multiple tables of L2 + method: create 50 tables and add vectors into them in turn + expected: status ok + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_add_vector_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + time.sleep(2) + for j in range(10): + for i in range(50): + status, ids = connect.add_vectors(table_name=table_list[i], records=vectors) + assert status.OK() + +class TestAddIP: + """ + ****************************************************************** + The following cases are used to test `add_vectors / index / search / delete` mixed function + ****************************************************************** + """ + + def test_add_vector_create_table(self, connect, ip_table): + ''' + target: test add vector, then create table again + method: add vector and create table + expected: status not ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + param = {'table_name': ip_table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_add_vector_has_table(self, connect, ip_table): + ''' + target: test add vector, then check table existence + method: add vector and call HasTable + expected: table exists, status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + ret = connect.has_table(ip_table) + assert ret == True + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector(self, connect, ip_table): + ''' + target: test add vector after table deleted + method: delete table and add vector + expected: status not ok + ''' + status = connect.delete_table(ip_table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_delete_table_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after table_2 deleted + method: delete table_2 and add vector to table_1 + expected: status ok + ''' + param = {'table_name': 'test_delete_table_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.delete_table(ip_table) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_table(self, connect, ip_table): + ''' + target: test delete table after add vector + method: add vector and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.delete_table(ip_table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_delete_another_table(self, connect, ip_table): + ''' + target: test delete table_1 table after add vector to table_2 + method: add vector and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_table(self, connect, ip_table): + ''' + target: test delete table after add vector for a while + method: add vector, sleep, and delete table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.delete_table(ip_table) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_delete_another_table(self, connect, ip_table): + ''' + target: test delete table_1 table after add vector to table_2 for a while + method: add vector , sleep, and delete table + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_delete_another_table', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector(self, connect, ip_table): + ''' + target: test add vector after build index + method: build index and add vector + expected: status ok + ''' + status = connect.create_index(ip_table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_create_index_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_create_index_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_index(ip_table, index_params) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index(self, connect, ip_table): + ''' + target: test build index add after vector + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_create_index_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index(self, connect, ip_table): + ''' + target: test build index add after vector for a while + method: add vector and build index + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_create_index_another(self, connect, ip_table): + ''' + target: test add vector to table_2 after build index for table_1 for a while + method: build index and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_create_index_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status = connect.create_index(param['table_name'], index_params) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector(self, connect, ip_table): + ''' + target: test add vector after search table + method: search table and add vector + expected: status ok + ''' + vector = gen_single_vector(dim) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + status, ids = connect.add_vectors(ip_table, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_search_vector_add_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_search_vector_add_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + status, ids = connect.add_vectors(param['table_name'], vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector(self, connect, ip_table): + ''' + target: test search vector after add vector + method: add vector and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_search_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 + method: search table and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector(self, connect, ip_table): + ''' + target: test search vector after add vector after a while + method: add vector, sleep, and search table + expected: status ok + ''' + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status, result = connect.search_vectors(ip_table, 1, nprobe, vector) + assert status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vector_sleep_search_vector_another(self, connect, ip_table): + ''' + target: test add vector to table_1 after search table_2 a while + method: search table , sleep, and add vector + expected: status ok + ''' + param = {'table_name': 'test_add_vector_sleep_search_vector_another', + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + vector = gen_single_vector(dim) + status, ids = connect.add_vectors(ip_table, vector) + time.sleep(1) + status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector) + connect.delete_table(param['table_name']) + assert status.OK() + + """ + ****************************************************************** + The following cases are used to test `add_vectors` function + ****************************************************************** + """ + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids(self, connect, ip_table): + ''' + target: test add vectors in table, use customize ids + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors, ids) + time.sleep(2) + assert status.OK() + assert len(ids) == nq + # check search result + status, result = connect.search_vectors(ip_table, top_k, nprobe, vectors) + logging.getLogger().info(result) + assert len(result) == nq + for i in range(nq): + assert result[i][0].id == i + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_ids_no_ids(self, connect, ip_table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use customize ids first, and then use no ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors, ids) + assert status.OK() + status, ids = connect.add_vectors(ip_table, vectors) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_twice_not_ids_ids(self, connect, ip_table): + ''' + target: check the result of add_vectors, with params ids and no ids + method: test add vectors twice, use not ids first, and then use customize ids + expected: status not OK + ''' + nq = 5; top_k = 1; nprobe = 1 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(nq)] + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + status, ids = connect.add_vectors(ip_table, vectors, ids) + logging.getLogger().info(status) + logging.getLogger().info(ids) + assert not status.OK() + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors_ids_length_not_match(self, connect, ip_table): + ''' + target: test add vectors in table, use customize ids, len(ids) != len(vectors) + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + ids = [i for i in range(1, nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(ip_table, vectors, ids) + + @pytest.fixture( + scope="function", + params=gen_invalid_vector_ids() + ) + def get_vector_id(self, request): + yield request.param + + def test_add_vectors_ids_invalid(self, connect, ip_table, get_vector_id): + ''' + target: test add vectors in table, use customize ids, which are not int64 + method: create table and add vectors in it + expected: raise an exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + vector_id = get_vector_id + ids = [vector_id for i in range(nq)] + with pytest.raises(Exception) as e: + status, ids = connect.add_vectors(ip_table, vectors, ids) + + @pytest.mark.timeout(ADD_TIMEOUT) + def test_add_vectors(self, connect, ip_table): + ''' + target: test add vectors in table created before + method: create table and add vectors in it, check the ids returned and the table length after vectors added + expected: the length of ids and the table row count + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + assert len(ids) == nq + + @pytest.mark.level(2) + def test_add_vectors_without_connect(self, dis_connect, ip_table): + ''' + target: test add vectors without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + with pytest.raises(Exception) as e: + status, ids = dis_connect.add_vectors(ip_table, vectors) + + def test_add_vector_dim_not_matched(self, connect, ip_table): + ''' + target: test add vector, the vector dimension is not equal to the table dimension + method: the vector dimension is half of the table dimension, check the status + expected: status not ok + ''' + vector = gen_single_vector(int(dim)//2) + status, ids = connect.add_vectors(ip_table, vector) + assert not status.OK() + + def test_add_vectors_dim_not_matched(self, connect, ip_table): + ''' + target: test add vectors, the vector dimension is not equal to the table dimension + method: the vectors dimension is half of the table dimension, check the status + expected: status not ok + ''' + nq = 5 + vectors = gen_vectors(nq, int(dim)//2) + status, ids = connect.add_vectors(ip_table, vectors) + assert not status.OK() + + def test_add_vector_query_after_sleep(self, connect, ip_table): + ''' + target: test add vectors, and search it after sleep + method: set vector[0][1] as query vectors + expected: status ok and result length is 1 + ''' + nq = 5 + vectors = gen_vectors(nq, dim) + status, ids = connect.add_vectors(ip_table, vectors) + time.sleep(3) + status, result = connect.search_vectors(ip_table, 1, nprobe, [vectors[0]]) + assert status.OK() + assert len(result) == 1 + + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_multi_threading(self, connect, ip_table): + ''' + target: test add vectors, with multi threading + method: 10 thread add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + thread_num = 4 + loops = 100 + threads = [] + total_ids = [] + vector = gen_single_vector(dim) + def add(): + i = 0 + while i < loops: + status, ids = connect.add_vectors(ip_table, vector) + total_ids.append(ids[0]) + i = i + 1 + for i in range(thread_num): + x = threading.Thread(target=add, args=()) + threads.append(x) + x.start() + time.sleep(0.2) + for th in threads: + th.join() + assert len(total_ids) == thread_num * loops + # make sure ids not the same + assert len(set(total_ids)) == thread_num * loops + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(ADD_TIMEOUT) + def _test_add_vector_with_multiprocessing(self, args): + ''' + target: test add vectors, with multi processes + method: 10 processed add vectors concurrently + expected: status ok and result length is equal to the length off added vectors + ''' + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num + + def test_add_vector_multi_tables(self, connect): + ''' + target: test add vectors is correct or not with multiple tables of IP + method: create 50 tables and add vectors into them in turn + expected: status ok + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_add_vector_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + time.sleep(2) + for j in range(10): + for i in range(50): + status, ids = connect.add_vectors(table_name=table_list[i], records=vectors) + assert status.OK() + +class TestAddAdvance: + + @pytest.fixture( + scope="function", + params=[ + 1, + 10, + 100, + 1000, + pytest.param(5000 - 1, marks=pytest.mark.xfail), + pytest.param(5000, marks=pytest.mark.xfail), + pytest.param(5000 + 1, marks=pytest.mark.xfail), + ], + ) + def insert_count(self, request): + yield request.param + + def test_insert_much(self, connect, table, insert_count): + ''' + target: test add vectors with different length of vectors + method: set different vectors as add method params + expected: length of ids is equal to the length of vectors + ''' + nb = insert_count + insert_vec_list = gen_vectors(nb, dim) + status, ids = connect.add_vectors(table, insert_vec_list) + assert len(ids) == nb + assert status.OK() + + def test_insert_much_ip(self, connect, ip_table, insert_count): + ''' + target: test add vectors with different length of vectors + method: set different vectors as add method params + expected: length of ids is equal to the length of vectors + ''' + nb = insert_count + insert_vec_list = gen_vectors(nb, dim) + status, ids = connect.add_vectors(ip_table, insert_vec_list) + assert len(ids) == nb + assert status.OK() + +class TestAddTableNameInvalid(object): + """ + Test adding vectors with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + @pytest.mark.level(2) + def test_add_vectors_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + vectors = gen_vectors(1, dim) + status, result = connect.add_vectors(table_name, vectors) + assert not status.OK() + + +class TestAddTableVectorsInvalid(object): + single_vector = gen_single_vector(dim) + vectors = gen_vectors(2, dim) + + """ + Test adding vectors with invalid vectors + """ + @pytest.fixture( + scope="function", + params=gen_invalid_vectors() + ) + def gen_vector(self, request): + yield request.param + + @pytest.mark.level(2) + def test_add_vector_with_invalid_vectors(self, connect, table, gen_vector): + tmp_single_vector = copy.deepcopy(self.single_vector) + tmp_single_vector[0][1] = gen_vector + with pytest.raises(Exception) as e: + status, result = connect.add_vectors(table, tmp_single_vector) + + @pytest.mark.level(1) + def test_add_vectors_with_invalid_vectors(self, connect, table, gen_vector): + tmp_vectors = copy.deepcopy(self.vectors) + tmp_vectors[1][1] = gen_vector + with pytest.raises(Exception) as e: + status, result = connect.add_vectors(table, tmp_vectors) \ No newline at end of file diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py new file mode 100644 index 0000000000..5ec9539011 --- /dev/null +++ b/tests/milvus_python_test/test_connect.py @@ -0,0 +1,386 @@ +import pytest +from milvus import Milvus +import pdb +import threading +from multiprocessing import Process +from utils import * + +__version__ = '0.5.0' +CONNECT_TIMEOUT = 12 + + +class TestConnect: + + def local_ip(self, args): + ''' + check if ip is localhost or not + ''' + if not args["ip"] or args["ip"] == 'localhost' or args["ip"] == "127.0.0.1": + return True + else: + return False + + def test_disconnect(self, connect): + ''' + target: test disconnect + method: disconnect a connected client + expected: connect failed after disconnected + ''' + res = connect.disconnect() + assert res.OK() + with pytest.raises(Exception) as e: + res = connect.server_version() + + def test_disconnect_repeatedly(self, connect, args): + ''' + target: test disconnect repeatedly + method: disconnect a connected client, disconnect again + expected: raise an error after disconnected + ''' + if not connect.connected(): + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + res = milvus.disconnect() + with pytest.raises(Exception) as e: + res = milvus.disconnect() + else: + res = connect.disconnect() + with pytest.raises(Exception) as e: + res = connect.disconnect() + + def test_connect_correct_ip_port(self, args): + ''' + target: test connect with corrent ip and port value + method: set correct ip and port + expected: connected is True + ''' + milvus = Milvus() + milvus.connect(host=args["ip"], port=args["port"]) + assert milvus.connected() + + def test_connect_connected(self, args): + ''' + target: test connect and disconnect with corrent ip and port value, assert connected value + method: set correct ip and port + expected: connected is False + ''' + milvus = Milvus() + milvus.connect(host=args["ip"], port=args["port"]) + milvus.disconnect() + assert not milvus.connected() + + # TODO: Currently we test with remote IP, localhost testing need to add + def _test_connect_ip_localhost(self, args): + ''' + target: test connect with ip value: localhost + method: set host localhost + expected: connected is True + ''' + milvus = Milvus() + milvus.connect(host='localhost', port=args["port"]) + assert milvus.connected() + + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_ip_null(self, args): + ''' + target: test connect with wrong ip value + method: set host null + expected: not use default ip, connected is False + ''' + milvus = Milvus() + ip = "" + with pytest.raises(Exception) as e: + milvus.connect(host=ip, port=args["port"], timeout=1) + assert not milvus.connected() + + def test_connect_uri(self, args): + ''' + target: test connect with correct uri + method: uri format and value are both correct + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_uri_null(self, args): + ''' + target: test connect with null uri + method: uri set null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "" + + if self.local_ip(args): + milvus.connect(uri=uri_value, timeout=1) + assert milvus.connected() + else: + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_uri_wrong_port_null(self, args): + ''' + target: test uri connect with port value wouldn't connected + method: set uri port null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://%s:" % args["ip"] + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_wrong_uri_wrong_ip_null(self, args): + ''' + target: test uri connect with ip value wouldn't connected + method: set uri ip null + expected: connected is True + ''' + milvus = Milvus() + uri_value = "tcp://:%s" % args["port"] + + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() + + # TODO: enable + def _test_connect_with_multiprocess(self, args): + ''' + target: test uri connect with multiprocess + method: set correct uri, test with multiprocessing connecting + expected: all connection is connected + ''' + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + process_num = 4 + processes = [] + + def connect(milvus): + milvus.connect(uri=uri_value) + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value) + assert milvus.connected() + + for i in range(process_num): + milvus = Milvus() + p = Process(target=connect, args=(milvus, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + def test_connect_repeatedly(self, args): + ''' + target: test connect repeatedly + method: connect again + expected: status.code is 0, and status.message shows have connected already + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_disconnect_repeatedly_once(self, args): + ''' + target: test connect and disconnect repeatedly + method: disconnect, and then connect, assert connect status + expected: status.code is 0 + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + def test_connect_disconnect_repeatedly_times(self, args): + ''' + target: test connect and disconnect for 10 times repeatedly + method: disconnect, and then connect, assert connect status + expected: status.code is 0 + ''' + times = 10 + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + milvus.connect(uri=uri_value) + for i in range(times): + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + # TODO: enable + def _test_connect_disconnect_with_multiprocess(self, args): + ''' + target: test uri connect and disconnect repeatly with multiprocess + method: set correct uri, test with multiprocessing connecting and disconnecting + expected: all connection is connected after 10 times operation + ''' + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + process_num = 4 + processes = [] + + def connect(milvus): + milvus.connect(uri=uri_value) + milvus.disconnect() + milvus.connect(uri=uri_value) + assert milvus.connected() + + for i in range(process_num): + milvus = Milvus() + p = Process(target=connect, args=(milvus, )) + processes.append(p) + p.start() + for p in processes: + p.join() + + def test_connect_param_priority_no_port(self, args): + ''' + target: both host_ip_port / uri are both given, if port is null, use the uri params + method: port set "", check if wrong uri connection is ok + expected: connect raise an exception and connected is false + ''' + milvus = Milvus() + uri_value = "tcp://%s:19540" % args["ip"] + milvus.connect(host=args["ip"], port="", uri=uri_value) + assert milvus.connected() + + def test_connect_param_priority_uri(self, args): + ''' + target: both host_ip_port / uri are both given, if host is null, use the uri params + method: host set "", check if correct uri connection is ok + expected: connected is False + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + with pytest.raises(Exception) as e: + milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1) + assert not milvus.connected() + + def test_connect_param_priority_both_hostip_uri(self, args): + ''' + target: both host_ip_port / uri are both given, and not null, use the uri params + method: check if wrong uri connection is ok + expected: connect raise an exception and connected is false + ''' + milvus = Milvus() + uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) + with pytest.raises(Exception) as e: + milvus.connect(host=args["ip"], port=19540, uri=uri_value, timeout=1) + assert not milvus.connected() + + def _test_add_vector_and_disconnect_concurrently(self): + ''' + Target: test disconnect in the middle of add vectors + Method: + a. use coroutine or multi-processing, to simulate network crashing + b. data_set not too large incase disconnection happens when data is underd-preparing + c. data_set not too small incase disconnection happens when data has already been transferred + d. make sure disconnection happens when data is in-transport + Expected: Failure, get_table_row_count == 0 + + ''' + pass + + def _test_search_vector_and_disconnect_concurrently(self): + ''' + Target: Test disconnect in the middle of search vectors(with large nq and topk)multiple times, and search/add vectors still work + Method: + a. coroutine or multi-processing, to simulate network crashing + b. connect, search and disconnect, repeating many times + c. connect and search, add vectors + Expected: Successfully searched back, successfully added + + ''' + pass + + def _test_thread_safe_with_one_connection_shared_in_multi_threads(self): + ''' + Target: test 1 connection thread safe + Method: 1 connection shared in multi-threads, all adding vectors, or other things + Expected: Functional as one thread + + ''' + pass + + +class TestConnectIPInvalid(object): + """ + Test connect server with invalid ip + """ + @pytest.fixture( + scope="function", + params=gen_invalid_ips() + ) + def get_invalid_ip(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_ip(self, args, get_invalid_ip): + milvus = Milvus() + ip = get_invalid_ip + with pytest.raises(Exception) as e: + milvus.connect(host=ip, port=args["port"], timeout=1) + assert not milvus.connected() + + +class TestConnectPortInvalid(object): + """ + Test connect server with invalid ip + """ + + @pytest.fixture( + scope="function", + params=gen_invalid_ports() + ) + def get_invalid_port(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_port(self, args, get_invalid_port): + ''' + target: test ip:port connect with invalid port value + method: set port in gen_invalid_ports + expected: connected is False + ''' + milvus = Milvus() + port = get_invalid_port + with pytest.raises(Exception) as e: + milvus.connect(host=args["ip"], port=port, timeout=1) + assert not milvus.connected() + + +class TestConnectURIInvalid(object): + """ + Test connect server with invalid uri + """ + @pytest.fixture( + scope="function", + params=gen_invalid_uris() + ) + def get_invalid_uri(self, request): + yield request.param + + @pytest.mark.level(2) + @pytest.mark.timeout(CONNECT_TIMEOUT) + def test_connect_with_invalid_uri(self, get_invalid_uri): + ''' + target: test uri connect with invalid uri value + method: set port in gen_invalid_uris + expected: connected is False + ''' + milvus = Milvus() + uri_value = get_invalid_uri + with pytest.raises(Exception) as e: + milvus.connect(uri=uri_value, timeout=1) + assert not milvus.connected() diff --git a/tests/milvus_python_test/test_delete_vectors.py b/tests/milvus_python_test/test_delete_vectors.py new file mode 100644 index 0000000000..cb450838a2 --- /dev/null +++ b/tests/milvus_python_test/test_delete_vectors.py @@ -0,0 +1,419 @@ +import time +import random +import pdb +import logging +import threading +from builtins import Exception +from multiprocessing import Pool, Process +import pytest + +from milvus import Milvus, IndexType +from utils import * + + +dim = 128 +index_file_size = 10 +table_id = "test_delete" +DELETE_TIMEOUT = 60 +vectors = gen_vectors(100, dim) + +class TestDeleteVectorsBase: + """ + generate invalid query range params + """ + @pytest.fixture( + scope="function", + params=[ + (get_current_day(), get_current_day()), + (get_last_day(1), get_last_day(1)), + (get_next_day(1), get_next_day(1)) + ] + ) + def get_invalid_range(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_invalid_range(self, connect, table, get_invalid_range): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with invalid date params + expected: return code 0 + ''' + start_date = get_invalid_range[0] + end_date = get_invalid_range[1] + status, ids = connect.add_vectors(table, vectors) + status = connect.delete_vectors_by_range(table, start_date, end_date) + assert not status.OK() + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_last_day(1)), + (get_last_day(2), get_current_day()), + (get_next_day(1), get_next_day(2)) + ] + ) + def get_valid_range_no_result(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_no_result(self, connect, table, get_valid_range_no_result): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_valid_range_no_result[0] + end_date = get_valid_range_no_result[1] + status, ids = connect.add_vectors(table, vectors) + time.sleep(2) + status = connect.delete_vectors_by_range(table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(table) + assert result == 100 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_next_day(2)), + (get_current_day(), get_next_day(2)), + ] + ) + def get_valid_range(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range(self, connect, table, get_valid_range): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_valid_range[0] + end_date = get_valid_range[1] + status, ids = connect.add_vectors(table, vectors) + time.sleep(2) + status = connect.delete_vectors_by_range(table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(table) + assert result == 0 + + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_index_created(self, connect, table, get_index_params): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + logging.getLogger().info(status) + logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date)) + status = connect.delete_vectors_by_range(table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(table) + assert result == 0 + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_no_data(self, connect, table): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params, and no data in db + expected: return code 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + # status, ids = connect.add_vectors(table, vectors) + status = connect.delete_vectors_by_range(table, start_date, end_date) + assert status.OK() + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_table_not_existed(self, connect): + ''' + target: test delete vectors, table not existed in db + method: call `delete_vectors_by_range`, with table not existed + expected: return code not 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + table_name = gen_unique_str("not_existed_table") + status = connect.delete_vectors_by_range(table_name, start_date, end_date) + assert not status.OK() + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_table_None(self, connect, table): + ''' + target: test delete vectors, table set Nope + method: call `delete_vectors_by_range`, with table value is None + expected: return code not 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + table_name = None + with pytest.raises(Exception) as e: + status = connect.delete_vectors_by_range(table_name, start_date, end_date) + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range): + ''' + target: test delete vectors is correct or not with multiple tables of L2 + method: create 50 tables and add vectors into them , then delete vectors + in valid range + expected: return code 0 + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + start_date = get_valid_range[0] + end_date = get_valid_range[1] + for i in range(50): + status = connect.delete_vectors_by_range(table_list[i], start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(table_list[i]) + assert result == 0 + + +class TestDeleteVectorsIP: + """ + generate invalid query range params + """ + @pytest.fixture( + scope="function", + params=[ + (get_current_day(), get_current_day()), + (get_last_day(1), get_last_day(1)), + (get_next_day(1), get_next_day(1)) + ] + ) + def get_invalid_range(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_invalid_range(self, connect, ip_table, get_invalid_range): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with invalid date params + expected: return code 0 + ''' + start_date = get_invalid_range[0] + end_date = get_invalid_range[1] + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.delete_vectors_by_range(ip_table, start_date, end_date) + assert not status.OK() + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_last_day(1)), + (get_last_day(2), get_current_day()), + (get_next_day(1), get_next_day(2)) + ] + ) + def get_valid_range_no_result(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_no_result(self, connect, ip_table, get_valid_range_no_result): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_valid_range_no_result[0] + end_date = get_valid_range_no_result[1] + status, ids = connect.add_vectors(ip_table, vectors) + time.sleep(2) + status = connect.delete_vectors_by_range(ip_table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(ip_table) + assert result == 100 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_next_day(2)), + (get_current_day(), get_next_day(2)), + ] + ) + def get_valid_range(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range(self, connect, ip_table, get_valid_range): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_valid_range[0] + end_date = get_valid_range[1] + status, ids = connect.add_vectors(ip_table, vectors) + time.sleep(2) + status = connect.delete_vectors_by_range(ip_table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(ip_table) + assert result == 0 + + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_index_created(self, connect, ip_table, get_index_params): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params + expected: return code 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + logging.getLogger().info(status) + logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date)) + status = connect.delete_vectors_by_range(ip_table, start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(ip_table) + assert result == 0 + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_no_data(self, connect, ip_table): + ''' + target: test delete vectors, no index created + method: call `delete_vectors_by_range`, with valid date params, and no data in db + expected: return code 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + # status, ids = connect.add_vectors(table, vectors) + status = connect.delete_vectors_by_range(ip_table, start_date, end_date) + assert status.OK() + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_table_None(self, connect, ip_table): + ''' + target: test delete vectors, table set Nope + method: call `delete_vectors_by_range`, with table value is None + expected: return code not 0 + ''' + start_date = get_current_day() + end_date = get_next_day(2) + table_name = None + with pytest.raises(Exception) as e: + status = connect.delete_vectors_by_range(table_name, start_date, end_date) + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range): + ''' + target: test delete vectors is correct or not with multiple tables of IP + method: create 50 tables and add vectors into them , then delete vectors + in valid range + expected: return code 0 + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + start_date = get_valid_range[0] + end_date = get_valid_range[1] + for i in range(50): + status = connect.delete_vectors_by_range(table_list[i], start_date, end_date) + assert status.OK() + status, result = connect.get_table_row_count(table_list[i]) + assert result == 0 + +class TestDeleteVectorsParamsInvalid: + + """ + Test search table with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + @pytest.mark.level(2) + def test_delete_vectors_table_invalid_name(self, connect, get_table_name): + ''' + ''' + start_date = get_current_day() + end_date = get_next_day(2) + table_name = get_table_name + logging.getLogger().info(table_name) + top_k = 1 + nprobe = 1 + status = connect.delete_vectors_by_range(table_name, start_date, end_date) + assert not status.OK() + + """ + Test search table with invalid query ranges + """ + @pytest.fixture( + scope="function", + params=gen_invalid_query_ranges() + ) + def get_query_ranges(self, request): + yield request.param + + @pytest.mark.timeout(DELETE_TIMEOUT) + def test_delete_vectors_range_invalid(self, connect, table, get_query_ranges): + ''' + target: test search fuction, with the wrong query_range + method: search with query_range + expected: raise an error, and the connection is normal + ''' + start_date = get_query_ranges[0][0] + end_date = get_query_ranges[0][1] + status, ids = connect.add_vectors(table, vectors) + logging.getLogger().info(get_query_ranges) + with pytest.raises(Exception) as e: + status = connect.delete_vectors_by_range(table, start_date, end_date) \ No newline at end of file diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py new file mode 100644 index 0000000000..53a5fe16d2 --- /dev/null +++ b/tests/milvus_python_test/test_index.py @@ -0,0 +1,966 @@ +""" + For testing index operations, including `create_index`, `describe_index` and `drop_index` interfaces +""" +import logging +import pytest +import time +import pdb +import threading +from multiprocessing import Pool, Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +nb = 100000 +dim = 128 +index_file_size = 10 +vectors = gen_vectors(nb, dim) +vectors /= numpy.linalg.norm(vectors) +vectors = vectors.tolist() +BUILD_TIMEOUT = 60 +nprobe = 1 + + +class TestIndexBase: + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index_params() + ) + def get_simple_index_params(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, table, get_index_params): + ''' + target: test create index interface + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + + @pytest.mark.level(2) + def test_create_index_without_connect(self, dis_connect, table): + ''' + target: test create index without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.create_index(table, random.choice(gen_index_params())) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, table, get_index_params): + ''' + target: test create index interface, search with more query vectors + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + logging.getLogger().info(connect.describe_index(table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + logging.getLogger().info(result) + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def _test_create_index_multiprocessing(self, connect, table, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + status, ids = connect.add_vectors(table, vectors) + + def build(connect): + status = connect.create_index(table) + assert status.OK() + + process_num = 8 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + p = Process(target=build, args=(m,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + def _test_create_index_multiprocessing_multitable(self, connect, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + process_num = 8 + loop_num = 8 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_create_index_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_type': IndexType.FLAT, + 'store_raw_vector': False} + connect.create_table(param) + j = j + 1 + + def create_index(): + i = 0 + while i < loop_num: + # assert connect.has_table(table[ids*process_num+i]) + status, ids = connect.add_vectors(table[ids*process_num+i], vectors) + + status = connect.create_index(table[ids*process_num+i]) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + i = i + 1 + + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + ids = i + p = Process(target=create_index, args=(m,ids)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_create_index_table_not_existed(self, connect): + ''' + target: test create index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, create index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status = connect.create_index(table_name, random.choice(gen_index_params())) + assert not status.OK() + + def test_create_index_table_None(self, connect): + ''' + target: test create index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, create index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.create_index(table_name, random.choice(gen_index_params())) + + def test_create_index_no_vectors(self, connect, table): + ''' + target: test create index interface when there is no vectors in table + method: create table and add no vectors in it, and then create index + expected: return code equals to 0 + ''' + status = connect.create_index(table, random.choice(gen_index_params())) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_then_add_vectors(self, connect, table): + ''' + target: test create index interface when there is no vectors in table, and does not affect the subsequent process + method: create table and add no vectors in it, and then create index, add vectors in it + expected: return code equals to 0 + ''' + status = connect.create_index(table, random.choice(gen_index_params())) + status, ids = connect.add_vectors(table, vectors) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly(self, connect, table): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + status, ids = connect.add_vectors(table, vectors) + index_params = random.choice(gen_index_params()) + # index_params = get_index_params + status = connect.create_index(table, index_params) + status = connect.create_index(table, index_params) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly(self, connect, table): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + status, ids = connect.add_vectors(table, vectors) + index_params = random.sample(gen_index_params(), 2) + logging.getLogger().info(index_params) + status = connect.create_index(table, index_params[0]) + status = connect.create_index(table, index_params[1]) + assert status.OK() + status, result = connect.describe_index(table) + assert result._nlist == index_params[1]["nlist"] + assert result._table_name == table + assert result._index_type == index_params[1]["index_type"] + + """ + ****************************************************************** + The following cases are used to test `describe_index` function + ****************************************************************** + """ + + def test_describe_index(self, connect, table, get_index_params): + ''' + target: test describe index interface + method: create table and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table + assert result._index_type == index_params["index_type"] + + def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params): + ''' + target: test create, describe and drop index interface with multiple tables of L2 + method: create tables and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(10): + table_name = gen_unique_str('test_create_index_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + index_params = get_simple_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + status = connect.create_index(table_name, index_params) + assert status.OK() + + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table_list[i] + assert result._index_type == index_params["index_type"] + + for i in range(10): + status = connect.drop_index(table_list[i]) + assert status.OK() + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_describe_index_without_connect(self, dis_connect, table): + ''' + target: test describe index without connection + method: describe index, and check if describe successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_index(table) + + def test_describe_index_table_not_existed(self, connect): + ''' + target: test describe index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status, result = connect.describe_index(table_name) + assert not status.OK() + + def test_describe_index_table_None(self, connect): + ''' + target: test describe index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, describe index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.describe_index(table_name) + + def test_describe_index_not_create(self, connect, table): + ''' + target: test describe index interface when index not created + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + status, ids = connect.add_vectors(table, vectors) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert status.OK() + # assert result._nlist == index_params["nlist"] + # assert result._table_name == table + # assert result._index_type == index_params["index_type"] + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, table, get_index_params): + ''' + target: test drop index interface + method: create table and add vectors in it, create index, call drop index + expected: return code 0, and default index param + ''' + index_params = get_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + def test_drop_index_repeatly(self, connect, table, get_simple_index_params): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_drop_index_without_connect(self, dis_connect, table): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.drop_index(table) + + def test_drop_index_table_not_existed(self, connect): + ''' + target: test drop index interface when table name not existed + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index, and then drop it + expected: return code not equals to 0, drop index failed + ''' + table_name = gen_unique_str(self.__class__.__name__) + status = connect.drop_index(table_name) + assert not status.OK() + + def test_drop_index_table_None(self, connect): + ''' + target: test drop index interface when table name is None + method: create table and add vectors in it, create index with an table_name: None + expected: return code not equals to 0, drop index failed + ''' + table_name = None + with pytest.raises(Exception) as e: + status = connect.drop_index(table_name) + + def test_drop_index_table_not_create(self, connect, table): + ''' + target: test drop index interface when index not created + method: create table and add vectors in it, create index + expected: return code not equals to 0, drop index failed + ''' + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + status, result = connect.describe_index(table) + logging.getLogger().info(result) + # no create index + status = connect.drop_index(table) + logging.getLogger().info(status) + assert status.OK() + + def test_create_drop_index_repeatly(self, connect, table, get_simple_index_params): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(table, vectors) + for i in range(2): + status = connect.create_index(table, index_params) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + def test_create_drop_index_repeatly_different_index_params(self, connect, table): + ''' + target: test create / drop index repeatly, use the different index params + method: create index, drop index, four times, each tme use different index_params to create index + expected: return code 0 + ''' + index_params = random.sample(gen_index_params(), 2) + status, ids = connect.add_vectors(table, vectors) + for i in range(2): + status = connect.create_index(table, index_params[i]) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + status = connect.drop_index(table) + assert status.OK() + status, result = connect.describe_index(table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table + assert result._index_type == IndexType.FLAT + + +class TestIndexIP: + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.fixture( + scope="function", + params=gen_simple_index_params() + ) + def get_simple_index_params(self, request): + yield request.param + + """ + ****************************************************************** + The following cases are used to test `create_index` function + ****************************************************************** + """ + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index(self, connect, ip_table, get_index_params): + ''' + target: test create index interface + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + + @pytest.mark.level(2) + def test_create_index_without_connect(self, dis_connect, ip_table): + ''' + target: test create index without connection + method: create table and add vectors in it, check if added successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.create_index(ip_table, random.choice(gen_index_params())) + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_search_with_query_vectors(self, connect, ip_table, get_index_params): + ''' + target: test create index interface, search with more query vectors + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [vectors[0], vectors[1], vectors[2]] + top_k = 5 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + # logging.getLogger().info(result) + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + @pytest.mark.level(2) + def _test_create_index_multiprocessing(self, connect, ip_table, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + status, ids = connect.add_vectors(ip_table, vectors) + + def build(connect): + status = connect.create_index(ip_table) + assert status.OK() + + process_num = 8 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + p = Process(target=build, args=(m,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + + # TODO: enable + @pytest.mark.timeout(BUILD_TIMEOUT) + def _test_create_index_multiprocessing_multitable(self, connect, args): + ''' + target: test create index interface with multiprocess + method: create table and add vectors in it, create index + expected: return code equals to 0, and search success + ''' + process_num = 8 + loop_num = 8 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_create_index_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim} + connect.create_table(param) + j = j + 1 + + def create_index(): + i = 0 + while i < loop_num: + # assert connect.has_table(table[ids*process_num+i]) + status, ids = connect.add_vectors(table[ids*process_num+i], vectors) + + status = connect.create_index(table[ids*process_num+i]) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + assert result[0][0].distance == 0.0 + i = i + 1 + + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + for i in range(process_num): + m = Milvus() + m.connect(uri=uri) + ids = i + p = Process(target=create_index, args=(m,ids)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_create_index_no_vectors(self, connect, ip_table): + ''' + target: test create index interface when there is no vectors in table + method: create table and add no vectors in it, and then create index + expected: return code equals to 0 + ''' + status = connect.create_index(ip_table, random.choice(gen_index_params())) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_index_no_vectors_then_add_vectors(self, connect, ip_table): + ''' + target: test create index interface when there is no vectors in table, and does not affect the subsequent process + method: create table and add no vectors in it, and then create index, add vectors in it + expected: return code equals to 0 + ''' + status = connect.create_index(ip_table, random.choice(gen_index_params())) + status, ids = connect.add_vectors(ip_table, vectors) + assert status.OK() + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_same_index_repeatedly(self, connect, ip_table): + ''' + target: check if index can be created repeatedly, with the same create_index params + method: create index after index have been built + expected: return code success, and search ok + ''' + status, ids = connect.add_vectors(ip_table, vectors) + index_params = random.choice(gen_index_params()) + # index_params = get_index_params + status = connect.create_index(ip_table, index_params) + status = connect.create_index(ip_table, index_params) + assert status.OK() + query_vec = [vectors[0]] + top_k = 1 + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec) + assert len(result) == 1 + assert len(result[0]) == top_k + + @pytest.mark.timeout(BUILD_TIMEOUT) + def test_create_different_index_repeatedly(self, connect, ip_table): + ''' + target: check if index can be created repeatedly, with the different create_index params + method: create another index with different index_params after index have been built + expected: return code 0, and describe index result equals with the second index params + ''' + status, ids = connect.add_vectors(ip_table, vectors) + index_params = random.sample(gen_index_params(), 2) + logging.getLogger().info(index_params) + status = connect.create_index(ip_table, index_params[0]) + status = connect.create_index(ip_table, index_params[1]) + assert status.OK() + status, result = connect.describe_index(ip_table) + assert result._nlist == index_params[1]["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params[1]["index_type"] + + """ + ****************************************************************** + The following cases are used to test `describe_index` function + ****************************************************************** + """ + + def test_describe_index(self, connect, ip_table, get_index_params): + ''' + target: test describe index interface + method: create table and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + index_params = get_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params["index_type"] + + def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params): + ''' + target: test create, describe and drop index interface with multiple tables of IP + method: create tables and add vectors in it, create index, call describe index + expected: return code 0, and index instructure + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(10): + table_name = gen_unique_str('test_create_index_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + index_params = get_simple_index_params + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + status = connect.create_index(table_name, index_params) + assert status.OK() + + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == index_params["nlist"] + assert result._table_name == table_list[i] + assert result._index_type == index_params["index_type"] + + for i in range(10): + status = connect.drop_index(table_list[i]) + assert status.OK() + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_describe_index_without_connect(self, dis_connect, ip_table): + ''' + target: test describe index without connection + method: describe index, and check if describe successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_index(ip_table) + + def test_describe_index_not_create(self, connect, ip_table): + ''' + target: test describe index interface when index not created + method: create table and add vectors in it, create index with an random table_name + , make sure the table name not in index + expected: return code not equals to 0, describe index failed + ''' + status, ids = connect.add_vectors(ip_table, vectors) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert status.OK() + # assert result._nlist == index_params["nlist"] + # assert result._table_name == table + # assert result._index_type == index_params["index_type"] + + """ + ****************************************************************** + The following cases are used to test `drop_index` function + ****************************************************************** + """ + + def test_drop_index(self, connect, ip_table, get_index_params): + ''' + target: test drop index interface + method: create table and add vectors in it, create index, call drop index + expected: return code 0, and default index param + ''' + index_params = get_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): + ''' + target: test drop index repeatly + method: create index, call drop index, and drop again + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + @pytest.mark.level(2) + def test_drop_index_without_connect(self, dis_connect, ip_table): + ''' + target: test drop index without connection + method: drop index, and check if drop successfully + expected: raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.drop_index(ip_table, random.choice(gen_index_params())) + + def test_drop_index_table_not_create(self, connect, ip_table): + ''' + target: test drop index interface when index not created + method: create table and add vectors in it, create index + expected: return code not equals to 0, drop index failed + ''' + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(ip_table, vectors) + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + # no create index + status = connect.drop_index(ip_table) + logging.getLogger().info(status) + assert status.OK() + + def test_create_drop_index_repeatly(self, connect, ip_table, get_simple_index_params): + ''' + target: test create / drop index repeatly, use the same index params + method: create index, drop index, four times + expected: return code 0 + ''' + index_params = get_simple_index_params + status, ids = connect.add_vectors(ip_table, vectors) + for i in range(2): + status = connect.create_index(ip_table, index_params) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table): + ''' + target: test create / drop index repeatly, use the different index params + method: create index, drop index, four times, each tme use different index_params to create index + expected: return code 0 + ''' + index_params = random.sample(gen_index_params(), 2) + status, ids = connect.add_vectors(ip_table, vectors) + for i in range(2): + status = connect.create_index(ip_table, index_params[i]) + assert status.OK() + status, result = connect.describe_index(ip_table) + assert result._nlist == index_params[i]["nlist"] + assert result._table_name == ip_table + assert result._index_type == index_params[i]["index_type"] + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + status = connect.drop_index(ip_table) + assert status.OK() + status, result = connect.describe_index(ip_table) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == ip_table + assert result._index_type == IndexType.FLAT + + +class TestIndexTableInvalid(object): + """ + Test create / describe / drop index interfaces with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + # @pytest.mark.level(1) + def test_create_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status = connect.create_index(table_name, random.choice(gen_index_params())) + assert not status.OK() + + # @pytest.mark.level(1) + def test_describe_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status, result = connect.describe_index(table_name) + assert not status.OK() + + # @pytest.mark.level(1) + def test_drop_index_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + status = connect.drop_index(table_name) + assert not status.OK() + + +class TestCreateIndexParamsInvalid(object): + """ + Test Building index with invalid table names, table names not in db + """ + @pytest.fixture( + scope="function", + params=gen_invalid_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_index_with_invalid_index_params(self, connect, table, get_index_params): + index_params = get_index_params + index_type = index_params["index_type"] + nlist = index_params["nlist"] + logging.getLogger().info(index_params) + status, ids = connect.add_vectors(table, vectors) + # if not isinstance(index_type, int) or not isinstance(nlist, int): + with pytest.raises(Exception) as e: + status = connect.create_index(table, index_params) + # else: + # status = connect.create_index(table, index_params) + # assert not status.OK() diff --git a/tests/milvus_python_test/test_mix.py b/tests/milvus_python_test/test_mix.py new file mode 100644 index 0000000000..4578e330b3 --- /dev/null +++ b/tests/milvus_python_test/test_mix.py @@ -0,0 +1,180 @@ +import pdb +import copy +import pytest +import threading +import datetime +import logging +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +dim = 128 +index_file_size = 10 +table_id = "test_mix" +add_interval_time = 2 +vectors = gen_vectors(100000, dim) +vectors /= numpy.linalg.norm(vectors) +vectors = vectors.tolist() +top_k = 1 +nprobe = 1 +epsilon = 0.0001 +index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384} + + +class TestMixBase: + + # TODO: enable + def _test_search_during_createIndex(self, args): + loops = 100000 + table = "test_search_during_createIndex" + query_vecs = [vectors[0], vectors[1]] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + id_0 = 0; id_1 = 0 + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + milvus_instance.create_table({'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2}) + for i in range(10): + status, ids = milvus_instance.add_vectors(table, vectors) + # logging.getLogger().info(ids) + if i == 0: + id_0 = ids[0]; id_1 = ids[1] + def create_index(milvus_instance): + logging.getLogger().info("In create index") + status = milvus_instance.create_index(table, index_params) + logging.getLogger().info(status) + status, result = milvus_instance.describe_index(table) + logging.getLogger().info(result) + def add_vectors(milvus_instance): + logging.getLogger().info("In add vectors") + status, ids = milvus_instance.add_vectors(table, vectors) + logging.getLogger().info(status) + def search(milvus_instance): + for i in range(loops): + status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs) + logging.getLogger().info(status) + assert result[0][0].id == id_0 + assert result[1][0].id == id_1 + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + p_search = Process(target=search, args=(milvus_instance, )) + p_search.start() + milvus_instance = Milvus() + milvus_instance.connect(uri=uri) + p_create = Process(target=add_vectors, args=(milvus_instance, )) + p_create.start() + p_create.join() + + def test_mix_multi_tables(self, connect): + ''' + target: test functions with multiple tables of different metric_types and index_types + method: create 60 tables which 30 are L2 and the other are IP, add vectors into them + and test describe index and search + expected: status ok + ''' + nq = 10000 + vectors = gen_vectors(nq, dim) + table_list = [] + idx = [] + + #create table and add vectors + for i in range(30): + table_name = gen_unique_str('test_mix_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + assert status.OK() + for i in range(30): + table_name = gen_unique_str('test_mix_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, ids = connect.add_vectors(table_name=table_name, records=vectors) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + assert status.OK() + time.sleep(2) + + #create index + for i in range(10): + index_params = {'index_type': IndexType.FLAT, 'nlist': 16384} + status = connect.create_index(table_list[i], index_params) + assert status.OK() + status = connect.create_index(table_list[30 + i], index_params) + assert status.OK() + index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384} + status = connect.create_index(table_list[10 + i], index_params) + assert status.OK() + status = connect.create_index(table_list[40 + i], index_params) + assert status.OK() + index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384} + status = connect.create_index(table_list[20 + i], index_params) + assert status.OK() + status = connect.create_index(table_list[50 + i], index_params) + assert status.OK() + + #describe index + for i in range(10): + status, result = connect.describe_index(table_list[i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[i] + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(table_list[10 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[10 + i] + assert result._index_type == IndexType.IVFLAT + status, result = connect.describe_index(table_list[20 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[20 + i] + assert result._index_type == IndexType.IVF_SQ8 + status, result = connect.describe_index(table_list[30 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[30 + i] + assert result._index_type == IndexType.FLAT + status, result = connect.describe_index(table_list[40 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[40 + i] + assert result._index_type == IndexType.IVFLAT + status, result = connect.describe_index(table_list[50 + i]) + logging.getLogger().info(result) + assert result._nlist == 16384 + assert result._table_name == table_list[50 + i] + assert result._index_type == IndexType.IVF_SQ8 + + #search + query_vecs = [vectors[0], vectors[10], vectors[20]] + for i in range(60): + table = table_list[i] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) + +def check_result(result, id): + if len(result) >= 5: + return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id] + else: + return id in (i.id for i in result) \ No newline at end of file diff --git a/tests/milvus_python_test/test_ping.py b/tests/milvus_python_test/test_ping.py new file mode 100644 index 0000000000..a55559bc63 --- /dev/null +++ b/tests/milvus_python_test/test_ping.py @@ -0,0 +1,77 @@ +import logging +import pytest + +__version__ = '0.5.0' + + +class TestPing: + def test_server_version(self, connect): + ''' + target: test get the server version + method: call the server_version method after connected + expected: version should be the pymilvus version + ''' + status, res = connect.server_version() + assert res == __version__ + + def test_server_status(self, connect): + ''' + target: test get the server status + method: call the server_status method after connected + expected: status returned should be ok + ''' + status, msg = connect.server_status() + assert status.OK() + + def _test_server_cmd_with_params_version(self, connect): + ''' + target: test cmd: version + method: cmd = "version" ... + expected: when cmd = 'version', return version of server; + ''' + cmd = "version" + status, msg = connect.cmd(cmd) + logging.getLogger().info(status) + logging.getLogger().info(msg) + assert status.OK() + assert msg == __version__ + + def _test_server_cmd_with_params_others(self, connect): + ''' + target: test cmd: lalala + method: cmd = "lalala" ... + expected: when cmd = 'version', return version of server; + ''' + cmd = "rm -rf test" + status, msg = connect.cmd(cmd) + logging.getLogger().info(status) + logging.getLogger().info(msg) + assert status.OK() + # assert msg == __version__ + + def test_connected(self, connect): + assert connect.connected() + + +class TestPingDisconnect: + def test_server_version(self, dis_connect): + ''' + target: test get the server version, after disconnect + method: call the server_version method after connected + expected: version should not be the pymilvus version + ''' + res = None + with pytest.raises(Exception) as e: + status, res = connect.server_version() + assert res is None + + def test_server_status(self, dis_connect): + ''' + target: test get the server status, after disconnect + method: call the server_status method after connected + expected: status returned should be not ok + ''' + status = None + with pytest.raises(Exception) as e: + status, msg = connect.server_status() + assert status is None diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py new file mode 100644 index 0000000000..4f02ae0d37 --- /dev/null +++ b/tests/milvus_python_test/test_search_vectors.py @@ -0,0 +1,650 @@ +import pdb +import copy +import pytest +import threading +import datetime +import logging +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus, IndexType, MetricType +from utils import * + +dim = 128 +table_id = "test_search" +add_interval_time = 2 +vectors = gen_vectors(100, dim) +# vectors /= numpy.linalg.norm(vectors) +# vectors = vectors.tolist() +nrpobe = 1 +epsilon = 0.001 + + +class TestSearchBase: + def init_data(self, connect, table, nb=100): + ''' + Generate vectors and add it in table, before search vectors + ''' + global vectors + if nb == 100: + add_vectors = vectors + else: + add_vectors = gen_vectors(nb, dim) + # add_vectors /= numpy.linalg.norm(add_vectors) + # add_vectors = add_vectors.tolist() + status, ids = connect.add_vectors(table, add_vectors) + sleep(add_interval_time) + return add_vectors, ids + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + """ + generate top-k params + """ + @pytest.fixture( + scope="function", + params=[1, 99, 101, 1024, 2048, 2049] + ) + def get_top_k(self, request): + yield request.param + + + def test_search_top_k_flat_index(self, connect, table, get_top_k): + ''' + target: test basic search fuction, all the search params is corrent, change top-k value + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + vectors, ids = self.init_data(connect, table) + query_vec = [vectors[0]] + top_k = get_top_k + nprobe = 1 + status, result = connect.search_vectors(table, top_k, nrpobe, query_vec) + if top_k <= 2048: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert result[0][0].distance <= epsilon + assert check_result(result[0], ids[0]) + else: + assert not status.OK() + + def test_search_l2_index_params(self, connect, table, get_index_params): + ''' + target: test basic search fuction, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + + index_params = get_index_params + logging.getLogger().info(index_params) + vectors, ids = self.init_data(connect, table) + status = connect.create_index(table, index_params) + query_vec = [vectors[0]] + top_k = 10 + nprobe = 1 + status, result = connect.search_vectors(table, top_k, nrpobe, query_vec) + logging.getLogger().info(result) + if top_k <= 1024: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert result[0][0].distance <= epsilon + else: + assert not status.OK() + + def test_search_ip_index_params(self, connect, ip_table, get_index_params): + ''' + target: test basic search fuction, all the search params is corrent, test all index params, and build + method: search with the given vectors, check the result + expected: search status ok, and the length of the result is top_k + ''' + + index_params = get_index_params + logging.getLogger().info(index_params) + vectors, ids = self.init_data(connect, ip_table) + status = connect.create_index(ip_table, index_params) + query_vec = [vectors[0]] + top_k = 10 + nprobe = 1 + status, result = connect.search_vectors(ip_table, top_k, nrpobe, query_vec) + logging.getLogger().info(result) + + if top_k <= 1024: + assert status.OK() + assert len(result[0]) == min(len(vectors), top_k) + assert check_result(result[0], ids[0]) + assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance) + else: + assert not status.OK() + + @pytest.mark.level(2) + def test_search_vectors_without_connect(self, dis_connect, table): + ''' + target: test search vectors without connection + method: use dis connected instance, call search method and check if search successfully + expected: raise exception + ''' + query_vectors = [vectors[0]] + top_k = 1 + nprobe = 1 + with pytest.raises(Exception) as e: + status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors) + + def test_search_table_name_not_existed(self, connect, table): + ''' + target: search table not existed + method: search with the random table_name, which is not in db + expected: status not ok + ''' + table_name = gen_unique_str("not_existed_table") + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + assert not status.OK() + + def test_search_table_name_None(self, connect, table): + ''' + target: search table that table name is None + method: search with the table_name: None + expected: status not ok + ''' + table_name = None + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + + def test_search_top_k_query_records(self, connect, table): + ''' + target: test search fuction, with search params: query_records + method: search with the given query_records, which are subarrays of the inserted vectors + expected: status ok and the returned vectors should be query_records + ''' + top_k = 10 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0],vectors[55],vectors[99]] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert len(result[i]) == top_k + assert result[i][0].distance <= epsilon + + """ + generate invalid query range params + """ + @pytest.fixture( + scope="function", + params=[ + (get_current_day(), get_current_day()), + (get_last_day(1), get_last_day(1)), + (get_next_day(1), get_next_day(1)) + ] + ) + def get_invalid_range(self, request): + yield request.param + + def test_search_invalid_query_ranges(self, connect, table, get_invalid_range): + ''' + target: search table with query ranges + method: search with the same query ranges + expected: status not ok + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_invalid_range] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert not status.OK() + assert len(result) == 0 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_last_day(1)), + (get_last_day(2), get_current_day()), + (get_next_day(1), get_next_day(2)) + ] + ) + def get_valid_range_no_result(self, request): + yield request.param + + def test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result): + ''' + target: search table with normal query ranges, but no data in db + method: search with query ranges (low, low) + expected: length of result is 0 + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_valid_range_no_result] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert status.OK() + assert len(result) == 0 + + """ + generate valid query range params, no search result + """ + @pytest.fixture( + scope="function", + params=[ + (get_last_day(2), get_next_day(2)), + (get_current_day(), get_next_day(2)), + ] + ) + def get_valid_range(self, request): + yield request.param + + def test_search_valid_query_ranges(self, connect, table, get_valid_range): + ''' + target: search table with normal query ranges, but no data in db + method: search with query ranges (low, normal) + expected: length of result is 0 + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, table) + query_vecs = [vectors[0]] + query_ranges = [get_valid_range] + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges) + assert status.OK() + assert len(result) == 1 + assert result[0][0].distance <= epsilon + + def test_search_distance_l2_flat_index(self, connect, table): + ''' + target: search table, and check the result: distance + method: compare the return distance value with value computed with Euclidean + expected: the return distance equals to the computed value + ''' + nb = 2 + top_k = 1 + nprobe = 1 + vectors, ids = self.init_data(connect, table, nb=nb) + query_vecs = [[0.50 for i in range(dim)]] + distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0])) + distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1])) + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + def test_search_distance_ip_flat_index(self, connect, ip_table): + ''' + target: search ip_table, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + nb = 2 + top_k = 1 + nprobe = 1 + vectors, ids = self.init_data(connect, ip_table, nb=nb) + index_params = { + "index_type": IndexType.FLAT, + "nlist": 16384 + } + connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [[0.50 for i in range(dim)]] + distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) + distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params): + ''' + target: search table, and check the result: distance + method: compare the return distance value with value computed with Inner product + expected: the return distance equals to the computed value + ''' + top_k = 2 + nprobe = 1 + vectors, ids = self.init_data(connect, ip_table, nb=2) + index_params = get_index_params + connect.create_index(ip_table, index_params) + logging.getLogger().info(connect.describe_index(ip_table)) + query_vecs = [[0.50 for i in range(dim)]] + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0])) + distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1])) + assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance) + + # TODO: enable + # @pytest.mark.repeat(5) + @pytest.mark.timeout(30) + def _test_search_concurrent(self, connect, table): + vectors, ids = self.init_data(connect, table) + thread_num = 10 + nb = 100 + top_k = 10 + threads = [] + query_vecs = vectors[nb//2:nb] + def search(): + status, result = connect.search_vectors(table, top_k, query_vecs) + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert result[i][0].id in ids + assert result[i][0].distance == 0.0 + for i in range(thread_num): + x = threading.Thread(target=search, args=()) + threads.append(x) + x.start() + for th in threads: + th.join() + + # TODO: enable + @pytest.mark.timeout(30) + def _test_search_concurrent_multiprocessing(self, args): + ''' + target: test concurrent search with multiprocessess + method: search with 10 processes, each process uses dependent connection + expected: status ok and the returned vectors should be query_records + ''' + nb = 100 + top_k = 10 + process_num = 4 + processes = [] + table = gen_unique_str("test_search_concurrent_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_type': IndexType.FLAT, + 'store_raw_vector': False} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vectors, ids = self.init_data(milvus, table, nb=nb) + query_vecs = vectors[nb//2:nb] + def search(milvus): + status, result = milvus.search_vectors(table, top_k, query_vecs) + assert len(result) == len(query_vecs) + for i in range(len(query_vecs)): + assert result[i][0].id in ids + assert result[i][0].distance == 0.0 + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=search, args=(milvus, )) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + + def test_search_multi_table_L2(search, args): + ''' + target: test search multi tables of L2 + method: add vectors into 10 tables, and search + expected: search status ok, the length of result + ''' + num = 10 + top_k = 10 + nprobe = 1 + tables = [] + idx = [] + for i in range(num): + table = gen_unique_str("test_add_multitable_%d" % i) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': MetricType.L2} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + status, ids = milvus.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == len(vectors) + tables.append(table) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + time.sleep(6) + query_vecs = [vectors[0], vectors[10], vectors[20]] + # start query from random table + for i in range(num): + table = tables[i] + status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) + + def test_search_multi_table_IP(search, args): + ''' + target: test search multi tables of IP + method: add vectors into 10 tables, and search + expected: search status ok, the length of result + ''' + num = 10 + top_k = 10 + nprobe = 1 + tables = [] + idx = [] + for i in range(num): + table = gen_unique_str("test_add_multitable_%d" % i) + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': MetricType.L2} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + status, ids = milvus.add_vectors(table, vectors) + assert status.OK() + assert len(ids) == len(vectors) + tables.append(table) + idx.append(ids[0]) + idx.append(ids[10]) + idx.append(ids[20]) + time.sleep(6) + query_vecs = [vectors[0], vectors[10], vectors[20]] + # start query from random table + for i in range(num): + table = tables[i] + status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs) + assert status.OK() + assert len(result) == len(query_vecs) + for j in range(len(query_vecs)): + assert len(result[j]) == top_k + for j in range(len(query_vecs)): + assert check_result(result[j], idx[3 * i + j]) +""" +****************************************************************** +# The following cases are used to test `search_vectors` function +# with invalid table_name top-k / nprobe / query_range +****************************************************************** +""" + +class TestSearchParamsInvalid(object): + index_params = random.choice(gen_index_params()) + logging.getLogger().info(index_params) + + def init_data(self, connect, table, nb=100): + ''' + Generate vectors and add it in table, before search vectors + ''' + global vectors + if nb == 100: + add_vectors = vectors + else: + add_vectors = gen_vectors(nb, dim) + status, ids = connect.add_vectors(table, add_vectors) + sleep(add_interval_time) + return add_vectors, ids + + """ + Test search table with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + logging.getLogger().info(table_name) + top_k = 1 + nprobe = 1 + query_vecs = gen_vectors(1, dim) + status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs) + assert not status.OK() + + """ + Test search table with invalid top-k + """ + @pytest.fixture( + scope="function", + params=gen_invalid_top_ks() + ) + def get_top_k(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_top_k(self, connect, table, get_top_k): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = get_top_k + logging.getLogger().info(top_k) + nprobe = 1 + query_vecs = gen_vectors(1, dim) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + res = connect.server_version() + + @pytest.mark.level(2) + def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = get_top_k + logging.getLogger().info(top_k) + nprobe = 1 + query_vecs = gen_vectors(1, dim) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + res = connect.server_version() + + """ + Test search table with invalid nprobe + """ + @pytest.fixture( + scope="function", + params=gen_invalid_nprobes() + ) + def get_nprobes(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_with_invalid_nrpobe(self, connect, table, get_nprobes): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = get_nprobes + logging.getLogger().info(nprobe) + query_vecs = gen_vectors(1, dim) + if isinstance(nprobe, int) and nprobe > 0: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, top_k, nprobe, query_vecs) + + @pytest.mark.level(2) + def test_search_with_invalid_nrpobe_ip(self, connect, ip_table, get_nprobes): + ''' + target: test search fuction, with the wrong top_k + method: search with top_k + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = get_nprobes + logging.getLogger().info(nprobe) + query_vecs = gen_vectors(1, dim) + if isinstance(nprobe, int) and nprobe > 0: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs) + + """ + Test search table with invalid query ranges + """ + @pytest.fixture( + scope="function", + params=gen_invalid_query_ranges() + ) + def get_query_ranges(self, request): + yield request.param + + @pytest.mark.level(2) + def test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges): + ''' + target: test search fuction, with the wrong query_range + method: search with query_range + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + query_ranges = get_query_ranges + logging.getLogger().info(query_ranges) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges) + + + @pytest.mark.level(2) + def test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges): + ''' + target: test search fuction, with the wrong query_range + method: search with query_range + expected: raise an error, and the connection is normal + ''' + top_k = 1 + nprobe = 1 + query_vecs = [vectors[0]] + query_ranges = get_query_ranges + logging.getLogger().info(query_ranges) + with pytest.raises(Exception) as e: + status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges) + + +def check_result(result, id): + if len(result) >= 5: + return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id] + else: + return id in (i.id for i in result) \ No newline at end of file diff --git a/tests/milvus_python_test/test_table.py b/tests/milvus_python_test/test_table.py new file mode 100644 index 0000000000..c92afb023e --- /dev/null +++ b/tests/milvus_python_test/test_table.py @@ -0,0 +1,883 @@ +import random +import pdb +import pytest +import logging +import itertools + +from time import sleep +from multiprocessing import Process +import numpy +from milvus import Milvus +from milvus import IndexType, MetricType +from utils import * + +dim = 128 +delete_table_interval_time = 3 +index_file_size = 10 +vectors = gen_vectors(100, dim) + + +class TestTable: + + """ + ****************************************************************** + The following cases are used to test `create_table` function + ****************************************************************** + """ + + def test_create_table(self, connect): + ''' + target: test create normal table + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert status.OK() + + def test_create_table_ip(self, connect): + ''' + target: test create normal table + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + assert status.OK() + + @pytest.mark.level(2) + def test_create_table_without_connection(self, dis_connect): + ''' + target: test create table, without connection + method: create table with correct params, with a disconnected instance + expected: create raise exception + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = dis_connect.create_table(param) + + def test_create_table_existed(self, connect): + ''' + target: test create table but the table name have already existed + method: create table with the same table_name + expected: create status return not ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + status = connect.create_table(param) + assert not status.OK() + + @pytest.mark.level(2) + def test_create_table_existed_ip(self, connect): + ''' + target: test create table but the table name have already existed + method: create table with the same table_name + expected: create status return not ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + status = connect.create_table(param) + status = connect.create_table(param) + assert not status.OK() + + def test_create_table_None(self, connect): + ''' + target: test create table but the table name is None + method: create table, param table_name is None + expected: create raise error + ''' + param = {'table_name': None, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_create_table_no_dimension(self, connect): + ''' + target: test create table with no dimension params + method: create table with corrent params + expected: create status return ok + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_create_table_no_file_size(self, connect): + ''' + target: test create table with no index_file_size params + method: create table with corrent params + expected: create status return ok, use default 1024 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + logging.getLogger().info(status) + status, result = connect.describe_table(table_name) + logging.getLogger().info(result) + assert result.index_file_size == 1024 + + def test_create_table_no_metric_type(self, connect): + ''' + target: test create table with no metric_type params + method: create table with corrent params + expected: create status return ok, use default L2 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + status = connect.create_table(param) + status, result = connect.describe_table(table_name) + logging.getLogger().info(result) + assert result.metric_type == MetricType.L2 + + """ + ****************************************************************** + The following cases are used to test `describe_table` function + ****************************************************************** + """ + + def test_table_describe_result(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.table_name == table_name + assert res.metric_type == MetricType.L2 + + def test_table_describe_table_name_ip(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.table_name == table_name + assert res.metric_type == MetricType.IP + + # TODO: enable + @pytest.mark.level(2) + def _test_table_describe_table_name_multiprocessing(self, connect, args): + ''' + target: test describe table created with multiprocess + method: create table, assert the value returned by describe method + expected: table_name equals with the table name created + ''' + table_name = gen_unique_str("test_table") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + + def describetable(milvus): + status, res = milvus.describe_table(table_name) + assert res.table_name == table_name + + process_num = 4 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=describetable, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + @pytest.mark.level(2) + def test_table_describe_without_connection(self, table, dis_connect): + ''' + target: test describe table, without connection + method: describe table with correct params, with a disconnected instance + expected: describe raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.describe_table(table) + + def test_table_describe_dimension(self, connect): + ''' + target: test describe table created with correct params + method: create table, assert the dimention value returned by describe method + expected: dimention equals with dimention when created + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim+1, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, res = connect.describe_table(table_name) + assert res.dimension == dim+1 + + """ + ****************************************************************** + The following cases are used to test `delete_table` function + ****************************************************************** + """ + + def test_delete_table(self, connect, table): + ''' + target: test delete table created with correct params + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + status = connect.delete_table(table) + assert not connect.has_table(table) + + def test_delete_table_ip(self, connect, ip_table): + ''' + target: test delete table created with correct params + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + status = connect.delete_table(ip_table) + assert not connect.has_table(ip_table) + + @pytest.mark.level(2) + def test_table_delete_without_connection(self, table, dis_connect): + ''' + target: test describe table, without connection + method: describe table with correct params, with a disconnected instance + expected: describe raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.delete_table(table) + + def test_delete_table_not_existed(self, connect): + ''' + target: test delete table not in index + method: delete all tables, and delete table again, + assert the value returned by delete method + expected: status not ok + ''' + table_name = gen_unique_str("test_table") + status = connect.delete_table(table_name) + assert not status.code==0 + + def test_delete_table_repeatedly(self, connect): + ''' + target: test delete table created with correct params + method: create table and delete new table repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 1 + for i in range(loops): + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(1) + assert not connect.has_table(table_name) + + def test_delete_create_table_repeatedly(self, connect): + ''' + target: test delete and create the same table repeatedly + method: try to create the same table and delete repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 5 + for i in range(loops): + table_name = "test_table" + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(2) + assert status.OK() + + @pytest.mark.level(2) + def test_delete_create_table_repeatedly_ip(self, connect): + ''' + target: test delete and create the same table repeatedly + method: try to create the same table and delete repeatedly, + assert the value returned by delete method + expected: create ok and delete ok + ''' + loops = 5 + for i in range(loops): + table_name = "test_table" + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status = connect.delete_table(table_name) + time.sleep(2) + assert status.OK() + + # TODO: enable + @pytest.mark.level(2) + def _test_delete_table_multiprocessing(self, args): + ''' + target: test delete table with multiprocess + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + process_num = 6 + processes = [] + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + + def deletetable(milvus): + status = milvus.delete_table(table) + # assert not status.code==0 + assert milvus.has_table(table) + assert status.OK() + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=deletetable, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + # TODO: enable + @pytest.mark.level(2) + def _test_delete_table_multiprocessing_multitable(self, connect): + ''' + target: test delete table with multiprocess + method: create table and then delete, + assert the value returned by delete method + expected: status ok, and no table in tables + ''' + process_num = 5 + loop_num = 2 + processes = [] + + table = [] + j = 0 + while j < (process_num*loop_num): + table_name = gen_unique_str("test_delete_table_with_multiprocessing") + table.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + j = j + 1 + + def delete(connect,ids): + i = 0 + while i < loop_num: + # assert connect.has_table(table[ids*8+i]) + status = connect.delete_table(table[ids*process_num+i]) + time.sleep(2) + assert status.OK() + assert not connect.has_table(table[ids*process_num+i]) + i = i + 1 + + for i in range(process_num): + ids = i + p = Process(target=delete, args=(connect,ids)) + processes.append(p) + p.start() + for p in processes: + p.join() + + """ + ****************************************************************** + The following cases are used to test `has_table` function + ****************************************************************** + """ + + def test_has_table(self, connect): + ''' + target: test if the created table existed + method: create table, assert the value returned by has_table method + expected: True + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + assert connect.has_table(table_name) + + def test_has_table_ip(self, connect): + ''' + target: test if the created table existed + method: create table, assert the value returned by has_table method + expected: True + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + assert connect.has_table(table_name) + + @pytest.mark.level(2) + def test_has_table_without_connection(self, table, dis_connect): + ''' + target: test has table, without connection + method: calling has table with correct params, with a disconnected instance + expected: has table raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.has_table(table) + + def test_has_table_not_existed(self, connect): + ''' + target: test if table not created + method: random a table name, which not existed in db, + assert the value returned by has_table method + expected: False + ''' + table_name = gen_unique_str("test_table") + assert not connect.has_table(table_name) + + """ + ****************************************************************** + The following cases are used to test `show_tables` function + ****************************************************************** + """ + + def test_show_tables(self, connect): + ''' + target: test show tables is correct or not, if table created + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + status, result = connect.show_tables() + assert status.OK() + assert table_name in result + + def test_show_tables_ip(self, connect): + ''' + target: test show tables is correct or not, if table created + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + status, result = connect.show_tables() + assert status.OK() + assert table_name in result + + @pytest.mark.level(2) + def test_show_tables_without_connection(self, dis_connect): + ''' + target: test show_tables, without connection + method: calling show_tables with correct params, with a disconnected instance + expected: show_tables raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.show_tables() + + def test_show_tables_no_table(self, connect): + ''' + target: test show tables is correct or not, if no table in db + method: delete all tables, + assert the value returned by show_tables method is equal to [] + expected: the status is ok, and the result is equal to [] + ''' + status, result = connect.show_tables() + if result: + for table_name in result: + connect.delete_table(table_name) + time.sleep(delete_table_interval_time) + status, result = connect.show_tables() + assert status.OK() + assert len(result) == 0 + + # TODO: enable + @pytest.mark.level(2) + def _test_show_tables_multiprocessing(self, connect, args): + ''' + target: test show tables is correct or not with processes + method: create table, assert the value returned by show_tables method is equal to 0 + expected: table_name in show tables + ''' + table_name = gen_unique_str("test_table") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + + def showtables(milvus): + status, result = milvus.show_tables() + assert status.OK() + assert table_name in result + + process_num = 8 + processes = [] + + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=showtables, args=(milvus,)) + processes.append(p) + p.start() + for p in processes: + p.join() + + """ + ****************************************************************** + The following cases are used to test `preload_table` function + ****************************************************************** + """ + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + @pytest.mark.level(1) + def test_preload_table(self, connect, table, get_index_params): + index_params = get_index_params + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status = connect.preload_table(table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_ip(self, connect, ip_table, get_index_params): + index_params = get_index_params + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status = connect.preload_table(ip_table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_not_existed(self, connect, table): + table_name = gen_unique_str("test_preload_table_not_existed") + index_params = random.choice(gen_index_params()) + status, ids = connect.add_vectors(table, vectors) + status = connect.create_index(table, index_params) + status = connect.preload_table(table_name) + assert not status.OK() + + @pytest.mark.level(1) + def test_preload_table_not_existed_ip(self, connect, ip_table): + table_name = gen_unique_str("test_preload_table_not_existed") + index_params = random.choice(gen_index_params()) + status, ids = connect.add_vectors(ip_table, vectors) + status = connect.create_index(ip_table, index_params) + status = connect.preload_table(table_name) + assert not status.OK() + + @pytest.mark.level(1) + def test_preload_table_no_vectors(self, connect, table): + status = connect.preload_table(table) + assert status.OK() + + @pytest.mark.level(1) + def test_preload_table_no_vectors_ip(self, connect, ip_table): + status = connect.preload_table(ip_table) + assert status.OK() + + # TODO: psutils get memory usage + @pytest.mark.level(1) + def test_preload_table_memory_usage(self, connect, table): + pass + + +class TestTableInvalid(object): + """ + Test creating table with invalid table names + """ + @pytest.fixture( + scope="function", + params=gen_invalid_table_names() + ) + def get_table_name(self, request): + yield request.param + + def test_create_table_with_invalid_tablename(self, connect, get_table_name): + table_name = get_table_name + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + assert not status.OK() + + def test_create_table_with_empty_tablename(self, connect): + table_name = '' + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + def test_preload_table_with_invalid_tablename(self, connect): + table_name = '' + with pytest.raises(Exception) as e: + status = connect.preload_table(table_name) + + +class TestCreateTableDimInvalid(object): + """ + Test creating table with invalid dimension + """ + @pytest.fixture( + scope="function", + params=gen_invalid_dims() + ) + def get_dim(self, request): + yield request.param + + @pytest.mark.timeout(5) + def test_create_table_with_invalid_dimension(self, connect, get_dim): + dimension = get_dim + table = gen_unique_str("test_create_table_with_invalid_dimension") + param = {'table_name': table, + 'dimension': dimension, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + if isinstance(dimension, int) and dimension > 0: + status = connect.create_table(param) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +# TODO: max / min index file size +class TestCreateTableIndexSizeInvalid(object): + """ + Test creating tables with invalid index_file_size + """ + @pytest.fixture( + scope="function", + params=gen_invalid_file_sizes() + ) + def get_file_size(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_table_with_invalid_file_size(self, connect, table, get_file_size): + file_size = get_file_size + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': file_size, + 'metric_type': MetricType.L2} + if isinstance(file_size, int) and file_size > 0: + status = connect.create_table(param) + assert not status.OK() + else: + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +class TestCreateMetricTypeInvalid(object): + """ + Test creating tables with invalid metric_type + """ + @pytest.fixture( + scope="function", + params=gen_invalid_metric_types() + ) + def get_metric_type(self, request): + yield request.param + + @pytest.mark.level(2) + def test_create_table_with_invalid_file_size(self, connect, table, get_metric_type): + metric_type = get_metric_type + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': 10, + 'metric_type': metric_type} + with pytest.raises(Exception) as e: + status = connect.create_table(param) + + +def create_table(connect, **params): + param = {'table_name': params["table_name"], + 'dimension': params["dimension"], + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + status = connect.create_table(param) + return status + +def search_table(connect, **params): + status, result = connect.search_vectors( + params["table_name"], + params["top_k"], + params["nprobe"], + params["query_vectors"]) + return status + +def preload_table(connect, **params): + status = connect.preload_table(params["table_name"]) + return status + +def has(connect, **params): + status = connect.has_table(params["table_name"]) + return status + +def show(connect, **params): + status, result = connect.show_tables() + return status + +def delete(connect, **params): + status = connect.delete_table(params["table_name"]) + return status + +def describe(connect, **params): + status, result = connect.describe_table(params["table_name"]) + return status + +def rowcount(connect, **params): + status, result = connect.get_table_row_count(params["table_name"]) + return status + +def create_index(connect, **params): + status = connect.create_index(params["table_name"], params["index_params"]) + return status + +func_map = { + # 0:has, + 1:show, + 10:create_table, + 11:describe, + 12:rowcount, + 13:search_table, + 14:preload_table, + 15:create_index, + 30:delete +} + +def gen_sequence(): + raw_seq = func_map.keys() + result = itertools.permutations(raw_seq) + for x in result: + yield x + +class TestTableLogic(object): + + @pytest.mark.parametrize("logic_seq", gen_sequence()) + @pytest.mark.level(2) + def test_logic(self, connect, logic_seq): + if self.is_right(logic_seq): + self.execute(logic_seq, connect) + else: + self.execute_with_error(logic_seq, connect) + + def is_right(self, seq): + if sorted(seq) == seq: + return True + + not_created = True + has_deleted = False + for i in range(len(seq)): + if seq[i] > 10 and not_created: + return False + elif seq [i] > 10 and has_deleted: + return False + elif seq[i] == 10: + not_created = False + elif seq[i] == 30: + has_deleted = True + + return True + + def execute(self, logic_seq, connect): + basic_params = self.gen_params() + for i in range(len(logic_seq)): + # logging.getLogger().info(logic_seq[i]) + f = func_map[logic_seq[i]] + status = f(connect, **basic_params) + assert status.OK() + + def execute_with_error(self, logic_seq, connect): + basic_params = self.gen_params() + + error_flag = False + for i in range(len(logic_seq)): + f = func_map[logic_seq[i]] + status = f(connect, **basic_params) + if not status.OK(): + # logging.getLogger().info(logic_seq[i]) + error_flag = True + break + assert error_flag == True + + def gen_params(self): + table_name = gen_unique_str("test_table") + top_k = 1 + vectors = gen_vectors(2, dim) + param = {'table_name': table_name, + 'dimension': dim, + 'index_type': IndexType.IVFLAT, + 'metric_type': MetricType.L2, + 'nprobe': 1, + 'top_k': top_k, + 'index_params': { + 'index_type': IndexType.IVF_SQ8, + 'nlist': 16384 + }, + 'query_vectors': vectors} + return param diff --git a/tests/milvus_python_test/test_table_count.py b/tests/milvus_python_test/test_table_count.py new file mode 100644 index 0000000000..af9ae185ab --- /dev/null +++ b/tests/milvus_python_test/test_table_count.py @@ -0,0 +1,296 @@ +import random +import pdb + +import pytest +import logging +import itertools + +from time import sleep +from multiprocessing import Process +from milvus import Milvus +from utils import * +from milvus import IndexType, MetricType + +dim = 128 +index_file_size = 10 +add_time_interval = 5 + + +class TestTableCount: + """ + params means different nb, the nb value may trigger merge, or not + """ + @pytest.fixture( + scope="function", + params=[ + 100, + 5000, + 100000, + ], + ) + def add_vectors_nb(self, request): + yield request.param + + """ + generate valid create_index params + """ + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + def test_table_rows_count(self, connect, table, add_vectors_nb): + ''' + target: test table rows_count is correct or not + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nb = add_vectors_nb + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + status, res = connect.get_table_row_count(table) + assert res == nb + + def test_table_rows_count_after_index_created(self, connect, table, get_index_params): + ''' + target: test get_table_row_count, after index have been created + method: add vectors in db, and create index, then calling get_table_row_count with correct params + expected: get_table_row_count raise exception + ''' + nb = 100 + index_params = get_index_params + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + # logging.getLogger().info(index_params) + connect.create_index(table, index_params) + status, res = connect.get_table_row_count(table) + assert res == nb + + @pytest.mark.level(2) + def test_count_without_connection(self, table, dis_connect): + ''' + target: test get_table_row_count, without connection + method: calling get_table_row_count with correct params, with a disconnected instance + expected: get_table_row_count raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.get_table_row_count(table) + + def test_table_rows_count_no_vectors(self, connect, table): + ''' + target: test table rows_count is correct or not, if table is empty + method: create table and no vectors in it, + assert the value returned by get_table_row_count method is equal to 0 + expected: the count is equal to 0 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + connect.create_table(param) + status, res = connect.get_table_row_count(table) + assert res == 0 + + # TODO: enable + @pytest.mark.level(2) + @pytest.mark.timeout(20) + def _test_table_rows_count_multiprocessing(self, connect, table, args): + ''' + target: test table rows_count is correct or not with multiprocess + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 2 + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + vectors = gen_vectors(nq, dim) + res = connect.add_vectors(table_name=table, records=vectors) + time.sleep(add_time_interval) + + def rows_count(milvus): + status, res = milvus.get_table_row_count(table) + logging.getLogger().info(status) + assert res == nq + + process_num = 8 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=rows_count, args=(milvus, )) + processes.append(p) + p.start() + logging.getLogger().info(p) + for p in processes: + p.join() + + def test_table_rows_count_multi_tables(self, connect): + ''' + target: test table rows_count is correct or not with multiple tables of L2 + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_table_rows_count_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.L2} + connect.create_table(param) + res = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + for i in range(50): + status, res = connect.get_table_row_count(table_list[i]) + assert status.OK() + assert res == nq + + +class TestTableCountIP: + """ + params means different nb, the nb value may trigger merge, or not + """ + + @pytest.fixture( + scope="function", + params=[ + 100, + 5000, + 100000, + ], + ) + def add_vectors_nb(self, request): + yield request.param + + """ + generate valid create_index params + """ + + @pytest.fixture( + scope="function", + params=gen_index_params() + ) + def get_index_params(self, request): + yield request.param + + def test_table_rows_count(self, connect, ip_table, add_vectors_nb): + ''' + target: test table rows_count is correct or not + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nb = add_vectors_nb + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + status, res = connect.get_table_row_count(ip_table) + assert res == nb + + def test_table_rows_count_after_index_created(self, connect, ip_table, get_index_params): + ''' + target: test get_table_row_count, after index have been created + method: add vectors in db, and create index, then calling get_table_row_count with correct params + expected: get_table_row_count raise exception + ''' + nb = 100 + index_params = get_index_params + vectors = gen_vectors(nb, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + # logging.getLogger().info(index_params) + connect.create_index(ip_table, index_params) + status, res = connect.get_table_row_count(ip_table) + assert res == nb + + @pytest.mark.level(2) + def test_count_without_connection(self, ip_table, dis_connect): + ''' + target: test get_table_row_count, without connection + method: calling get_table_row_count with correct params, with a disconnected instance + expected: get_table_row_count raise exception + ''' + with pytest.raises(Exception) as e: + status = dis_connect.get_table_row_count(ip_table) + + def test_table_rows_count_no_vectors(self, connect, ip_table): + ''' + target: test table rows_count is correct or not, if table is empty + method: create table and no vectors in it, + assert the value returned by get_table_row_count method is equal to 0 + expected: the count is equal to 0 + ''' + table_name = gen_unique_str("test_table") + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size} + connect.create_table(param) + status, res = connect.get_table_row_count(ip_table) + assert res == 0 + + # TODO: enable + @pytest.mark.level(2) + @pytest.mark.timeout(20) + def _test_table_rows_count_multiprocessing(self, connect, ip_table, args): + ''' + target: test table rows_count is correct or not with multiprocess + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 2 + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + vectors = gen_vectors(nq, dim) + res = connect.add_vectors(table_name=ip_table, records=vectors) + time.sleep(add_time_interval) + + def rows_count(milvus): + status, res = milvus.get_table_row_count(ip_table) + logging.getLogger().info(status) + assert res == nq + + process_num = 8 + processes = [] + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=rows_count, args=(milvus,)) + processes.append(p) + p.start() + logging.getLogger().info(p) + for p in processes: + p.join() + + def test_table_rows_count_multi_tables(self, connect): + ''' + target: test table rows_count is correct or not with multiple tables of IP + method: create table and add vectors in it, + assert the value returned by get_table_row_count method is equal to length of vectors + expected: the count is equal to the length of vectors + ''' + nq = 100 + vectors = gen_vectors(nq, dim) + table_list = [] + for i in range(50): + table_name = gen_unique_str('test_table_rows_count_multi_tables') + table_list.append(table_name) + param = {'table_name': table_name, + 'dimension': dim, + 'index_file_size': index_file_size, + 'metric_type': MetricType.IP} + connect.create_table(param) + res = connect.add_vectors(table_name=table_name, records=vectors) + time.sleep(2) + for i in range(50): + status, res = connect.get_table_row_count(table_list[i]) + assert status.OK() + assert res == nq \ No newline at end of file diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py new file mode 100644 index 0000000000..6d2c56dbed --- /dev/null +++ b/tests/milvus_python_test/utils.py @@ -0,0 +1,545 @@ +# STL imports +import random +import string +import struct +import sys +import time, datetime +import copy +import numpy as np +from utils import * +from milvus import Milvus, IndexType, MetricType + + +def gen_inaccuracy(num): + return num/255.0 + +def gen_vectors(num, dim): + return [[random.random() for _ in range(dim)] for _ in range(num)] + + +def gen_single_vector(dim): + return [[random.random() for _ in range(dim)]] + + +def gen_vector(nb, d, seed=np.random.RandomState(1234)): + xb = seed.rand(nb, d).astype("float32") + return xb.tolist() + + +def gen_unique_str(str=None): + prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8)) + return prefix if str is None else str + "_" + prefix + + +def get_current_day(): + return time.strftime('%Y-%m-%d', time.localtime()) + + +def get_last_day(day): + tmp = datetime.datetime.now()-datetime.timedelta(days=day) + return tmp.strftime('%Y-%m-%d') + + +def get_next_day(day): + tmp = datetime.datetime.now()+datetime.timedelta(days=day) + return tmp.strftime('%Y-%m-%d') + + +def gen_long_str(num): + string = '' + for _ in range(num): + char = random.choice('tomorrow') + string += char + + +def gen_invalid_ips(): + ips = [ + "255.0.0.0", + "255.255.0.0", + "255.255.255.0", + "255.255.255.255", + "127.0.0", + "123.0.0.2", + "12-s", + " ", + "12 s", + "BB。A", + " siede ", + "(mn)", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return ips + + +def gen_invalid_ports(): + ports = [ + # empty + " ", + -1, + # too big port + 100000, + # not correct port + 39540, + "BB。A", + " siede ", + "(mn)", + "\n", + "\t", + "中文" + ] + return ports + + +def gen_invalid_uris(): + ip = None + port = 19530 + + uris = [ + " ", + "中文", + + # invalid protocol + # "tc://%s:%s" % (ip, port), + # "tcp%s:%s" % (ip, port), + + # # invalid port + # "tcp://%s:100000" % ip, + # "tcp://%s: " % ip, + # "tcp://%s:19540" % ip, + # "tcp://%s:-1" % ip, + # "tcp://%s:string" % ip, + + # invalid ip + "tcp:// :%s" % port, + "tcp://123.0.0.1:%s" % port, + "tcp://127.0.0:%s" % port, + "tcp://255.0.0.0:%s" % port, + "tcp://255.255.0.0:%s" % port, + "tcp://255.255.255.0:%s" % port, + "tcp://255.255.255.255:%s" % port, + "tcp://\n:%s" % port, + + ] + return uris + + +def gen_invalid_table_names(): + table_names = [ + "12-s", + "12/s", + " ", + # "", + # None, + "12 s", + "BB。A", + "c|c", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return table_names + + +def gen_invalid_top_ks(): + top_ks = [ + 0, + -1, + None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return top_ks + + +def gen_invalid_dims(): + dims = [ + 0, + -1, + 100001, + 1000000000000001, + None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return dims + + +def gen_invalid_file_sizes(): + file_sizes = [ + 0, + -1, + 1000000000000001, + None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return file_sizes + + +def gen_invalid_index_types(): + invalid_types = [ + 0, + -1, + 100, + 1000000000000001, + # None, + False, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return invalid_types + + +def gen_invalid_nlists(): + nlists = [ + 0, + -1, + 1000000000000001, + # None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return nlists + + +def gen_invalid_nprobes(): + nprobes = [ + 0, + -1, + 1000000000000001, + None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return nprobes + + +def gen_invalid_metric_types(): + metric_types = [ + 0, + -1, + 1000000000000001, + # None, + [1,2,3], + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文" + ] + return metric_types + + +def gen_invalid_vectors(): + invalid_vectors = [ + "1*2", + [], + [1], + [1,2], + [" "], + ['a'], + [None], + None, + (1,2), + {"a": 1}, + " ", + "", + "String", + "12-s", + "BB。A", + " siede ", + "(mn)", + "#12s", + "pip+", + "=c", + "\n", + "\t", + "中文", + "a".join("a" for i in range(256)) + ] + return invalid_vectors + + +def gen_invalid_vector_ids(): + invalid_vector_ids = [ + 1.0, + -1.0, + None, + # int 64 + 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, + " ", + "", + "String", + "BB。A", + " siede ", + "(mn)", + "#12s", + "=c", + "\n", + "中文", + ] + return invalid_vector_ids + + + +def gen_invalid_query_ranges(): + query_ranges = [ + [(get_last_day(1), "")], + [(get_current_day(), "")], + [(get_next_day(1), "")], + [(get_current_day(), get_last_day(1))], + [(get_next_day(1), get_last_day(1))], + [(get_next_day(1), get_current_day())], + [(0, get_next_day(1))], + [(-1, get_next_day(1))], + [(1, get_next_day(1))], + [(100001, get_next_day(1))], + [(1000000000000001, get_next_day(1))], + [(None, get_next_day(1))], + [([1,2,3], get_next_day(1))], + [((1,2), get_next_day(1))], + [({"a": 1}, get_next_day(1))], + [(" ", get_next_day(1))], + [("", get_next_day(1))], + [("String", get_next_day(1))], + [("12-s", get_next_day(1))], + [("BB。A", get_next_day(1))], + [(" siede ", get_next_day(1))], + [("(mn)", get_next_day(1))], + [("#12s", get_next_day(1))], + [("pip+", get_next_day(1))], + [("=c", get_next_day(1))], + [("\n", get_next_day(1))], + [("\t", get_next_day(1))], + [("中文", get_next_day(1))], + [("a".join("a" for i in range(256)), get_next_day(1))] + ] + return query_ranges + + +def gen_invalid_index_params(): + index_params = [] + for index_type in gen_invalid_index_types(): + index_param = {"index_type": index_type, "nlist": 16384} + index_params.append(index_param) + for nlist in gen_invalid_nlists(): + index_param = {"index_type": IndexType.IVFLAT, "nlist": nlist} + index_params.append(index_param) + return index_params + + +def gen_index_params(): + index_params = [] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + nlists = [1, 16384, 50000] + + def gen_params(index_types, nlists): + return [ {"index_type": index_type, "nlist": nlist} \ + for index_type in index_types \ + for nlist in nlists] + + return gen_params(index_types, nlists) + +def gen_simple_index_params(): + index_params = [] + index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H] + nlists = [16384] + + def gen_params(index_types, nlists): + return [ {"index_type": index_type, "nlist": nlist} \ + for index_type in index_types \ + for nlist in nlists] + + return gen_params(index_types, nlists) + +if __name__ == "__main__": + import numpy + + dim = 128 + nq = 10000 + table = "test" + + + file_name = '/poc/yuncong/ann_1000m/query.npy' + data = np.load(file_name) + vectors = data[0:nq].tolist() + # print(vectors) + + connect = Milvus() + # connect.connect(host="192.168.1.27") + # print(connect.show_tables()) + # print(connect.get_table_row_count(table)) + # sys.exit() + connect.connect(host="127.0.0.1") + connect.delete_table(table) + # sys.exit() + # time.sleep(2) + print(connect.get_table_row_count(table)) + param = {'table_name': table, + 'dimension': dim, + 'metric_type': MetricType.L2, + 'index_file_size': 10} + status = connect.create_table(param) + print(status) + print(connect.get_table_row_count(table)) + # add vectors + for i in range(10): + status, ids = connect.add_vectors(table, vectors) + print(status) + print(ids[0]) + # print(ids[0]) + index_params = {"index_type": IndexType.IVFLAT, "nlist": 16384} + status = connect.create_index(table, index_params) + print(status) + # sys.exit() + query_vec = [vectors[0]] + # print(numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) + top_k = 12 + nprobe = 1 + for i in range(2): + result = connect.search_vectors(table, top_k, nprobe, query_vec) + print(result) + sys.exit() + + + table = gen_unique_str("test_add_vector_with_multiprocessing") + uri = "tcp://%s:%s" % (args["ip"], args["port"]) + param = {'table_name': table, + 'dimension': dim, + 'index_file_size': index_file_size} + # create table + milvus = Milvus() + milvus.connect(uri=uri) + milvus.create_table(param) + vector = gen_single_vector(dim) + + process_num = 4 + loop_num = 10 + processes = [] + # with dependent connection + def add(milvus): + i = 0 + while i < loop_num: + status, ids = milvus.add_vectors(table, vector) + i = i + 1 + for i in range(process_num): + milvus = Milvus() + milvus.connect(uri=uri) + p = Process(target=add, args=(milvus,)) + processes.append(p) + p.start() + time.sleep(0.2) + for p in processes: + p.join() + time.sleep(3) + status, count = milvus.get_table_row_count(table) + assert count == process_num * loop_num \ No newline at end of file