package com; import io.milvus.client.*; import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; public class TestMix { private int dimension = 128; private int nb = 100000; int n_list = 8192; int n_probe = 20; int top_k = 10; double epsilon = 0.001; int index_file_size = 20; public List normalize(List w2v){ float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum); final float norm = (float) Math.sqrt(squareSum); w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList()); return w2v; } public List> gen_vectors(int nb, boolean norm) { List> xb = new ArrayList<>(); Random random = new Random(); for (int i = 0; i < nb; ++i) { List vector = new ArrayList<>(); for (int j = 0; j < dimension; j++) { vector.add(random.nextFloat()); } if (norm == true) { vector = normalize(vector); } xb.add(vector); } return xb; } @Test(dataProvider = "Table", dataProviderClass = MainClass.class) public void test_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { int thread_num = 10; int nq = 5; List> vectors = gen_vectors(nb, false); List> queryVectors = vectors.subList(0,nq); InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); client.insert(insertParam); Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) .withNList(n_list) .build(); CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); client.createIndex(createIndexParam); ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); SearchResponse res_search = client.search(searchParam); assert (res_search.getResponse().ok()); }); } executor.awaitQuiescence(100, TimeUnit.SECONDS); executor.shutdown(); } @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) public void test_connect_threads(String host, String port) throws InterruptedException { int thread_num = 100; ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { MilvusClient client = new MilvusGrpcClient(); ConnectParam connectParam = new ConnectParam.Builder() .withHost(host) .withPort(port) .build(); client.connect(connectParam); assert(client.connected()); try { client.disconnect(); } catch (InterruptedException e) { e.printStackTrace(); } }); } executor.awaitQuiescence(100, TimeUnit.SECONDS); executor.shutdown(); } @Test(dataProvider = "Table", dataProviderClass = MainClass.class) public void test_add_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { int thread_num = 10; List> vectors = gen_vectors(nb,false); InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { InsertResponse res_insert = client.insert(insertParam); assert (res_insert.getResponse().ok()); }); } executor.awaitQuiescence(100, TimeUnit.SECONDS); executor.shutdown(); Thread.sleep(2000); TableParam tableParam = new TableParam.Builder(tableName).build(); GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); } @Test(dataProvider = "Table", dataProviderClass = MainClass.class) public void test_add_index_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { int thread_num = 50; List> vectors = gen_vectors(nb,false); InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { InsertResponse res_insert = client.insert(insertParam); Index index = new Index.Builder().withIndexType(IndexType.IVF_SQ8) .withNList(n_list) .build(); CreateIndexParam createIndexParam = new CreateIndexParam.Builder(tableName).withIndex(index).build(); client.createIndex(createIndexParam); assert (res_insert.getResponse().ok()); }); } executor.awaitQuiescence(300, TimeUnit.SECONDS); executor.shutdown(); Thread.sleep(2000); TableParam tableParam = new TableParam.Builder(tableName).build(); GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); } @Test(dataProvider = "Table", dataProviderClass = MainClass.class) public void test_add_search_vectors_threads(MilvusClient client, String tableName) throws InterruptedException { int thread_num = 50; int nq = 5; List> vectors = gen_vectors(nb, true); List> queryVectors = vectors.subList(0,nq); InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { InsertResponse res_insert = client.insert(insertParam); assert (res_insert.getResponse().ok()); try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } SearchParam searchParam = new SearchParam.Builder(tableName, queryVectors).withNProbe(n_probe).withTopK(top_k).build(); SearchResponse res_search = client.search(searchParam); assert (res_search.getResponse().ok()); List> res = client.search(searchParam).getQueryResultsList(); double distance = res.get(0).get(0).getDistance(); if (tableName.startsWith("L2")) { Assert.assertEquals(distance, 0.0, epsilon); }else if (tableName.startsWith("IP")) { Assert.assertEquals(distance, 1.0, epsilon); } }); } executor.awaitQuiescence(300, TimeUnit.SECONDS); executor.shutdown(); Thread.sleep(2000); TableParam tableParam = new TableParam.Builder(tableName).build(); GetTableRowCountResponse getTableRowCountResponse = client.getTableRowCount(tableParam); Assert.assertEquals(getTableRowCountResponse.getTableRowCount(), thread_num * nb); } @Test(dataProvider = "DefaultConnectArgs", dataProviderClass = MainClass.class) public void test_create_insert_delete_threads(String host, String port) throws InterruptedException { int thread_num = 100; List> vectors = gen_vectors(nb,false); ForkJoinPool executor = new ForkJoinPool(); for (int i = 0; i < thread_num; i++) { executor.execute( () -> { MilvusClient client = new MilvusGrpcClient(); ConnectParam connectParam = new ConnectParam.Builder() .withHost(host) .withPort(port) .build(); client.connect(connectParam); String tableName = RandomStringUtils.randomAlphabetic(10); TableSchema tableSchema = new TableSchema.Builder(tableName, dimension) .withIndexFileSize(index_file_size) .withMetricType(MetricType.IP) .build(); TableSchemaParam tableSchemaParam = new TableSchemaParam.Builder(tableSchema).build(); client.createTable(tableSchemaParam); InsertParam insertParam = new InsertParam.Builder(tableName, vectors).build(); client.insert(insertParam); TableParam tableParam = new TableParam.Builder(tableName).build(); Response response = client.dropTable(tableParam); Assert.assertTrue(response.ok()); try { client.disconnect(); } catch (InterruptedException e) { e.printStackTrace(); } }); } executor.awaitQuiescence(100, TimeUnit.SECONDS); executor.shutdown(); } }