diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0deb2dbeba..886db2ecc1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -69,6 +69,7 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-624 - Re-organize project directory for open-source
- MS-635 - Add compile option to support customized faiss
- MS-660 - add ubuntu_build_deps.sh
+- \#18 - Add all test cases
# Milvus 0.4.0 (2019-09-12)
diff --git a/tests/milvus-java-test/.gitignore b/tests/milvus-java-test/.gitignore
new file mode 100644
index 0000000000..3a553813a8
--- /dev/null
+++ b/tests/milvus-java-test/.gitignore
@@ -0,0 +1,4 @@
+target/
+.idea/
+test-output/
+lib/*
diff --git a/tests/milvus-java-test/bin/run.sh b/tests/milvus-java-test/bin/run.sh
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/milvus-java-test/ci/function/file_transfer.groovy b/tests/milvus-java-test/ci/function/file_transfer.groovy
new file mode 100644
index 0000000000..bebae14832
--- /dev/null
+++ b/tests/milvus-java-test/ci/function/file_transfer.groovy
@@ -0,0 +1,10 @@
+def FileTransfer (sourceFiles, remoteDirectory, remoteIP, protocol = "ftp", makeEmptyDirs = true) {
+ if (protocol == "ftp") {
+ ftpPublisher masterNodeName: '', paramPublish: [parameterName: ''], alwaysPublishFromMaster: false, continueOnError: false, failOnError: true, publishers: [
+ [configName: "${remoteIP}", transfers: [
+ [asciiMode: false, cleanRemote: false, excludes: '', flatten: false, makeEmptyDirs: "${makeEmptyDirs}", noDefaultExcludes: false, patternSeparator: '[, ]+', remoteDirectory: "${remoteDirectory}", remoteDirectorySDF: false, removePrefix: '', sourceFiles: "${sourceFiles}"]], usePromotionTimestamp: true, useWorkspaceInPromotion: false, verbose: true
+ ]
+ ]
+ }
+}
+return this
diff --git a/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy
new file mode 100644
index 0000000000..2e9332fa6e
--- /dev/null
+++ b/tests/milvus-java-test/ci/jenkinsfile/cleanup.groovy
@@ -0,0 +1,13 @@
+try {
+ def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
+ if (!result) {
+ sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
+ }
+} catch (exc) {
+ def result = sh script: "helm status ${env.JOB_NAME}-${env.BUILD_NUMBER}", returnStatus: true
+ if (!result) {
+ sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
+ }
+ throw exc
+}
+
diff --git a/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy
new file mode 100644
index 0000000000..6650b7c219
--- /dev/null
+++ b/tests/milvus-java-test/ci/jenkinsfile/deploy_server.groovy
@@ -0,0 +1,16 @@
+try {
+ sh 'helm init --client-only --skip-refresh --stable-repo-url https://kubernetes.oss-cn-hangzhou.aliyuncs.com/charts'
+ sh 'helm repo add milvus https://registry.zilliz.com/chartrepo/milvus'
+ sh 'helm repo update'
+ dir ("milvus-helm") {
+ checkout([$class: 'GitSCM', branches: [[name: "${HELM_BRANCH}"]], doGenerateSubmoduleConfigurations: false, extensions: [], submoduleCfg: [], userRemoteConfigs: [[credentialsId: "${params.GIT_USER}", url: "git@192.168.1.105:megasearch/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${HELM_BRANCH}:refs/remotes/origin/${HELM_BRANCH}"]]])
+ dir ("milvus/milvus-gpu") {
+ sh "helm install --wait --timeout 300 --set engine.image.tag=${IMAGE_TAG} --set expose.type=clusterIP --name ${env.JOB_NAME}-${env.BUILD_NUMBER} -f ci/values.yaml --namespace milvus-sdk-test --version 0.3.1 ."
+ }
+ }
+} catch (exc) {
+ echo 'Helm running failed!'
+ sh "helm del --purge ${env.JOB_NAME}-${env.BUILD_NUMBER}"
+ throw exc
+}
+
diff --git a/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy
new file mode 100644
index 0000000000..662c93bc77
--- /dev/null
+++ b/tests/milvus-java-test/ci/jenkinsfile/integration_test.groovy
@@ -0,0 +1,13 @@
+timeout(time: 30, unit: 'MINUTES') {
+ try {
+ dir ("milvus-java-test") {
+ sh "mvn clean install"
+ sh "java -cp \"target/MilvusSDkJavaTest-1.0-SNAPSHOT.jar:lib/*\" com.MainClass -h ${env.JOB_NAME}-${env.BUILD_NUMBER}-milvus-gpu-engine.milvus-sdk-test.svc.cluster.local"
+ }
+
+ } catch (exc) {
+ echo 'Milvus-SDK-Java Integration Test Failed !'
+ throw exc
+ }
+}
+
diff --git a/tests/milvus-java-test/ci/jenkinsfile/notify.groovy b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy
new file mode 100644
index 0000000000..0a257b8cd8
--- /dev/null
+++ b/tests/milvus-java-test/ci/jenkinsfile/notify.groovy
@@ -0,0 +1,15 @@
+def notify() {
+ if (!currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
+ // Send an email only if the build status has changed from green/unstable to red
+ emailext subject: '$DEFAULT_SUBJECT',
+ body: '$DEFAULT_CONTENT',
+ recipientProviders: [
+ [$class: 'DevelopersRecipientProvider'],
+ [$class: 'RequesterRecipientProvider']
+ ],
+ replyTo: '$DEFAULT_REPLYTO',
+ to: '$DEFAULT_RECIPIENTS'
+ }
+}
+return this
+
diff --git a/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy
new file mode 100644
index 0000000000..7e106c0296
--- /dev/null
+++ b/tests/milvus-java-test/ci/jenkinsfile/upload_unit_test_out.groovy
@@ -0,0 +1,13 @@
+timeout(time: 5, unit: 'MINUTES') {
+ dir ("${PROJECT_NAME}_test") {
+ if (fileExists('test_out')) {
+ def fileTransfer = load "${env.WORKSPACE}/ci/function/file_transfer.groovy"
+ fileTransfer.FileTransfer("test_out/", "${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}", 'nas storage')
+ if (currentBuild.resultIsBetterOrEqualTo('SUCCESS')) {
+ echo "Milvus Dev Test Out Viewer \"ftp://192.168.1.126/data/${PROJECT_NAME}/test/${JOB_NAME}-${BUILD_ID}\""
+ }
+ } else {
+ error("Milvus Dev Test Out directory don't exists!")
+ }
+ }
+}
diff --git a/tests/milvus-java-test/ci/main_jenkinsfile b/tests/milvus-java-test/ci/main_jenkinsfile
new file mode 100644
index 0000000000..5df9d61ccb
--- /dev/null
+++ b/tests/milvus-java-test/ci/main_jenkinsfile
@@ -0,0 +1,110 @@
+pipeline {
+ agent none
+
+ options {
+ timestamps()
+ }
+
+ environment {
+ SRC_BRANCH = "master"
+ IMAGE_TAG = "${params.IMAGE_TAG}-release"
+ HELM_BRANCH = "${params.IMAGE_TAG}"
+ TEST_URL = "git@192.168.1.105:Test/milvus-java-test.git"
+ TEST_BRANCH = "${params.IMAGE_TAG}"
+ }
+
+ stages {
+ stage("Setup env") {
+ agent {
+ kubernetes {
+ label 'dev-test'
+ defaultContainer 'jnlp'
+ yaml """
+ apiVersion: v1
+ kind: Pod
+ metadata:
+ labels:
+ app: milvus
+ componet: test
+ spec:
+ containers:
+ - name: milvus-testframework-java
+ image: registry.zilliz.com/milvus/milvus-java-test:v0.1
+ command:
+ - cat
+ tty: true
+ volumeMounts:
+ - name: kubeconf
+ mountPath: /root/.kube/
+ readOnly: true
+ volumes:
+ - name: kubeconf
+ secret:
+ secretName: test-cluster-config
+ """
+ }
+ }
+
+ stages {
+ stage("Deploy Server") {
+ steps {
+ gitlabCommitStatus(name: 'Deloy Server') {
+ container('milvus-testframework-java') {
+ script {
+ load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/deploy_server.groovy"
+ }
+ }
+ }
+ }
+ }
+ stage("Integration Test") {
+ steps {
+ gitlabCommitStatus(name: 'Integration Test') {
+ container('milvus-testframework-java') {
+ script {
+ print "In integration test stage"
+ load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/integration_test.groovy"
+ }
+ }
+ }
+ }
+ }
+ stage ("Cleanup Env") {
+ steps {
+ gitlabCommitStatus(name: 'Cleanup Env') {
+ container('milvus-testframework-java') {
+ script {
+ load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
+ }
+ }
+ }
+ }
+ }
+ }
+ post {
+ always {
+ container('milvus-testframework-java') {
+ script {
+ load "${env.WORKSPACE}/milvus-java-test/ci/jenkinsfile/cleanup.groovy"
+ }
+ }
+ }
+ success {
+ script {
+ echo "Milvus java-sdk test success !"
+ }
+ }
+ aborted {
+ script {
+ echo "Milvus java-sdk test aborted !"
+ }
+ }
+ failure {
+ script {
+ echo "Milvus java-sdk test failed !"
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml
new file mode 100644
index 0000000000..1381e1454f
--- /dev/null
+++ b/tests/milvus-java-test/ci/pod_containers/milvus-testframework.yaml
@@ -0,0 +1,13 @@
+apiVersion: v1
+kind: Pod
+metadata:
+ labels:
+ app: milvus
+ componet: testframework-java
+spec:
+ containers:
+ - name: milvus-testframework-java
+ image: maven:3.6.2-jdk-8
+ command:
+ - cat
+ tty: true
diff --git a/tests/milvus-java-test/milvus-java-test.iml b/tests/milvus-java-test/milvus-java-test.iml
new file mode 100644
index 0000000000..78b2cc53b2
--- /dev/null
+++ b/tests/milvus-java-test/milvus-java-test.iml
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/tests/milvus-java-test/pom.xml b/tests/milvus-java-test/pom.xml
new file mode 100644
index 0000000000..db02ff2c00
--- /dev/null
+++ b/tests/milvus-java-test/pom.xml
@@ -0,0 +1,137 @@
+
+
+ 4.0.0
+
+ milvus
+ MilvusSDkJavaTest
+ 1.0-SNAPSHOT
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ copy-dependencies
+ package
+
+ copy-dependencies
+
+
+ lib
+ false
+ true
+
+
+
+
+
+
+
+
+ UTF-8
+ 1.23.0
+ 3.9.0
+ 3.9.0
+ 1.8
+ 1.8
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ oss.sonatype.org-snapshot
+ http://oss.sonatype.org/content/repositories/snapshots
+
+ false
+
+
+ true
+
+
+
+
+
+
+ org.apache.commons
+ commons-lang3
+ 3.4
+
+
+
+ commons-cli
+ commons-cli
+ 1.3
+
+
+
+ org.testng
+ testng
+ 6.10
+
+
+
+ junit
+ junit
+ 4.9
+
+
+
+
+
+
+
+
+
+ io.milvus
+ milvus-sdk-java
+ 0.1.1-SNAPSHOT
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/milvus-java-test/src/main/java/com/MainClass.java b/tests/milvus-java-test/src/main/java/com/MainClass.java
new file mode 100644
index 0000000000..0e61e5f3d0
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/MainClass.java
@@ -0,0 +1,147 @@
+package com;
+
+import io.milvus.client.*;
+import org.apache.commons.cli.*;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.testng.SkipException;
+import org.testng.TestNG;
+import org.testng.annotations.DataProvider;
+import org.testng.xml.XmlClass;
+import org.testng.xml.XmlSuite;
+import org.testng.xml.XmlTest;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class MainClass {
+ private static String host = "127.0.0.1";
+ private static String port = "19530";
+ public Integer index_file_size = 50;
+ public Integer dimension = 128;
+
+ public static void setHost(String host) {
+ MainClass.host = host;
+ }
+
+ public static void setPort(String port) {
+ MainClass.port = port;
+ }
+
+ @DataProvider(name="DefaultConnectArgs")
+ public static Object[][] defaultConnectArgs(){
+ return new Object[][]{{host, port}};
+ }
+
+ @DataProvider(name="ConnectInstance")
+ public Object[][] connectInstance(){
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ String tableName = RandomStringUtils.randomAlphabetic(10);
+ return new Object[][]{{client, tableName}};
+ }
+
+ @DataProvider(name="DisConnectInstance")
+ public Object[][] disConnectInstance(){
+ // Generate connection instance
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ try {
+ client.disconnect();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ String tableName = RandomStringUtils.randomAlphabetic(10);
+ return new Object[][]{{client, tableName}};
+ }
+
+ @DataProvider(name="Table")
+ public Object[][] provideTable(){
+ Object[][] tables = new Object[2][2];
+ MetricType metricTypes[] = { MetricType.L2, MetricType.IP };
+ for (Integer i = 0; i < metricTypes.length; ++i) {
+ String tableName = metricTypes[i].toString()+"_"+RandomStringUtils.randomAlphabetic(10);
+ // Generate connection instance
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(metricTypes[i])
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ Response res = client.createTable(tableSchemaParam);
+ if (!res.ok()) {
+ System.out.println(res.getMessage());
+ throw new SkipException("Table created failed");
+ }
+ tables[i] = new Object[]{client, tableName};
+ }
+ return tables;
+ }
+
+ public static void main(String[] args) {
+ CommandLineParser parser = new DefaultParser();
+ Options options = new Options();
+ options.addOption("h", "host", true, "milvus-server hostname/ip");
+ options.addOption("p", "port", true, "milvus-server port");
+ try {
+ CommandLine cmd = parser.parse(options, args);
+ String host = cmd.getOptionValue("host");
+ if (host != null) {
+ setHost(host);
+ }
+ String port = cmd.getOptionValue("port");
+ if (port != null) {
+ setPort(port);
+ }
+ System.out.println("Host: "+host+", Port: "+port);
+ }
+ catch(ParseException exp) {
+ System.err.println("Parsing failed. Reason: " + exp.getMessage() );
+ }
+
+// TestListenerAdapter tla = new TestListenerAdapter();
+// TestNG testng = new TestNG();
+// testng.setTestClasses(new Class[] { TestPing.class });
+// testng.setTestClasses(new Class[] { TestConnect.class });
+// testng.addListener(tla);
+// testng.run();
+
+ XmlSuite suite = new XmlSuite();
+ suite.setName("TmpSuite");
+
+ XmlTest test = new XmlTest(suite);
+ test.setName("TmpTest");
+ List classes = new ArrayList();
+
+ classes.add(new XmlClass("com.TestPing"));
+ classes.add(new XmlClass("com.TestAddVectors"));
+ classes.add(new XmlClass("com.TestConnect"));
+ classes.add(new XmlClass("com.TestDeleteVectors"));
+ classes.add(new XmlClass("com.TestIndex"));
+ classes.add(new XmlClass("com.TestSearchVectors"));
+ classes.add(new XmlClass("com.TestTable"));
+ classes.add(new XmlClass("com.TestTableCount"));
+
+ test.setXmlClasses(classes) ;
+
+ List suites = new ArrayList();
+ suites.add(suite);
+ TestNG tng = new TestNG();
+ tng.setXmlSuites(suites);
+ tng.run();
+
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestAddVectors.java b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java
new file mode 100644
index 0000000000..038b8d2a8d
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestAddVectors.java
@@ -0,0 +1,154 @@
+package com;
+
+import io.milvus.client.InsertParam;
+import io.milvus.client.InsertResponse;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.TableParam;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class TestAddVectors {
+ int dimension = 128;
+
+ public List> gen_vectors(Integer nb) {
+ List> xb = new LinkedList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ LinkedList vector = new LinkedList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ String tableNameNew = tableName + "_";
+ InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_add_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 100;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ Thread.currentThread().sleep(1000);
+ // Assert table row count
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_timeout(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 200000;
+ List> vectors = gen_vectors(nb);
+ System.out.println(new Date());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withTimeout(1).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_big_data(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 500000;
+ List> vectors = gen_vectors(nb);
+ System.out.println(new Date());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_with_ids(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ // Add vectors with ids
+ List vectorIds;
+ vectorIds = Stream.iterate(0L, n -> n)
+ .limit(nb)
+ .collect(Collectors.toList());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ Thread.currentThread().sleep(1000);
+ // Assert table row count
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
+ }
+
+ // TODO: MS-628
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_with_invalid_ids(MilvusClient client, String tableName) {
+ int nb = 10;
+ List> vectors = gen_vectors(nb);
+ // Add vectors with ids
+ List vectorIds;
+ vectorIds = Stream.iterate(0L, n -> n)
+ .limit(nb+1)
+ .collect(Collectors.toList());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_with_invalid_dimension(MilvusClient client, String tableName) {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ vectors.get(0).add((float) 0);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_with_invalid_vectors(MilvusClient client, String tableName) {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ vectors.set(0, new ArrayList<>());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_repeatably(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 100000;
+ int loops = 10;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = null;
+ for (int i = 0; i < loops; ++i ) {
+ long startTime = System.currentTimeMillis();
+ res = client.insert(insertParam);
+ long endTime = System.currentTimeMillis();
+ System.out.println("Total execution time: " + (endTime-startTime) + "ms");
+ }
+ Thread.currentThread().sleep(1000);
+ // Assert table row count
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb * loops);
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestConnect.java b/tests/milvus-java-test/src/main/java/com/TestConnect.java
new file mode 100644
index 0000000000..77b6fe6a33
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestConnect.java
@@ -0,0 +1,80 @@
+package com;
+
+import io.milvus.client.ConnectParam;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusGrpcClient;
+import io.milvus.client.Response;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+public class TestConnect {
+ @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
+ public void test_connect(String host, String port){
+ System.out.println("Host: "+host+", Port: "+port);
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ Response res = client.connect(connectParam);
+ assert(res.ok());
+ assert(client.connected());
+ }
+
+ @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
+ public void test_connect_repeat(String host, String port){
+ MilvusGrpcClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ Response res = client.connect(connectParam);
+ assert(!res.ok());
+ assert(client.connected());
+ }
+
+ @Test(dataProvider="InvalidConnectArgs")
+ public void test_connect_invalid_connect_args(String ip, String port) throws InterruptedException {
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(ip)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ assert(!client.connected());
+ }
+
+ // TODO: MS-615
+ @DataProvider(name="InvalidConnectArgs")
+ public Object[][] generate_invalid_connect_args() {
+ String port = "19530";
+ String ip = "";
+ return new Object[][]{
+ {"1.1.1.1", port},
+ {"255.255.0.0", port},
+ {"1.2.2", port},
+ {"中文", port},
+ {"www.baidu.com", "100000"},
+ {"127.0.0.1", "100000"},
+ {"www.baidu.com", "80"},
+ };
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_disconnect(MilvusClient client, String tableName){
+ assert(!client.connected());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_disconnect_repeatably(MilvusClient client, String tableNam){
+ Response res = null;
+ try {
+ res = client.disconnect();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ assert(res.ok());
+ assert(!client.connected());
+ }
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java
new file mode 100644
index 0000000000..69b5e41434
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestDeleteVectors.java
@@ -0,0 +1,122 @@
+package com;
+
+import io.milvus.client.*;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.*;
+
+public class TestDeleteVectors {
+ int index_file_size = 50;
+ int dimension = 128;
+
+ public List> gen_vectors(Integer nb) {
+ List> xb = new LinkedList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ LinkedList vector = new LinkedList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ public static Date getDeltaDate(int delta) {
+ Date today = new Date();
+ Calendar c = Calendar.getInstance();
+ c.setTime(today);
+ c.add(Calendar.DAY_OF_MONTH, delta);
+ return c.getTime();
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ // Add vectors
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ Thread.sleep(1000);
+ DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(res_delete.ok());
+ Thread.sleep(1000);
+ // Assert table row count
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ String tableNameNew = tableName + "_";
+ DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableNameNew).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(!res_delete.ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(!res_delete.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_table_empty(MilvusClient client, String tableName) throws InterruptedException {
+ DateRange dateRange = new DateRange(getDeltaDate(-1), getDeltaDate(1));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(res_delete.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_invalid_date_range(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 100;
+ List> vectors = gen_vectors(nb);
+ // Add vectors
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ Thread.sleep(1000);
+ DateRange dateRange = new DateRange(getDeltaDate(1), getDeltaDate(0));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(!res_delete.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_invalid_date_range_1(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 100;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ DateRange dateRange = new DateRange(getDeltaDate(2), getDeltaDate(-1));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(!res_delete.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_delete_vectors_no_result(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 100;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ InsertResponse res = client.insert(insertParam);
+ assert(res.getResponse().ok());
+ Thread.sleep(1000);
+ DateRange dateRange = new DateRange(getDeltaDate(-3), getDeltaDate(-2));
+ DeleteByRangeParam param = new DeleteByRangeParam.Builder(dateRange, tableName).build();
+ Response res_delete = client.deleteByRange(param);
+ assert(res_delete.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestIndex.java b/tests/milvus-java-test/src/main/java/com/TestIndex.java
new file mode 100644
index 0000000000..d003771b0b
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestIndex.java
@@ -0,0 +1,340 @@
+package com;
+
+import io.milvus.client.*;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.Date;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+public class TestIndex {
+ int index_file_size = 10;
+ int dimension = 128;
+ int n_list = 1024;
+ int default_n_list = 16384;
+ int nb = 100000;
+ IndexType indexType = IndexType.IVF_SQ8;
+ IndexType defaultIndexType = IndexType.FLAT;
+
+ public List> gen_vectors(Integer nb) {
+ List> xb = new LinkedList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ LinkedList vector = new LinkedList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), n_list);
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 500000;
+ IndexType indexType = IndexType.IVF_SQ8;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ System.out.println(new Date());
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).withTimeout(1).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(!res_create.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVFLAT;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_IVFSQ8H(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8_H;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_with_no_vector(MilvusClient client, String tableName) {
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ String tableNameNew = tableName + "_";
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableNameNew).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(!res_create.ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_create_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(!res_create.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_create_index_invalid_n_list(MilvusClient client, String tableName) throws InterruptedException {
+ int n_list = 0;
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(!res_create.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_describe_index(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), n_list);
+ Assert.assertEquals(index1.getIndexType(), indexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_alter_index(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ // Create another index
+ IndexType indexTypeNew = IndexType.IVFLAT;
+ int n_list_new = n_list + 1;
+ Index index_new = new Index.Builder().withIndexType(indexTypeNew)
+ .withNList(n_list_new)
+ .build();
+ CreateIndexParam createIndexParamNew = new CreateIndexParam.Builder(tableName).withIndex(index_new).build();
+ Response res_create_new = client.createIndex(createIndexParamNew);
+ assert(res_create_new.ok());
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res_create.ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), n_list_new);
+ Assert.assertEquals(index1.getIndexType(), indexTypeNew);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_describe_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ String tableNameNew = tableName + "_";
+ TableParam tableParam = new TableParam.Builder(tableNameNew).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_describe_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_index(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(defaultIndexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res_drop = client.dropIndex(tableParam);
+ assert(res_drop.ok());
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), default_n_list);
+ Assert.assertEquals(index1.getIndexType(), defaultIndexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_index_repeatably(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(defaultIndexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ Response res_create = client.createIndex(createIndexParam);
+ assert(res_create.ok());
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res_drop = client.dropIndex(tableParam);
+ res_drop = client.dropIndex(tableParam);
+ assert(res_drop.ok());
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), default_n_list);
+ Assert.assertEquals(index1.getIndexType(), defaultIndexType);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_index_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ String tableNameNew = tableName + "_";
+ TableParam tableParam = new TableParam.Builder(tableNameNew).build();
+ Response res_drop = client.dropIndex(tableParam);
+ assert(!res_drop.ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_drop_index_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res_drop = client.dropIndex(tableParam);
+ assert(!res_drop.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_index_no_index_created(MilvusClient client, String tableName) throws InterruptedException {
+ List> vectors = gen_vectors(nb);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res_drop = client.dropIndex(tableParam);
+ assert(res_drop.ok());
+ DescribeIndexResponse res = client.describeIndex(tableParam);
+ assert(res.getResponse().ok());
+ Index index1 = res.getIndex().get();
+ Assert.assertEquals(index1.getNList(), default_n_list);
+ Assert.assertEquals(index1.getIndexType(), defaultIndexType);
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestMix.java b/tests/milvus-java-test/src/main/java/com/TestMix.java
new file mode 100644
index 0000000000..795ab8630e
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestMix.java
@@ -0,0 +1,221 @@
+package com;
+
+import io.milvus.client.*;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+public class TestMix {
+ private int dimension = 128;
+ private int nb = 100000;
+ int n_list = 8192;
+ int n_probe = 20;
+ int top_k = 10;
+ double epsilon = 0.001;
+ int index_file_size = 20;
+
+ public List normalize(List w2v){
+ float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
+ final float norm = (float) Math.sqrt(squareSum);
+ w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
+ return w2v;
+ }
+
+ public List> gen_vectors(int nb, boolean norm) {
+ List> xb = new ArrayList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ List vector = new ArrayList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ if (norm == true) {
+ vector = normalize(vector);
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
+ int thread_num = 10;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ });
+ }
+ executor.awaitQuiescence(100, TimeUnit.SECONDS);
+ executor.shutdown();
+ }
+
+ @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
+ public void test_connect_threads(String host, String port) throws InterruptedException {
+ int thread_num = 100;
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ assert(client.connected());
+ try {
+ client.disconnect();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ });
+ }
+ executor.awaitQuiescence(100, TimeUnit.SECONDS);
+ executor.shutdown();
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
+ int thread_num = 10;
+ List> vectors = gen_vectors(nb,false);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ InsertResponse res_insert = client.insert(insertParam);
+ assert (res_insert.getResponse().ok());
+ });
+ }
+ executor.awaitQuiescence(100, TimeUnit.SECONDS);
+ executor.shutdown();
+
+ Thread.sleep(2000);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
+ Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
+ int thread_num = 50;
+ List> vectors = gen_vectors(nb,false);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ InsertResponse res_insert = client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ assert (res_insert.getResponse().ok());
+ });
+ }
+ executor.awaitQuiescence(300, TimeUnit.SECONDS);
+ executor.shutdown();
+ Thread.sleep(2000);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
+ Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException {
+ int thread_num = 50;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, true);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ InsertResponse res_insert = client.insert(insertParam);
+ assert (res_insert.getResponse().ok());
+ try {
+ TimeUnit.SECONDS.sleep(1);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ List> res = client.search(searchParam).getQueryResultsList();
+ double distance = res.get(0).get(0).getDistance();
+ if (tableName.startsWith("L2")) {
+ Assert.assertEquals(distance, 0.0, epsilon);
+ }else if (tableName.startsWith("IP")) {
+ Assert.assertEquals(distance, 1.0, epsilon);
+ }
+ });
+ }
+ executor.awaitQuiescence(300, TimeUnit.SECONDS);
+ executor.shutdown();
+ Thread.sleep(2000);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam);
+ Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb);
+ }
+
+ @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
+ public void test_create_insert_delete_threads(String host, String port) throws InterruptedException {
+ int thread_num = 100;
+ List> vectors = gen_vectors(nb,false);
+ ForkJoinPool executor = new ForkJoinPool();
+ for (int i = 0; i < thread_num; i++) {
+ executor.execute(
+ () -> {
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ String tableName = RandomStringUtils.randomAlphabetic(10);
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.IP)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ client.createTable(tableSchemaParam);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response response = client.dropTable(tableParam);
+ Assert.assertTrue(response.ok());
+ try {
+ client.disconnect();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ });
+ }
+ executor.awaitQuiescence(100, TimeUnit.SECONDS);
+ executor.shutdown();
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestPing.java b/tests/milvus-java-test/src/main/java/com/TestPing.java
new file mode 100644
index 0000000000..46850f4a17
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestPing.java
@@ -0,0 +1,28 @@
+package com;
+
+import io.milvus.client.ConnectParam;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusGrpcClient;
+import io.milvus.client.Response;
+import org.testng.annotations.Test;
+
+public class TestPing {
+ @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class)
+ public void test_server_status(String host, String port){
+ System.out.println("Host: "+host+", Port: "+port);
+ MilvusClient client = new MilvusGrpcClient();
+ ConnectParam connectParam = new ConnectParam.Builder()
+ .withHost(host)
+ .withPort(port)
+ .build();
+ client.connect(connectParam);
+ Response res = client.serverStatus();
+ assert (res.ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_server_status_without_connected(MilvusGrpcClient client, String tableName){
+ Response res = client.serverStatus();
+ assert (!res.ok());
+ }
+}
\ No newline at end of file
diff --git a/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java
new file mode 100644
index 0000000000..c574298652
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestSearchVectors.java
@@ -0,0 +1,480 @@
+package com;
+
+import io.milvus.client.*;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class TestSearchVectors {
+ int index_file_size = 10;
+ int dimension = 128;
+ int n_list = 1024;
+ int default_n_list = 16384;
+ int nb = 100000;
+ int n_probe = 20;
+ int top_k = 10;
+ double epsilon = 0.001;
+ IndexType indexType = IndexType.IVF_SQ8;
+ IndexType defaultIndexType = IndexType.FLAT;
+
+
+ public List normalize(List w2v){
+ float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
+ final float norm = (float) Math.sqrt(squareSum);
+ w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
+ return w2v;
+ }
+
+ public List> gen_vectors(int nb, boolean norm) {
+ List> xb = new ArrayList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ List vector = new ArrayList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ if (norm == true) {
+ vector = normalize(vector);
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ public static Date getDeltaDate(int delta) {
+ Date today = new Date();
+ Calendar c = Calendar.getInstance();
+ c.setTime(today);
+ c.add(Calendar.DAY_OF_MONTH, delta);
+ return c.getTime();
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ String tableNameNew = tableName + "_";
+ int nq = 5;
+ int nb = 100;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ SearchParam searchParam = new SearchParam.Builder(tableNameNew, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVFLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_ids_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVFLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, true);
+ List> queryVectors = vectors.subList(0,nq);
+ List vectorIds;
+ vectorIds = Stream.iterate(0L, n -> n)
+ .limit(nb)
+ .collect(Collectors.toList());
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).withVectorIds(vectorIds).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.get(0).get(0).getVectorId(), 0L);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVFLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_distance_IVFLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVFLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, true);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ double distance = res_search.get(0).get(0).getDistance();
+ if (tableName.startsWith("L2")) {
+ Assert.assertEquals(distance, 0.0, epsilon);
+ }else if (tableName.startsWith("IP")) {
+ Assert.assertEquals(distance, 1.0, epsilon);
+ }
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_distance_IVFSQ8(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ int nb = 1000;
+ List> vectors = gen_vectors(nb, true);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(default_n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getResultDistancesList();
+ for (int i = 0; i < nq; i++) {
+ double distance = res_search.get(i).get(0);
+ System.out.println(distance);
+ if (tableName.startsWith("L2")) {
+ Assert.assertEquals(distance, 0.0, epsilon);
+ }else if (tableName.startsWith("IP")) {
+ Assert.assertEquals(distance, 1.0, epsilon);
+ }
+ }
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_FLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_FLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res_search.size(), nq);
+ Assert.assertEquals(res_search.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_FLAT_timeout(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ int nb = 100000;
+ int nq = 1000;
+ int top_k = 2048;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withTimeout(1).build();
+ System.out.println(new Date());
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_FLAT_big_data_size(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ int nb = 100000;
+ int nq = 2000;
+ int top_k = 2048;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ System.out.println(new Date());
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_distance_FLAT(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.FLAT;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, true);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build();
+ List> res_search = client.search(searchParam).getQueryResultsList();
+ double distance = res_search.get(0).get(0).getDistance();
+ if (tableName.startsWith("L2")) {
+ Assert.assertEquals(distance, 0.0, epsilon);
+ }else if (tableName.startsWith("IP")) {
+ Assert.assertEquals(distance, 1.0, epsilon);
+ }
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_invalid_n_probe(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ int n_probe_new = 0;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe_new).withTopK(top_k).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_invalid_top_k(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ int top_k_new = 0;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k_new).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+// @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+// public void test_search_invalid_query_vectors(MilvusClient client, String tableName) throws InterruptedException {
+// IndexType indexType = IndexType.IVF_SQ8;
+// int nq = 5;
+// List> vectors = gen_vectors(nb, false);
+// InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+// client.insert(insertParam);
+// Index index = new Index.Builder().withIndexType(indexType)
+// .withNList(n_list)
+// .build();
+// CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+// client.createIndex(createIndexParam);
+// TableParam tableParam = new TableParam.Builder(tableName).build();
+// SearchParam searchParam = new SearchParam.Builder(tableName, null).withNProbe(n_probe).withTopK(top_k).build();
+// SearchResponse res_search = client.search(searchParam);
+// assert (!res_search.getResponse().ok());
+// }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_range(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ List> res = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res.size(), nq);
+ Assert.assertEquals(res.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_range(MilvusClient client, String tableName) throws InterruptedException {
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(-1), getDeltaDate(1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ List> res = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res.size(), nq);
+ Assert.assertEquals(res.get(0).size(), top_k);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ List> res = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res.size(), 0);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_range_no_result(MilvusClient client, String tableName) throws InterruptedException {
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(-3), getDeltaDate(-1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (res_search.getResponse().ok());
+ List> res = client.search(searchParam).getQueryResultsList();
+ Assert.assertEquals(res.size(), 0);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_index_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
+ IndexType indexType = IndexType.IVF_SQ8;
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Index index = new Index.Builder().withIndexType(indexType)
+ .withNList(n_list)
+ .build();
+ CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build();
+ client.createIndex(createIndexParam);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_search_range_invalid(MilvusClient client, String tableName) throws InterruptedException {
+ int nq = 5;
+ List> vectors = gen_vectors(nb, false);
+ List> queryVectors = vectors.subList(0,nq);
+ List dateRange = new ArrayList<>();
+ dateRange.add(new DateRange(getDeltaDate(2), getDeltaDate(-1)));
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();
+ client.insert(insertParam);
+ Thread.sleep(1000);
+ SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).withDateRanges(dateRange).build();
+ SearchResponse res_search = client.search(searchParam);
+ assert (!res_search.getResponse().ok());
+ }
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestTable.java b/tests/milvus-java-test/src/main/java/com/TestTable.java
new file mode 100644
index 0000000000..12b80c654b
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestTable.java
@@ -0,0 +1,155 @@
+package com;
+
+
+import io.milvus.client.*;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.List;
+
+public class TestTable {
+ int index_file_size = 50;
+ int dimension = 128;
+
+ @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
+ public void test_create_table(MilvusClient client, String tableName){
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ Response res = client.createTable(tableSchemaParam);
+ assert(res.ok());
+ Assert.assertEquals(res.ok(), true);
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_create_table_disconnect(MilvusClient client, String tableName){
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ Response res = client.createTable(tableSchemaParam);
+ assert(!res.ok());
+ }
+
+ @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
+ public void test_create_table_repeatably(MilvusClient client, String tableName){
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ Response res = client.createTable(tableSchemaParam);
+ Assert.assertEquals(res.ok(), true);
+ Response res_new = client.createTable(tableSchemaParam);
+ Assert.assertEquals(res_new.ok(), false);
+ }
+
+ @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
+ public void test_create_table_wrong_params(MilvusClient client, String tableName){
+ Integer dimension = 0;
+ TableSchema tableSchema = new TableSchema.Builder(tableName, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ Response res = client.createTable(tableSchemaParam);
+ System.out.println(res.toString());
+ Assert.assertEquals(res.ok(), false);
+ }
+
+ @Test(dataProvider = "ConnectInstance", dataProviderClass = MainClass.class)
+ public void test_show_tables(MilvusClient client, String tableName){
+ Integer tableNum = 10;
+ ShowTablesResponse res = null;
+ for (int i = 0; i < tableNum; ++i) {
+ String tableNameNew = tableName+"_"+Integer.toString(i);
+ TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ client.createTable(tableSchemaParam);
+ List tableNames = client.showTables().getTableNames();
+ Assert.assertTrue(tableNames.contains(tableNameNew));
+ }
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_show_tables_without_connect(MilvusClient client, String tableName){
+ ShowTablesResponse res = client.showTables();
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_table(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res = client.dropTable(tableParam);
+ assert(res.ok());
+ Thread.currentThread().sleep(1000);
+ List tableNames = client.showTables().getTableNames();
+ Assert.assertFalse(tableNames.contains(tableName));
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_drop_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName+"_").build();
+ Response res = client.dropTable(tableParam);
+ assert(!res.ok());
+ List tableNames = client.showTables().getTableNames();
+ Assert.assertTrue(tableNames.contains(tableName));
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_drop_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Response res = client.dropTable(tableParam);
+ assert(!res.ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_describe_table(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeTableResponse res = client.describeTable(tableParam);
+ assert(res.getResponse().ok());
+ TableSchema tableSchema = res.getTableSchema().get();
+ Assert.assertEquals(tableSchema.getDimension(), dimension);
+ Assert.assertEquals(tableSchema.getTableName(), tableName);
+ Assert.assertEquals(tableSchema.getIndexFileSize(), index_file_size);
+ Assert.assertEquals(tableSchema.getMetricType().name(), tableName.substring(0,2));
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_describe_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ DescribeTableResponse res = client.describeTable(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_has_table_not_existed(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName+"_").build();
+ HasTableResponse res = client.hasTable(tableParam);
+ assert(res.getResponse().ok());
+ Assert.assertFalse(res.hasTable());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_has_table_without_connect(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ HasTableResponse res = client.hasTable(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_has_table(MilvusClient client, String tableName) throws InterruptedException {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ HasTableResponse res = client.hasTable(tableParam);
+ assert(res.getResponse().ok());
+ Assert.assertTrue(res.hasTable());
+ }
+
+
+}
diff --git a/tests/milvus-java-test/src/main/java/com/TestTableCount.java b/tests/milvus-java-test/src/main/java/com/TestTableCount.java
new file mode 100644
index 0000000000..afb16c471a
--- /dev/null
+++ b/tests/milvus-java-test/src/main/java/com/TestTableCount.java
@@ -0,0 +1,89 @@
+package com;
+
+import io.milvus.client.*;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+public class TestTableCount {
+ int index_file_size = 50;
+ int dimension = 128;
+
+ public List> gen_vectors(Integer nb) {
+ List> xb = new ArrayList<>();
+ Random random = new Random();
+ for (int i = 0; i < nb; ++i) {
+ ArrayList vector = new ArrayList<>();
+ for (int j = 0; j < dimension; j++) {
+ vector.add(random.nextFloat());
+ }
+ xb.add(vector);
+ }
+ return xb;
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_table_count_no_vectors(MilvusClient client, String tableName) {
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), 0);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_table_count_table_not_existed(MilvusClient client, String tableName) {
+ TableParam tableParam = new TableParam.Builder(tableName+"_").build();
+ GetTableRowCountResponse res = client.getTableRowCount(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "DisConnectInstance", dataProviderClass = MainClass.class)
+ public void test_table_count_without_connect(MilvusClient client, String tableName) {
+ TableParam tableParam = new TableParam.Builder(tableName+"_").build();
+ GetTableRowCountResponse res = client.getTableRowCount(tableParam);
+ assert(!res.getResponse().ok());
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_table_count(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ // Add vectors
+ InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build();;
+ client.insert(insertParam);
+ Thread.currentThread().sleep(1000);
+ TableParam tableParam = new TableParam.Builder(tableName).build();
+ Assert.assertEquals(client.getTableRowCount(tableParam).getTableRowCount(), nb);
+ }
+
+ @Test(dataProvider = "Table", dataProviderClass = MainClass.class)
+ public void test_table_count_multi_tables(MilvusClient client, String tableName) throws InterruptedException {
+ int nb = 10000;
+ List> vectors = gen_vectors(nb);
+ Integer tableNum = 10;
+ GetTableRowCountResponse res = null;
+ for (int i = 0; i < tableNum; ++i) {
+ String tableNameNew = tableName + "_" + Integer.toString(i);
+ TableSchema tableSchema = new TableSchema.Builder(tableNameNew, dimension)
+ .withIndexFileSize(index_file_size)
+ .withMetricType(MetricType.L2)
+ .build();
+ TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build();
+ client.createTable(tableSchemaParam);
+ // Add vectors
+ InsertParam insertParam = new InsertParam.Builder(tableNameNew, vectors).build();
+ client.insert(insertParam);
+ }
+ Thread.currentThread().sleep(1000);
+ for (int i = 0; i < tableNum; ++i) {
+ String tableNameNew = tableName + "_" + Integer.toString(i);
+ TableParam tableParam = new TableParam.Builder(tableNameNew).build();
+ res = client.getTableRowCount(tableParam);
+ Assert.assertEquals(res.getTableRowCount(), nb);
+ }
+ }
+
+}
+
+
diff --git a/tests/milvus-java-test/testng.xml b/tests/milvus-java-test/testng.xml
new file mode 100644
index 0000000000..19520f7eab
--- /dev/null
+++ b/tests/milvus-java-test/testng.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/milvus_ann_acc/.gitignore b/tests/milvus_ann_acc/.gitignore
new file mode 100644
index 0000000000..f250cab9fe
--- /dev/null
+++ b/tests/milvus_ann_acc/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+logs/
diff --git a/tests/milvus_ann_acc/client.py b/tests/milvus_ann_acc/client.py
new file mode 100644
index 0000000000..de4ef17cb6
--- /dev/null
+++ b/tests/milvus_ann_acc/client.py
@@ -0,0 +1,149 @@
+import pdb
+import random
+import logging
+import json
+import time, datetime
+from multiprocessing import Process
+import numpy
+import sklearn.preprocessing
+from milvus import Milvus, IndexType, MetricType
+
+logger = logging.getLogger("milvus_ann_acc.client")
+
+SERVER_HOST_DEFAULT = "127.0.0.1"
+SERVER_PORT_DEFAULT = 19530
+
+
+def time_wrapper(func):
+ """
+ This decorator prints the execution time for the decorated function.
+ """
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
+ return result
+ return wrapper
+
+
+class MilvusClient(object):
+ def __init__(self, table_name=None, ip=None, port=None):
+ self._milvus = Milvus()
+ self._table_name = table_name
+ try:
+ if not ip:
+ self._milvus.connect(
+ host = SERVER_HOST_DEFAULT,
+ port = SERVER_PORT_DEFAULT)
+ else:
+ self._milvus.connect(
+ host = ip,
+ port = port)
+ except Exception as e:
+ raise e
+
+ def __str__(self):
+ return 'Milvus table %s' % self._table_name
+
+ def check_status(self, status):
+ if not status.OK():
+ logger.error(status.message)
+ raise Exception("Status not ok")
+
+ def create_table(self, table_name, dimension, index_file_size, metric_type):
+ if not self._table_name:
+ self._table_name = table_name
+ if metric_type == "l2":
+ metric_type = MetricType.L2
+ elif metric_type == "ip":
+ metric_type = MetricType.IP
+ else:
+ logger.error("Not supported metric_type: %s" % metric_type)
+ self._metric_type = metric_type
+ create_param = {'table_name': table_name,
+ 'dimension': dimension,
+ 'index_file_size': index_file_size,
+ "metric_type": metric_type}
+ status = self._milvus.create_table(create_param)
+ self.check_status(status)
+
+ @time_wrapper
+ def insert(self, X, ids):
+ if self._metric_type == MetricType.IP:
+ logger.info("Set normalize for metric_type: Inner Product")
+ X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
+ X = X.astype(numpy.float32)
+ status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids)
+ self.check_status(status)
+ return status, result
+
+ @time_wrapper
+ def create_index(self, index_type, nlist):
+ if index_type == "flat":
+ index_type = IndexType.FLAT
+ elif index_type == "ivf_flat":
+ index_type = IndexType.IVFLAT
+ elif index_type == "ivf_sq8":
+ index_type = IndexType.IVF_SQ8
+ elif index_type == "ivf_sq8h":
+ index_type = IndexType.IVF_SQ8H
+ elif index_type == "mix_nsg":
+ index_type = IndexType.MIX_NSG
+ index_params = {
+ "index_type": index_type,
+ "nlist": nlist,
+ }
+ logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
+ status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
+ self.check_status(status)
+
+ def describe_index(self):
+ return self._milvus.describe_index(self._table_name)
+
+ def drop_index(self):
+ logger.info("Drop index: %s" % self._table_name)
+ return self._milvus.drop_index(self._table_name)
+
+ @time_wrapper
+ def query(self, X, top_k, nprobe):
+ if self._metric_type == MetricType.IP:
+ logger.info("Set normalize for metric_type: Inner Product")
+ X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
+ X = X.astype(numpy.float32)
+ status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist())
+ self.check_status(status)
+ # logger.info(results[0])
+ ids = []
+ for result in results:
+ tmp_ids = []
+ for item in result:
+ tmp_ids.append(item.id)
+ ids.append(tmp_ids)
+ return ids
+
+ def count(self):
+ return self._milvus.get_table_row_count(self._table_name)[1]
+
+ def delete(self, timeout=60):
+ logger.info("Start delete table: %s" % self._table_name)
+ self._milvus.delete_table(self._table_name)
+ i = 0
+ while i < timeout:
+ if self.count():
+ time.sleep(1)
+ i = i + 1
+ else:
+ break
+ if i >= timeout:
+ logger.error("Delete table timeout")
+
+ def describe(self):
+ return self._milvus.describe_table(self._table_name)
+
+ def exists_table(self):
+ return self._milvus.has_table(self._table_name)
+
+ @time_wrapper
+ def preload_table(self):
+ return self._milvus.preload_table(self._table_name)
diff --git a/tests/milvus_ann_acc/config.yaml b/tests/milvus_ann_acc/config.yaml
new file mode 100644
index 0000000000..e2ac2c1bfb
--- /dev/null
+++ b/tests/milvus_ann_acc/config.yaml
@@ -0,0 +1,17 @@
+datasets:
+ sift-128-euclidean:
+ cpu_cache_size: 16
+ gpu_cache_size: 5
+ index_file_size: [1024]
+ nytimes-16-angular:
+ cpu_cache_size: 16
+ gpu_cache_size: 5
+ index_file_size: [1024]
+
+index:
+ index_types: ['flat', 'ivf_flat', 'ivf_sq8']
+ nlists: [8092, 16384]
+
+search:
+ nprobes: [1, 8, 32]
+ top_ks: [10]
diff --git a/tests/milvus_ann_acc/main.py b/tests/milvus_ann_acc/main.py
new file mode 100644
index 0000000000..308e8246c7
--- /dev/null
+++ b/tests/milvus_ann_acc/main.py
@@ -0,0 +1,26 @@
+
+import argparse
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--dataset',
+ metavar='NAME',
+ help='the dataset to load training points from',
+ default='glove-100-angular',
+ choices=DATASETS.keys())
+ parser.add_argument(
+ "-k", "--count",
+ default=10,
+ type=positive_int,
+ help="the number of near neighbours to search for")
+ parser.add_argument(
+ '--definitions',
+ metavar='FILE',
+ help='load algorithm definitions from FILE',
+ default='algos.yaml')
+ parser.add_argument(
+ '--image-tag',
+ default=None,
+ help='pull image first')
\ No newline at end of file
diff --git a/tests/milvus_ann_acc/test.py b/tests/milvus_ann_acc/test.py
new file mode 100644
index 0000000000..c4fbc33195
--- /dev/null
+++ b/tests/milvus_ann_acc/test.py
@@ -0,0 +1,132 @@
+import os
+import pdb
+import time
+import random
+import sys
+import h5py
+import numpy
+import logging
+from logging import handlers
+
+from client import MilvusClient
+
+LOG_FOLDER = "logs"
+logger = logging.getLogger("milvus_ann_acc")
+
+formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
+if not os.path.exists(LOG_FOLDER):
+ os.system('mkdir -p %s' % LOG_FOLDER)
+fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'acc'), "D", 1, 10)
+fileTimeHandler.suffix = "%Y%m%d.log"
+fileTimeHandler.setFormatter(formatter)
+logging.basicConfig(level=logging.DEBUG)
+fileTimeHandler.setFormatter(formatter)
+logger.addHandler(fileTimeHandler)
+
+
+def get_dataset_fn(dataset_name):
+ file_path = "/test/milvus/ann_hdf5/"
+ if not os.path.exists(file_path):
+ raise Exception("%s not exists" % file_path)
+ return os.path.join(file_path, '%s.hdf5' % dataset_name)
+
+
+def get_dataset(dataset_name):
+ hdf5_fn = get_dataset_fn(dataset_name)
+ hdf5_f = h5py.File(hdf5_fn)
+ return hdf5_f
+
+
+def parse_dataset_name(dataset_name):
+ data_type = dataset_name.split("-")[0]
+ dimension = int(dataset_name.split("-")[1])
+ metric = dataset_name.split("-")[-1]
+ # metric = dataset.attrs['distance']
+ # dimension = len(dataset["train"][0])
+ if metric == "euclidean":
+ metric_type = "l2"
+ elif metric == "angular":
+ metric_type = "ip"
+ return ("ann"+data_type, dimension, metric_type)
+
+
+def get_table_name(dataset_name, index_file_size):
+ data_type, dimension, metric_type = parse_dataset_name(dataset_name)
+ dataset = get_dataset(dataset_name)
+ table_size = len(dataset["train"])
+ table_size = str(table_size // 1000000)+"m"
+ table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
+ return table_name
+
+
+def main(dataset_name, index_file_size, nlist=16384, force=False):
+ top_k = 10
+ nprobes = [32, 128]
+
+ dataset = get_dataset(dataset_name)
+ table_name = get_table_name(dataset_name, index_file_size)
+ m = MilvusClient(table_name)
+ if m.exists_table():
+ if force is True:
+ logger.info("Re-create table: %s" % table_name)
+ m.delete()
+ time.sleep(10)
+ else:
+ logger.info("Table name: %s existed" % table_name)
+ return
+ data_type, dimension, metric_type = parse_dataset_name(dataset_name)
+ m.create_table(table_name, dimension, index_file_size, metric_type)
+ print(m.describe())
+ vectors = numpy.array(dataset["train"])
+ query_vectors = numpy.array(dataset["test"])
+ # m.insert(vectors)
+
+ interval = 100000
+ loops = len(vectors) // interval + 1
+
+ for i in range(loops):
+ start = i*interval
+ end = min((i+1)*interval, len(vectors))
+ tmp_vectors = vectors[start:end]
+ if start < end:
+ m.insert(tmp_vectors, ids=[i for i in range(start, end)])
+
+ time.sleep(60)
+ print(m.count())
+
+ for index_type in ["ivf_flat", "ivf_sq8", "ivf_sq8h"]:
+ m.create_index(index_type, nlist)
+ print(m.describe_index())
+ if m.count() != len(vectors):
+ return
+ m.preload_table()
+ true_ids = numpy.array(dataset["neighbors"])
+ for nprobe in nprobes:
+ print("nprobe: %s" % nprobe)
+ sum_radio = 0.0; avg_radio = 0.0
+ result_ids = m.query(query_vectors, top_k, nprobe)
+ # print(result_ids[:10])
+ for index, result_item in enumerate(result_ids):
+ if len(set(true_ids[index][:top_k])) != len(set(result_item)):
+ logger.info("Error happened")
+ # logger.info(query_vectors[index])
+ # logger.info(true_ids[index][:top_k], result_item)
+ tmp = set(true_ids[index][:top_k]).intersection(set(result_item))
+ sum_radio = sum_radio + (len(tmp) / top_k)
+ avg_radio = round(sum_radio / len(result_ids), 4)
+ logger.info(avg_radio)
+ m.drop_index()
+
+
+if __name__ == "__main__":
+ print("glove-25-angular")
+ # main("sift-128-euclidean", 1024, force=True)
+ for index_file_size in [50, 1024]:
+ print("Index file size: %d" % index_file_size)
+ main("glove-25-angular", index_file_size, force=True)
+
+ print("sift-128-euclidean")
+ for index_file_size in [50, 1024]:
+ print("Index file size: %d" % index_file_size)
+ main("sift-128-euclidean", index_file_size, force=True)
+ # m = MilvusClient()
\ No newline at end of file
diff --git a/tests/milvus_benchmark/.gitignore b/tests/milvus_benchmark/.gitignore
new file mode 100644
index 0000000000..70af07bba8
--- /dev/null
+++ b/tests/milvus_benchmark/.gitignore
@@ -0,0 +1,8 @@
+random_data
+benchmark_logs/
+db/
+logs/
+*idmap*.txt
+__pycache__/
+venv
+.idea
\ No newline at end of file
diff --git a/tests/milvus_benchmark/README.md b/tests/milvus_benchmark/README.md
new file mode 100644
index 0000000000..72abd1264a
--- /dev/null
+++ b/tests/milvus_benchmark/README.md
@@ -0,0 +1,57 @@
+# Quick start
+
+## 运行
+
+### 运行示例:
+
+`python3 main.py --image=registry.zilliz.com/milvus/engine:branch-0.3.1-release --run-count=2 --run-type=performance`
+
+### 运行参数:
+
+--image: 容器模式,传入镜像名称,如传入,则运行测试时,会先进行pull image,基于image生成milvus server容器
+
+--local: 与image参数互斥,本地模式,连接使用本地启动的milvus server进行测试
+
+--run-count: 重复运行次数
+
+--suites: 测试集配置文件,默认使用suites.yaml
+
+--run-type: 测试类型,包括性能--performance、准确性测试--accuracy以及稳定性--stability
+
+### 测试集配置文件:
+
+`operations:
+
+ insert:
+
+ [
+ {"table.index_type": "ivf_flat", "server.index_building_threshold": 300, "table.size": 2000000, "table.ni": 100000, "table.dim": 512},
+ ]
+
+ build: []
+
+ query:
+
+ [
+ {"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 1, "server.use_blas_threshold": 800},
+ {"dataset": "ip_ivfsq8_1000", "top_ks": [10], "nqs": [10, 100], "server.nprobe": 10, "server.use_blas_threshold": 20},
+ ]`
+
+## 测试结果:
+
+性能:
+
+`INFO:milvus_benchmark.runner:Start warm query, query params: top-k: 1, nq: 1
+
+INFO:milvus_benchmark.client:query run in 19.19s
+INFO:milvus_benchmark.runner:Start query, query params: top-k: 64, nq: 10, actually length of vectors: 10
+INFO:milvus_benchmark.runner:Start run query, run 1 of 1
+INFO:milvus_benchmark.client:query run in 0.2s
+INFO:milvus_benchmark.runner:Avarage query time: 0.20
+INFO:milvus_benchmark.runner:[[0.2]]`
+
+**│ 10 │ 0.2 │**
+
+准确率:
+
+`INFO:milvus_benchmark.runner:Avarage accuracy: 1.0`
\ No newline at end of file
diff --git a/tests/milvus_benchmark/__init__.py b/tests/milvus_benchmark/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/milvus_benchmark/client.py b/tests/milvus_benchmark/client.py
new file mode 100644
index 0000000000..4744c10854
--- /dev/null
+++ b/tests/milvus_benchmark/client.py
@@ -0,0 +1,244 @@
+import pdb
+import random
+import logging
+import json
+import sys
+import time, datetime
+from multiprocessing import Process
+from milvus import Milvus, IndexType, MetricType
+
+logger = logging.getLogger("milvus_benchmark.client")
+
+SERVER_HOST_DEFAULT = "127.0.0.1"
+SERVER_PORT_DEFAULT = 19530
+
+
+def time_wrapper(func):
+ """
+ This decorator prints the execution time for the decorated function.
+ """
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ logger.info("Milvus {} run in {}s".format(func.__name__, round(end - start, 2)))
+ return result
+ return wrapper
+
+
+class MilvusClient(object):
+ def __init__(self, table_name=None, ip=None, port=None):
+ self._milvus = Milvus()
+ self._table_name = table_name
+ try:
+ if not ip:
+ self._milvus.connect(
+ host = SERVER_HOST_DEFAULT,
+ port = SERVER_PORT_DEFAULT)
+ else:
+ self._milvus.connect(
+ host = ip,
+ port = port)
+ except Exception as e:
+ raise e
+
+ def __str__(self):
+ return 'Milvus table %s' % self._table_name
+
+ def check_status(self, status):
+ if not status.OK():
+ logger.error(status.message)
+ raise Exception("Status not ok")
+
+ def create_table(self, table_name, dimension, index_file_size, metric_type):
+ if not self._table_name:
+ self._table_name = table_name
+ if metric_type == "l2":
+ metric_type = MetricType.L2
+ elif metric_type == "ip":
+ metric_type = MetricType.IP
+ else:
+ logger.error("Not supported metric_type: %s" % metric_type)
+ create_param = {'table_name': table_name,
+ 'dimension': dimension,
+ 'index_file_size': index_file_size,
+ "metric_type": metric_type}
+ status = self._milvus.create_table(create_param)
+ self.check_status(status)
+
+ @time_wrapper
+ def insert(self, X, ids=None):
+ status, result = self._milvus.add_vectors(self._table_name, X, ids)
+ self.check_status(status)
+ return status, result
+
+ @time_wrapper
+ def create_index(self, index_type, nlist):
+ if index_type == "flat":
+ index_type = IndexType.FLAT
+ elif index_type == "ivf_flat":
+ index_type = IndexType.IVFLAT
+ elif index_type == "ivf_sq8":
+ index_type = IndexType.IVF_SQ8
+ elif index_type == "mix_nsg":
+ index_type = IndexType.MIX_NSG
+ elif index_type == "ivf_sq8h":
+ index_type = IndexType.IVF_SQ8H
+ index_params = {
+ "index_type": index_type,
+ "nlist": nlist,
+ }
+ logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
+ status = self._milvus.create_index(self._table_name, index=index_params, timeout=6*3600)
+ self.check_status(status)
+
+ def describe_index(self):
+ return self._milvus.describe_index(self._table_name)
+
+ def drop_index(self):
+ logger.info("Drop index: %s" % self._table_name)
+ return self._milvus.drop_index(self._table_name)
+
+ @time_wrapper
+ def query(self, X, top_k, nprobe):
+ status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X)
+ self.check_status(status)
+ return status, result
+
+ def count(self):
+ return self._milvus.get_table_row_count(self._table_name)[1]
+
+ def delete(self, timeout=60):
+ logger.info("Start delete table: %s" % self._table_name)
+ self._milvus.delete_table(self._table_name)
+ i = 0
+ while i < timeout:
+ if self.count():
+ time.sleep(1)
+ i = i + 1
+ continue
+ else:
+ break
+ if i < timeout:
+ logger.error("Delete table timeout")
+
+ def describe(self):
+ return self._milvus.describe_table(self._table_name)
+
+ def exists_table(self):
+ return self._milvus.has_table(self._table_name)
+
+ @time_wrapper
+ def preload_table(self):
+ return self._milvus.preload_table(self._table_name, timeout=3000)
+
+
+def fit(table_name, X):
+ milvus = Milvus()
+ milvus.connect(host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT)
+ start = time.time()
+ status, ids = milvus.add_vectors(table_name, X)
+ end = time.time()
+ logger(status, round(end - start, 2))
+
+
+def fit_concurrent(table_name, process_num, vectors):
+ processes = []
+
+ for i in range(process_num):
+ p = Process(target=fit, args=(table_name, vectors, ))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+
+if __name__ == "__main__":
+
+ # table_name = "sift_2m_20_128_l2"
+ table_name = "test_tset1"
+ m = MilvusClient(table_name)
+ # m.create_table(table_name, 128, 50, "l2")
+
+ print(m.describe())
+ # print(m.count())
+ # print(m.describe_index())
+ insert_vectors = [[random.random() for _ in range(128)] for _ in range(10000)]
+ for i in range(5):
+ m.insert(insert_vectors)
+ print(m.create_index("ivf_sq8h", 16384))
+ X = [insert_vectors[0]]
+ top_k = 10
+ nprobe = 10
+ print(m.query(X, top_k, nprobe))
+
+ # # # print(m.drop_index())
+ # # print(m.describe_index())
+ # # sys.exit()
+ # # # insert_vectors = [[random.random() for _ in range(128)] for _ in range(100000)]
+ # # # for i in range(100):
+ # # # m.insert(insert_vectors)
+ # # # time.sleep(5)
+ # # # print(m.describe_index())
+ # # # print(m.drop_index())
+ # # m.create_index("ivf_sq8h", 16384)
+ # print(m.count())
+ # print(m.describe_index())
+
+
+
+ # sys.exit()
+ # print(m.create_index("ivf_sq8h", 16384))
+ # print(m.count())
+ # print(m.describe_index())
+ import numpy as np
+
+ def mmap_fvecs(fname):
+ x = np.memmap(fname, dtype='int32', mode='r')
+ d = x[0]
+ return x.view('float32').reshape(-1, d + 1)[:, 1:]
+
+ print(mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs"))
+ # SIFT_SRC_QUERY_DATA_DIR = '/poc/yuncong/ann_1000m'
+ # file_name = SIFT_SRC_QUERY_DATA_DIR+'/'+'query.npy'
+ # data = numpy.load(file_name)
+ # query_vectors = data[0:2].tolist()
+ # print(len(query_vectors))
+ # results = m.query(query_vectors, 10, 10)
+ # result_ids = []
+ # for result in results[1]:
+ # tmp = []
+ # for item in result:
+ # tmp.append(item.id)
+ # result_ids.append(tmp)
+ # print(result_ids[0][:10])
+ # # gt
+ # file_name = SIFT_SRC_QUERY_DATA_DIR+"/gnd/"+"idx_1M.ivecs"
+ # a = numpy.fromfile(file_name, dtype='int32')
+ # d = a[0]
+ # true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
+ # print(true_ids[:3, :2])
+
+ # print(len(true_ids[0]))
+ # import numpy as np
+ # import sklearn.preprocessing
+
+ # def mmap_fvecs(fname):
+ # x = np.memmap(fname, dtype='int32', mode='r')
+ # d = x[0]
+ # return x.view('float32').reshape(-1, d + 1)[:, 1:]
+
+ # data = mmap_fvecs("/poc/deep1b/deep1B_queries.fvecs")
+ # print(data[0], len(data[0]), len(data))
+
+ # total_size = 10000
+ # # total_size = 1000000000
+ # file_size = 1000
+ # # file_size = 100000
+ # file_num = total_size // file_size
+ # for i in range(file_num):
+ # fname = "/test/milvus/raw_data/deep1b/binary_96_%05d" % i
+ # print(fname, i*file_size, (i+1)*file_size)
+ # single_data = data[i*file_size : (i+1)*file_size]
+ # single_data = sklearn.preprocessing.normalize(single_data, axis=1, norm='l2')
+ # np.save(fname, single_data)
diff --git a/tests/milvus_benchmark/conf/log_config.conf b/tests/milvus_benchmark/conf/log_config.conf
new file mode 100644
index 0000000000..c9c14d57f4
--- /dev/null
+++ b/tests/milvus_benchmark/conf/log_config.conf
@@ -0,0 +1,28 @@
+* GLOBAL:
+ FORMAT = "%datetime | %level | %logger | %msg"
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
+ ENABLED = true
+ TO_FILE = true
+ TO_STANDARD_OUTPUT = false
+ SUBSECOND_PRECISION = 3
+ PERFORMANCE_TRACKING = false
+ MAX_LOG_FILE_SIZE = 2097152 ## Throw log files away after 2MB
+* DEBUG:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
+ ENABLED = true
+* WARNING:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
+* TRACE:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
+* VERBOSE:
+ FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
+ TO_FILE = false
+ TO_STANDARD_OUTPUT = false
+## Error logs
+* ERROR:
+ ENABLED = true
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
+* FATAL:
+ ENABLED = true
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"
+
diff --git a/tests/milvus_benchmark/conf/server_config.yaml b/tests/milvus_benchmark/conf/server_config.yaml
new file mode 100644
index 0000000000..a5d2081b8a
--- /dev/null
+++ b/tests/milvus_benchmark/conf/server_config.yaml
@@ -0,0 +1,28 @@
+cache_config:
+ cache_insert_data: false
+ cpu_cache_capacity: 16
+ gpu_cache_capacity: 6
+ cpu_cache_threshold: 0.85
+db_config:
+ backend_url: sqlite://:@:/
+ build_index_gpu: 0
+ insert_buffer_size: 4
+ preload_table: null
+ primary_path: /opt/milvus
+ secondary_path: null
+engine_config:
+ use_blas_threshold: 20
+metric_config:
+ collector: prometheus
+ enable_monitor: true
+ prometheus_config:
+ port: 8080
+resource_config:
+ resource_pool:
+ - cpu
+ - gpu0
+server_config:
+ address: 0.0.0.0
+ deploy_mode: single
+ port: 19530
+ time_zone: UTC+8
diff --git a/tests/milvus_benchmark/conf/server_config.yaml.cpu b/tests/milvus_benchmark/conf/server_config.yaml.cpu
new file mode 100644
index 0000000000..95ab5f5343
--- /dev/null
+++ b/tests/milvus_benchmark/conf/server_config.yaml.cpu
@@ -0,0 +1,31 @@
+server_config:
+ address: 0.0.0.0
+ port: 19530
+ deploy_mode: single
+ time_zone: UTC+8
+
+db_config:
+ primary_path: /opt/milvus
+ secondary_path:
+ backend_url: sqlite://:@:/
+ insert_buffer_size: 4
+ build_index_gpu: 0
+ preload_table:
+
+metric_config:
+ enable_monitor: false
+ collector: prometheus
+ prometheus_config:
+ port: 8080
+
+cache_config:
+ cpu_cache_capacity: 16
+ cpu_cache_threshold: 0.85
+ cache_insert_data: false
+
+engine_config:
+ use_blas_threshold: 20
+
+resource_config:
+ resource_pool:
+ - cpu
\ No newline at end of file
diff --git a/tests/milvus_benchmark/conf/server_config.yaml.multi b/tests/milvus_benchmark/conf/server_config.yaml.multi
new file mode 100644
index 0000000000..002d3bd2a6
--- /dev/null
+++ b/tests/milvus_benchmark/conf/server_config.yaml.multi
@@ -0,0 +1,33 @@
+server_config:
+ address: 0.0.0.0
+ port: 19530
+ deploy_mode: single
+ time_zone: UTC+8
+
+db_config:
+ primary_path: /opt/milvus
+ secondary_path:
+ backend_url: sqlite://:@:/
+ insert_buffer_size: 4
+ build_index_gpu: 0
+ preload_table:
+
+metric_config:
+ enable_monitor: false
+ collector: prometheus
+ prometheus_config:
+ port: 8080
+
+cache_config:
+ cpu_cache_capacity: 16
+ cpu_cache_threshold: 0.85
+ cache_insert_data: false
+
+engine_config:
+ use_blas_threshold: 20
+
+resource_config:
+ resource_pool:
+ - cpu
+ - gpu0
+ - gpu1
\ No newline at end of file
diff --git a/tests/milvus_benchmark/conf/server_config.yaml.single b/tests/milvus_benchmark/conf/server_config.yaml.single
new file mode 100644
index 0000000000..033d8868d1
--- /dev/null
+++ b/tests/milvus_benchmark/conf/server_config.yaml.single
@@ -0,0 +1,32 @@
+server_config:
+ address: 0.0.0.0
+ port: 19530
+ deploy_mode: single
+ time_zone: UTC+8
+
+db_config:
+ primary_path: /opt/milvus
+ secondary_path:
+ backend_url: sqlite://:@:/
+ insert_buffer_size: 4
+ build_index_gpu: 0
+ preload_table:
+
+metric_config:
+ enable_monitor: false
+ collector: prometheus
+ prometheus_config:
+ port: 8080
+
+cache_config:
+ cpu_cache_capacity: 16
+ cpu_cache_threshold: 0.85
+ cache_insert_data: false
+
+engine_config:
+ use_blas_threshold: 20
+
+resource_config:
+ resource_pool:
+ - cpu
+ - gpu0
\ No newline at end of file
diff --git a/tests/milvus_benchmark/demo.py b/tests/milvus_benchmark/demo.py
new file mode 100644
index 0000000000..27152e0980
--- /dev/null
+++ b/tests/milvus_benchmark/demo.py
@@ -0,0 +1,51 @@
+import os
+import logging
+import pdb
+import time
+import random
+from multiprocessing import Process
+import numpy as np
+from client import MilvusClient
+
+nq = 100000
+dimension = 128
+run_count = 1
+table_name = "sift_10m_1024_128_ip"
+insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
+
+def do_query(milvus, table_name, top_ks, nqs, nprobe, run_count):
+ bi_res = []
+ for index, nq in enumerate(nqs):
+ tmp_res = []
+ for top_k in top_ks:
+ avg_query_time = 0.0
+ total_query_time = 0.0
+ vectors = insert_vectors[0:nq]
+ for i in range(run_count):
+ start_time = time.time()
+ status, query_res = milvus.query(vectors, top_k, nprobe)
+ total_query_time = total_query_time + (time.time() - start_time)
+ if status.code:
+ print(status.message)
+ avg_query_time = round(total_query_time / run_count, 2)
+ tmp_res.append(avg_query_time)
+ bi_res.append(tmp_res)
+ return bi_res
+
+while 1:
+ milvus_instance = MilvusClient(table_name, ip="192.168.1.197", port=19530)
+ top_ks = random.sample([x for x in range(1, 100)], 4)
+ nqs = random.sample([x for x in range(1, 1000)], 3)
+ nprobe = random.choice([x for x in range(1, 500)])
+ res = do_query(milvus_instance, table_name, top_ks, nqs, nprobe, run_count)
+ status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
+ if not status.OK():
+ logger.error(status.message)
+
+ # status = milvus_instance.drop_index()
+ if not status.OK():
+ print(status.message)
+ index_type = "ivf_sq8"
+ status = milvus_instance.create_index(index_type, 16384)
+ if not status.OK():
+ print(status.message)
\ No newline at end of file
diff --git a/tests/milvus_benchmark/docker_runner.py b/tests/milvus_benchmark/docker_runner.py
new file mode 100644
index 0000000000..008c9866b4
--- /dev/null
+++ b/tests/milvus_benchmark/docker_runner.py
@@ -0,0 +1,261 @@
+import os
+import logging
+import pdb
+import time
+import random
+from multiprocessing import Process
+import numpy as np
+from client import MilvusClient
+import utils
+import parser
+from runner import Runner
+
+logger = logging.getLogger("milvus_benchmark.docker")
+
+
+class DockerRunner(Runner):
+ """run docker mode"""
+ def __init__(self, image):
+ super(DockerRunner, self).__init__()
+ self.image = image
+
+ def run(self, definition, run_type=None):
+ if run_type == "performance":
+ for op_type, op_value in definition.items():
+ # run docker mode
+ run_count = op_value["run_count"]
+ run_params = op_value["params"]
+ container = None
+
+ if op_type == "insert":
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["table_name"]
+ volume_name = param["db_path_prefix"]
+ print(table_name)
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ for k, v in param.items():
+ if k.startswith("server."):
+ # Update server config
+ utils.modify_config(k, v, type="server", db_slave=None)
+ container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
+ time.sleep(2)
+ milvus = MilvusClient(table_name)
+ # Check has table or not
+ if milvus.exists_table():
+ milvus.delete()
+ time.sleep(10)
+ milvus.create_table(table_name, dimension, index_file_size, metric_type)
+ res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
+ logger.info(res)
+
+ # wait for file merge
+ time.sleep(6 * (table_size / 500000))
+ # Clear up
+ utils.remove_container(container)
+
+ elif op_type == "query":
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["dataset"]
+ volume_name = param["db_path_prefix"]
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ for k, v in param.items():
+ if k.startswith("server."):
+ utils.modify_config(k, v, type="server")
+ container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
+ time.sleep(2)
+ milvus = MilvusClient(table_name)
+ logger.debug(milvus._milvus.show_tables())
+ # Check has table or not
+ if not milvus.exists_table():
+ logger.warning("Table %s not existed, continue exec next params ..." % table_name)
+ continue
+ # parse index info
+ index_types = param["index.index_types"]
+ nlists = param["index.nlists"]
+ # parse top-k, nq, nprobe
+ top_ks, nqs, nprobes = parser.search_params_parser(param)
+ for index_type in index_types:
+ for nlist in nlists:
+ result = milvus.describe_index()
+ logger.info(result)
+ milvus.create_index(index_type, nlist)
+ result = milvus.describe_index()
+ logger.info(result)
+ # preload index
+ milvus.preload_table()
+ logger.info("Start warm up query")
+ res = self.do_query(milvus, table_name, [1], [1], 1, 1)
+ logger.info("End warm up query")
+ # Run query test
+ for nprobe in nprobes:
+ logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
+ res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
+ headers = ["Nprobe/Top-k"]
+ headers.extend([str(top_k) for top_k in top_ks])
+ utils.print_table(headers, nqs, res)
+ utils.remove_container(container)
+
+ elif run_type == "accuracy":
+ """
+ {
+ "dataset": "random_50m_1024_512",
+ "index.index_types": ["flat", ivf_flat", "ivf_sq8"],
+ "index.nlists": [16384],
+ "nprobes": [1, 32, 128],
+ "nqs": [100],
+ "top_ks": [1, 64],
+ "server.use_blas_threshold": 1100,
+ "server.cpu_cache_capacity": 256
+ }
+ """
+ for op_type, op_value in definition.items():
+ if op_type != "query":
+ logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
+ break
+ run_count = op_value["run_count"]
+ run_params = op_value["params"]
+ container = None
+
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["dataset"]
+ sift_acc = False
+ if "sift_acc" in param:
+ sift_acc = param["sift_acc"]
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ for k, v in param.items():
+ if k.startswith("server."):
+ utils.modify_config(k, v, type="server")
+ volume_name = param["db_path_prefix"]
+ container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
+ time.sleep(2)
+ milvus = MilvusClient(table_name)
+ # Check has table or not
+ if not milvus.exists_table():
+ logger.warning("Table %s not existed, continue exec next params ..." % table_name)
+ continue
+
+ # parse index info
+ index_types = param["index.index_types"]
+ nlists = param["index.nlists"]
+ # parse top-k, nq, nprobe
+ top_ks, nqs, nprobes = parser.search_params_parser(param)
+
+ if sift_acc is True:
+ # preload groundtruth data
+ true_ids_all = self.get_groundtruth_ids(table_size)
+
+ acc_dict = {}
+ for index_type in index_types:
+ for nlist in nlists:
+ result = milvus.describe_index()
+ logger.info(result)
+ milvus.create_index(index_type, nlist)
+ # preload index
+ milvus.preload_table()
+ # Run query test
+ for nprobe in nprobes:
+ logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
+ for top_k in top_ks:
+ for nq in nqs:
+ result_ids = []
+ id_prefix = "%s_index_%s_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
+ (table_name, index_type, nlist, metric_type, nprobe, top_k, nq)
+ if sift_acc is False:
+ self.do_query_acc(milvus, table_name, top_k, nq, nprobe, id_prefix)
+ if index_type != "flat":
+ # Compute accuracy
+ base_name = "%s_index_flat_nlist_%s_metric_type_%s_nprobe_%s_top_k_%s_nq_%s" % \
+ (table_name, nlist, metric_type, nprobe, top_k, nq)
+ avg_acc = self.compute_accuracy(base_name, id_prefix)
+ logger.info("Query: <%s> accuracy: %s" % (id_prefix, avg_acc))
+ else:
+ result_ids = self.do_query_ids(milvus, table_name, top_k, nq, nprobe)
+ acc_value = self.get_recall_value(true_ids_all[:nq, :top_k].tolist(), result_ids)
+ logger.info("Query: <%s> accuracy: %s" % (id_prefix, acc_value))
+ # # print accuracy table
+ # headers = [table_name]
+ # headers.extend([str(top_k) for top_k in top_ks])
+ # utils.print_table(headers, nqs, res)
+
+ # remove container, and run next definition
+ logger.info("remove container, and run next definition")
+ utils.remove_container(container)
+
+ elif run_type == "stability":
+ for op_type, op_value in definition.items():
+ if op_type != "query":
+ logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
+ break
+ run_count = op_value["run_count"]
+ run_params = op_value["params"]
+ container = None
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["dataset"]
+ volume_name = param["db_path_prefix"]
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+
+ # set default test time
+ if "during_time" not in param:
+ during_time = 100 # seconds
+ else:
+ during_time = int(param["during_time"]) * 60
+ # set default query process num
+ if "query_process_num" not in param:
+ query_process_num = 10
+ else:
+ query_process_num = int(param["query_process_num"])
+
+ for k, v in param.items():
+ if k.startswith("server."):
+ utils.modify_config(k, v, type="server")
+
+ container = utils.run_server(self.image, test_type="remote", volume_name=volume_name, db_slave=None)
+ time.sleep(2)
+ milvus = MilvusClient(table_name)
+ # Check has table or not
+ if not milvus.exists_table():
+ logger.warning("Table %s not existed, continue exec next params ..." % table_name)
+ continue
+
+ start_time = time.time()
+ insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(10000)]
+ while time.time() < start_time + during_time:
+ processes = []
+ # do query
+ # for i in range(query_process_num):
+ # milvus_instance = MilvusClient(table_name)
+ # top_k = random.choice([x for x in range(1, 100)])
+ # nq = random.choice([x for x in range(1, 100)])
+ # nprobe = random.choice([x for x in range(1, 1000)])
+ # # logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
+ # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], [nprobe], run_count, ))
+ # processes.append(p)
+ # p.start()
+ # time.sleep(0.1)
+ # for p in processes:
+ # p.join()
+ milvus_instance = MilvusClient(table_name)
+ top_ks = random.sample([x for x in range(1, 100)], 3)
+ nqs = random.sample([x for x in range(1, 1000)], 3)
+ nprobe = random.choice([x for x in range(1, 500)])
+ res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
+ if int(time.time() - start_time) % 120 == 0:
+ status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
+ if not status.OK():
+ logger.error(status)
+ # status = milvus_instance.drop_index()
+ # if not status.OK():
+ # logger.error(status)
+ # index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
+ result = milvus.describe_index()
+ logger.info(result)
+ milvus_instance.create_index("ivf_sq8", 16384)
+ utils.remove_container(container)
+
+ else:
+ logger.warning("Run type: %s not supported" % run_type)
+
diff --git a/tests/milvus_benchmark/local_runner.py b/tests/milvus_benchmark/local_runner.py
new file mode 100644
index 0000000000..1067f500bc
--- /dev/null
+++ b/tests/milvus_benchmark/local_runner.py
@@ -0,0 +1,132 @@
+import os
+import logging
+import pdb
+import time
+import random
+from multiprocessing import Process
+import numpy as np
+from client import MilvusClient
+import utils
+import parser
+from runner import Runner
+
+logger = logging.getLogger("milvus_benchmark.local_runner")
+
+
+class LocalRunner(Runner):
+ """run local mode"""
+ def __init__(self, ip, port):
+ super(LocalRunner, self).__init__()
+ self.ip = ip
+ self.port = port
+
+ def run(self, definition, run_type=None):
+ if run_type == "performance":
+ for op_type, op_value in definition.items():
+ run_count = op_value["run_count"]
+ run_params = op_value["params"]
+
+ if op_type == "insert":
+ for index, param in enumerate(run_params):
+ table_name = param["table_name"]
+ # random_1m_100_512
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
+ # Check has table or not
+ if milvus.exists_table():
+ milvus.delete()
+ time.sleep(10)
+ milvus.create_table(table_name, dimension, index_file_size, metric_type)
+ res = self.do_insert(milvus, table_name, data_type, dimension, table_size, param["ni_per"])
+ logger.info(res)
+
+ elif op_type == "query":
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["dataset"]
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+
+ milvus = MilvusClient(table_name, ip=self.ip, port=self.port)
+ # parse index info
+ index_types = param["index.index_types"]
+ nlists = param["index.nlists"]
+ # parse top-k, nq, nprobe
+ top_ks, nqs, nprobes = parser.search_params_parser(param)
+
+ for index_type in index_types:
+ for nlist in nlists:
+ milvus.create_index(index_type, nlist)
+ # preload index
+ milvus.preload_table()
+ # Run query test
+ for nprobe in nprobes:
+ logger.info("index_type: %s, nlist: %s, metric_type: %s, nprobe: %s" % (index_type, nlist, metric_type, nprobe))
+ res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
+ headers = [param["dataset"]]
+ headers.extend([str(top_k) for top_k in top_ks])
+ utils.print_table(headers, nqs, res)
+
+ elif run_type == "stability":
+ for op_type, op_value in definition.items():
+ if op_type != "query":
+ logger.warning("invalid operation: %s in accuracy test, only support query operation" % op_type)
+ break
+ run_count = op_value["run_count"]
+ run_params = op_value["params"]
+ nq = 10000
+
+ for index, param in enumerate(run_params):
+ logger.info("Definition param: %s" % str(param))
+ table_name = param["dataset"]
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+
+ # set default test time
+ if "during_time" not in param:
+ during_time = 100 # seconds
+ else:
+ during_time = int(param["during_time"]) * 60
+ # set default query process num
+ if "query_process_num" not in param:
+ query_process_num = 10
+ else:
+ query_process_num = int(param["query_process_num"])
+ milvus = MilvusClient(table_name)
+ # Check has table or not
+ if not milvus.exists_table():
+ logger.warning("Table %s not existed, continue exec next params ..." % table_name)
+ continue
+
+ start_time = time.time()
+ insert_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)]
+ while time.time() < start_time + during_time:
+ processes = []
+ # # do query
+ # for i in range(query_process_num):
+ # milvus_instance = MilvusClient(table_name)
+ # top_k = random.choice([x for x in range(1, 100)])
+ # nq = random.choice([x for x in range(1, 1000)])
+ # nprobe = random.choice([x for x in range(1, 500)])
+ # logger.info(nprobe)
+ # p = Process(target=self.do_query, args=(milvus_instance, table_name, [top_k], [nq], 64, run_count, ))
+ # processes.append(p)
+ # p.start()
+ # time.sleep(0.1)
+ # for p in processes:
+ # p.join()
+ milvus_instance = MilvusClient(table_name)
+ top_ks = random.sample([x for x in range(1, 100)], 4)
+ nqs = random.sample([x for x in range(1, 1000)], 3)
+ nprobe = random.choice([x for x in range(1, 500)])
+ res = self.do_query(milvus, table_name, top_ks, nqs, nprobe, run_count)
+ # milvus_instance = MilvusClient(table_name)
+ status, res = milvus_instance.insert(insert_vectors, ids=[x for x in range(len(insert_vectors))])
+ if not status.OK():
+ logger.error(status.message)
+ if (time.time() - start_time) % 300 == 0:
+ status = milvus_instance.drop_index()
+ if not status.OK():
+ logger.error(status.message)
+ index_type = random.choice(["flat", "ivf_flat", "ivf_sq8"])
+ status = milvus_instance.create_index(index_type, 16384)
+ if not status.OK():
+ logger.error(status.message)
diff --git a/tests/milvus_benchmark/main.py b/tests/milvus_benchmark/main.py
new file mode 100644
index 0000000000..c11237c5e7
--- /dev/null
+++ b/tests/milvus_benchmark/main.py
@@ -0,0 +1,131 @@
+import os
+import sys
+import time
+import pdb
+import argparse
+import logging
+import utils
+from yaml import load, dump
+from logging import handlers
+from parser import operations_parser
+from local_runner import LocalRunner
+from docker_runner import DockerRunner
+
+DEFAULT_IMAGE = "milvusdb/milvus:latest"
+LOG_FOLDER = "benchmark_logs"
+logger = logging.getLogger("milvus_benchmark")
+
+formatter = logging.Formatter('[%(asctime)s] [%(levelname)-4s] [%(pathname)s:%(lineno)d] %(message)s')
+if not os.path.exists(LOG_FOLDER):
+ os.system('mkdir -p %s' % LOG_FOLDER)
+fileTimeHandler = handlers.TimedRotatingFileHandler(os.path.join(LOG_FOLDER, 'milvus_benchmark'), "D", 1, 10)
+fileTimeHandler.suffix = "%Y%m%d.log"
+fileTimeHandler.setFormatter(formatter)
+logging.basicConfig(level=logging.DEBUG)
+fileTimeHandler.setFormatter(formatter)
+logger.addHandler(fileTimeHandler)
+
+
+def positive_int(s):
+ i = None
+ try:
+ i = int(s)
+ except ValueError:
+ pass
+ if not i or i < 1:
+ raise argparse.ArgumentTypeError("%r is not a positive integer" % s)
+ return i
+
+
+# # link random_data if not exists
+# def init_env():
+# if not os.path.islink(BINARY_DATA_FOLDER):
+# try:
+# os.symlink(SRC_BINARY_DATA_FOLDER, BINARY_DATA_FOLDER)
+# except Exception as e:
+# logger.error("Create link failed: %s" % str(e))
+# sys.exit()
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--image',
+ help='use the given image')
+ parser.add_argument(
+ '--local',
+ action='store_true',
+ help='use local milvus server')
+ parser.add_argument(
+ "--run-count",
+ default=1,
+ type=positive_int,
+ help="run each db operation times")
+ # performance / stability / accuracy test
+ parser.add_argument(
+ "--run-type",
+ default="performance",
+ help="run type, default performance")
+ parser.add_argument(
+ '--suites',
+ metavar='FILE',
+ help='load test suites from FILE',
+ default='suites.yaml')
+ parser.add_argument(
+ '--ip',
+ help='server ip param for local mode',
+ default='127.0.0.1')
+ parser.add_argument(
+ '--port',
+ help='server port param for local mode',
+ default='19530')
+
+ args = parser.parse_args()
+
+ operations = None
+ # Get all benchmark test suites
+ if args.suites:
+ with open(args.suites) as f:
+ suites_dict = load(f)
+ f.close()
+ # With definition order
+ operations = operations_parser(suites_dict, run_type=args.run_type)
+
+ # init_env()
+ run_params = {"run_count": args.run_count}
+
+ if args.image:
+ # for docker mode
+ if args.local:
+ logger.error("Local mode and docker mode are incompatible arguments")
+ sys.exit(-1)
+ # Docker pull image
+ if not utils.pull_image(args.image):
+ raise Exception('Image %s pull failed' % image)
+
+ # TODO: Check milvus server port is available
+ logger.info("Init: remove all containers created with image: %s" % args.image)
+ utils.remove_all_containers(args.image)
+ runner = DockerRunner(args.image)
+ for operation_type in operations:
+ logger.info("Start run test, test type: %s" % operation_type)
+ run_params["params"] = operations[operation_type]
+ runner.run({operation_type: run_params}, run_type=args.run_type)
+ logger.info("Run params: %s" % str(run_params))
+
+ if args.local:
+ # for local mode
+ ip = args.ip
+ port = args.port
+
+ runner = LocalRunner(ip, port)
+ for operation_type in operations:
+ logger.info("Start run local mode test, test type: %s" % operation_type)
+ run_params["params"] = operations[operation_type]
+ runner.run({operation_type: run_params}, run_type=args.run_type)
+ logger.info("Run params: %s" % str(run_params))
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/tests/milvus_benchmark/operation.py b/tests/milvus_benchmark/operation.py
new file mode 100644
index 0000000000..348fa47f4a
--- /dev/null
+++ b/tests/milvus_benchmark/operation.py
@@ -0,0 +1,10 @@
+from __future__ import absolute_import
+import pdb
+import time
+
+class Base(object):
+ pass
+
+
+class Insert(Base):
+ pass
\ No newline at end of file
diff --git a/tests/milvus_benchmark/parser.py b/tests/milvus_benchmark/parser.py
new file mode 100644
index 0000000000..1c1d2605f2
--- /dev/null
+++ b/tests/milvus_benchmark/parser.py
@@ -0,0 +1,66 @@
+import pdb
+import logging
+
+logger = logging.getLogger("milvus_benchmark.parser")
+
+
+def operations_parser(operations, run_type="performance"):
+ definitions = operations[run_type]
+ return definitions
+
+
+def table_parser(table_name):
+ tmp = table_name.split("_")
+ # if len(tmp) != 5:
+ # return None
+ data_type = tmp[0]
+ table_size_unit = tmp[1][-1]
+ table_size = tmp[1][0:-1]
+ if table_size_unit == "m":
+ table_size = int(table_size) * 1000000
+ elif table_size_unit == "b":
+ table_size = int(table_size) * 1000000000
+ index_file_size = int(tmp[2])
+ dimension = int(tmp[3])
+ metric_type = str(tmp[4])
+ return (data_type, table_size, index_file_size, dimension, metric_type)
+
+
+def search_params_parser(param):
+ # parse top-k, set default value if top-k not in param
+ if "top_ks" not in param:
+ top_ks = [10]
+ else:
+ top_ks = param["top_ks"]
+ if isinstance(top_ks, int):
+ top_ks = [top_ks]
+ elif isinstance(top_ks, list):
+ top_ks = list(top_ks)
+ else:
+ logger.warning("Invalid format top-ks: %s" % str(top_ks))
+
+ # parse nqs, set default value if nq not in param
+ if "nqs" not in param:
+ nqs = [10]
+ else:
+ nqs = param["nqs"]
+ if isinstance(nqs, int):
+ nqs = [nqs]
+ elif isinstance(nqs, list):
+ nqs = list(nqs)
+ else:
+ logger.warning("Invalid format nqs: %s" % str(nqs))
+
+ # parse nprobes
+ if "nprobes" not in param:
+ nprobes = [1]
+ else:
+ nprobes = param["nprobes"]
+ if isinstance(nprobes, int):
+ nprobes = [nprobes]
+ elif isinstance(nprobes, list):
+ nprobes = list(nprobes)
+ else:
+ logger.warning("Invalid format nprobes: %s" % str(nprobes))
+
+ return top_ks, nqs, nprobes
\ No newline at end of file
diff --git a/tests/milvus_benchmark/report.py b/tests/milvus_benchmark/report.py
new file mode 100644
index 0000000000..6041311513
--- /dev/null
+++ b/tests/milvus_benchmark/report.py
@@ -0,0 +1,10 @@
+# from tablereport import Table
+# from tablereport.shortcut import write_to_excel
+
+# RESULT_FOLDER = "results"
+
+# def create_table(headers, bodys, table_name):
+# table = Table(header=[headers],
+# body=[bodys])
+
+# write_to_excel('%s/%s.xlsx' % (RESULT_FOLDER, table_name), table)
\ No newline at end of file
diff --git a/tests/milvus_benchmark/requirements.txt b/tests/milvus_benchmark/requirements.txt
new file mode 100644
index 0000000000..1285b4d2ba
--- /dev/null
+++ b/tests/milvus_benchmark/requirements.txt
@@ -0,0 +1,6 @@
+numpy==1.16.3
+pymilvus>=0.1.18
+pyyaml==3.12
+docker==4.0.2
+tableprint==0.8.0
+ansicolors==1.1.8
\ No newline at end of file
diff --git a/tests/milvus_benchmark/runner.py b/tests/milvus_benchmark/runner.py
new file mode 100644
index 0000000000..d3ad345da9
--- /dev/null
+++ b/tests/milvus_benchmark/runner.py
@@ -0,0 +1,219 @@
+import os
+import logging
+import pdb
+import time
+import random
+from multiprocessing import Process
+import numpy as np
+from client import MilvusClient
+import utils
+import parser
+
+logger = logging.getLogger("milvus_benchmark.runner")
+
+SERVER_HOST_DEFAULT = "127.0.0.1"
+SERVER_PORT_DEFAULT = 19530
+VECTORS_PER_FILE = 1000000
+SIFT_VECTORS_PER_FILE = 100000
+MAX_NQ = 10001
+FILE_PREFIX = "binary_"
+
+RANDOM_SRC_BINARY_DATA_DIR = '/tmp/random/binary_data'
+SIFT_SRC_DATA_DIR = '/tmp/sift1b/query'
+SIFT_SRC_BINARY_DATA_DIR = '/tmp/sift1b/binary_data'
+SIFT_SRC_GROUNDTRUTH_DATA_DIR = '/tmp/sift1b/groundtruth'
+
+WARM_TOP_K = 1
+WARM_NQ = 1
+DEFAULT_DIM = 512
+
+GROUNDTRUTH_MAP = {
+ "1000000": "idx_1M.ivecs",
+ "2000000": "idx_2M.ivecs",
+ "5000000": "idx_5M.ivecs",
+ "10000000": "idx_10M.ivecs",
+ "20000000": "idx_20M.ivecs",
+ "50000000": "idx_50M.ivecs",
+ "100000000": "idx_100M.ivecs",
+ "200000000": "idx_200M.ivecs",
+ "500000000": "idx_500M.ivecs",
+ "1000000000": "idx_1000M.ivecs",
+}
+
+
+def gen_file_name(idx, table_dimension, data_type):
+ s = "%05d" % idx
+ fname = FILE_PREFIX + str(table_dimension) + "d_" + s + ".npy"
+ if data_type == "random":
+ fname = RANDOM_SRC_BINARY_DATA_DIR+'/'+fname
+ elif data_type == "sift":
+ fname = SIFT_SRC_BINARY_DATA_DIR+'/'+fname
+ return fname
+
+
+def get_vectors_from_binary(nq, dimension, data_type):
+ # use the first file, nq should be less than VECTORS_PER_FILE
+ if nq > MAX_NQ:
+ raise Exception("Over size nq")
+ if data_type == "random":
+ file_name = gen_file_name(0, dimension, data_type)
+ elif data_type == "sift":
+ file_name = SIFT_SRC_DATA_DIR+'/'+'query.npy'
+ data = np.load(file_name)
+ vectors = data[0:nq].tolist()
+ return vectors
+
+
+class Runner(object):
+ def __init__(self):
+ pass
+
+ def do_insert(self, milvus, table_name, data_type, dimension, size, ni):
+ '''
+ @params:
+ mivlus: server connect instance
+ dimension: table dimensionn
+ # index_file_size: size trigger file merge
+ size: row count of vectors to be insert
+ ni: row count of vectors to be insert each time
+ # store_id: if store the ids returned by call add_vectors or not
+ @return:
+ total_time: total time for all insert operation
+ qps: vectors added per second
+ ni_time: avarage insert operation time
+ '''
+ bi_res = {}
+ total_time = 0.0
+ qps = 0.0
+ ni_time = 0.0
+ if data_type == "random":
+ vectors_per_file = VECTORS_PER_FILE
+ elif data_type == "sift":
+ vectors_per_file = SIFT_VECTORS_PER_FILE
+ if size % vectors_per_file or ni > vectors_per_file:
+ raise Exception("Not invalid table size or ni")
+ file_num = size // vectors_per_file
+ for i in range(file_num):
+ file_name = gen_file_name(i, dimension, data_type)
+ logger.info("Load npy file: %s start" % file_name)
+ data = np.load(file_name)
+ logger.info("Load npy file: %s end" % file_name)
+ loops = vectors_per_file // ni
+ for j in range(loops):
+ vectors = data[j*ni:(j+1)*ni].tolist()
+ ni_start_time = time.time()
+ # start insert vectors
+ start_id = i * vectors_per_file + j * ni
+ end_id = start_id + len(vectors)
+ logger.info("Start id: %s, end id: %s" % (start_id, end_id))
+ ids = [k for k in range(start_id, end_id)]
+ status, ids = milvus.insert(vectors, ids=ids)
+ ni_end_time = time.time()
+ total_time = total_time + ni_end_time - ni_start_time
+
+ qps = round(size / total_time, 2)
+ ni_time = round(total_time / (loops * file_num), 2)
+ bi_res["total_time"] = round(total_time, 2)
+ bi_res["qps"] = qps
+ bi_res["ni_time"] = ni_time
+ return bi_res
+
+ def do_query(self, milvus, table_name, top_ks, nqs, nprobe, run_count):
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
+
+ bi_res = []
+ for index, nq in enumerate(nqs):
+ tmp_res = []
+ for top_k in top_ks:
+ avg_query_time = 0.0
+ total_query_time = 0.0
+ vectors = base_query_vectors[0:nq]
+ logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
+ for i in range(run_count):
+ logger.info("Start run query, run %d of %s" % (i+1, run_count))
+ start_time = time.time()
+ status, query_res = milvus.query(vectors, top_k, nprobe)
+ total_query_time = total_query_time + (time.time() - start_time)
+ if status.code:
+ logger.error("Query failed with message: %s" % status.message)
+ avg_query_time = round(total_query_time / run_count, 2)
+ logger.info("Avarage query time: %.2f" % avg_query_time)
+ tmp_res.append(avg_query_time)
+ bi_res.append(tmp_res)
+ return bi_res
+
+ def do_query_ids(self, milvus, table_name, top_k, nq, nprobe):
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
+ vectors = base_query_vectors[0:nq]
+ logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
+ status, query_res = milvus.query(vectors, top_k, nprobe)
+ if not status.OK():
+ msg = "Query failed with message: %s" % status.message
+ raise Exception(msg)
+ result_ids = []
+ for result in query_res:
+ tmp = []
+ for item in result:
+ tmp.append(item.id)
+ result_ids.append(tmp)
+ return result_ids
+
+ def do_query_acc(self, milvus, table_name, top_k, nq, nprobe, id_store_name):
+ (data_type, table_size, index_file_size, dimension, metric_type) = parser.table_parser(table_name)
+ base_query_vectors = get_vectors_from_binary(MAX_NQ, dimension, data_type)
+ vectors = base_query_vectors[0:nq]
+ logger.info("Start query, query params: top-k: {}, nq: {}, actually length of vectors: {}".format(top_k, nq, len(vectors)))
+ status, query_res = milvus.query(vectors, top_k, nprobe)
+ if not status.OK():
+ msg = "Query failed with message: %s" % status.message
+ raise Exception(msg)
+ # if file existed, cover it
+ if os.path.isfile(id_store_name):
+ os.remove(id_store_name)
+ with open(id_store_name, 'a+') as fd:
+ for nq_item in query_res:
+ for item in nq_item:
+ fd.write(str(item.id)+'\t')
+ fd.write('\n')
+
+ # compute and print accuracy
+ def compute_accuracy(self, flat_file_name, index_file_name):
+ flat_id_list = []; index_id_list = []
+ logger.info("Loading flat id file: %s" % flat_file_name)
+ with open(flat_file_name, 'r') as flat_id_fd:
+ for line in flat_id_fd:
+ tmp_list = line.strip("\n").strip().split("\t")
+ flat_id_list.append(tmp_list)
+ logger.info("Loading index id file: %s" % index_file_name)
+ with open(index_file_name) as index_id_fd:
+ for line in index_id_fd:
+ tmp_list = line.strip("\n").strip().split("\t")
+ index_id_list.append(tmp_list)
+ if len(flat_id_list) != len(index_id_list):
+ raise Exception("Flat index result length: not match, Acc compute exiting ..." % (len(flat_id_list), len(index_id_list)))
+ # get the accuracy
+ return self.get_recall_value(flat_id_list, index_id_list)
+
+ def get_recall_value(self, flat_id_list, index_id_list):
+ """
+ Use the intersection length
+ """
+ sum_radio = 0.0
+ for index, item in enumerate(index_id_list):
+ tmp = set(item).intersection(set(flat_id_list[index]))
+ sum_radio = sum_radio + len(tmp) / len(item)
+ return round(sum_radio / len(index_id_list), 3)
+
+ """
+ Implementation based on:
+ https://github.com/facebookresearch/faiss/blob/master/benchs/datasets.py
+ """
+ def get_groundtruth_ids(self, table_size):
+ fname = GROUNDTRUTH_MAP[str(table_size)]
+ fname = SIFT_SRC_GROUNDTRUTH_DATA_DIR + "/" + fname
+ a = np.fromfile(fname, dtype='int32')
+ d = a[0]
+ true_ids = a.reshape(-1, d + 1)[:, 1:].copy()
+ return true_ids
diff --git a/tests/milvus_benchmark/suites.yaml b/tests/milvus_benchmark/suites.yaml
new file mode 100644
index 0000000000..f30b963c03
--- /dev/null
+++ b/tests/milvus_benchmark/suites.yaml
@@ -0,0 +1,38 @@
+# data sets
+datasets:
+ hf5:
+ gist-960,sift-128
+ npy:
+ 50000000-512, 100000000-512
+
+operations:
+ # interface: search_vectors
+ query:
+ # dataset: table name you have already created
+ # key starts with "server." need to reconfig and restart server, including nprpbe/nlist/use_blas_threshold/..
+ [
+ # debug
+ # {"dataset": "ip_ivfsq8_1000", "top_ks": [16], "nqs": [1], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
+
+ {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
+ {"dataset": "ip_ivfsq8_1000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
+ {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
+ {"dataset": "ip_ivfsq8_5000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
+ {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], "nqs": [1, 10, 100, 500, 800, 1000], "server.nprobe": 1, "server.use_blas_threshold": 800, "server.cpu_cache_capacity": 110},
+ # {"dataset": "ip_ivfsq8_40000", "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256], "nqs": [1, 10, 100, 1000], "server.nprobe": 10, "server.use_blas_threshold": 20, "server.cpu_cache_capacity": 110},
+ ]
+
+ # interface: add_vectors
+ insert:
+ # index_type: flat/ivf_flat/ivf_sq8
+ [
+ # debug
+
+
+ {"table_name": "ip_ivf_flat_20m_1024", "table.index_type": "ivf_flat", "server.index_building_threshold": 1024, "table.size": 20000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
+ {"table_name": "ip_ivf_sq8_50m_1024", "table.index_type": "ivf_sq8", "server.index_building_threshold": 1024, "table.size": 50000000, "table.ni": 100000, "table.dim": 512, "server.cpu_cache_capacity": 110},
+ ]
+
+ # TODO: interface: build_index
+ build: []
+
diff --git a/tests/milvus_benchmark/suites_accuracy.yaml b/tests/milvus_benchmark/suites_accuracy.yaml
new file mode 100644
index 0000000000..fd8a5904fa
--- /dev/null
+++ b/tests/milvus_benchmark/suites_accuracy.yaml
@@ -0,0 +1,121 @@
+
+accuracy:
+ # interface: search_vectors
+ query:
+ [
+ {
+ "dataset": "random_20m_1024_512_ip",
+ # index info
+ "index.index_types": ["flat", "ivf_sq8"],
+ "index.nlists": [16384],
+ "index.metric_types": ["ip"],
+ "nprobes": [1, 16, 64],
+ "top_ks": [64],
+ "nqs": [100],
+ "server.cpu_cache_capacity": 100,
+ "server.resources": ["cpu", "gpu0"],
+ "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip",
+ },
+ # {
+ # "dataset": "sift_50m_1024_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "index.metric_types": ["l2"],
+ # "nprobes": [1, 16, 64],
+ # "top_ks": [64],
+ # "nqs": [100],
+ # "server.cpu_cache_capacity": 160,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2",
+ # "sift_acc": true
+ # },
+ # {
+ # "dataset": "sift_50m_1024_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "index.metric_types": ["l2"],
+ # "nprobes": [1, 16, 64],
+ # "top_ks": [64],
+ # "nqs": [100],
+ # "server.cpu_cache_capacity": 160,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_50m_1024_128_l2_sq8",
+ # "sift_acc": true
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "index.metric_types": ["l2"],
+ # "nprobes": [1, 16, 64, 128],
+ # "top_ks": [64],
+ # "nqs": [100],
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
+ # "sift_acc": true
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "index.metric_types": ["l2"],
+ # "nprobes": [1, 16, 64, 128],
+ # "top_ks": [64],
+ # "nqs": [100],
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
+ # "sift_acc": true
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "index.metric_types": ["l2"],
+ # "nprobes": [1, 16, 64, 128],
+ # "top_ks": [64],
+ # "nqs": [100],
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu", "gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h",
+ # "sift_acc": true
+ # },
+ # {
+ # "dataset": "sift_1m_1024_128_l2",
+ # "index.index_types": ["flat", "ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1, 32, 128, 256, 512],
+ # "nqs": 10,
+ # "top_ks": 10,
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "sift_10m_1024_128_l2",
+ # "index.index_types": ["flat", "ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1, 32, 128, 256, 512],
+ # "nqs": 10,
+ # "top_ks": 10,
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 32,
+ # },
+ # {
+ # "dataset": "sift_50m_1024_128_l2",
+ # "index.index_types": ["flat", "ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1, 32, 128, 256, 512],
+ # "nqs": 10,
+ # "top_ks": 10,
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 64,
+ # }
+
+
+ ]
\ No newline at end of file
diff --git a/tests/milvus_benchmark/suites_performance.yaml b/tests/milvus_benchmark/suites_performance.yaml
new file mode 100644
index 0000000000..52d457a400
--- /dev/null
+++ b/tests/milvus_benchmark/suites_performance.yaml
@@ -0,0 +1,258 @@
+performance:
+
+ # interface: add_vectors
+ insert:
+ # index_type: flat/ivf_flat/ivf_sq8/mix_nsg
+ [
+ # debug
+ # data_type / data_size / index_file_size / dimension
+ # data_type: random / ann_sift
+ # data_size: 10m / 1b
+ # {
+ # "table_name": "random_50m_1024_512_ip",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # "server.cpu_cache_capacity": 16,
+ # # "server.resources": ["gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "table_name": "random_5m_1024_512_ip",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # "server.cpu_cache_capacity": 16,
+ # "server.resources": ["gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data/random_5m_1024_512_ip"
+ # },
+ # {
+ # "table_name": "sift_1m_50_128_l2",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # # "server.cpu_cache_capacity": 16,
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "table_name": "sift_1m_256_128_l2",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # # "server.cpu_cache_capacity": 16,
+ # "db_path_prefix": "/test/milvus/db_data"
+ # }
+ # {
+ # "table_name": "sift_50m_1024_128_l2",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "table_name": "sift_100m_1024_128_l2",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # },
+ # {
+ # "table_name": "sift_1b_2048_128_l2",
+ # "ni_per": 100000,
+ # "processes": 5, # multiprocessing
+ # "server.cpu_cache_capacity": 16,
+ # }
+ ]
+
+ # interface: search_vectors
+ query:
+ # dataset: table name you have already created
+ # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
+ [
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2_sq8h"
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data/sift_1b_2048_128_l2"
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 200,
+ # "server.resources": ["cpu"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ {
+ "dataset": "random_50m_1024_512_ip",
+ "index.index_types": ["ivf_sq8h"],
+ "index.nlists": [16384],
+ "nprobes": [8],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ "top_ks": [512],
+ # "nqs": [1, 10, 100, 500, 1000],
+ "nqs": [500],
+ "server.use_blas_threshold": 1100,
+ "server.cpu_cache_capacity": 150,
+ "server.gpu_cache_capacity": 6,
+ "server.resources": ["cpu", "gpu0", "gpu1"],
+ "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip"
+ },
+ # {
+ # "dataset": "random_50m_1024_512_ip",
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 150,
+ # "server.resources": ["cpu", "gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data/random_50m_1024_512_ip_sq8"
+ # },
+ # {
+ # "dataset": "random_20m_1024_512_ip",
+ # "index.index_types": ["flat"],
+ # "index.nlists": [16384],
+ # "nprobes": [50],
+ # "top_ks": [64],
+ # "nqs": [10],
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 100,
+ # "server.resources": ["cpu", "gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data/random_20m_1024_512_ip"
+ # },
+ # {
+ # "dataset": "random_100m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 250,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "dataset": "random_100m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [8, 32],
+ # "top_ks": [1, 8, 16, 32, 64, 128, 256, 512, 1000],
+ # "nqs": [1, 10, 100, 500, 1000],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 250,
+ # "server.resources": ["cpu"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 64
+ # },
+ # {
+ # "dataset": "sift_500m_1024_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 8, 16, 64, 256, 512, 1000],
+ # "nqs": [1, 100, 500, 800, 1000, 1500],
+ # # "top_ks": [256],
+ # # "nqs": [800],
+ # "processes": 1, # multiprocessing
+ # # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 120,
+ # "server.resources": ["gpu0", "gpu1"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8h"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # # "top_ks": [1],
+ # # "nqs": [1],
+ # "top_ks": [256],
+ # "nqs": [800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 110,
+ # "server.resources": ["cpu", "gpu0"],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # },
+ # {
+ # "dataset": "random_50m_1024_512_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # # "top_ks": [256],
+ # # "nqs": [800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 128
+ # },
+ # [
+ # {
+ # "dataset": "sift_1m_50_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1],
+ # "nqs": [1],
+ # "db_path_prefix": "/test/milvus/db_data"
+ # # "processes": 1, # multiprocessing
+ # # "server.use_blas_threshold": 1100,
+ # # "server.cpu_cache_capacity": 256
+ # }
+ ]
\ No newline at end of file
diff --git a/tests/milvus_benchmark/suites_stability.yaml b/tests/milvus_benchmark/suites_stability.yaml
new file mode 100644
index 0000000000..408221079e
--- /dev/null
+++ b/tests/milvus_benchmark/suites_stability.yaml
@@ -0,0 +1,17 @@
+
+stability:
+ # interface: search_vectors / add_vectors mix operation
+ query:
+ [
+ {
+ "dataset": "random_20m_1024_512_ip",
+ # "nqs": [1, 10, 100, 1000, 10000],
+ # "pds": [0.1, 0.44, 0.44, 0.02],
+ "query_process_num": 10,
+ # each 10s, do an insertion
+ # "insert_interval": 1,
+ # minutes
+ "during_time": 360,
+ "server.cpu_cache_capacity": 100
+ },
+ ]
\ No newline at end of file
diff --git a/tests/milvus_benchmark/suites_yzb.yaml b/tests/milvus_benchmark/suites_yzb.yaml
new file mode 100644
index 0000000000..59efcb37e4
--- /dev/null
+++ b/tests/milvus_benchmark/suites_yzb.yaml
@@ -0,0 +1,171 @@
+#"server.resources": ["gpu0", "gpu1"]
+
+performance:
+ # interface: search_vectors
+ query:
+ # dataset: table name you have already created
+ # key starts with "server." need to reconfig and restart server, including use_blas_threshold/cpu_cache_capacity ..
+ [
+ # debug
+ # {
+ # "dataset": "random_10m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 16,
+ # },
+ # {
+ # "dataset": "random_10m_1024_512_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 64
+ # },
+# {
+# "dataset": "sift_50m_1024_128_l2",
+# # index info
+# "index.index_types": ["ivf_sq8"],
+# "index.nlists": [16384],
+# "nprobes": [1, 32, 128],
+# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+# "nqs": [1, 10, 100, 500, 800],
+# # "top_ks": [256],
+# # "nqs": [800],
+# "processes": 1, # multiprocessing
+# "server.use_blas_threshold": 1100,
+# "server.cpu_cache_capacity": 310,
+# "server.resources": ["gpu0", "gpu1"]
+# },
+ {
+ "dataset": "sift_1m_1024_128_l2",
+ # index info
+ "index.index_types": ["ivf_sq8"],
+ "index.nlists": [16384],
+ "nprobes": [32],
+ "top_ks": [10],
+ "nqs": [100],
+ # "top_ks": [256],
+ # "nqs": [800],
+ "processes": 1, # multiprocessing
+ "server.use_blas_threshold": 1100,
+ "server.cpu_cache_capacity": 310,
+ "server.resources": ["cpu"]
+ },
+ {
+ "dataset": "sift_1m_1024_128_l2",
+ # index info
+ "index.index_types": ["ivf_sq8"],
+ "index.nlists": [16384],
+ "nprobes": [32],
+ "top_ks": [10],
+ "nqs": [100],
+ # "top_ks": [256],
+ # "nqs": [800],
+ "processes": 1, # multiprocessing
+ "server.use_blas_threshold": 1100,
+ "server.cpu_cache_capacity": 310,
+ "server.resources": ["gpu0"]
+ },
+ {
+ "dataset": "sift_1m_1024_128_l2",
+ # index info
+ "index.index_types": ["ivf_sq8"],
+ "index.nlists": [16384],
+ "nprobes": [32],
+ "top_ks": [10],
+ "nqs": [100],
+ # "top_ks": [256],
+ # "nqs": [800],
+ "processes": 1, # multiprocessing
+ "server.use_blas_threshold": 1100,
+ "server.cpu_cache_capacity": 310,
+ "server.resources": ["gpu0", "gpu1"]
+ },
+ # {
+ # "dataset": "sift_1b_2048_128_l2",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # # "top_ks": [256],
+ # # "nqs": [800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 310
+ # },
+# {
+# "dataset": "random_50m_1024_512_l2",
+# # index info
+# "index.index_types": ["ivf_sq8"],
+# "index.nlists": [16384],
+# "nprobes": [1],
+# "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+# "nqs": [1, 10, 100, 500, 800],
+# # "top_ks": [256],
+# # "nqs": [800],
+# "processes": 1, # multiprocessing
+# "server.use_blas_threshold": 1100,
+# "server.cpu_cache_capacity": 128,
+# "server.resources": ["gpu0", "gpu1"]
+# },
+ # {
+ # "dataset": "random_100m_1024_512_ip",
+ # # index info
+ # "index.index_types": ["ivf_sq8"],
+ # "index.nlists": [16384],
+ # "nprobes": [1],
+ # "top_ks": [1, 2, 4, 8, 16, 32, 64, 128, 256],
+ # "nqs": [1, 10, 100, 500, 800],
+ # "processes": 1, # multiprocessing
+ # "server.use_blas_threshold": 1100,
+ # "server.cpu_cache_capacity": 256
+ # },
+ ]
\ No newline at end of file
diff --git a/tests/milvus_benchmark/utils.py b/tests/milvus_benchmark/utils.py
new file mode 100644
index 0000000000..f6522578ad
--- /dev/null
+++ b/tests/milvus_benchmark/utils.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+from __future__ import print_function
+
+__true_print = print # noqa
+
+import os
+import sys
+import pdb
+import time
+import datetime
+import argparse
+import threading
+import logging
+import docker
+import multiprocessing
+import numpy
+# import psutil
+from yaml import load, dump
+import tableprint as tp
+
+logger = logging.getLogger("milvus_benchmark.utils")
+
+MULTI_DB_SLAVE_PATH = "/opt/milvus/data2;/opt/milvus/data3"
+
+
+def get_current_time():
+ return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
+
+
+def print_table(headers, columns, data):
+ bodys = []
+ for index, value in enumerate(columns):
+ tmp = [value]
+ tmp.extend(data[index])
+ bodys.append(tmp)
+ tp.table(bodys, headers)
+
+
+def modify_config(k, v, type=None, file_path="conf/server_config.yaml", db_slave=None):
+ if not os.path.isfile(file_path):
+ raise Exception('File: %s not found' % file_path)
+ with open(file_path) as f:
+ config_dict = load(f)
+ f.close()
+ if config_dict:
+ if k.find("use_blas_threshold") != -1:
+ config_dict['engine_config']['use_blas_threshold'] = int(v)
+ elif k.find("cpu_cache_capacity") != -1:
+ config_dict['cache_config']['cpu_cache_capacity'] = int(v)
+ elif k.find("gpu_cache_capacity") != -1:
+ config_dict['cache_config']['gpu_cache_capacity'] = int(v)
+ elif k.find("resource_pool") != -1:
+ config_dict['resource_config']['resource_pool'] = v
+
+ if db_slave:
+ config_dict['db_config']['db_slave_path'] = MULTI_DB_SLAVE_PATH
+ with open(file_path, 'w') as f:
+ dump(config_dict, f, default_flow_style=False)
+ f.close()
+ else:
+ raise Exception('Load file:%s error' % file_path)
+
+
+def pull_image(image):
+ registry = image.split(":")[0]
+ image_tag = image.split(":")[1]
+ client = docker.APIClient(base_url='unix://var/run/docker.sock')
+ logger.info("Start pulling image: %s" % image)
+ return client.pull(registry, image_tag)
+
+
+def run_server(image, mem_limit=None, timeout=30, test_type="local", volume_name=None, db_slave=None):
+ import colors
+
+ client = docker.from_env()
+ # if mem_limit is None:
+ # mem_limit = psutil.virtual_memory().available
+ # logger.info('Memory limit:', mem_limit)
+ # cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1)
+ # logger.info('Running on CPUs:', cpu_limit)
+ for dir_item in ['logs', 'db']:
+ try:
+ os.mkdir(os.path.abspath(dir_item))
+ except Exception as e:
+ pass
+
+ if test_type == "local":
+ volumes = {
+ os.path.abspath('conf'):
+ {'bind': '/opt/milvus/conf', 'mode': 'ro'},
+ os.path.abspath('logs'):
+ {'bind': '/opt/milvus/logs', 'mode': 'rw'},
+ os.path.abspath('db'):
+ {'bind': '/opt/milvus/db', 'mode': 'rw'},
+ }
+ elif test_type == "remote":
+ if volume_name is None:
+ raise Exception("No volume name")
+ remote_log_dir = volume_name+'/logs'
+ remote_db_dir = volume_name+'/db'
+
+ for dir_item in [remote_log_dir, remote_db_dir]:
+ if not os.path.isdir(dir_item):
+ os.makedirs(dir_item, exist_ok=True)
+ volumes = {
+ os.path.abspath('conf'):
+ {'bind': '/opt/milvus/conf', 'mode': 'ro'},
+ remote_log_dir:
+ {'bind': '/opt/milvus/logs', 'mode': 'rw'},
+ remote_db_dir:
+ {'bind': '/opt/milvus/db', 'mode': 'rw'}
+ }
+ # add volumes
+ if db_slave and isinstance(db_slave, int):
+ for i in range(2, db_slave+1):
+ remote_db_dir = volume_name+'/data'+str(i)
+ if not os.path.isdir(remote_db_dir):
+ os.makedirs(remote_db_dir, exist_ok=True)
+ volumes[remote_db_dir] = {'bind': '/opt/milvus/data'+str(i), 'mode': 'rw'}
+
+ container = client.containers.run(
+ image,
+ volumes=volumes,
+ runtime="nvidia",
+ ports={'19530/tcp': 19530, '8080/tcp': 8080},
+ environment=["OMP_NUM_THREADS=48"],
+ # cpuset_cpus=cpu_limit,
+ # mem_limit=mem_limit,
+ # environment=[""],
+ detach=True)
+
+ def stream_logs():
+ for line in container.logs(stream=True):
+ logger.info(colors.color(line.decode().rstrip(), fg='blue'))
+
+ if sys.version_info >= (3, 0):
+ t = threading.Thread(target=stream_logs, daemon=True)
+ else:
+ t = threading.Thread(target=stream_logs)
+ t.daemon = True
+ t.start()
+
+ logger.info('Container: %s started' % container)
+ return container
+ # exit_code = container.wait(timeout=timeout)
+ # # Exit if exit code
+ # if exit_code == 0:
+ # return container
+ # elif exit_code is not None:
+ # print(colors.color(container.logs().decode(), fg='red'))
+ # raise Exception('Child process raised exception %s' % str(exit_code))
+
+def restart_server(container):
+ client = docker.APIClient(base_url='unix://var/run/docker.sock')
+
+ client.restart(container.name)
+ logger.info('Container: %s restarted' % container.name)
+ return container
+
+
+def remove_container(container):
+ container.remove(force=True)
+ logger.info('Container: %s removed' % container)
+
+
+def remove_all_containers(image):
+ client = docker.from_env()
+ try:
+ for container in client.containers.list():
+ if image in container.image.tags:
+ container.stop(timeout=30)
+ container.remove(force=True)
+ except Exception as e:
+ logger.error("Containers removed failed")
+
+
+def container_exists(image):
+ '''
+ Check if container existed with the given image name
+ @params: image name
+ @return: container if exists
+ '''
+ res = False
+ client = docker.from_env()
+ for container in client.containers.list():
+ if image in container.image.tags:
+ # True
+ res = container
+ return res
+
+
+if __name__ == '__main__':
+ # print(pull_image('branch-0.3.1-debug'))
+ stop_server()
\ No newline at end of file
diff --git a/tests/milvus_python_test/.dockerignore b/tests/milvus_python_test/.dockerignore
new file mode 100644
index 0000000000..c97d9d043c
--- /dev/null
+++ b/tests/milvus_python_test/.dockerignore
@@ -0,0 +1,14 @@
+node_modules
+npm-debug.log
+Dockerfile*
+docker-compose*
+.dockerignore
+.git
+.gitignore
+.env
+*/bin
+*/obj
+README.md
+LICENSE
+.vscode
+__pycache__
\ No newline at end of file
diff --git a/tests/milvus_python_test/.gitignore b/tests/milvus_python_test/.gitignore
new file mode 100644
index 0000000000..9bd7345e51
--- /dev/null
+++ b/tests/milvus_python_test/.gitignore
@@ -0,0 +1,13 @@
+.python-version
+.pytest_cache
+__pycache__
+.vscode
+.idea
+
+test_out/
+*.pyc
+
+db/
+logs/
+
+.coverage
diff --git a/tests/milvus_python_test/Dockerfile b/tests/milvus_python_test/Dockerfile
new file mode 100644
index 0000000000..ec78b943dc
--- /dev/null
+++ b/tests/milvus_python_test/Dockerfile
@@ -0,0 +1,14 @@
+FROM python:3.6.8-jessie
+
+LABEL Name=megasearch_engine_test Version=0.0.1
+
+WORKDIR /app
+ADD . /app
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ libc-dev build-essential && \
+ python3 -m pip install -r requirements.txt && \
+ apt-get remove --purge -y
+
+ENTRYPOINT [ "/app/docker-entrypoint.sh" ]
+CMD [ "start" ]
\ No newline at end of file
diff --git a/tests/milvus_python_test/MilvusCases.md b/tests/milvus_python_test/MilvusCases.md
new file mode 100644
index 0000000000..ea6372c373
--- /dev/null
+++ b/tests/milvus_python_test/MilvusCases.md
@@ -0,0 +1,143 @@
+# Milvus test cases
+
+## * Interfaces test
+
+### 1. 连接测试
+
+#### 1.1 连接
+
+| cases | expected |
+| ---------------- | -------------------------------------------- |
+| 非法IP 123.0.0.2 | method: connect raise error in given timeout |
+| 正常 uri | attr: connected assert true |
+| 非法 uri | method: connect raise error in given timeout |
+| 最大连接数 | all connection attrs: connected assert true |
+| | |
+
+#### 1.2 断开连接
+
+| cases | expected |
+| ------------------------ | ------------------- |
+| 正常连接下,断开连接 | connect raise error |
+| 正常连接下,重复断开连接 | connect raise error |
+
+### 2. Table operation
+
+#### 2.1 表创建
+
+##### 2.1.1 表名
+
+| cases | expected |
+| ------------------------- | ----------- |
+| 基础功能,参数正常 | status pass |
+| 表名已存在 | status fail |
+| 表名:"中文" | status pass |
+| 表名带特殊字符: "-39fsd-" | status pass |
+| 表名带空格: "test1 2" | status pass |
+| invalid dim: 0 | raise error |
+| invalid dim: -1 | raise error |
+| invalid dim: 100000000 | raise error |
+| invalid dim: "string" | raise error |
+| index_type: 0 | status pass |
+| index_type: 1 | status pass |
+| index_type: 2 | status pass |
+| index_type: string | raise error |
+| | |
+
+##### 2.1.2 维数支持
+
+| cases | expected |
+| --------------------- | ----------- |
+| 维数: 0 | raise error |
+| 维数负数: -1 | raise error |
+| 维数最大值: 100000000 | raise error |
+| 维数字符串: "string" | raise error |
+| | |
+
+##### 2.1.3 索引类型支持
+
+| cases | expected |
+| ---------------- | ----------- |
+| 索引类型: 0 | status pass |
+| 索引类型: 1 | status pass |
+| 索引类型: 2 | status pass |
+| 索引类型: string | raise error |
+| | |
+
+#### 2.2 表说明
+
+| cases | expected |
+| ---------------------- | -------------------------------- |
+| 创建表后,执行describe | 返回结构体,元素与创建表参数一致 |
+| | |
+
+#### 2.3 表删除
+
+| cases | expected |
+| -------------- | ---------------------- |
+| 删除已存在表名 | has_table return False |
+| 删除不存在表名 | status fail |
+| | |
+
+#### 2.4 表是否存在
+
+| cases | expected |
+| ----------------------- | ------------ |
+| 存在表,调用has_table | assert true |
+| 不存在表,调用has_table | assert false |
+| | |
+
+#### 2.5 查询表记录条数
+
+| cases | expected |
+| -------------------- | ------------------------ |
+| 空表 | 0 |
+| 空表插入数据(单条) | 1 |
+| 空表插入数据(多条) | assert length of vectors |
+
+#### 2.6 查询表数量
+
+| cases | expected |
+| --------------------------------------------- | -------------------------------- |
+| 两张表,一张空表,一张有数据:调用show tables | assert length of table list == 2 |
+| | |
+
+### 3. Add vectors
+
+| interfaces | cases | expected |
+| ----------- | --------------------------------------------------------- | ------------------------------------ |
+| add_vectors | add basic | assert length of ids == nq |
+| | add vectors into table not existed | status fail |
+| | dim not match: single vector | status fail |
+| | dim not match: vector list | status fail |
+| | single vector element empty | status fail |
+| | vector list element empty | status fail |
+| | query immediately after adding | status pass |
+| | query immediately after sleep 6s | status pass && length of result == 1 |
+| | concurrent add with multi threads(share one connection) | status pass |
+| | concurrent add with multi threads(independent connection) | status pass |
+| | concurrent add with multi process(independent connection) | status pass |
+| | index_type: 2 | status pass |
+| | index_type: string | raise error |
+| | | |
+
+### 4. Search vectors
+
+| interfaces | cases | expected |
+| -------------- | ------------------------------------------------- | -------------------------------- |
+| search_vectors | search basic(query vector in vectors, top-k nq | assert length of result == nq |
+| | concurrent search | status pass |
+| | query_range(get_current_day(), get_current_day()) | assert length of result == nq |
+| | invalid query_range: "" | raise error |
+| | query_range(get_last_day(2), get_last_day(1)) | assert length of result == 0 |
+| | query_range(get_last_day(2), get_current_day()) | assert length of result == nq |
+| | query_range((get_last_day(2), get_next_day(2)) | assert length of result == nq |
+| | query_range((get_current_day(), get_next_day(2)) | assert length of result == nq |
+| | query_range(get_next_day(1), get_next_day(2)) | assert length of result == 0 |
+| | score: vector[i] = vector[i]+-0.01 | score > 99.9 |
\ No newline at end of file
diff --git a/tests/milvus_python_test/README.md b/tests/milvus_python_test/README.md
new file mode 100644
index 0000000000..69b4384d4c
--- /dev/null
+++ b/tests/milvus_python_test/README.md
@@ -0,0 +1,14 @@
+# Requirements
+* python 3.6.8
+
+# How to use this Test Project
+```shell
+pytest . -q -v
+ ```
+with allure test report
+ ```shell
+pytest --alluredir=test_out . -q -v
+allure serve test_out
+ ```
+# Contribution getting started
+* Follow PEP-8 for naming and black for formatting.
\ No newline at end of file
diff --git a/tests/milvus_python_test/conf/log_config.conf b/tests/milvus_python_test/conf/log_config.conf
new file mode 100644
index 0000000000..c530fa4c60
--- /dev/null
+++ b/tests/milvus_python_test/conf/log_config.conf
@@ -0,0 +1,27 @@
+* GLOBAL:
+ FORMAT = "%datetime | %level | %logger | %msg"
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-global.log"
+ ENABLED = true
+ TO_FILE = true
+ TO_STANDARD_OUTPUT = false
+ SUBSECOND_PRECISION = 3
+ PERFORMANCE_TRACKING = false
+ MAX_LOG_FILE_SIZE = 209715200 ## Throw log files away after 200MB
+* DEBUG:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-debug.log"
+ ENABLED = true
+* WARNING:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-warning.log"
+* TRACE:
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-trace.log"
+* VERBOSE:
+ FORMAT = "%datetime{%d/%M/%y} | %level-%vlevel | %msg"
+ TO_FILE = false
+ TO_STANDARD_OUTPUT = false
+## Error logs
+* ERROR:
+ ENABLED = true
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-error.log"
+* FATAL:
+ ENABLED = true
+ FILENAME = "/opt/milvus/logs/milvus-%datetime{%H:%m}-fatal.log"
diff --git a/tests/milvus_python_test/conf/server_config.yaml b/tests/milvus_python_test/conf/server_config.yaml
new file mode 100644
index 0000000000..6fe7e05791
--- /dev/null
+++ b/tests/milvus_python_test/conf/server_config.yaml
@@ -0,0 +1,32 @@
+server_config:
+ address: 0.0.0.0
+ port: 19530
+ deploy_mode: single
+ time_zone: UTC+8
+
+db_config:
+ primary_path: /opt/milvus
+ secondary_path:
+ backend_url: sqlite://:@:/
+ insert_buffer_size: 4
+ build_index_gpu: 0
+ preload_table:
+
+metric_config:
+ enable_monitor: true
+ collector: prometheus
+ prometheus_config:
+ port: 8080
+
+cache_config:
+ cpu_cache_capacity: 8
+ cpu_cache_threshold: 0.85
+ cache_insert_data: false
+
+engine_config:
+ use_blas_threshold: 20
+
+resource_config:
+ resource_pool:
+ - cpu
+ - gpu0
diff --git a/tests/milvus_python_test/conftest.py b/tests/milvus_python_test/conftest.py
new file mode 100644
index 0000000000..c6046ed56f
--- /dev/null
+++ b/tests/milvus_python_test/conftest.py
@@ -0,0 +1,128 @@
+import socket
+import pdb
+import logging
+
+import pytest
+from utils import gen_unique_str
+from milvus import Milvus, IndexType, MetricType
+
+index_file_size = 10
+
+
+def pytest_addoption(parser):
+ parser.addoption("--ip", action="store", default="localhost")
+ parser.addoption("--port", action="store", default=19530)
+
+
+def check_server_connection(request):
+ ip = request.config.getoption("--ip")
+ port = request.config.getoption("--port")
+ connected = True
+ if ip and (ip not in ['localhost', '127.0.0.1']):
+ try:
+ socket.getaddrinfo(ip, port, 0, 0, socket.IPPROTO_TCP)
+ except Exception as e:
+ print("Socket connnet failed: %s" % str(e))
+ connected = False
+ return connected
+
+
+def get_args(request):
+ args = {
+ "ip": request.config.getoption("--ip"),
+ "port": request.config.getoption("--port")
+ }
+ return args
+
+
+@pytest.fixture(scope="module")
+def connect(request):
+ ip = request.config.getoption("--ip")
+ port = request.config.getoption("--port")
+ milvus = Milvus()
+ try:
+ milvus.connect(host=ip, port=port)
+ except:
+ pytest.exit("Milvus server can not connected, exit pytest ...")
+
+ def fin():
+ try:
+ milvus.disconnect()
+ except:
+ pass
+
+ request.addfinalizer(fin)
+ return milvus
+
+
+@pytest.fixture(scope="module")
+def dis_connect(request):
+ ip = request.config.getoption("--ip")
+ port = request.config.getoption("--port")
+ milvus = Milvus()
+ milvus.connect(host=ip, port=port)
+ milvus.disconnect()
+ def fin():
+ try:
+ milvus.disconnect()
+ except:
+ pass
+
+ request.addfinalizer(fin)
+ return milvus
+
+
+@pytest.fixture(scope="module")
+def args(request):
+ ip = request.config.getoption("--ip")
+ port = request.config.getoption("--port")
+ args = {"ip": ip, "port": port}
+ return args
+
+
+@pytest.fixture(scope="function")
+def table(request, connect):
+ ori_table_name = getattr(request.module, "table_id", "test")
+ table_name = gen_unique_str(ori_table_name)
+ dim = getattr(request.module, "dim", "128")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ # logging.getLogger().info(status)
+ if not status.OK():
+ pytest.exit("Table can not be created, exit pytest ...")
+
+ def teardown():
+ status, table_names = connect.show_tables()
+ for table_name in table_names:
+ connect.delete_table(table_name)
+
+ request.addfinalizer(teardown)
+
+ return table_name
+
+
+@pytest.fixture(scope="function")
+def ip_table(request, connect):
+ ori_table_name = getattr(request.module, "table_id", "test")
+ table_name = gen_unique_str(ori_table_name)
+ dim = getattr(request.module, "dim", "128")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ status = connect.create_table(param)
+ # logging.getLogger().info(status)
+ if not status.OK():
+ pytest.exit("Table can not be created, exit pytest ...")
+
+ def teardown():
+ status, table_names = connect.show_tables()
+ for table_name in table_names:
+ connect.delete_table(table_name)
+
+ request.addfinalizer(teardown)
+
+ return table_name
\ No newline at end of file
diff --git a/tests/milvus_python_test/docker-entrypoint.sh b/tests/milvus_python_test/docker-entrypoint.sh
new file mode 100755
index 0000000000..af9ba0ba66
--- /dev/null
+++ b/tests/milvus_python_test/docker-entrypoint.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+set -e
+
+if [ "$1" = 'start' ]; then
+ tail -f /dev/null
+fi
+
+exec "$@"
\ No newline at end of file
diff --git a/tests/milvus_python_test/pytest.ini b/tests/milvus_python_test/pytest.ini
new file mode 100644
index 0000000000..3f95dc29b8
--- /dev/null
+++ b/tests/milvus_python_test/pytest.ini
@@ -0,0 +1,9 @@
+[pytest]
+log_format = [%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)
+
+log_cli = true
+log_level = 20
+
+timeout = 300
+
+level = 1
\ No newline at end of file
diff --git a/tests/milvus_python_test/requirements.txt b/tests/milvus_python_test/requirements.txt
new file mode 100644
index 0000000000..4bdecd6033
--- /dev/null
+++ b/tests/milvus_python_test/requirements.txt
@@ -0,0 +1,25 @@
+astroid==2.2.5
+atomicwrites==1.3.0
+attrs==19.1.0
+importlib-metadata==0.15
+isort==4.3.20
+lazy-object-proxy==1.4.1
+mccabe==0.6.1
+more-itertools==7.0.0
+numpy==1.16.3
+pluggy==0.12.0
+py==1.8.0
+pylint==2.3.1
+pytest==4.5.0
+pytest-timeout==1.3.3
+pytest-repeat==0.8.0
+allure-pytest==2.7.0
+pytest-print==0.1.2
+pytest-level==0.1.1
+six==1.12.0
+thrift==0.11.0
+typed-ast==1.3.5
+wcwidth==0.1.7
+wrapt==1.11.1
+zipp==0.5.1
+pymilvus-test>=0.2.0
diff --git a/tests/milvus_python_test/requirements_cluster.txt b/tests/milvus_python_test/requirements_cluster.txt
new file mode 100644
index 0000000000..a1f3b69be9
--- /dev/null
+++ b/tests/milvus_python_test/requirements_cluster.txt
@@ -0,0 +1,25 @@
+astroid==2.2.5
+atomicwrites==1.3.0
+attrs==19.1.0
+importlib-metadata==0.15
+isort==4.3.20
+lazy-object-proxy==1.4.1
+mccabe==0.6.1
+more-itertools==7.0.0
+numpy==1.16.3
+pluggy==0.12.0
+py==1.8.0
+pylint==2.3.1
+pytest==4.5.0
+pytest-timeout==1.3.3
+pytest-repeat==0.8.0
+allure-pytest==2.7.0
+pytest-print==0.1.2
+pytest-level==0.1.1
+six==1.12.0
+thrift==0.11.0
+typed-ast==1.3.5
+wcwidth==0.1.7
+wrapt==1.11.1
+zipp==0.5.1
+pymilvus>=0.1.24
diff --git a/tests/milvus_python_test/requirements_no_pymilvus.txt b/tests/milvus_python_test/requirements_no_pymilvus.txt
new file mode 100644
index 0000000000..45884c0c71
--- /dev/null
+++ b/tests/milvus_python_test/requirements_no_pymilvus.txt
@@ -0,0 +1,24 @@
+astroid==2.2.5
+atomicwrites==1.3.0
+attrs==19.1.0
+importlib-metadata==0.15
+isort==4.3.20
+lazy-object-proxy==1.4.1
+mccabe==0.6.1
+more-itertools==7.0.0
+numpy==1.16.3
+pluggy==0.12.0
+py==1.8.0
+pylint==2.3.1
+pytest==4.5.0
+pytest-timeout==1.3.3
+pytest-repeat==0.8.0
+allure-pytest==2.7.0
+pytest-print==0.1.2
+pytest-level==0.1.1
+six==1.12.0
+thrift==0.11.0
+typed-ast==1.3.5
+wcwidth==0.1.7
+wrapt==1.11.1
+zipp==0.5.1
diff --git a/tests/milvus_python_test/run.sh b/tests/milvus_python_test/run.sh
new file mode 100644
index 0000000000..cee5b061f5
--- /dev/null
+++ b/tests/milvus_python_test/run.sh
@@ -0,0 +1,4 @@
+#/bin/bash
+
+
+pytest . $@
\ No newline at end of file
diff --git a/tests/milvus_python_test/test.template b/tests/milvus_python_test/test.template
new file mode 100644
index 0000000000..0403d72860
--- /dev/null
+++ b/tests/milvus_python_test/test.template
@@ -0,0 +1,41 @@
+'''
+Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
+Unauthorized copying of this file, via any medium is strictly prohibited.
+Proprietary and confidential.
+'''
+
+'''
+Test Description:
+
+This document is only a template to show how to write a auto-test script
+
+本文档仅仅是个展示如何编写自动化测试脚本的模板
+
+'''
+
+import pytest
+from milvus import Milvus
+
+
+class TestConnection:
+ def test_connect_localhost(self):
+
+ """
+ TestCase1.1
+ Test target: This case is to check if the server can be connected.
+ Test method: Call API: milvus.connect to connect local milvus server, ip address: 127.0.0.1 and port: 19530, check the return status
+ Expectation: Return status is OK.
+
+ 测试目的:本用例测试客户端是否可以与服务器建立连接
+ 测试方法:调用SDK API: milvus.connect方法连接本地服务器,IP地址:127.0.0.1,端口:19530,检查调用返回状态
+ 期望结果:返回状态是:OK
+
+ """
+
+ milvus = Milvus()
+ milvus.connect(host='127.0.0.1', port='19530')
+ assert milvus.connected
+
+
+
+
diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py
new file mode 100644
index 0000000000..d9faeede33
--- /dev/null
+++ b/tests/milvus_python_test/test_add_vectors.py
@@ -0,0 +1,1233 @@
+import time
+import random
+import pdb
+import threading
+import logging
+from multiprocessing import Pool, Process
+import pytest
+from milvus import Milvus, IndexType, MetricType
+from utils import *
+
+
+dim = 128
+index_file_size = 10
+table_id = "test_add"
+ADD_TIMEOUT = 60
+nprobe = 1
+epsilon = 0.0001
+
+index_params = random.choice(gen_index_params())
+logging.getLogger().info(index_params)
+
+
+class TestAddBase:
+ """
+ ******************************************************************
+ The following cases are used to test `add_vectors / index / search / delete` mixed function
+ ******************************************************************
+ """
+
+ def test_add_vector_create_table(self, connect, table):
+ '''
+ target: test add vector, then create table again
+ method: add vector and create table
+ expected: status not ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ assert not status.OK()
+
+ def test_add_vector_has_table(self, connect, table):
+ '''
+ target: test add vector, then check table existence
+ method: add vector and call HasTable
+ expected: table exists, status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ ret = connect.has_table(table)
+ assert ret == True
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_delete_table_add_vector(self, connect, table):
+ '''
+ target: test add vector after table deleted
+ method: delete table and add vector
+ expected: status not ok
+ '''
+ status = connect.delete_table(table)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_delete_table_add_vector_another(self, connect, table):
+ '''
+ target: test add vector to table_1 after table_2 deleted
+ method: delete table_2 and add vector to table_1
+ expected: status ok
+ '''
+ param = {'table_name': 'test_delete_table_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ status = connect.delete_table(table)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(param['table_name'], vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_delete_table(self, connect, table):
+ '''
+ target: test delete table after add vector
+ method: add vector and delete table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status = connect.delete_table(table)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_delete_another_table(self, connect, table):
+ '''
+ target: test delete table_1 table after add vector to table_2
+ method: add vector and delete table
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_delete_another_table',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status = connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_delete_table(self, connect, table):
+ '''
+ target: test delete table after add vector for a while
+ method: add vector, sleep, and delete table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status = connect.delete_table(table)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_delete_another_table(self, connect, table):
+ '''
+ target: test delete table_1 table after add vector to table_2 for a while
+ method: add vector , sleep, and delete table
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_delete_another_table',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status = connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_create_index_add_vector(self, connect, table):
+ '''
+ target: test add vector after build index
+ method: build index and add vector
+ expected: status ok
+ '''
+ status = connect.create_index(table, index_params)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_create_index_add_vector_another(self, connect, table):
+ '''
+ target: test add vector to table_2 after build index for table_1
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_create_index_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ status = connect.create_index(table, index_params)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_create_index(self, connect, table):
+ '''
+ target: test build index add after vector
+ method: add vector and build index
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_create_index_another(self, connect, table):
+ '''
+ target: test add vector to table_2 after build index for table_1
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_create_index_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status = connect.create_index(param['table_name'], index_params)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_create_index(self, connect, table):
+ '''
+ target: test build index add after vector for a while
+ method: add vector and build index
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_create_index_another(self, connect, table):
+ '''
+ target: test add vector to table_2 after build index for table_1 for a while
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_create_index_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status = connect.create_index(param['table_name'], index_params)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_search_vector_add_vector(self, connect, table):
+ '''
+ target: test add vector after search table
+ method: search table and add vector
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, result = connect.search_vectors(table, 1, nprobe, vector)
+ status, ids = connect.add_vectors(table, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_search_vector_add_vector_another(self, connect, table):
+ '''
+ target: test add vector to table_1 after search table_2
+ method: search table and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_search_vector_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, result = connect.search_vectors(table, 1, nprobe, vector)
+ status, ids = connect.add_vectors(param['table_name'], vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_search_vector(self, connect, table):
+ '''
+ target: test search vector after add vector
+ method: add vector and search table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status, result = connect.search_vectors(table, 1, nprobe, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_search_vector_another(self, connect, table):
+ '''
+ target: test add vector to table_1 after search table_2
+ method: search table and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_search_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_search_vector(self, connect, table):
+ '''
+ target: test search vector after add vector after a while
+ method: add vector, sleep, and search table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status, result = connect.search_vectors(table, 1, nprobe, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_search_vector_another(self, connect, table):
+ '''
+ target: test add vector to table_1 after search table_2 a while
+ method: search table , sleep, and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_search_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(table, vector)
+ time.sleep(1)
+ status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ """
+ ******************************************************************
+ The following cases are used to test `add_vectors` function
+ ******************************************************************
+ """
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_ids(self, connect, table):
+ '''
+ target: test add vectors in table, use customize ids
+ method: create table and add vectors in it, check the ids returned and the table length after vectors added
+ expected: the length of ids and the table row count
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(table, vectors, ids)
+ time.sleep(2)
+ assert status.OK()
+ assert len(ids) == nq
+ # check search result
+ status, result = connect.search_vectors(table, top_k, nprobe, vectors)
+ logging.getLogger().info(result)
+ assert len(result) == nq
+ for i in range(nq):
+ assert result[i][0].id == i
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_twice_ids_no_ids(self, connect, table):
+ '''
+ target: check the result of add_vectors, with params ids and no ids
+ method: test add vectors twice, use customize ids first, and then use no ids
+ expected: status not OK
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(table, vectors, ids)
+ assert status.OK()
+ status, ids = connect.add_vectors(table, vectors)
+ logging.getLogger().info(status)
+ logging.getLogger().info(ids)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_twice_not_ids_ids(self, connect, table):
+ '''
+ target: check the result of add_vectors, with params ids and no ids
+ method: test add vectors twice, use not ids first, and then use customize ids
+ expected: status not OK
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(table, vectors)
+ assert status.OK()
+ status, ids = connect.add_vectors(table, vectors, ids)
+ logging.getLogger().info(status)
+ logging.getLogger().info(ids)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_ids_length_not_match(self, connect, table):
+ '''
+ target: test add vectors in table, use customize ids, len(ids) != len(vectors)
+ method: create table and add vectors in it
+ expected: raise an exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(1, nq)]
+ with pytest.raises(Exception) as e:
+ status, ids = connect.add_vectors(table, vectors, ids)
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_vector_ids()
+ )
+ def get_vector_id(self, request):
+ yield request.param
+
+ def test_add_vectors_ids_invalid(self, connect, table, get_vector_id):
+ '''
+ target: test add vectors in table, use customize ids, which are not int64
+ method: create table and add vectors in it
+ expected: raise an exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ vector_id = get_vector_id
+ ids = [vector_id for i in range(nq)]
+ with pytest.raises(Exception) as e:
+ status, ids = connect.add_vectors(table, vectors, ids)
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors(self, connect, table):
+ '''
+ target: test add vectors in table created before
+ method: create table and add vectors in it, check the ids returned and the table length after vectors added
+ expected: the length of ids and the table row count
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ status, ids = connect.add_vectors(table, vectors)
+ assert status.OK()
+ assert len(ids) == nq
+
+ @pytest.mark.level(2)
+ def test_add_vectors_without_connect(self, dis_connect, table):
+ '''
+ target: test add vectors without connection
+ method: create table and add vectors in it, check if added successfully
+ expected: raise exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ with pytest.raises(Exception) as e:
+ status, ids = dis_connect.add_vectors(table, vectors)
+
+ def test_add_table_not_existed(self, connect):
+ '''
+ target: test add vectors in table, which not existed before
+ method: add vectors table not existed, check the status
+ expected: status not ok
+ '''
+ nq = 5
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(gen_unique_str("not_exist_table"), vector)
+ assert not status.OK()
+ assert not ids
+
+ def test_add_vector_dim_not_matched(self, connect, table):
+ '''
+ target: test add vector, the vector dimension is not equal to the table dimension
+ method: the vector dimension is half of the table dimension, check the status
+ expected: status not ok
+ '''
+ vector = gen_single_vector(int(dim)//2)
+ status, ids = connect.add_vectors(table, vector)
+ assert not status.OK()
+
+ def test_add_vectors_dim_not_matched(self, connect, table):
+ '''
+ target: test add vectors, the vector dimension is not equal to the table dimension
+ method: the vectors dimension is half of the table dimension, check the status
+ expected: status not ok
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, int(dim)//2)
+ status, ids = connect.add_vectors(table, vectors)
+ assert not status.OK()
+
+ def test_add_vector_query_after_sleep(self, connect, table):
+ '''
+ target: test add vectors, and search it after sleep
+ method: set vector[0][1] as query vectors
+ expected: status ok and result length is 1
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ status, ids = connect.add_vectors(table, vectors)
+ time.sleep(3)
+ status, result = connect.search_vectors(table, 1, nprobe, [vectors[0]])
+ assert status.OK()
+ assert len(result) == 1
+
+ # @pytest.mark.repeat(5)
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def _test_add_vector_multi_threading(self, connect, table):
+ '''
+ target: test add vectors, with multi threading
+ method: 10 thread add vectors concurrently
+ expected: status ok and result length is equal to the length off added vectors
+ '''
+ thread_num = 4
+ loops = 100
+ threads = []
+ total_ids = []
+ vector = gen_single_vector(dim)
+ def add():
+ i = 0
+ while i < loops:
+ status, ids = connect.add_vectors(table, vector)
+ total_ids.append(ids[0])
+ i = i + 1
+ for i in range(thread_num):
+ x = threading.Thread(target=add, args=())
+ threads.append(x)
+ x.start()
+ time.sleep(0.2)
+ for th in threads:
+ th.join()
+ assert len(total_ids) == thread_num * loops
+ # make sure ids not the same
+ assert len(set(total_ids)) == thread_num * loops
+
+ # TODO: enable
+ # @pytest.mark.repeat(5)
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def _test_add_vector_with_multiprocessing(self, args):
+ '''
+ target: test add vectors, with multi processes
+ method: 10 processed add vectors concurrently
+ expected: status ok and result length is equal to the length off added vectors
+ '''
+ table = gen_unique_str("test_add_vector_with_multiprocessing")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ vector = gen_single_vector(dim)
+
+ process_num = 4
+ loop_num = 10
+ processes = []
+ # with dependent connection
+ def add(milvus):
+ i = 0
+ while i < loop_num:
+ status, ids = milvus.add_vectors(table, vector)
+ i = i + 1
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=add, args=(milvus,))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+ time.sleep(3)
+ status, count = milvus.get_table_row_count(table)
+ assert count == process_num * loop_num
+
+ def test_add_vector_multi_tables(self, connect):
+ '''
+ target: test add vectors is correct or not with multiple tables of L2
+ method: create 50 tables and add vectors into them in turn
+ expected: status ok
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_add_vector_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ time.sleep(2)
+ for j in range(10):
+ for i in range(50):
+ status, ids = connect.add_vectors(table_name=table_list[i], records=vectors)
+ assert status.OK()
+
+class TestAddIP:
+ """
+ ******************************************************************
+ The following cases are used to test `add_vectors / index / search / delete` mixed function
+ ******************************************************************
+ """
+
+ def test_add_vector_create_table(self, connect, ip_table):
+ '''
+ target: test add vector, then create table again
+ method: add vector and create table
+ expected: status not ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ param = {'table_name': ip_table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ assert not status.OK()
+
+ def test_add_vector_has_table(self, connect, ip_table):
+ '''
+ target: test add vector, then check table existence
+ method: add vector and call HasTable
+ expected: table exists, status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ ret = connect.has_table(ip_table)
+ assert ret == True
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_delete_table_add_vector(self, connect, ip_table):
+ '''
+ target: test add vector after table deleted
+ method: delete table and add vector
+ expected: status not ok
+ '''
+ status = connect.delete_table(ip_table)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_delete_table_add_vector_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_1 after table_2 deleted
+ method: delete table_2 and add vector to table_1
+ expected: status ok
+ '''
+ param = {'table_name': 'test_delete_table_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ status = connect.delete_table(ip_table)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(param['table_name'], vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_delete_table(self, connect, ip_table):
+ '''
+ target: test delete table after add vector
+ method: add vector and delete table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status = connect.delete_table(ip_table)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_delete_another_table(self, connect, ip_table):
+ '''
+ target: test delete table_1 table after add vector to table_2
+ method: add vector and delete table
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_delete_another_table',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status = connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_delete_table(self, connect, ip_table):
+ '''
+ target: test delete table after add vector for a while
+ method: add vector, sleep, and delete table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status = connect.delete_table(ip_table)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_delete_another_table(self, connect, ip_table):
+ '''
+ target: test delete table_1 table after add vector to table_2 for a while
+ method: add vector , sleep, and delete table
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_delete_another_table',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status = connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_create_index_add_vector(self, connect, ip_table):
+ '''
+ target: test add vector after build index
+ method: build index and add vector
+ expected: status ok
+ '''
+ status = connect.create_index(ip_table, index_params)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_create_index_add_vector_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_2 after build index for table_1
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_create_index_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ status = connect.create_index(ip_table, index_params)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_create_index(self, connect, ip_table):
+ '''
+ target: test build index add after vector
+ method: add vector and build index
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_create_index_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_2 after build index for table_1
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_create_index_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status = connect.create_index(param['table_name'], index_params)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_create_index(self, connect, ip_table):
+ '''
+ target: test build index add after vector for a while
+ method: add vector and build index
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_create_index_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_2 after build index for table_1 for a while
+ method: build index and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_create_index_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status = connect.create_index(param['table_name'], index_params)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_search_vector_add_vector(self, connect, ip_table):
+ '''
+ target: test add vector after search table
+ method: search table and add vector
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, result = connect.search_vectors(ip_table, 1, nprobe, vector)
+ status, ids = connect.add_vectors(ip_table, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_search_vector_add_vector_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_1 after search table_2
+ method: search table and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_search_vector_add_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, result = connect.search_vectors(ip_table, 1, nprobe, vector)
+ status, ids = connect.add_vectors(param['table_name'], vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_search_vector(self, connect, ip_table):
+ '''
+ target: test search vector after add vector
+ method: add vector and search table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status, result = connect.search_vectors(ip_table, 1, nprobe, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_search_vector_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_1 after search table_2
+ method: search table and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_search_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_search_vector(self, connect, ip_table):
+ '''
+ target: test search vector after add vector after a while
+ method: add vector, sleep, and search table
+ expected: status ok
+ '''
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status, result = connect.search_vectors(ip_table, 1, nprobe, vector)
+ assert status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vector_sleep_search_vector_another(self, connect, ip_table):
+ '''
+ target: test add vector to table_1 after search table_2 a while
+ method: search table , sleep, and add vector
+ expected: status ok
+ '''
+ param = {'table_name': 'test_add_vector_sleep_search_vector_another',
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ vector = gen_single_vector(dim)
+ status, ids = connect.add_vectors(ip_table, vector)
+ time.sleep(1)
+ status, result = connect.search_vectors(param['table_name'], 1, nprobe, vector)
+ connect.delete_table(param['table_name'])
+ assert status.OK()
+
+ """
+ ******************************************************************
+ The following cases are used to test `add_vectors` function
+ ******************************************************************
+ """
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_ids(self, connect, ip_table):
+ '''
+ target: test add vectors in table, use customize ids
+ method: create table and add vectors in it, check the ids returned and the table length after vectors added
+ expected: the length of ids and the table row count
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(ip_table, vectors, ids)
+ time.sleep(2)
+ assert status.OK()
+ assert len(ids) == nq
+ # check search result
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, vectors)
+ logging.getLogger().info(result)
+ assert len(result) == nq
+ for i in range(nq):
+ assert result[i][0].id == i
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_twice_ids_no_ids(self, connect, ip_table):
+ '''
+ target: check the result of add_vectors, with params ids and no ids
+ method: test add vectors twice, use customize ids first, and then use no ids
+ expected: status not OK
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(ip_table, vectors, ids)
+ assert status.OK()
+ status, ids = connect.add_vectors(ip_table, vectors)
+ logging.getLogger().info(status)
+ logging.getLogger().info(ids)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_twice_not_ids_ids(self, connect, ip_table):
+ '''
+ target: check the result of add_vectors, with params ids and no ids
+ method: test add vectors twice, use not ids first, and then use customize ids
+ expected: status not OK
+ '''
+ nq = 5; top_k = 1; nprobe = 1
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(nq)]
+ status, ids = connect.add_vectors(ip_table, vectors)
+ assert status.OK()
+ status, ids = connect.add_vectors(ip_table, vectors, ids)
+ logging.getLogger().info(status)
+ logging.getLogger().info(ids)
+ assert not status.OK()
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors_ids_length_not_match(self, connect, ip_table):
+ '''
+ target: test add vectors in table, use customize ids, len(ids) != len(vectors)
+ method: create table and add vectors in it
+ expected: raise an exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ ids = [i for i in range(1, nq)]
+ with pytest.raises(Exception) as e:
+ status, ids = connect.add_vectors(ip_table, vectors, ids)
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_vector_ids()
+ )
+ def get_vector_id(self, request):
+ yield request.param
+
+ def test_add_vectors_ids_invalid(self, connect, ip_table, get_vector_id):
+ '''
+ target: test add vectors in table, use customize ids, which are not int64
+ method: create table and add vectors in it
+ expected: raise an exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ vector_id = get_vector_id
+ ids = [vector_id for i in range(nq)]
+ with pytest.raises(Exception) as e:
+ status, ids = connect.add_vectors(ip_table, vectors, ids)
+
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def test_add_vectors(self, connect, ip_table):
+ '''
+ target: test add vectors in table created before
+ method: create table and add vectors in it, check the ids returned and the table length after vectors added
+ expected: the length of ids and the table row count
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ assert status.OK()
+ assert len(ids) == nq
+
+ @pytest.mark.level(2)
+ def test_add_vectors_without_connect(self, dis_connect, ip_table):
+ '''
+ target: test add vectors without connection
+ method: create table and add vectors in it, check if added successfully
+ expected: raise exception
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ with pytest.raises(Exception) as e:
+ status, ids = dis_connect.add_vectors(ip_table, vectors)
+
+ def test_add_vector_dim_not_matched(self, connect, ip_table):
+ '''
+ target: test add vector, the vector dimension is not equal to the table dimension
+ method: the vector dimension is half of the table dimension, check the status
+ expected: status not ok
+ '''
+ vector = gen_single_vector(int(dim)//2)
+ status, ids = connect.add_vectors(ip_table, vector)
+ assert not status.OK()
+
+ def test_add_vectors_dim_not_matched(self, connect, ip_table):
+ '''
+ target: test add vectors, the vector dimension is not equal to the table dimension
+ method: the vectors dimension is half of the table dimension, check the status
+ expected: status not ok
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, int(dim)//2)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ assert not status.OK()
+
+ def test_add_vector_query_after_sleep(self, connect, ip_table):
+ '''
+ target: test add vectors, and search it after sleep
+ method: set vector[0][1] as query vectors
+ expected: status ok and result length is 1
+ '''
+ nq = 5
+ vectors = gen_vectors(nq, dim)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ time.sleep(3)
+ status, result = connect.search_vectors(ip_table, 1, nprobe, [vectors[0]])
+ assert status.OK()
+ assert len(result) == 1
+
+ # @pytest.mark.repeat(5)
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def _test_add_vector_multi_threading(self, connect, ip_table):
+ '''
+ target: test add vectors, with multi threading
+ method: 10 thread add vectors concurrently
+ expected: status ok and result length is equal to the length off added vectors
+ '''
+ thread_num = 4
+ loops = 100
+ threads = []
+ total_ids = []
+ vector = gen_single_vector(dim)
+ def add():
+ i = 0
+ while i < loops:
+ status, ids = connect.add_vectors(ip_table, vector)
+ total_ids.append(ids[0])
+ i = i + 1
+ for i in range(thread_num):
+ x = threading.Thread(target=add, args=())
+ threads.append(x)
+ x.start()
+ time.sleep(0.2)
+ for th in threads:
+ th.join()
+ assert len(total_ids) == thread_num * loops
+ # make sure ids not the same
+ assert len(set(total_ids)) == thread_num * loops
+
+ # TODO: enable
+ # @pytest.mark.repeat(5)
+ @pytest.mark.timeout(ADD_TIMEOUT)
+ def _test_add_vector_with_multiprocessing(self, args):
+ '''
+ target: test add vectors, with multi processes
+ method: 10 processed add vectors concurrently
+ expected: status ok and result length is equal to the length off added vectors
+ '''
+ table = gen_unique_str("test_add_vector_with_multiprocessing")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ vector = gen_single_vector(dim)
+
+ process_num = 4
+ loop_num = 10
+ processes = []
+ # with dependent connection
+ def add(milvus):
+ i = 0
+ while i < loop_num:
+ status, ids = milvus.add_vectors(table, vector)
+ i = i + 1
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=add, args=(milvus,))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+ time.sleep(3)
+ status, count = milvus.get_table_row_count(table)
+ assert count == process_num * loop_num
+
+ def test_add_vector_multi_tables(self, connect):
+ '''
+ target: test add vectors is correct or not with multiple tables of IP
+ method: create 50 tables and add vectors into them in turn
+ expected: status ok
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_add_vector_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ time.sleep(2)
+ for j in range(10):
+ for i in range(50):
+ status, ids = connect.add_vectors(table_name=table_list[i], records=vectors)
+ assert status.OK()
+
+class TestAddAdvance:
+
+ @pytest.fixture(
+ scope="function",
+ params=[
+ 1,
+ 10,
+ 100,
+ 1000,
+ pytest.param(5000 - 1, marks=pytest.mark.xfail),
+ pytest.param(5000, marks=pytest.mark.xfail),
+ pytest.param(5000 + 1, marks=pytest.mark.xfail),
+ ],
+ )
+ def insert_count(self, request):
+ yield request.param
+
+ def test_insert_much(self, connect, table, insert_count):
+ '''
+ target: test add vectors with different length of vectors
+ method: set different vectors as add method params
+ expected: length of ids is equal to the length of vectors
+ '''
+ nb = insert_count
+ insert_vec_list = gen_vectors(nb, dim)
+ status, ids = connect.add_vectors(table, insert_vec_list)
+ assert len(ids) == nb
+ assert status.OK()
+
+ def test_insert_much_ip(self, connect, ip_table, insert_count):
+ '''
+ target: test add vectors with different length of vectors
+ method: set different vectors as add method params
+ expected: length of ids is equal to the length of vectors
+ '''
+ nb = insert_count
+ insert_vec_list = gen_vectors(nb, dim)
+ status, ids = connect.add_vectors(ip_table, insert_vec_list)
+ assert len(ids) == nb
+ assert status.OK()
+
+class TestAddTableNameInvalid(object):
+ """
+ Test adding vectors with invalid table names
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_table_names()
+ )
+ def get_table_name(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_add_vectors_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ vectors = gen_vectors(1, dim)
+ status, result = connect.add_vectors(table_name, vectors)
+ assert not status.OK()
+
+
+class TestAddTableVectorsInvalid(object):
+ single_vector = gen_single_vector(dim)
+ vectors = gen_vectors(2, dim)
+
+ """
+ Test adding vectors with invalid vectors
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_vectors()
+ )
+ def gen_vector(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_add_vector_with_invalid_vectors(self, connect, table, gen_vector):
+ tmp_single_vector = copy.deepcopy(self.single_vector)
+ tmp_single_vector[0][1] = gen_vector
+ with pytest.raises(Exception) as e:
+ status, result = connect.add_vectors(table, tmp_single_vector)
+
+ @pytest.mark.level(1)
+ def test_add_vectors_with_invalid_vectors(self, connect, table, gen_vector):
+ tmp_vectors = copy.deepcopy(self.vectors)
+ tmp_vectors[1][1] = gen_vector
+ with pytest.raises(Exception) as e:
+ status, result = connect.add_vectors(table, tmp_vectors)
\ No newline at end of file
diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py
new file mode 100644
index 0000000000..5ec9539011
--- /dev/null
+++ b/tests/milvus_python_test/test_connect.py
@@ -0,0 +1,386 @@
+import pytest
+from milvus import Milvus
+import pdb
+import threading
+from multiprocessing import Process
+from utils import *
+
+__version__ = '0.5.0'
+CONNECT_TIMEOUT = 12
+
+
+class TestConnect:
+
+ def local_ip(self, args):
+ '''
+ check if ip is localhost or not
+ '''
+ if not args["ip"] or args["ip"] == 'localhost' or args["ip"] == "127.0.0.1":
+ return True
+ else:
+ return False
+
+ def test_disconnect(self, connect):
+ '''
+ target: test disconnect
+ method: disconnect a connected client
+ expected: connect failed after disconnected
+ '''
+ res = connect.disconnect()
+ assert res.OK()
+ with pytest.raises(Exception) as e:
+ res = connect.server_version()
+
+ def test_disconnect_repeatedly(self, connect, args):
+ '''
+ target: test disconnect repeatedly
+ method: disconnect a connected client, disconnect again
+ expected: raise an error after disconnected
+ '''
+ if not connect.connected():
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ milvus.connect(uri=uri_value)
+ res = milvus.disconnect()
+ with pytest.raises(Exception) as e:
+ res = milvus.disconnect()
+ else:
+ res = connect.disconnect()
+ with pytest.raises(Exception) as e:
+ res = connect.disconnect()
+
+ def test_connect_correct_ip_port(self, args):
+ '''
+ target: test connect with corrent ip and port value
+ method: set correct ip and port
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ milvus.connect(host=args["ip"], port=args["port"])
+ assert milvus.connected()
+
+ def test_connect_connected(self, args):
+ '''
+ target: test connect and disconnect with corrent ip and port value, assert connected value
+ method: set correct ip and port
+ expected: connected is False
+ '''
+ milvus = Milvus()
+ milvus.connect(host=args["ip"], port=args["port"])
+ milvus.disconnect()
+ assert not milvus.connected()
+
+ # TODO: Currently we test with remote IP, localhost testing need to add
+ def _test_connect_ip_localhost(self, args):
+ '''
+ target: test connect with ip value: localhost
+ method: set host localhost
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ milvus.connect(host='localhost', port=args["port"])
+ assert milvus.connected()
+
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_wrong_ip_null(self, args):
+ '''
+ target: test connect with wrong ip value
+ method: set host null
+ expected: not use default ip, connected is False
+ '''
+ milvus = Milvus()
+ ip = ""
+ with pytest.raises(Exception) as e:
+ milvus.connect(host=ip, port=args["port"], timeout=1)
+ assert not milvus.connected()
+
+ def test_connect_uri(self, args):
+ '''
+ target: test connect with correct uri
+ method: uri format and value are both correct
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ def test_connect_uri_null(self, args):
+ '''
+ target: test connect with null uri
+ method: uri set null
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ uri_value = ""
+
+ if self.local_ip(args):
+ milvus.connect(uri=uri_value, timeout=1)
+ assert milvus.connected()
+ else:
+ with pytest.raises(Exception) as e:
+ milvus.connect(uri=uri_value, timeout=1)
+ assert not milvus.connected()
+
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_wrong_uri_wrong_port_null(self, args):
+ '''
+ target: test uri connect with port value wouldn't connected
+ method: set uri port null
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:" % args["ip"]
+ with pytest.raises(Exception) as e:
+ milvus.connect(uri=uri_value, timeout=1)
+
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_wrong_uri_wrong_ip_null(self, args):
+ '''
+ target: test uri connect with ip value wouldn't connected
+ method: set uri ip null
+ expected: connected is True
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://:%s" % args["port"]
+
+ with pytest.raises(Exception) as e:
+ milvus.connect(uri=uri_value, timeout=1)
+ assert not milvus.connected()
+
+ # TODO: enable
+ def _test_connect_with_multiprocess(self, args):
+ '''
+ target: test uri connect with multiprocess
+ method: set correct uri, test with multiprocessing connecting
+ expected: all connection is connected
+ '''
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ process_num = 4
+ processes = []
+
+ def connect(milvus):
+ milvus.connect(uri=uri_value)
+ with pytest.raises(Exception) as e:
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ for i in range(process_num):
+ milvus = Milvus()
+ p = Process(target=connect, args=(milvus, ))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ def test_connect_repeatedly(self, args):
+ '''
+ target: test connect repeatedly
+ method: connect again
+ expected: status.code is 0, and status.message shows have connected already
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ milvus.connect(uri=uri_value)
+
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ def test_connect_disconnect_repeatedly_once(self, args):
+ '''
+ target: test connect and disconnect repeatedly
+ method: disconnect, and then connect, assert connect status
+ expected: status.code is 0
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ milvus.connect(uri=uri_value)
+
+ milvus.disconnect()
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ def test_connect_disconnect_repeatedly_times(self, args):
+ '''
+ target: test connect and disconnect for 10 times repeatedly
+ method: disconnect, and then connect, assert connect status
+ expected: status.code is 0
+ '''
+ times = 10
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ milvus.connect(uri=uri_value)
+ for i in range(times):
+ milvus.disconnect()
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ # TODO: enable
+ def _test_connect_disconnect_with_multiprocess(self, args):
+ '''
+ target: test uri connect and disconnect repeatly with multiprocess
+ method: set correct uri, test with multiprocessing connecting and disconnecting
+ expected: all connection is connected after 10 times operation
+ '''
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ process_num = 4
+ processes = []
+
+ def connect(milvus):
+ milvus.connect(uri=uri_value)
+ milvus.disconnect()
+ milvus.connect(uri=uri_value)
+ assert milvus.connected()
+
+ for i in range(process_num):
+ milvus = Milvus()
+ p = Process(target=connect, args=(milvus, ))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ def test_connect_param_priority_no_port(self, args):
+ '''
+ target: both host_ip_port / uri are both given, if port is null, use the uri params
+ method: port set "", check if wrong uri connection is ok
+ expected: connect raise an exception and connected is false
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:19540" % args["ip"]
+ milvus.connect(host=args["ip"], port="", uri=uri_value)
+ assert milvus.connected()
+
+ def test_connect_param_priority_uri(self, args):
+ '''
+ target: both host_ip_port / uri are both given, if host is null, use the uri params
+ method: host set "", check if correct uri connection is ok
+ expected: connected is False
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ with pytest.raises(Exception) as e:
+ milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1)
+ assert not milvus.connected()
+
+ def test_connect_param_priority_both_hostip_uri(self, args):
+ '''
+ target: both host_ip_port / uri are both given, and not null, use the uri params
+ method: check if wrong uri connection is ok
+ expected: connect raise an exception and connected is false
+ '''
+ milvus = Milvus()
+ uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
+ with pytest.raises(Exception) as e:
+ milvus.connect(host=args["ip"], port=19540, uri=uri_value, timeout=1)
+ assert not milvus.connected()
+
+ def _test_add_vector_and_disconnect_concurrently(self):
+ '''
+ Target: test disconnect in the middle of add vectors
+ Method:
+ a. use coroutine or multi-processing, to simulate network crashing
+ b. data_set not too large incase disconnection happens when data is underd-preparing
+ c. data_set not too small incase disconnection happens when data has already been transferred
+ d. make sure disconnection happens when data is in-transport
+ Expected: Failure, get_table_row_count == 0
+
+ '''
+ pass
+
+ def _test_search_vector_and_disconnect_concurrently(self):
+ '''
+ Target: Test disconnect in the middle of search vectors(with large nq and topk)multiple times, and search/add vectors still work
+ Method:
+ a. coroutine or multi-processing, to simulate network crashing
+ b. connect, search and disconnect, repeating many times
+ c. connect and search, add vectors
+ Expected: Successfully searched back, successfully added
+
+ '''
+ pass
+
+ def _test_thread_safe_with_one_connection_shared_in_multi_threads(self):
+ '''
+ Target: test 1 connection thread safe
+ Method: 1 connection shared in multi-threads, all adding vectors, or other things
+ Expected: Functional as one thread
+
+ '''
+ pass
+
+
+class TestConnectIPInvalid(object):
+ """
+ Test connect server with invalid ip
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_ips()
+ )
+ def get_invalid_ip(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_with_invalid_ip(self, args, get_invalid_ip):
+ milvus = Milvus()
+ ip = get_invalid_ip
+ with pytest.raises(Exception) as e:
+ milvus.connect(host=ip, port=args["port"], timeout=1)
+ assert not milvus.connected()
+
+
+class TestConnectPortInvalid(object):
+ """
+ Test connect server with invalid ip
+ """
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_ports()
+ )
+ def get_invalid_port(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_with_invalid_port(self, args, get_invalid_port):
+ '''
+ target: test ip:port connect with invalid port value
+ method: set port in gen_invalid_ports
+ expected: connected is False
+ '''
+ milvus = Milvus()
+ port = get_invalid_port
+ with pytest.raises(Exception) as e:
+ milvus.connect(host=args["ip"], port=port, timeout=1)
+ assert not milvus.connected()
+
+
+class TestConnectURIInvalid(object):
+ """
+ Test connect server with invalid uri
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_uris()
+ )
+ def get_invalid_uri(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(CONNECT_TIMEOUT)
+ def test_connect_with_invalid_uri(self, get_invalid_uri):
+ '''
+ target: test uri connect with invalid uri value
+ method: set port in gen_invalid_uris
+ expected: connected is False
+ '''
+ milvus = Milvus()
+ uri_value = get_invalid_uri
+ with pytest.raises(Exception) as e:
+ milvus.connect(uri=uri_value, timeout=1)
+ assert not milvus.connected()
diff --git a/tests/milvus_python_test/test_delete_vectors.py b/tests/milvus_python_test/test_delete_vectors.py
new file mode 100644
index 0000000000..cb450838a2
--- /dev/null
+++ b/tests/milvus_python_test/test_delete_vectors.py
@@ -0,0 +1,419 @@
+import time
+import random
+import pdb
+import logging
+import threading
+from builtins import Exception
+from multiprocessing import Pool, Process
+import pytest
+
+from milvus import Milvus, IndexType
+from utils import *
+
+
+dim = 128
+index_file_size = 10
+table_id = "test_delete"
+DELETE_TIMEOUT = 60
+vectors = gen_vectors(100, dim)
+
+class TestDeleteVectorsBase:
+ """
+ generate invalid query range params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_current_day(), get_current_day()),
+ (get_last_day(1), get_last_day(1)),
+ (get_next_day(1), get_next_day(1))
+ ]
+ )
+ def get_invalid_range(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_invalid_range(self, connect, table, get_invalid_range):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with invalid date params
+ expected: return code 0
+ '''
+ start_date = get_invalid_range[0]
+ end_date = get_invalid_range[1]
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
+ assert not status.OK()
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_last_day(1)),
+ (get_last_day(2), get_current_day()),
+ (get_next_day(1), get_next_day(2))
+ ]
+ )
+ def get_valid_range_no_result(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_no_result(self, connect, table, get_valid_range_no_result):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_valid_range_no_result[0]
+ end_date = get_valid_range_no_result[1]
+ status, ids = connect.add_vectors(table, vectors)
+ time.sleep(2)
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(table)
+ assert result == 100
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_next_day(2)),
+ (get_current_day(), get_next_day(2)),
+ ]
+ )
+ def get_valid_range(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range(self, connect, table, get_valid_range):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_valid_range[0]
+ end_date = get_valid_range[1]
+ status, ids = connect.add_vectors(table, vectors)
+ time.sleep(2)
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(table)
+ assert result == 0
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_index_created(self, connect, table, get_index_params):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ logging.getLogger().info(status)
+ logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(table)
+ assert result == 0
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_no_data(self, connect, table):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params, and no data in db
+ expected: return code 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ # status, ids = connect.add_vectors(table, vectors)
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
+ assert status.OK()
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_table_not_existed(self, connect):
+ '''
+ target: test delete vectors, table not existed in db
+ method: call `delete_vectors_by_range`, with table not existed
+ expected: return code not 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ table_name = gen_unique_str("not_existed_table")
+ status = connect.delete_vectors_by_range(table_name, start_date, end_date)
+ assert not status.OK()
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_table_None(self, connect, table):
+ '''
+ target: test delete vectors, table set Nope
+ method: call `delete_vectors_by_range`, with table value is None
+ expected: return code not 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ table_name = None
+ with pytest.raises(Exception) as e:
+ status = connect.delete_vectors_by_range(table_name, start_date, end_date)
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
+ '''
+ target: test delete vectors is correct or not with multiple tables of L2
+ method: create 50 tables and add vectors into them , then delete vectors
+ in valid range
+ expected: return code 0
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ time.sleep(2)
+ start_date = get_valid_range[0]
+ end_date = get_valid_range[1]
+ for i in range(50):
+ status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(table_list[i])
+ assert result == 0
+
+
+class TestDeleteVectorsIP:
+ """
+ generate invalid query range params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_current_day(), get_current_day()),
+ (get_last_day(1), get_last_day(1)),
+ (get_next_day(1), get_next_day(1))
+ ]
+ )
+ def get_invalid_range(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_invalid_range(self, connect, ip_table, get_invalid_range):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with invalid date params
+ expected: return code 0
+ '''
+ start_date = get_invalid_range[0]
+ end_date = get_invalid_range[1]
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
+ assert not status.OK()
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_last_day(1)),
+ (get_last_day(2), get_current_day()),
+ (get_next_day(1), get_next_day(2))
+ ]
+ )
+ def get_valid_range_no_result(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_no_result(self, connect, ip_table, get_valid_range_no_result):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_valid_range_no_result[0]
+ end_date = get_valid_range_no_result[1]
+ status, ids = connect.add_vectors(ip_table, vectors)
+ time.sleep(2)
+ status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(ip_table)
+ assert result == 100
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_next_day(2)),
+ (get_current_day(), get_next_day(2)),
+ ]
+ )
+ def get_valid_range(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range(self, connect, ip_table, get_valid_range):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_valid_range[0]
+ end_date = get_valid_range[1]
+ status, ids = connect.add_vectors(ip_table, vectors)
+ time.sleep(2)
+ status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(ip_table)
+ assert result == 0
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_index_created(self, connect, ip_table, get_index_params):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params
+ expected: return code 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ logging.getLogger().info(status)
+ logging.getLogger().info("Start delete vectors by range: %s:%s" % (start_date, end_date))
+ status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(ip_table)
+ assert result == 0
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_no_data(self, connect, ip_table):
+ '''
+ target: test delete vectors, no index created
+ method: call `delete_vectors_by_range`, with valid date params, and no data in db
+ expected: return code 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ # status, ids = connect.add_vectors(table, vectors)
+ status = connect.delete_vectors_by_range(ip_table, start_date, end_date)
+ assert status.OK()
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_table_None(self, connect, ip_table):
+ '''
+ target: test delete vectors, table set Nope
+ method: call `delete_vectors_by_range`, with table value is None
+ expected: return code not 0
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ table_name = None
+ with pytest.raises(Exception) as e:
+ status = connect.delete_vectors_by_range(table_name, start_date, end_date)
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_valid_range_multi_tables(self, connect, get_valid_range):
+ '''
+ target: test delete vectors is correct or not with multiple tables of IP
+ method: create 50 tables and add vectors into them , then delete vectors
+ in valid range
+ expected: return code 0
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_delete_vectors_valid_range_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ time.sleep(2)
+ start_date = get_valid_range[0]
+ end_date = get_valid_range[1]
+ for i in range(50):
+ status = connect.delete_vectors_by_range(table_list[i], start_date, end_date)
+ assert status.OK()
+ status, result = connect.get_table_row_count(table_list[i])
+ assert result == 0
+
+class TestDeleteVectorsParamsInvalid:
+
+ """
+ Test search table with invalid table names
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_table_names()
+ )
+ def get_table_name(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_delete_vectors_table_invalid_name(self, connect, get_table_name):
+ '''
+ '''
+ start_date = get_current_day()
+ end_date = get_next_day(2)
+ table_name = get_table_name
+ logging.getLogger().info(table_name)
+ top_k = 1
+ nprobe = 1
+ status = connect.delete_vectors_by_range(table_name, start_date, end_date)
+ assert not status.OK()
+
+ """
+ Test search table with invalid query ranges
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_query_ranges()
+ )
+ def get_query_ranges(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(DELETE_TIMEOUT)
+ def test_delete_vectors_range_invalid(self, connect, table, get_query_ranges):
+ '''
+ target: test search fuction, with the wrong query_range
+ method: search with query_range
+ expected: raise an error, and the connection is normal
+ '''
+ start_date = get_query_ranges[0][0]
+ end_date = get_query_ranges[0][1]
+ status, ids = connect.add_vectors(table, vectors)
+ logging.getLogger().info(get_query_ranges)
+ with pytest.raises(Exception) as e:
+ status = connect.delete_vectors_by_range(table, start_date, end_date)
\ No newline at end of file
diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py
new file mode 100644
index 0000000000..53a5fe16d2
--- /dev/null
+++ b/tests/milvus_python_test/test_index.py
@@ -0,0 +1,966 @@
+"""
+ For testing index operations, including `create_index`, `describe_index` and `drop_index` interfaces
+"""
+import logging
+import pytest
+import time
+import pdb
+import threading
+from multiprocessing import Pool, Process
+import numpy
+from milvus import Milvus, IndexType, MetricType
+from utils import *
+
+nb = 100000
+dim = 128
+index_file_size = 10
+vectors = gen_vectors(nb, dim)
+vectors /= numpy.linalg.norm(vectors)
+vectors = vectors.tolist()
+BUILD_TIMEOUT = 60
+nprobe = 1
+
+
+class TestIndexBase:
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_simple_index_params()
+ )
+ def get_simple_index_params(self, request):
+ yield request.param
+
+ """
+ ******************************************************************
+ The following cases are used to test `create_index` function
+ ******************************************************************
+ """
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index(self, connect, table, get_index_params):
+ '''
+ target: test create index interface
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+
+ @pytest.mark.level(2)
+ def test_create_index_without_connect(self, dis_connect, table):
+ '''
+ target: test create index without connection
+ method: create table and add vectors in it, check if added successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.create_index(table, random.choice(gen_index_params()))
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index_search_with_query_vectors(self, connect, table, get_index_params):
+ '''
+ target: test create index interface, search with more query vectors
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ logging.getLogger().info(connect.describe_index(table))
+ query_vecs = [vectors[0], vectors[1], vectors[2]]
+ top_k = 5
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ logging.getLogger().info(result)
+
+ # TODO: enable
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ @pytest.mark.level(2)
+ def _test_create_index_multiprocessing(self, connect, table, args):
+ '''
+ target: test create index interface with multiprocess
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ status, ids = connect.add_vectors(table, vectors)
+
+ def build(connect):
+ status = connect.create_index(table)
+ assert status.OK()
+
+ process_num = 8
+ processes = []
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+
+ for i in range(process_num):
+ m = Milvus()
+ m.connect(uri=uri)
+ p = Process(target=build, args=(m,))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+ assert result[0][0].distance == 0.0
+
+ # TODO: enable
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def _test_create_index_multiprocessing_multitable(self, connect, args):
+ '''
+ target: test create index interface with multiprocess
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ process_num = 8
+ loop_num = 8
+ processes = []
+
+ table = []
+ j = 0
+ while j < (process_num*loop_num):
+ table_name = gen_unique_str("test_create_index_multiprocessing")
+ table.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_type': IndexType.FLAT,
+ 'store_raw_vector': False}
+ connect.create_table(param)
+ j = j + 1
+
+ def create_index():
+ i = 0
+ while i < loop_num:
+ # assert connect.has_table(table[ids*process_num+i])
+ status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
+
+ status = connect.create_index(table[ids*process_num+i])
+ assert status.OK()
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+ assert result[0][0].distance == 0.0
+ i = i + 1
+
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+
+ for i in range(process_num):
+ m = Milvus()
+ m.connect(uri=uri)
+ ids = i
+ p = Process(target=create_index, args=(m,ids))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+
+ def test_create_index_table_not_existed(self, connect):
+ '''
+ target: test create index interface when table name not existed
+ method: create table and add vectors in it, create index with an random table_name
+ , make sure the table name not in index
+ expected: return code not equals to 0, create index failed
+ '''
+ table_name = gen_unique_str(self.__class__.__name__)
+ status = connect.create_index(table_name, random.choice(gen_index_params()))
+ assert not status.OK()
+
+ def test_create_index_table_None(self, connect):
+ '''
+ target: test create index interface when table name is None
+ method: create table and add vectors in it, create index with an table_name: None
+ expected: return code not equals to 0, create index failed
+ '''
+ table_name = None
+ with pytest.raises(Exception) as e:
+ status = connect.create_index(table_name, random.choice(gen_index_params()))
+
+ def test_create_index_no_vectors(self, connect, table):
+ '''
+ target: test create index interface when there is no vectors in table
+ method: create table and add no vectors in it, and then create index
+ expected: return code equals to 0
+ '''
+ status = connect.create_index(table, random.choice(gen_index_params()))
+ assert status.OK()
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index_no_vectors_then_add_vectors(self, connect, table):
+ '''
+ target: test create index interface when there is no vectors in table, and does not affect the subsequent process
+ method: create table and add no vectors in it, and then create index, add vectors in it
+ expected: return code equals to 0
+ '''
+ status = connect.create_index(table, random.choice(gen_index_params()))
+ status, ids = connect.add_vectors(table, vectors)
+ assert status.OK()
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_same_index_repeatedly(self, connect, table):
+ '''
+ target: check if index can be created repeatedly, with the same create_index params
+ method: create index after index have been built
+ expected: return code success, and search ok
+ '''
+ status, ids = connect.add_vectors(table, vectors)
+ index_params = random.choice(gen_index_params())
+ # index_params = get_index_params
+ status = connect.create_index(table, index_params)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_different_index_repeatedly(self, connect, table):
+ '''
+ target: check if index can be created repeatedly, with the different create_index params
+ method: create another index with different index_params after index have been built
+ expected: return code 0, and describe index result equals with the second index params
+ '''
+ status, ids = connect.add_vectors(table, vectors)
+ index_params = random.sample(gen_index_params(), 2)
+ logging.getLogger().info(index_params)
+ status = connect.create_index(table, index_params[0])
+ status = connect.create_index(table, index_params[1])
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ assert result._nlist == index_params[1]["nlist"]
+ assert result._table_name == table
+ assert result._index_type == index_params[1]["index_type"]
+
+ """
+ ******************************************************************
+ The following cases are used to test `describe_index` function
+ ******************************************************************
+ """
+
+ def test_describe_index(self, connect, table, get_index_params):
+ '''
+ target: test describe index interface
+ method: create table and add vectors in it, create index, call describe index
+ expected: return code 0, and index instructure
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert result._nlist == index_params["nlist"]
+ assert result._table_name == table
+ assert result._index_type == index_params["index_type"]
+
+ def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
+ '''
+ target: test create, describe and drop index interface with multiple tables of L2
+ method: create tables and add vectors in it, create index, call describe index
+ expected: return code 0, and index instructure
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(10):
+ table_name = gen_unique_str('test_create_index_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ index_params = get_simple_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ status = connect.create_index(table_name, index_params)
+ assert status.OK()
+
+ for i in range(10):
+ status, result = connect.describe_index(table_list[i])
+ logging.getLogger().info(result)
+ assert result._nlist == index_params["nlist"]
+ assert result._table_name == table_list[i]
+ assert result._index_type == index_params["index_type"]
+
+ for i in range(10):
+ status = connect.drop_index(table_list[i])
+ assert status.OK()
+ status, result = connect.describe_index(table_list[i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[i]
+ assert result._index_type == IndexType.FLAT
+
+ @pytest.mark.level(2)
+ def test_describe_index_without_connect(self, dis_connect, table):
+ '''
+ target: test describe index without connection
+ method: describe index, and check if describe successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.describe_index(table)
+
+ def test_describe_index_table_not_existed(self, connect):
+ '''
+ target: test describe index interface when table name not existed
+ method: create table and add vectors in it, create index with an random table_name
+ , make sure the table name not in index
+ expected: return code not equals to 0, describe index failed
+ '''
+ table_name = gen_unique_str(self.__class__.__name__)
+ status, result = connect.describe_index(table_name)
+ assert not status.OK()
+
+ def test_describe_index_table_None(self, connect):
+ '''
+ target: test describe index interface when table name is None
+ method: create table and add vectors in it, create index with an table_name: None
+ expected: return code not equals to 0, describe index failed
+ '''
+ table_name = None
+ with pytest.raises(Exception) as e:
+ status = connect.describe_index(table_name)
+
+ def test_describe_index_not_create(self, connect, table):
+ '''
+ target: test describe index interface when index not created
+ method: create table and add vectors in it, create index with an random table_name
+ , make sure the table name not in index
+ expected: return code not equals to 0, describe index failed
+ '''
+ status, ids = connect.add_vectors(table, vectors)
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert status.OK()
+ # assert result._nlist == index_params["nlist"]
+ # assert result._table_name == table
+ # assert result._index_type == index_params["index_type"]
+
+ """
+ ******************************************************************
+ The following cases are used to test `drop_index` function
+ ******************************************************************
+ """
+
+ def test_drop_index(self, connect, table, get_index_params):
+ '''
+ target: test drop index interface
+ method: create table and add vectors in it, create index, call drop index
+ expected: return code 0, and default index param
+ '''
+ index_params = get_index_params
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(table)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table
+ assert result._index_type == IndexType.FLAT
+
+ def test_drop_index_repeatly(self, connect, table, get_simple_index_params):
+ '''
+ target: test drop index repeatly
+ method: create index, call drop index, and drop again
+ expected: return code 0
+ '''
+ index_params = get_simple_index_params
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(table)
+ assert status.OK()
+ status = connect.drop_index(table)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table
+ assert result._index_type == IndexType.FLAT
+
+ @pytest.mark.level(2)
+ def test_drop_index_without_connect(self, dis_connect, table):
+ '''
+ target: test drop index without connection
+ method: drop index, and check if drop successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.drop_index(table)
+
+ def test_drop_index_table_not_existed(self, connect):
+ '''
+ target: test drop index interface when table name not existed
+ method: create table and add vectors in it, create index with an random table_name
+ , make sure the table name not in index, and then drop it
+ expected: return code not equals to 0, drop index failed
+ '''
+ table_name = gen_unique_str(self.__class__.__name__)
+ status = connect.drop_index(table_name)
+ assert not status.OK()
+
+ def test_drop_index_table_None(self, connect):
+ '''
+ target: test drop index interface when table name is None
+ method: create table and add vectors in it, create index with an table_name: None
+ expected: return code not equals to 0, drop index failed
+ '''
+ table_name = None
+ with pytest.raises(Exception) as e:
+ status = connect.drop_index(table_name)
+
+ def test_drop_index_table_not_create(self, connect, table):
+ '''
+ target: test drop index interface when index not created
+ method: create table and add vectors in it, create index
+ expected: return code not equals to 0, drop index failed
+ '''
+ index_params = random.choice(gen_index_params())
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ # no create index
+ status = connect.drop_index(table)
+ logging.getLogger().info(status)
+ assert status.OK()
+
+ def test_create_drop_index_repeatly(self, connect, table, get_simple_index_params):
+ '''
+ target: test create / drop index repeatly, use the same index params
+ method: create index, drop index, four times
+ expected: return code 0
+ '''
+ index_params = get_simple_index_params
+ status, ids = connect.add_vectors(table, vectors)
+ for i in range(2):
+ status = connect.create_index(table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(table)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table
+ assert result._index_type == IndexType.FLAT
+
+ def test_create_drop_index_repeatly_different_index_params(self, connect, table):
+ '''
+ target: test create / drop index repeatly, use the different index params
+ method: create index, drop index, four times, each tme use different index_params to create index
+ expected: return code 0
+ '''
+ index_params = random.sample(gen_index_params(), 2)
+ status, ids = connect.add_vectors(table, vectors)
+ for i in range(2):
+ status = connect.create_index(table, index_params[i])
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(table)
+ assert status.OK()
+ status, result = connect.describe_index(table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table
+ assert result._index_type == IndexType.FLAT
+
+
+class TestIndexIP:
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_simple_index_params()
+ )
+ def get_simple_index_params(self, request):
+ yield request.param
+
+ """
+ ******************************************************************
+ The following cases are used to test `create_index` function
+ ******************************************************************
+ """
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index(self, connect, ip_table, get_index_params):
+ '''
+ target: test create index interface
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+
+ @pytest.mark.level(2)
+ def test_create_index_without_connect(self, dis_connect, ip_table):
+ '''
+ target: test create index without connection
+ method: create table and add vectors in it, check if added successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.create_index(ip_table, random.choice(gen_index_params()))
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index_search_with_query_vectors(self, connect, ip_table, get_index_params):
+ '''
+ target: test create index interface, search with more query vectors
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ logging.getLogger().info(connect.describe_index(ip_table))
+ query_vecs = [vectors[0], vectors[1], vectors[2]]
+ top_k = 5
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ # logging.getLogger().info(result)
+
+ # TODO: enable
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ @pytest.mark.level(2)
+ def _test_create_index_multiprocessing(self, connect, ip_table, args):
+ '''
+ target: test create index interface with multiprocess
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ status, ids = connect.add_vectors(ip_table, vectors)
+
+ def build(connect):
+ status = connect.create_index(ip_table)
+ assert status.OK()
+
+ process_num = 8
+ processes = []
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+
+ for i in range(process_num):
+ m = Milvus()
+ m.connect(uri=uri)
+ p = Process(target=build, args=(m,))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+ assert result[0][0].distance == 0.0
+
+ # TODO: enable
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def _test_create_index_multiprocessing_multitable(self, connect, args):
+ '''
+ target: test create index interface with multiprocess
+ method: create table and add vectors in it, create index
+ expected: return code equals to 0, and search success
+ '''
+ process_num = 8
+ loop_num = 8
+ processes = []
+
+ table = []
+ j = 0
+ while j < (process_num*loop_num):
+ table_name = gen_unique_str("test_create_index_multiprocessing")
+ table.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim}
+ connect.create_table(param)
+ j = j + 1
+
+ def create_index():
+ i = 0
+ while i < loop_num:
+ # assert connect.has_table(table[ids*process_num+i])
+ status, ids = connect.add_vectors(table[ids*process_num+i], vectors)
+
+ status = connect.create_index(table[ids*process_num+i])
+ assert status.OK()
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(table[ids*process_num+i], top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+ assert result[0][0].distance == 0.0
+ i = i + 1
+
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+
+ for i in range(process_num):
+ m = Milvus()
+ m.connect(uri=uri)
+ ids = i
+ p = Process(target=create_index, args=(m,ids))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+
+ def test_create_index_no_vectors(self, connect, ip_table):
+ '''
+ target: test create index interface when there is no vectors in table
+ method: create table and add no vectors in it, and then create index
+ expected: return code equals to 0
+ '''
+ status = connect.create_index(ip_table, random.choice(gen_index_params()))
+ assert status.OK()
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_index_no_vectors_then_add_vectors(self, connect, ip_table):
+ '''
+ target: test create index interface when there is no vectors in table, and does not affect the subsequent process
+ method: create table and add no vectors in it, and then create index, add vectors in it
+ expected: return code equals to 0
+ '''
+ status = connect.create_index(ip_table, random.choice(gen_index_params()))
+ status, ids = connect.add_vectors(ip_table, vectors)
+ assert status.OK()
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_same_index_repeatedly(self, connect, ip_table):
+ '''
+ target: check if index can be created repeatedly, with the same create_index params
+ method: create index after index have been built
+ expected: return code success, and search ok
+ '''
+ status, ids = connect.add_vectors(ip_table, vectors)
+ index_params = random.choice(gen_index_params())
+ # index_params = get_index_params
+ status = connect.create_index(ip_table, index_params)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+ query_vec = [vectors[0]]
+ top_k = 1
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
+ assert len(result) == 1
+ assert len(result[0]) == top_k
+
+ @pytest.mark.timeout(BUILD_TIMEOUT)
+ def test_create_different_index_repeatedly(self, connect, ip_table):
+ '''
+ target: check if index can be created repeatedly, with the different create_index params
+ method: create another index with different index_params after index have been built
+ expected: return code 0, and describe index result equals with the second index params
+ '''
+ status, ids = connect.add_vectors(ip_table, vectors)
+ index_params = random.sample(gen_index_params(), 2)
+ logging.getLogger().info(index_params)
+ status = connect.create_index(ip_table, index_params[0])
+ status = connect.create_index(ip_table, index_params[1])
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ assert result._nlist == index_params[1]["nlist"]
+ assert result._table_name == ip_table
+ assert result._index_type == index_params[1]["index_type"]
+
+ """
+ ******************************************************************
+ The following cases are used to test `describe_index` function
+ ******************************************************************
+ """
+
+ def test_describe_index(self, connect, ip_table, get_index_params):
+ '''
+ target: test describe index interface
+ method: create table and add vectors in it, create index, call describe index
+ expected: return code 0, and index instructure
+ '''
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert result._nlist == index_params["nlist"]
+ assert result._table_name == ip_table
+ assert result._index_type == index_params["index_type"]
+
+ def test_describe_and_drop_index_multi_tables(self, connect, get_simple_index_params):
+ '''
+ target: test create, describe and drop index interface with multiple tables of IP
+ method: create tables and add vectors in it, create index, call describe index
+ expected: return code 0, and index instructure
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(10):
+ table_name = gen_unique_str('test_create_index_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ index_params = get_simple_index_params
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ status = connect.create_index(table_name, index_params)
+ assert status.OK()
+
+ for i in range(10):
+ status, result = connect.describe_index(table_list[i])
+ logging.getLogger().info(result)
+ assert result._nlist == index_params["nlist"]
+ assert result._table_name == table_list[i]
+ assert result._index_type == index_params["index_type"]
+
+ for i in range(10):
+ status = connect.drop_index(table_list[i])
+ assert status.OK()
+ status, result = connect.describe_index(table_list[i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[i]
+ assert result._index_type == IndexType.FLAT
+
+ @pytest.mark.level(2)
+ def test_describe_index_without_connect(self, dis_connect, ip_table):
+ '''
+ target: test describe index without connection
+ method: describe index, and check if describe successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.describe_index(ip_table)
+
+ def test_describe_index_not_create(self, connect, ip_table):
+ '''
+ target: test describe index interface when index not created
+ method: create table and add vectors in it, create index with an random table_name
+ , make sure the table name not in index
+ expected: return code not equals to 0, describe index failed
+ '''
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert status.OK()
+ # assert result._nlist == index_params["nlist"]
+ # assert result._table_name == table
+ # assert result._index_type == index_params["index_type"]
+
+ """
+ ******************************************************************
+ The following cases are used to test `drop_index` function
+ ******************************************************************
+ """
+
+ def test_drop_index(self, connect, ip_table, get_index_params):
+ '''
+ target: test drop index interface
+ method: create table and add vectors in it, create index, call drop index
+ expected: return code 0, and default index param
+ '''
+ index_params = get_index_params
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(ip_table)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == ip_table
+ assert result._index_type == IndexType.FLAT
+
+ def test_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
+ '''
+ target: test drop index repeatly
+ method: create index, call drop index, and drop again
+ expected: return code 0
+ '''
+ index_params = get_simple_index_params
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(ip_table)
+ assert status.OK()
+ status = connect.drop_index(ip_table)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == ip_table
+ assert result._index_type == IndexType.FLAT
+
+ @pytest.mark.level(2)
+ def test_drop_index_without_connect(self, dis_connect, ip_table):
+ '''
+ target: test drop index without connection
+ method: drop index, and check if drop successfully
+ expected: raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.drop_index(ip_table, random.choice(gen_index_params()))
+
+ def test_drop_index_table_not_create(self, connect, ip_table):
+ '''
+ target: test drop index interface when index not created
+ method: create table and add vectors in it, create index
+ expected: return code not equals to 0, drop index failed
+ '''
+ index_params = random.choice(gen_index_params())
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ # no create index
+ status = connect.drop_index(ip_table)
+ logging.getLogger().info(status)
+ assert status.OK()
+
+ def test_create_drop_index_repeatly(self, connect, ip_table, get_simple_index_params):
+ '''
+ target: test create / drop index repeatly, use the same index params
+ method: create index, drop index, four times
+ expected: return code 0
+ '''
+ index_params = get_simple_index_params
+ status, ids = connect.add_vectors(ip_table, vectors)
+ for i in range(2):
+ status = connect.create_index(ip_table, index_params)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(ip_table)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == ip_table
+ assert result._index_type == IndexType.FLAT
+
+ def test_create_drop_index_repeatly_different_index_params(self, connect, ip_table):
+ '''
+ target: test create / drop index repeatly, use the different index params
+ method: create index, drop index, four times, each tme use different index_params to create index
+ expected: return code 0
+ '''
+ index_params = random.sample(gen_index_params(), 2)
+ status, ids = connect.add_vectors(ip_table, vectors)
+ for i in range(2):
+ status = connect.create_index(ip_table, index_params[i])
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ assert result._nlist == index_params[i]["nlist"]
+ assert result._table_name == ip_table
+ assert result._index_type == index_params[i]["index_type"]
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ status = connect.drop_index(ip_table)
+ assert status.OK()
+ status, result = connect.describe_index(ip_table)
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == ip_table
+ assert result._index_type == IndexType.FLAT
+
+
+class TestIndexTableInvalid(object):
+ """
+ Test create / describe / drop index interfaces with invalid table names
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_table_names()
+ )
+ def get_table_name(self, request):
+ yield request.param
+
+ # @pytest.mark.level(1)
+ def test_create_index_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ status = connect.create_index(table_name, random.choice(gen_index_params()))
+ assert not status.OK()
+
+ # @pytest.mark.level(1)
+ def test_describe_index_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ status, result = connect.describe_index(table_name)
+ assert not status.OK()
+
+ # @pytest.mark.level(1)
+ def test_drop_index_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ status = connect.drop_index(table_name)
+ assert not status.OK()
+
+
+class TestCreateIndexParamsInvalid(object):
+ """
+ Test Building index with invalid table names, table names not in db
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_create_index_with_invalid_index_params(self, connect, table, get_index_params):
+ index_params = get_index_params
+ index_type = index_params["index_type"]
+ nlist = index_params["nlist"]
+ logging.getLogger().info(index_params)
+ status, ids = connect.add_vectors(table, vectors)
+ # if not isinstance(index_type, int) or not isinstance(nlist, int):
+ with pytest.raises(Exception) as e:
+ status = connect.create_index(table, index_params)
+ # else:
+ # status = connect.create_index(table, index_params)
+ # assert not status.OK()
diff --git a/tests/milvus_python_test/test_mix.py b/tests/milvus_python_test/test_mix.py
new file mode 100644
index 0000000000..4578e330b3
--- /dev/null
+++ b/tests/milvus_python_test/test_mix.py
@@ -0,0 +1,180 @@
+import pdb
+import copy
+import pytest
+import threading
+import datetime
+import logging
+from time import sleep
+from multiprocessing import Process
+import numpy
+from milvus import Milvus, IndexType, MetricType
+from utils import *
+
+dim = 128
+index_file_size = 10
+table_id = "test_mix"
+add_interval_time = 2
+vectors = gen_vectors(100000, dim)
+vectors /= numpy.linalg.norm(vectors)
+vectors = vectors.tolist()
+top_k = 1
+nprobe = 1
+epsilon = 0.0001
+index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
+
+
+class TestMixBase:
+
+ # TODO: enable
+ def _test_search_during_createIndex(self, args):
+ loops = 100000
+ table = "test_search_during_createIndex"
+ query_vecs = [vectors[0], vectors[1]]
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ id_0 = 0; id_1 = 0
+ milvus_instance = Milvus()
+ milvus_instance.connect(uri=uri)
+ milvus_instance.create_table({'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2})
+ for i in range(10):
+ status, ids = milvus_instance.add_vectors(table, vectors)
+ # logging.getLogger().info(ids)
+ if i == 0:
+ id_0 = ids[0]; id_1 = ids[1]
+ def create_index(milvus_instance):
+ logging.getLogger().info("In create index")
+ status = milvus_instance.create_index(table, index_params)
+ logging.getLogger().info(status)
+ status, result = milvus_instance.describe_index(table)
+ logging.getLogger().info(result)
+ def add_vectors(milvus_instance):
+ logging.getLogger().info("In add vectors")
+ status, ids = milvus_instance.add_vectors(table, vectors)
+ logging.getLogger().info(status)
+ def search(milvus_instance):
+ for i in range(loops):
+ status, result = milvus_instance.search_vectors(table, top_k, nprobe, query_vecs)
+ logging.getLogger().info(status)
+ assert result[0][0].id == id_0
+ assert result[1][0].id == id_1
+ milvus_instance = Milvus()
+ milvus_instance.connect(uri=uri)
+ p_search = Process(target=search, args=(milvus_instance, ))
+ p_search.start()
+ milvus_instance = Milvus()
+ milvus_instance.connect(uri=uri)
+ p_create = Process(target=add_vectors, args=(milvus_instance, ))
+ p_create.start()
+ p_create.join()
+
+ def test_mix_multi_tables(self, connect):
+ '''
+ target: test functions with multiple tables of different metric_types and index_types
+ method: create 60 tables which 30 are L2 and the other are IP, add vectors into them
+ and test describe index and search
+ expected: status ok
+ '''
+ nq = 10000
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ idx = []
+
+ #create table and add vectors
+ for i in range(30):
+ table_name = gen_unique_str('test_mix_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ idx.append(ids[0])
+ idx.append(ids[10])
+ idx.append(ids[20])
+ assert status.OK()
+ for i in range(30):
+ table_name = gen_unique_str('test_mix_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ status, ids = connect.add_vectors(table_name=table_name, records=vectors)
+ idx.append(ids[0])
+ idx.append(ids[10])
+ idx.append(ids[20])
+ assert status.OK()
+ time.sleep(2)
+
+ #create index
+ for i in range(10):
+ index_params = {'index_type': IndexType.FLAT, 'nlist': 16384}
+ status = connect.create_index(table_list[i], index_params)
+ assert status.OK()
+ status = connect.create_index(table_list[30 + i], index_params)
+ assert status.OK()
+ index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
+ status = connect.create_index(table_list[10 + i], index_params)
+ assert status.OK()
+ status = connect.create_index(table_list[40 + i], index_params)
+ assert status.OK()
+ index_params = {'index_type': IndexType.IVF_SQ8, 'nlist': 16384}
+ status = connect.create_index(table_list[20 + i], index_params)
+ assert status.OK()
+ status = connect.create_index(table_list[50 + i], index_params)
+ assert status.OK()
+
+ #describe index
+ for i in range(10):
+ status, result = connect.describe_index(table_list[i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[i]
+ assert result._index_type == IndexType.FLAT
+ status, result = connect.describe_index(table_list[10 + i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[10 + i]
+ assert result._index_type == IndexType.IVFLAT
+ status, result = connect.describe_index(table_list[20 + i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[20 + i]
+ assert result._index_type == IndexType.IVF_SQ8
+ status, result = connect.describe_index(table_list[30 + i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[30 + i]
+ assert result._index_type == IndexType.FLAT
+ status, result = connect.describe_index(table_list[40 + i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[40 + i]
+ assert result._index_type == IndexType.IVFLAT
+ status, result = connect.describe_index(table_list[50 + i])
+ logging.getLogger().info(result)
+ assert result._nlist == 16384
+ assert result._table_name == table_list[50 + i]
+ assert result._index_type == IndexType.IVF_SQ8
+
+ #search
+ query_vecs = [vectors[0], vectors[10], vectors[20]]
+ for i in range(60):
+ table = table_list[i]
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ for j in range(len(query_vecs)):
+ assert len(result[j]) == top_k
+ for j in range(len(query_vecs)):
+ assert check_result(result[j], idx[3 * i + j])
+
+def check_result(result, id):
+ if len(result) >= 5:
+ return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
+ else:
+ return id in (i.id for i in result)
\ No newline at end of file
diff --git a/tests/milvus_python_test/test_ping.py b/tests/milvus_python_test/test_ping.py
new file mode 100644
index 0000000000..a55559bc63
--- /dev/null
+++ b/tests/milvus_python_test/test_ping.py
@@ -0,0 +1,77 @@
+import logging
+import pytest
+
+__version__ = '0.5.0'
+
+
+class TestPing:
+ def test_server_version(self, connect):
+ '''
+ target: test get the server version
+ method: call the server_version method after connected
+ expected: version should be the pymilvus version
+ '''
+ status, res = connect.server_version()
+ assert res == __version__
+
+ def test_server_status(self, connect):
+ '''
+ target: test get the server status
+ method: call the server_status method after connected
+ expected: status returned should be ok
+ '''
+ status, msg = connect.server_status()
+ assert status.OK()
+
+ def _test_server_cmd_with_params_version(self, connect):
+ '''
+ target: test cmd: version
+ method: cmd = "version" ...
+ expected: when cmd = 'version', return version of server;
+ '''
+ cmd = "version"
+ status, msg = connect.cmd(cmd)
+ logging.getLogger().info(status)
+ logging.getLogger().info(msg)
+ assert status.OK()
+ assert msg == __version__
+
+ def _test_server_cmd_with_params_others(self, connect):
+ '''
+ target: test cmd: lalala
+ method: cmd = "lalala" ...
+ expected: when cmd = 'version', return version of server;
+ '''
+ cmd = "rm -rf test"
+ status, msg = connect.cmd(cmd)
+ logging.getLogger().info(status)
+ logging.getLogger().info(msg)
+ assert status.OK()
+ # assert msg == __version__
+
+ def test_connected(self, connect):
+ assert connect.connected()
+
+
+class TestPingDisconnect:
+ def test_server_version(self, dis_connect):
+ '''
+ target: test get the server version, after disconnect
+ method: call the server_version method after connected
+ expected: version should not be the pymilvus version
+ '''
+ res = None
+ with pytest.raises(Exception) as e:
+ status, res = connect.server_version()
+ assert res is None
+
+ def test_server_status(self, dis_connect):
+ '''
+ target: test get the server status, after disconnect
+ method: call the server_status method after connected
+ expected: status returned should be not ok
+ '''
+ status = None
+ with pytest.raises(Exception) as e:
+ status, msg = connect.server_status()
+ assert status is None
diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py
new file mode 100644
index 0000000000..4f02ae0d37
--- /dev/null
+++ b/tests/milvus_python_test/test_search_vectors.py
@@ -0,0 +1,650 @@
+import pdb
+import copy
+import pytest
+import threading
+import datetime
+import logging
+from time import sleep
+from multiprocessing import Process
+import numpy
+from milvus import Milvus, IndexType, MetricType
+from utils import *
+
+dim = 128
+table_id = "test_search"
+add_interval_time = 2
+vectors = gen_vectors(100, dim)
+# vectors /= numpy.linalg.norm(vectors)
+# vectors = vectors.tolist()
+nrpobe = 1
+epsilon = 0.001
+
+
+class TestSearchBase:
+ def init_data(self, connect, table, nb=100):
+ '''
+ Generate vectors and add it in table, before search vectors
+ '''
+ global vectors
+ if nb == 100:
+ add_vectors = vectors
+ else:
+ add_vectors = gen_vectors(nb, dim)
+ # add_vectors /= numpy.linalg.norm(add_vectors)
+ # add_vectors = add_vectors.tolist()
+ status, ids = connect.add_vectors(table, add_vectors)
+ sleep(add_interval_time)
+ return add_vectors, ids
+
+ """
+ generate valid create_index params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ """
+ generate top-k params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[1, 99, 101, 1024, 2048, 2049]
+ )
+ def get_top_k(self, request):
+ yield request.param
+
+
+ def test_search_top_k_flat_index(self, connect, table, get_top_k):
+ '''
+ target: test basic search fuction, all the search params is corrent, change top-k value
+ method: search with the given vectors, check the result
+ expected: search status ok, and the length of the result is top_k
+ '''
+ vectors, ids = self.init_data(connect, table)
+ query_vec = [vectors[0]]
+ top_k = get_top_k
+ nprobe = 1
+ status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
+ if top_k <= 2048:
+ assert status.OK()
+ assert len(result[0]) == min(len(vectors), top_k)
+ assert result[0][0].distance <= epsilon
+ assert check_result(result[0], ids[0])
+ else:
+ assert not status.OK()
+
+ def test_search_l2_index_params(self, connect, table, get_index_params):
+ '''
+ target: test basic search fuction, all the search params is corrent, test all index params, and build
+ method: search with the given vectors, check the result
+ expected: search status ok, and the length of the result is top_k
+ '''
+
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ vectors, ids = self.init_data(connect, table)
+ status = connect.create_index(table, index_params)
+ query_vec = [vectors[0]]
+ top_k = 10
+ nprobe = 1
+ status, result = connect.search_vectors(table, top_k, nrpobe, query_vec)
+ logging.getLogger().info(result)
+ if top_k <= 1024:
+ assert status.OK()
+ assert len(result[0]) == min(len(vectors), top_k)
+ assert check_result(result[0], ids[0])
+ assert result[0][0].distance <= epsilon
+ else:
+ assert not status.OK()
+
+ def test_search_ip_index_params(self, connect, ip_table, get_index_params):
+ '''
+ target: test basic search fuction, all the search params is corrent, test all index params, and build
+ method: search with the given vectors, check the result
+ expected: search status ok, and the length of the result is top_k
+ '''
+
+ index_params = get_index_params
+ logging.getLogger().info(index_params)
+ vectors, ids = self.init_data(connect, ip_table)
+ status = connect.create_index(ip_table, index_params)
+ query_vec = [vectors[0]]
+ top_k = 10
+ nprobe = 1
+ status, result = connect.search_vectors(ip_table, top_k, nrpobe, query_vec)
+ logging.getLogger().info(result)
+
+ if top_k <= 1024:
+ assert status.OK()
+ assert len(result[0]) == min(len(vectors), top_k)
+ assert check_result(result[0], ids[0])
+ assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
+ else:
+ assert not status.OK()
+
+ @pytest.mark.level(2)
+ def test_search_vectors_without_connect(self, dis_connect, table):
+ '''
+ target: test search vectors without connection
+ method: use dis connected instance, call search method and check if search successfully
+ expected: raise exception
+ '''
+ query_vectors = [vectors[0]]
+ top_k = 1
+ nprobe = 1
+ with pytest.raises(Exception) as e:
+ status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors)
+
+ def test_search_table_name_not_existed(self, connect, table):
+ '''
+ target: search table not existed
+ method: search with the random table_name, which is not in db
+ expected: status not ok
+ '''
+ table_name = gen_unique_str("not_existed_table")
+ top_k = 1
+ nprobe = 1
+ query_vecs = [vectors[0]]
+ status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
+ assert not status.OK()
+
+ def test_search_table_name_None(self, connect, table):
+ '''
+ target: search table that table name is None
+ method: search with the table_name: None
+ expected: status not ok
+ '''
+ table_name = None
+ top_k = 1
+ nprobe = 1
+ query_vecs = [vectors[0]]
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
+
+ def test_search_top_k_query_records(self, connect, table):
+ '''
+ target: test search fuction, with search params: query_records
+ method: search with the given query_records, which are subarrays of the inserted vectors
+ expected: status ok and the returned vectors should be query_records
+ '''
+ top_k = 10
+ nprobe = 1
+ vectors, ids = self.init_data(connect, table)
+ query_vecs = [vectors[0],vectors[55],vectors[99]]
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ for i in range(len(query_vecs)):
+ assert len(result[i]) == top_k
+ assert result[i][0].distance <= epsilon
+
+ """
+ generate invalid query range params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_current_day(), get_current_day()),
+ (get_last_day(1), get_last_day(1)),
+ (get_next_day(1), get_next_day(1))
+ ]
+ )
+ def get_invalid_range(self, request):
+ yield request.param
+
+ def test_search_invalid_query_ranges(self, connect, table, get_invalid_range):
+ '''
+ target: search table with query ranges
+ method: search with the same query ranges
+ expected: status not ok
+ '''
+ top_k = 2
+ nprobe = 1
+ vectors, ids = self.init_data(connect, table)
+ query_vecs = [vectors[0]]
+ query_ranges = [get_invalid_range]
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
+ assert not status.OK()
+ assert len(result) == 0
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_last_day(1)),
+ (get_last_day(2), get_current_day()),
+ (get_next_day(1), get_next_day(2))
+ ]
+ )
+ def get_valid_range_no_result(self, request):
+ yield request.param
+
+ def test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result):
+ '''
+ target: search table with normal query ranges, but no data in db
+ method: search with query ranges (low, low)
+ expected: length of result is 0
+ '''
+ top_k = 2
+ nprobe = 1
+ vectors, ids = self.init_data(connect, table)
+ query_vecs = [vectors[0]]
+ query_ranges = [get_valid_range_no_result]
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
+ assert status.OK()
+ assert len(result) == 0
+
+ """
+ generate valid query range params, no search result
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ (get_last_day(2), get_next_day(2)),
+ (get_current_day(), get_next_day(2)),
+ ]
+ )
+ def get_valid_range(self, request):
+ yield request.param
+
+ def test_search_valid_query_ranges(self, connect, table, get_valid_range):
+ '''
+ target: search table with normal query ranges, but no data in db
+ method: search with query ranges (low, normal)
+ expected: length of result is 0
+ '''
+ top_k = 2
+ nprobe = 1
+ vectors, ids = self.init_data(connect, table)
+ query_vecs = [vectors[0]]
+ query_ranges = [get_valid_range]
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
+ assert status.OK()
+ assert len(result) == 1
+ assert result[0][0].distance <= epsilon
+
+ def test_search_distance_l2_flat_index(self, connect, table):
+ '''
+ target: search table, and check the result: distance
+ method: compare the return distance value with value computed with Euclidean
+ expected: the return distance equals to the computed value
+ '''
+ nb = 2
+ top_k = 1
+ nprobe = 1
+ vectors, ids = self.init_data(connect, table, nb=nb)
+ query_vecs = [[0.50 for i in range(dim)]]
+ distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
+ distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
+
+ def test_search_distance_ip_flat_index(self, connect, ip_table):
+ '''
+ target: search ip_table, and check the result: distance
+ method: compare the return distance value with value computed with Inner product
+ expected: the return distance equals to the computed value
+ '''
+ nb = 2
+ top_k = 1
+ nprobe = 1
+ vectors, ids = self.init_data(connect, ip_table, nb=nb)
+ index_params = {
+ "index_type": IndexType.FLAT,
+ "nlist": 16384
+ }
+ connect.create_index(ip_table, index_params)
+ logging.getLogger().info(connect.describe_index(ip_table))
+ query_vecs = [[0.50 for i in range(dim)]]
+ distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
+ distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+ assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
+
+ def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params):
+ '''
+ target: search table, and check the result: distance
+ method: compare the return distance value with value computed with Inner product
+ expected: the return distance equals to the computed value
+ '''
+ top_k = 2
+ nprobe = 1
+ vectors, ids = self.init_data(connect, ip_table, nb=2)
+ index_params = get_index_params
+ connect.create_index(ip_table, index_params)
+ logging.getLogger().info(connect.describe_index(ip_table))
+ query_vecs = [[0.50 for i in range(dim)]]
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+ distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
+ distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
+ assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
+
+ # TODO: enable
+ # @pytest.mark.repeat(5)
+ @pytest.mark.timeout(30)
+ def _test_search_concurrent(self, connect, table):
+ vectors, ids = self.init_data(connect, table)
+ thread_num = 10
+ nb = 100
+ top_k = 10
+ threads = []
+ query_vecs = vectors[nb//2:nb]
+ def search():
+ status, result = connect.search_vectors(table, top_k, query_vecs)
+ assert len(result) == len(query_vecs)
+ for i in range(len(query_vecs)):
+ assert result[i][0].id in ids
+ assert result[i][0].distance == 0.0
+ for i in range(thread_num):
+ x = threading.Thread(target=search, args=())
+ threads.append(x)
+ x.start()
+ for th in threads:
+ th.join()
+
+ # TODO: enable
+ @pytest.mark.timeout(30)
+ def _test_search_concurrent_multiprocessing(self, args):
+ '''
+ target: test concurrent search with multiprocessess
+ method: search with 10 processes, each process uses dependent connection
+ expected: status ok and the returned vectors should be query_records
+ '''
+ nb = 100
+ top_k = 10
+ process_num = 4
+ processes = []
+ table = gen_unique_str("test_search_concurrent_multiprocessing")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_type': IndexType.FLAT,
+ 'store_raw_vector': False}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ vectors, ids = self.init_data(milvus, table, nb=nb)
+ query_vecs = vectors[nb//2:nb]
+ def search(milvus):
+ status, result = milvus.search_vectors(table, top_k, query_vecs)
+ assert len(result) == len(query_vecs)
+ for i in range(len(query_vecs)):
+ assert result[i][0].id in ids
+ assert result[i][0].distance == 0.0
+
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=search, args=(milvus, ))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+
+ def test_search_multi_table_L2(search, args):
+ '''
+ target: test search multi tables of L2
+ method: add vectors into 10 tables, and search
+ expected: search status ok, the length of result
+ '''
+ num = 10
+ top_k = 10
+ nprobe = 1
+ tables = []
+ idx = []
+ for i in range(num):
+ table = gen_unique_str("test_add_multitable_%d" % i)
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': 10,
+ 'metric_type': MetricType.L2}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ status, ids = milvus.add_vectors(table, vectors)
+ assert status.OK()
+ assert len(ids) == len(vectors)
+ tables.append(table)
+ idx.append(ids[0])
+ idx.append(ids[10])
+ idx.append(ids[20])
+ time.sleep(6)
+ query_vecs = [vectors[0], vectors[10], vectors[20]]
+ # start query from random table
+ for i in range(num):
+ table = tables[i]
+ status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ for j in range(len(query_vecs)):
+ assert len(result[j]) == top_k
+ for j in range(len(query_vecs)):
+ assert check_result(result[j], idx[3 * i + j])
+
+ def test_search_multi_table_IP(search, args):
+ '''
+ target: test search multi tables of IP
+ method: add vectors into 10 tables, and search
+ expected: search status ok, the length of result
+ '''
+ num = 10
+ top_k = 10
+ nprobe = 1
+ tables = []
+ idx = []
+ for i in range(num):
+ table = gen_unique_str("test_add_multitable_%d" % i)
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': 10,
+ 'metric_type': MetricType.L2}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ status, ids = milvus.add_vectors(table, vectors)
+ assert status.OK()
+ assert len(ids) == len(vectors)
+ tables.append(table)
+ idx.append(ids[0])
+ idx.append(ids[10])
+ idx.append(ids[20])
+ time.sleep(6)
+ query_vecs = [vectors[0], vectors[10], vectors[20]]
+ # start query from random table
+ for i in range(num):
+ table = tables[i]
+ status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
+ assert status.OK()
+ assert len(result) == len(query_vecs)
+ for j in range(len(query_vecs)):
+ assert len(result[j]) == top_k
+ for j in range(len(query_vecs)):
+ assert check_result(result[j], idx[3 * i + j])
+"""
+******************************************************************
+# The following cases are used to test `search_vectors` function
+# with invalid table_name top-k / nprobe / query_range
+******************************************************************
+"""
+
+class TestSearchParamsInvalid(object):
+ index_params = random.choice(gen_index_params())
+ logging.getLogger().info(index_params)
+
+ def init_data(self, connect, table, nb=100):
+ '''
+ Generate vectors and add it in table, before search vectors
+ '''
+ global vectors
+ if nb == 100:
+ add_vectors = vectors
+ else:
+ add_vectors = gen_vectors(nb, dim)
+ status, ids = connect.add_vectors(table, add_vectors)
+ sleep(add_interval_time)
+ return add_vectors, ids
+
+ """
+ Test search table with invalid table names
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_table_names()
+ )
+ def get_table_name(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_search_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ logging.getLogger().info(table_name)
+ top_k = 1
+ nprobe = 1
+ query_vecs = gen_vectors(1, dim)
+ status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
+ assert not status.OK()
+
+ """
+ Test search table with invalid top-k
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_top_ks()
+ )
+ def get_top_k(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_search_with_invalid_top_k(self, connect, table, get_top_k):
+ '''
+ target: test search fuction, with the wrong top_k
+ method: search with top_k
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = get_top_k
+ logging.getLogger().info(top_k)
+ nprobe = 1
+ query_vecs = gen_vectors(1, dim)
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ res = connect.server_version()
+
+ @pytest.mark.level(2)
+ def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k):
+ '''
+ target: test search fuction, with the wrong top_k
+ method: search with top_k
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = get_top_k
+ logging.getLogger().info(top_k)
+ nprobe = 1
+ query_vecs = gen_vectors(1, dim)
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+ res = connect.server_version()
+
+ """
+ Test search table with invalid nprobe
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_nprobes()
+ )
+ def get_nprobes(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_search_with_invalid_nrpobe(self, connect, table, get_nprobes):
+ '''
+ target: test search fuction, with the wrong top_k
+ method: search with top_k
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = 1
+ nprobe = get_nprobes
+ logging.getLogger().info(nprobe)
+ query_vecs = gen_vectors(1, dim)
+ if isinstance(nprobe, int) and nprobe > 0:
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+ assert not status.OK()
+ else:
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
+
+ @pytest.mark.level(2)
+ def test_search_with_invalid_nrpobe_ip(self, connect, ip_table, get_nprobes):
+ '''
+ target: test search fuction, with the wrong top_k
+ method: search with top_k
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = 1
+ nprobe = get_nprobes
+ logging.getLogger().info(nprobe)
+ query_vecs = gen_vectors(1, dim)
+ if isinstance(nprobe, int) and nprobe > 0:
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+ assert not status.OK()
+ else:
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
+
+ """
+ Test search table with invalid query ranges
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_query_ranges()
+ )
+ def get_query_ranges(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges):
+ '''
+ target: test search fuction, with the wrong query_range
+ method: search with query_range
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = 1
+ nprobe = 1
+ query_vecs = [vectors[0]]
+ query_ranges = get_query_ranges
+ logging.getLogger().info(query_ranges)
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges)
+
+
+ @pytest.mark.level(2)
+ def test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges):
+ '''
+ target: test search fuction, with the wrong query_range
+ method: search with query_range
+ expected: raise an error, and the connection is normal
+ '''
+ top_k = 1
+ nprobe = 1
+ query_vecs = [vectors[0]]
+ query_ranges = get_query_ranges
+ logging.getLogger().info(query_ranges)
+ with pytest.raises(Exception) as e:
+ status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges)
+
+
+def check_result(result, id):
+ if len(result) >= 5:
+ return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
+ else:
+ return id in (i.id for i in result)
\ No newline at end of file
diff --git a/tests/milvus_python_test/test_table.py b/tests/milvus_python_test/test_table.py
new file mode 100644
index 0000000000..c92afb023e
--- /dev/null
+++ b/tests/milvus_python_test/test_table.py
@@ -0,0 +1,883 @@
+import random
+import pdb
+import pytest
+import logging
+import itertools
+
+from time import sleep
+from multiprocessing import Process
+import numpy
+from milvus import Milvus
+from milvus import IndexType, MetricType
+from utils import *
+
+dim = 128
+delete_table_interval_time = 3
+index_file_size = 10
+vectors = gen_vectors(100, dim)
+
+
+class TestTable:
+
+ """
+ ******************************************************************
+ The following cases are used to test `create_table` function
+ ******************************************************************
+ """
+
+ def test_create_table(self, connect):
+ '''
+ target: test create normal table
+ method: create table with corrent params
+ expected: create status return ok
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ assert status.OK()
+
+ def test_create_table_ip(self, connect):
+ '''
+ target: test create normal table
+ method: create table with corrent params
+ expected: create status return ok
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ status = connect.create_table(param)
+ assert status.OK()
+
+ @pytest.mark.level(2)
+ def test_create_table_without_connection(self, dis_connect):
+ '''
+ target: test create table, without connection
+ method: create table with correct params, with a disconnected instance
+ expected: create raise exception
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ with pytest.raises(Exception) as e:
+ status = dis_connect.create_table(param)
+
+ def test_create_table_existed(self, connect):
+ '''
+ target: test create table but the table name have already existed
+ method: create table with the same table_name
+ expected: create status return not ok
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ status = connect.create_table(param)
+ assert not status.OK()
+
+ @pytest.mark.level(2)
+ def test_create_table_existed_ip(self, connect):
+ '''
+ target: test create table but the table name have already existed
+ method: create table with the same table_name
+ expected: create status return not ok
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ status = connect.create_table(param)
+ status = connect.create_table(param)
+ assert not status.OK()
+
+ def test_create_table_None(self, connect):
+ '''
+ target: test create table but the table name is None
+ method: create table, param table_name is None
+ expected: create raise error
+ '''
+ param = {'table_name': None,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+ def test_create_table_no_dimension(self, connect):
+ '''
+ target: test create table with no dimension params
+ method: create table with corrent params
+ expected: create status return ok
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+ def test_create_table_no_file_size(self, connect):
+ '''
+ target: test create table with no index_file_size params
+ method: create table with corrent params
+ expected: create status return ok, use default 1024
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ logging.getLogger().info(status)
+ status, result = connect.describe_table(table_name)
+ logging.getLogger().info(result)
+ assert result.index_file_size == 1024
+
+ def test_create_table_no_metric_type(self, connect):
+ '''
+ target: test create table with no metric_type params
+ method: create table with corrent params
+ expected: create status return ok, use default L2
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ status = connect.create_table(param)
+ status, result = connect.describe_table(table_name)
+ logging.getLogger().info(result)
+ assert result.metric_type == MetricType.L2
+
+ """
+ ******************************************************************
+ The following cases are used to test `describe_table` function
+ ******************************************************************
+ """
+
+ def test_table_describe_result(self, connect):
+ '''
+ target: test describe table created with correct params
+ method: create table, assert the value returned by describe method
+ expected: table_name equals with the table name created
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status, res = connect.describe_table(table_name)
+ assert res.table_name == table_name
+ assert res.metric_type == MetricType.L2
+
+ def test_table_describe_table_name_ip(self, connect):
+ '''
+ target: test describe table created with correct params
+ method: create table, assert the value returned by describe method
+ expected: table_name equals with the table name created
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ status, res = connect.describe_table(table_name)
+ assert res.table_name == table_name
+ assert res.metric_type == MetricType.IP
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ def _test_table_describe_table_name_multiprocessing(self, connect, args):
+ '''
+ target: test describe table created with multiprocess
+ method: create table, assert the value returned by describe method
+ expected: table_name equals with the table name created
+ '''
+ table_name = gen_unique_str("test_table")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+
+ def describetable(milvus):
+ status, res = milvus.describe_table(table_name)
+ assert res.table_name == table_name
+
+ process_num = 4
+ processes = []
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=describetable, args=(milvus,))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ @pytest.mark.level(2)
+ def test_table_describe_without_connection(self, table, dis_connect):
+ '''
+ target: test describe table, without connection
+ method: describe table with correct params, with a disconnected instance
+ expected: describe raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.describe_table(table)
+
+ def test_table_describe_dimension(self, connect):
+ '''
+ target: test describe table created with correct params
+ method: create table, assert the dimention value returned by describe method
+ expected: dimention equals with dimention when created
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim+1,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status, res = connect.describe_table(table_name)
+ assert res.dimension == dim+1
+
+ """
+ ******************************************************************
+ The following cases are used to test `delete_table` function
+ ******************************************************************
+ """
+
+ def test_delete_table(self, connect, table):
+ '''
+ target: test delete table created with correct params
+ method: create table and then delete,
+ assert the value returned by delete method
+ expected: status ok, and no table in tables
+ '''
+ status = connect.delete_table(table)
+ assert not connect.has_table(table)
+
+ def test_delete_table_ip(self, connect, ip_table):
+ '''
+ target: test delete table created with correct params
+ method: create table and then delete,
+ assert the value returned by delete method
+ expected: status ok, and no table in tables
+ '''
+ status = connect.delete_table(ip_table)
+ assert not connect.has_table(ip_table)
+
+ @pytest.mark.level(2)
+ def test_table_delete_without_connection(self, table, dis_connect):
+ '''
+ target: test describe table, without connection
+ method: describe table with correct params, with a disconnected instance
+ expected: describe raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.delete_table(table)
+
+ def test_delete_table_not_existed(self, connect):
+ '''
+ target: test delete table not in index
+ method: delete all tables, and delete table again,
+ assert the value returned by delete method
+ expected: status not ok
+ '''
+ table_name = gen_unique_str("test_table")
+ status = connect.delete_table(table_name)
+ assert not status.code==0
+
+ def test_delete_table_repeatedly(self, connect):
+ '''
+ target: test delete table created with correct params
+ method: create table and delete new table repeatedly,
+ assert the value returned by delete method
+ expected: create ok and delete ok
+ '''
+ loops = 1
+ for i in range(loops):
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status = connect.delete_table(table_name)
+ time.sleep(1)
+ assert not connect.has_table(table_name)
+
+ def test_delete_create_table_repeatedly(self, connect):
+ '''
+ target: test delete and create the same table repeatedly
+ method: try to create the same table and delete repeatedly,
+ assert the value returned by delete method
+ expected: create ok and delete ok
+ '''
+ loops = 5
+ for i in range(loops):
+ table_name = "test_table"
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status = connect.delete_table(table_name)
+ time.sleep(2)
+ assert status.OK()
+
+ @pytest.mark.level(2)
+ def test_delete_create_table_repeatedly_ip(self, connect):
+ '''
+ target: test delete and create the same table repeatedly
+ method: try to create the same table and delete repeatedly,
+ assert the value returned by delete method
+ expected: create ok and delete ok
+ '''
+ loops = 5
+ for i in range(loops):
+ table_name = "test_table"
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ status = connect.delete_table(table_name)
+ time.sleep(2)
+ assert status.OK()
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ def _test_delete_table_multiprocessing(self, args):
+ '''
+ target: test delete table with multiprocess
+ method: create table and then delete,
+ assert the value returned by delete method
+ expected: status ok, and no table in tables
+ '''
+ process_num = 6
+ processes = []
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+
+ def deletetable(milvus):
+ status = milvus.delete_table(table)
+ # assert not status.code==0
+ assert milvus.has_table(table)
+ assert status.OK()
+
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=deletetable, args=(milvus,))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ def _test_delete_table_multiprocessing_multitable(self, connect):
+ '''
+ target: test delete table with multiprocess
+ method: create table and then delete,
+ assert the value returned by delete method
+ expected: status ok, and no table in tables
+ '''
+ process_num = 5
+ loop_num = 2
+ processes = []
+
+ table = []
+ j = 0
+ while j < (process_num*loop_num):
+ table_name = gen_unique_str("test_delete_table_with_multiprocessing")
+ table.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ j = j + 1
+
+ def delete(connect,ids):
+ i = 0
+ while i < loop_num:
+ # assert connect.has_table(table[ids*8+i])
+ status = connect.delete_table(table[ids*process_num+i])
+ time.sleep(2)
+ assert status.OK()
+ assert not connect.has_table(table[ids*process_num+i])
+ i = i + 1
+
+ for i in range(process_num):
+ ids = i
+ p = Process(target=delete, args=(connect,ids))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ """
+ ******************************************************************
+ The following cases are used to test `has_table` function
+ ******************************************************************
+ """
+
+ def test_has_table(self, connect):
+ '''
+ target: test if the created table existed
+ method: create table, assert the value returned by has_table method
+ expected: True
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ assert connect.has_table(table_name)
+
+ def test_has_table_ip(self, connect):
+ '''
+ target: test if the created table existed
+ method: create table, assert the value returned by has_table method
+ expected: True
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ assert connect.has_table(table_name)
+
+ @pytest.mark.level(2)
+ def test_has_table_without_connection(self, table, dis_connect):
+ '''
+ target: test has table, without connection
+ method: calling has table with correct params, with a disconnected instance
+ expected: has table raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.has_table(table)
+
+ def test_has_table_not_existed(self, connect):
+ '''
+ target: test if table not created
+ method: random a table name, which not existed in db,
+ assert the value returned by has_table method
+ expected: False
+ '''
+ table_name = gen_unique_str("test_table")
+ assert not connect.has_table(table_name)
+
+ """
+ ******************************************************************
+ The following cases are used to test `show_tables` function
+ ******************************************************************
+ """
+
+ def test_show_tables(self, connect):
+ '''
+ target: test show tables is correct or not, if table created
+ method: create table, assert the value returned by show_tables method is equal to 0
+ expected: table_name in show tables
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ status, result = connect.show_tables()
+ assert status.OK()
+ assert table_name in result
+
+ def test_show_tables_ip(self, connect):
+ '''
+ target: test show tables is correct or not, if table created
+ method: create table, assert the value returned by show_tables method is equal to 0
+ expected: table_name in show tables
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ status, result = connect.show_tables()
+ assert status.OK()
+ assert table_name in result
+
+ @pytest.mark.level(2)
+ def test_show_tables_without_connection(self, dis_connect):
+ '''
+ target: test show_tables, without connection
+ method: calling show_tables with correct params, with a disconnected instance
+ expected: show_tables raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.show_tables()
+
+ def test_show_tables_no_table(self, connect):
+ '''
+ target: test show tables is correct or not, if no table in db
+ method: delete all tables,
+ assert the value returned by show_tables method is equal to []
+ expected: the status is ok, and the result is equal to []
+ '''
+ status, result = connect.show_tables()
+ if result:
+ for table_name in result:
+ connect.delete_table(table_name)
+ time.sleep(delete_table_interval_time)
+ status, result = connect.show_tables()
+ assert status.OK()
+ assert len(result) == 0
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ def _test_show_tables_multiprocessing(self, connect, args):
+ '''
+ target: test show tables is correct or not with processes
+ method: create table, assert the value returned by show_tables method is equal to 0
+ expected: table_name in show tables
+ '''
+ table_name = gen_unique_str("test_table")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+
+ def showtables(milvus):
+ status, result = milvus.show_tables()
+ assert status.OK()
+ assert table_name in result
+
+ process_num = 8
+ processes = []
+
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=showtables, args=(milvus,))
+ processes.append(p)
+ p.start()
+ for p in processes:
+ p.join()
+
+ """
+ ******************************************************************
+ The following cases are used to test `preload_table` function
+ ******************************************************************
+ """
+
+ """
+ generate valid create_index params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ @pytest.mark.level(1)
+ def test_preload_table(self, connect, table, get_index_params):
+ index_params = get_index_params
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ status = connect.preload_table(table)
+ assert status.OK()
+
+ @pytest.mark.level(1)
+ def test_preload_table_ip(self, connect, ip_table, get_index_params):
+ index_params = get_index_params
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ status = connect.preload_table(ip_table)
+ assert status.OK()
+
+ @pytest.mark.level(1)
+ def test_preload_table_not_existed(self, connect, table):
+ table_name = gen_unique_str("test_preload_table_not_existed")
+ index_params = random.choice(gen_index_params())
+ status, ids = connect.add_vectors(table, vectors)
+ status = connect.create_index(table, index_params)
+ status = connect.preload_table(table_name)
+ assert not status.OK()
+
+ @pytest.mark.level(1)
+ def test_preload_table_not_existed_ip(self, connect, ip_table):
+ table_name = gen_unique_str("test_preload_table_not_existed")
+ index_params = random.choice(gen_index_params())
+ status, ids = connect.add_vectors(ip_table, vectors)
+ status = connect.create_index(ip_table, index_params)
+ status = connect.preload_table(table_name)
+ assert not status.OK()
+
+ @pytest.mark.level(1)
+ def test_preload_table_no_vectors(self, connect, table):
+ status = connect.preload_table(table)
+ assert status.OK()
+
+ @pytest.mark.level(1)
+ def test_preload_table_no_vectors_ip(self, connect, ip_table):
+ status = connect.preload_table(ip_table)
+ assert status.OK()
+
+ # TODO: psutils get memory usage
+ @pytest.mark.level(1)
+ def test_preload_table_memory_usage(self, connect, table):
+ pass
+
+
+class TestTableInvalid(object):
+ """
+ Test creating table with invalid table names
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_table_names()
+ )
+ def get_table_name(self, request):
+ yield request.param
+
+ def test_create_table_with_invalid_tablename(self, connect, get_table_name):
+ table_name = get_table_name
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ assert not status.OK()
+
+ def test_create_table_with_empty_tablename(self, connect):
+ table_name = ''
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+ def test_preload_table_with_invalid_tablename(self, connect):
+ table_name = ''
+ with pytest.raises(Exception) as e:
+ status = connect.preload_table(table_name)
+
+
+class TestCreateTableDimInvalid(object):
+ """
+ Test creating table with invalid dimension
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_dims()
+ )
+ def get_dim(self, request):
+ yield request.param
+
+ @pytest.mark.timeout(5)
+ def test_create_table_with_invalid_dimension(self, connect, get_dim):
+ dimension = get_dim
+ table = gen_unique_str("test_create_table_with_invalid_dimension")
+ param = {'table_name': table,
+ 'dimension': dimension,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ if isinstance(dimension, int) and dimension > 0:
+ status = connect.create_table(param)
+ assert not status.OK()
+ else:
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+
+# TODO: max / min index file size
+class TestCreateTableIndexSizeInvalid(object):
+ """
+ Test creating tables with invalid index_file_size
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_file_sizes()
+ )
+ def get_file_size(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_create_table_with_invalid_file_size(self, connect, table, get_file_size):
+ file_size = get_file_size
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': file_size,
+ 'metric_type': MetricType.L2}
+ if isinstance(file_size, int) and file_size > 0:
+ status = connect.create_table(param)
+ assert not status.OK()
+ else:
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+
+class TestCreateMetricTypeInvalid(object):
+ """
+ Test creating tables with invalid metric_type
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_invalid_metric_types()
+ )
+ def get_metric_type(self, request):
+ yield request.param
+
+ @pytest.mark.level(2)
+ def test_create_table_with_invalid_file_size(self, connect, table, get_metric_type):
+ metric_type = get_metric_type
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': 10,
+ 'metric_type': metric_type}
+ with pytest.raises(Exception) as e:
+ status = connect.create_table(param)
+
+
+def create_table(connect, **params):
+ param = {'table_name': params["table_name"],
+ 'dimension': params["dimension"],
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ status = connect.create_table(param)
+ return status
+
+def search_table(connect, **params):
+ status, result = connect.search_vectors(
+ params["table_name"],
+ params["top_k"],
+ params["nprobe"],
+ params["query_vectors"])
+ return status
+
+def preload_table(connect, **params):
+ status = connect.preload_table(params["table_name"])
+ return status
+
+def has(connect, **params):
+ status = connect.has_table(params["table_name"])
+ return status
+
+def show(connect, **params):
+ status, result = connect.show_tables()
+ return status
+
+def delete(connect, **params):
+ status = connect.delete_table(params["table_name"])
+ return status
+
+def describe(connect, **params):
+ status, result = connect.describe_table(params["table_name"])
+ return status
+
+def rowcount(connect, **params):
+ status, result = connect.get_table_row_count(params["table_name"])
+ return status
+
+def create_index(connect, **params):
+ status = connect.create_index(params["table_name"], params["index_params"])
+ return status
+
+func_map = {
+ # 0:has,
+ 1:show,
+ 10:create_table,
+ 11:describe,
+ 12:rowcount,
+ 13:search_table,
+ 14:preload_table,
+ 15:create_index,
+ 30:delete
+}
+
+def gen_sequence():
+ raw_seq = func_map.keys()
+ result = itertools.permutations(raw_seq)
+ for x in result:
+ yield x
+
+class TestTableLogic(object):
+
+ @pytest.mark.parametrize("logic_seq", gen_sequence())
+ @pytest.mark.level(2)
+ def test_logic(self, connect, logic_seq):
+ if self.is_right(logic_seq):
+ self.execute(logic_seq, connect)
+ else:
+ self.execute_with_error(logic_seq, connect)
+
+ def is_right(self, seq):
+ if sorted(seq) == seq:
+ return True
+
+ not_created = True
+ has_deleted = False
+ for i in range(len(seq)):
+ if seq[i] > 10 and not_created:
+ return False
+ elif seq [i] > 10 and has_deleted:
+ return False
+ elif seq[i] == 10:
+ not_created = False
+ elif seq[i] == 30:
+ has_deleted = True
+
+ return True
+
+ def execute(self, logic_seq, connect):
+ basic_params = self.gen_params()
+ for i in range(len(logic_seq)):
+ # logging.getLogger().info(logic_seq[i])
+ f = func_map[logic_seq[i]]
+ status = f(connect, **basic_params)
+ assert status.OK()
+
+ def execute_with_error(self, logic_seq, connect):
+ basic_params = self.gen_params()
+
+ error_flag = False
+ for i in range(len(logic_seq)):
+ f = func_map[logic_seq[i]]
+ status = f(connect, **basic_params)
+ if not status.OK():
+ # logging.getLogger().info(logic_seq[i])
+ error_flag = True
+ break
+ assert error_flag == True
+
+ def gen_params(self):
+ table_name = gen_unique_str("test_table")
+ top_k = 1
+ vectors = gen_vectors(2, dim)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_type': IndexType.IVFLAT,
+ 'metric_type': MetricType.L2,
+ 'nprobe': 1,
+ 'top_k': top_k,
+ 'index_params': {
+ 'index_type': IndexType.IVF_SQ8,
+ 'nlist': 16384
+ },
+ 'query_vectors': vectors}
+ return param
diff --git a/tests/milvus_python_test/test_table_count.py b/tests/milvus_python_test/test_table_count.py
new file mode 100644
index 0000000000..af9ae185ab
--- /dev/null
+++ b/tests/milvus_python_test/test_table_count.py
@@ -0,0 +1,296 @@
+import random
+import pdb
+
+import pytest
+import logging
+import itertools
+
+from time import sleep
+from multiprocessing import Process
+from milvus import Milvus
+from utils import *
+from milvus import IndexType, MetricType
+
+dim = 128
+index_file_size = 10
+add_time_interval = 5
+
+
+class TestTableCount:
+ """
+ params means different nb, the nb value may trigger merge, or not
+ """
+ @pytest.fixture(
+ scope="function",
+ params=[
+ 100,
+ 5000,
+ 100000,
+ ],
+ )
+ def add_vectors_nb(self, request):
+ yield request.param
+
+ """
+ generate valid create_index params
+ """
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ def test_table_rows_count(self, connect, table, add_vectors_nb):
+ '''
+ target: test table rows_count is correct or not
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nb = add_vectors_nb
+ vectors = gen_vectors(nb, dim)
+ res = connect.add_vectors(table_name=table, records=vectors)
+ time.sleep(add_time_interval)
+ status, res = connect.get_table_row_count(table)
+ assert res == nb
+
+ def test_table_rows_count_after_index_created(self, connect, table, get_index_params):
+ '''
+ target: test get_table_row_count, after index have been created
+ method: add vectors in db, and create index, then calling get_table_row_count with correct params
+ expected: get_table_row_count raise exception
+ '''
+ nb = 100
+ index_params = get_index_params
+ vectors = gen_vectors(nb, dim)
+ res = connect.add_vectors(table_name=table, records=vectors)
+ time.sleep(add_time_interval)
+ # logging.getLogger().info(index_params)
+ connect.create_index(table, index_params)
+ status, res = connect.get_table_row_count(table)
+ assert res == nb
+
+ @pytest.mark.level(2)
+ def test_count_without_connection(self, table, dis_connect):
+ '''
+ target: test get_table_row_count, without connection
+ method: calling get_table_row_count with correct params, with a disconnected instance
+ expected: get_table_row_count raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.get_table_row_count(table)
+
+ def test_table_rows_count_no_vectors(self, connect, table):
+ '''
+ target: test table rows_count is correct or not, if table is empty
+ method: create table and no vectors in it,
+ assert the value returned by get_table_row_count method is equal to 0
+ expected: the count is equal to 0
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ connect.create_table(param)
+ status, res = connect.get_table_row_count(table)
+ assert res == 0
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(20)
+ def _test_table_rows_count_multiprocessing(self, connect, table, args):
+ '''
+ target: test table rows_count is correct or not with multiprocess
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nq = 2
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ vectors = gen_vectors(nq, dim)
+ res = connect.add_vectors(table_name=table, records=vectors)
+ time.sleep(add_time_interval)
+
+ def rows_count(milvus):
+ status, res = milvus.get_table_row_count(table)
+ logging.getLogger().info(status)
+ assert res == nq
+
+ process_num = 8
+ processes = []
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=rows_count, args=(milvus, ))
+ processes.append(p)
+ p.start()
+ logging.getLogger().info(p)
+ for p in processes:
+ p.join()
+
+ def test_table_rows_count_multi_tables(self, connect):
+ '''
+ target: test table rows_count is correct or not with multiple tables of L2
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_table_rows_count_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.L2}
+ connect.create_table(param)
+ res = connect.add_vectors(table_name=table_name, records=vectors)
+ time.sleep(2)
+ for i in range(50):
+ status, res = connect.get_table_row_count(table_list[i])
+ assert status.OK()
+ assert res == nq
+
+
+class TestTableCountIP:
+ """
+ params means different nb, the nb value may trigger merge, or not
+ """
+
+ @pytest.fixture(
+ scope="function",
+ params=[
+ 100,
+ 5000,
+ 100000,
+ ],
+ )
+ def add_vectors_nb(self, request):
+ yield request.param
+
+ """
+ generate valid create_index params
+ """
+
+ @pytest.fixture(
+ scope="function",
+ params=gen_index_params()
+ )
+ def get_index_params(self, request):
+ yield request.param
+
+ def test_table_rows_count(self, connect, ip_table, add_vectors_nb):
+ '''
+ target: test table rows_count is correct or not
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nb = add_vectors_nb
+ vectors = gen_vectors(nb, dim)
+ res = connect.add_vectors(table_name=ip_table, records=vectors)
+ time.sleep(add_time_interval)
+ status, res = connect.get_table_row_count(ip_table)
+ assert res == nb
+
+ def test_table_rows_count_after_index_created(self, connect, ip_table, get_index_params):
+ '''
+ target: test get_table_row_count, after index have been created
+ method: add vectors in db, and create index, then calling get_table_row_count with correct params
+ expected: get_table_row_count raise exception
+ '''
+ nb = 100
+ index_params = get_index_params
+ vectors = gen_vectors(nb, dim)
+ res = connect.add_vectors(table_name=ip_table, records=vectors)
+ time.sleep(add_time_interval)
+ # logging.getLogger().info(index_params)
+ connect.create_index(ip_table, index_params)
+ status, res = connect.get_table_row_count(ip_table)
+ assert res == nb
+
+ @pytest.mark.level(2)
+ def test_count_without_connection(self, ip_table, dis_connect):
+ '''
+ target: test get_table_row_count, without connection
+ method: calling get_table_row_count with correct params, with a disconnected instance
+ expected: get_table_row_count raise exception
+ '''
+ with pytest.raises(Exception) as e:
+ status = dis_connect.get_table_row_count(ip_table)
+
+ def test_table_rows_count_no_vectors(self, connect, ip_table):
+ '''
+ target: test table rows_count is correct or not, if table is empty
+ method: create table and no vectors in it,
+ assert the value returned by get_table_row_count method is equal to 0
+ expected: the count is equal to 0
+ '''
+ table_name = gen_unique_str("test_table")
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ connect.create_table(param)
+ status, res = connect.get_table_row_count(ip_table)
+ assert res == 0
+
+ # TODO: enable
+ @pytest.mark.level(2)
+ @pytest.mark.timeout(20)
+ def _test_table_rows_count_multiprocessing(self, connect, ip_table, args):
+ '''
+ target: test table rows_count is correct or not with multiprocess
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nq = 2
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ vectors = gen_vectors(nq, dim)
+ res = connect.add_vectors(table_name=ip_table, records=vectors)
+ time.sleep(add_time_interval)
+
+ def rows_count(milvus):
+ status, res = milvus.get_table_row_count(ip_table)
+ logging.getLogger().info(status)
+ assert res == nq
+
+ process_num = 8
+ processes = []
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=rows_count, args=(milvus,))
+ processes.append(p)
+ p.start()
+ logging.getLogger().info(p)
+ for p in processes:
+ p.join()
+
+ def test_table_rows_count_multi_tables(self, connect):
+ '''
+ target: test table rows_count is correct or not with multiple tables of IP
+ method: create table and add vectors in it,
+ assert the value returned by get_table_row_count method is equal to length of vectors
+ expected: the count is equal to the length of vectors
+ '''
+ nq = 100
+ vectors = gen_vectors(nq, dim)
+ table_list = []
+ for i in range(50):
+ table_name = gen_unique_str('test_table_rows_count_multi_tables')
+ table_list.append(table_name)
+ param = {'table_name': table_name,
+ 'dimension': dim,
+ 'index_file_size': index_file_size,
+ 'metric_type': MetricType.IP}
+ connect.create_table(param)
+ res = connect.add_vectors(table_name=table_name, records=vectors)
+ time.sleep(2)
+ for i in range(50):
+ status, res = connect.get_table_row_count(table_list[i])
+ assert status.OK()
+ assert res == nq
\ No newline at end of file
diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py
new file mode 100644
index 0000000000..6d2c56dbed
--- /dev/null
+++ b/tests/milvus_python_test/utils.py
@@ -0,0 +1,545 @@
+# STL imports
+import random
+import string
+import struct
+import sys
+import time, datetime
+import copy
+import numpy as np
+from utils import *
+from milvus import Milvus, IndexType, MetricType
+
+
+def gen_inaccuracy(num):
+ return num/255.0
+
+def gen_vectors(num, dim):
+ return [[random.random() for _ in range(dim)] for _ in range(num)]
+
+
+def gen_single_vector(dim):
+ return [[random.random() for _ in range(dim)]]
+
+
+def gen_vector(nb, d, seed=np.random.RandomState(1234)):
+ xb = seed.rand(nb, d).astype("float32")
+ return xb.tolist()
+
+
+def gen_unique_str(str=None):
+ prefix = "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
+ return prefix if str is None else str + "_" + prefix
+
+
+def get_current_day():
+ return time.strftime('%Y-%m-%d', time.localtime())
+
+
+def get_last_day(day):
+ tmp = datetime.datetime.now()-datetime.timedelta(days=day)
+ return tmp.strftime('%Y-%m-%d')
+
+
+def get_next_day(day):
+ tmp = datetime.datetime.now()+datetime.timedelta(days=day)
+ return tmp.strftime('%Y-%m-%d')
+
+
+def gen_long_str(num):
+ string = ''
+ for _ in range(num):
+ char = random.choice('tomorrow')
+ string += char
+
+
+def gen_invalid_ips():
+ ips = [
+ "255.0.0.0",
+ "255.255.0.0",
+ "255.255.255.0",
+ "255.255.255.255",
+ "127.0.0",
+ "123.0.0.2",
+ "12-s",
+ " ",
+ "12 s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return ips
+
+
+def gen_invalid_ports():
+ ports = [
+ # empty
+ " ",
+ -1,
+ # too big port
+ 100000,
+ # not correct port
+ 39540,
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "\n",
+ "\t",
+ "中文"
+ ]
+ return ports
+
+
+def gen_invalid_uris():
+ ip = None
+ port = 19530
+
+ uris = [
+ " ",
+ "中文",
+
+ # invalid protocol
+ # "tc://%s:%s" % (ip, port),
+ # "tcp%s:%s" % (ip, port),
+
+ # # invalid port
+ # "tcp://%s:100000" % ip,
+ # "tcp://%s: " % ip,
+ # "tcp://%s:19540" % ip,
+ # "tcp://%s:-1" % ip,
+ # "tcp://%s:string" % ip,
+
+ # invalid ip
+ "tcp:// :%s" % port,
+ "tcp://123.0.0.1:%s" % port,
+ "tcp://127.0.0:%s" % port,
+ "tcp://255.0.0.0:%s" % port,
+ "tcp://255.255.0.0:%s" % port,
+ "tcp://255.255.255.0:%s" % port,
+ "tcp://255.255.255.255:%s" % port,
+ "tcp://\n:%s" % port,
+
+ ]
+ return uris
+
+
+def gen_invalid_table_names():
+ table_names = [
+ "12-s",
+ "12/s",
+ " ",
+ # "",
+ # None,
+ "12 s",
+ "BB。A",
+ "c|c",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return table_names
+
+
+def gen_invalid_top_ks():
+ top_ks = [
+ 0,
+ -1,
+ None,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return top_ks
+
+
+def gen_invalid_dims():
+ dims = [
+ 0,
+ -1,
+ 100001,
+ 1000000000000001,
+ None,
+ False,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return dims
+
+
+def gen_invalid_file_sizes():
+ file_sizes = [
+ 0,
+ -1,
+ 1000000000000001,
+ None,
+ False,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return file_sizes
+
+
+def gen_invalid_index_types():
+ invalid_types = [
+ 0,
+ -1,
+ 100,
+ 1000000000000001,
+ # None,
+ False,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return invalid_types
+
+
+def gen_invalid_nlists():
+ nlists = [
+ 0,
+ -1,
+ 1000000000000001,
+ # None,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文"
+ ]
+ return nlists
+
+
+def gen_invalid_nprobes():
+ nprobes = [
+ 0,
+ -1,
+ 1000000000000001,
+ None,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文"
+ ]
+ return nprobes
+
+
+def gen_invalid_metric_types():
+ metric_types = [
+ 0,
+ -1,
+ 1000000000000001,
+ # None,
+ [1,2,3],
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文"
+ ]
+ return metric_types
+
+
+def gen_invalid_vectors():
+ invalid_vectors = [
+ "1*2",
+ [],
+ [1],
+ [1,2],
+ [" "],
+ ['a'],
+ [None],
+ None,
+ (1,2),
+ {"a": 1},
+ " ",
+ "",
+ "String",
+ "12-s",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "pip+",
+ "=c",
+ "\n",
+ "\t",
+ "中文",
+ "a".join("a" for i in range(256))
+ ]
+ return invalid_vectors
+
+
+def gen_invalid_vector_ids():
+ invalid_vector_ids = [
+ 1.0,
+ -1.0,
+ None,
+ # int 64
+ 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
+ " ",
+ "",
+ "String",
+ "BB。A",
+ " siede ",
+ "(mn)",
+ "#12s",
+ "=c",
+ "\n",
+ "中文",
+ ]
+ return invalid_vector_ids
+
+
+
+def gen_invalid_query_ranges():
+ query_ranges = [
+ [(get_last_day(1), "")],
+ [(get_current_day(), "")],
+ [(get_next_day(1), "")],
+ [(get_current_day(), get_last_day(1))],
+ [(get_next_day(1), get_last_day(1))],
+ [(get_next_day(1), get_current_day())],
+ [(0, get_next_day(1))],
+ [(-1, get_next_day(1))],
+ [(1, get_next_day(1))],
+ [(100001, get_next_day(1))],
+ [(1000000000000001, get_next_day(1))],
+ [(None, get_next_day(1))],
+ [([1,2,3], get_next_day(1))],
+ [((1,2), get_next_day(1))],
+ [({"a": 1}, get_next_day(1))],
+ [(" ", get_next_day(1))],
+ [("", get_next_day(1))],
+ [("String", get_next_day(1))],
+ [("12-s", get_next_day(1))],
+ [("BB。A", get_next_day(1))],
+ [(" siede ", get_next_day(1))],
+ [("(mn)", get_next_day(1))],
+ [("#12s", get_next_day(1))],
+ [("pip+", get_next_day(1))],
+ [("=c", get_next_day(1))],
+ [("\n", get_next_day(1))],
+ [("\t", get_next_day(1))],
+ [("中文", get_next_day(1))],
+ [("a".join("a" for i in range(256)), get_next_day(1))]
+ ]
+ return query_ranges
+
+
+def gen_invalid_index_params():
+ index_params = []
+ for index_type in gen_invalid_index_types():
+ index_param = {"index_type": index_type, "nlist": 16384}
+ index_params.append(index_param)
+ for nlist in gen_invalid_nlists():
+ index_param = {"index_type": IndexType.IVFLAT, "nlist": nlist}
+ index_params.append(index_param)
+ return index_params
+
+
+def gen_index_params():
+ index_params = []
+ index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
+ nlists = [1, 16384, 50000]
+
+ def gen_params(index_types, nlists):
+ return [ {"index_type": index_type, "nlist": nlist} \
+ for index_type in index_types \
+ for nlist in nlists]
+
+ return gen_params(index_types, nlists)
+
+def gen_simple_index_params():
+ index_params = []
+ index_types = [IndexType.FLAT, IndexType.IVFLAT, IndexType.IVF_SQ8, IndexType.IVF_SQ8H]
+ nlists = [16384]
+
+ def gen_params(index_types, nlists):
+ return [ {"index_type": index_type, "nlist": nlist} \
+ for index_type in index_types \
+ for nlist in nlists]
+
+ return gen_params(index_types, nlists)
+
+if __name__ == "__main__":
+ import numpy
+
+ dim = 128
+ nq = 10000
+ table = "test"
+
+
+ file_name = '/poc/yuncong/ann_1000m/query.npy'
+ data = np.load(file_name)
+ vectors = data[0:nq].tolist()
+ # print(vectors)
+
+ connect = Milvus()
+ # connect.connect(host="192.168.1.27")
+ # print(connect.show_tables())
+ # print(connect.get_table_row_count(table))
+ # sys.exit()
+ connect.connect(host="127.0.0.1")
+ connect.delete_table(table)
+ # sys.exit()
+ # time.sleep(2)
+ print(connect.get_table_row_count(table))
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'metric_type': MetricType.L2,
+ 'index_file_size': 10}
+ status = connect.create_table(param)
+ print(status)
+ print(connect.get_table_row_count(table))
+ # add vectors
+ for i in range(10):
+ status, ids = connect.add_vectors(table, vectors)
+ print(status)
+ print(ids[0])
+ # print(ids[0])
+ index_params = {"index_type": IndexType.IVFLAT, "nlist": 16384}
+ status = connect.create_index(table, index_params)
+ print(status)
+ # sys.exit()
+ query_vec = [vectors[0]]
+ # print(numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0])))
+ top_k = 12
+ nprobe = 1
+ for i in range(2):
+ result = connect.search_vectors(table, top_k, nprobe, query_vec)
+ print(result)
+ sys.exit()
+
+
+ table = gen_unique_str("test_add_vector_with_multiprocessing")
+ uri = "tcp://%s:%s" % (args["ip"], args["port"])
+ param = {'table_name': table,
+ 'dimension': dim,
+ 'index_file_size': index_file_size}
+ # create table
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ milvus.create_table(param)
+ vector = gen_single_vector(dim)
+
+ process_num = 4
+ loop_num = 10
+ processes = []
+ # with dependent connection
+ def add(milvus):
+ i = 0
+ while i < loop_num:
+ status, ids = milvus.add_vectors(table, vector)
+ i = i + 1
+ for i in range(process_num):
+ milvus = Milvus()
+ milvus.connect(uri=uri)
+ p = Process(target=add, args=(milvus,))
+ processes.append(p)
+ p.start()
+ time.sleep(0.2)
+ for p in processes:
+ p.join()
+ time.sleep(3)
+ status, count = milvus.get_table_row_count(table)
+ assert count == process_num * loop_num
\ No newline at end of file