From ed50cb303f0a1b16a7a1ab0a41b190f48434db2f Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Sat, 5 Sep 2020 19:08:29 +0800 Subject: [PATCH] [skip ci] use List> for binary searchParam (#3622) Signed-off-by: yudong.cai --- .../src/main/java/com/Constants.java | 6 +-- .../main/java/com/TestCollectionCount.java | 1 + .../main/java/com/TestCollectionCount_v2.java | 4 +- .../main/java/com/TestCollectionInfo_v2.java | 3 +- .../src/main/java/com/TestCollection_v2.java | 2 - .../src/main/java/com/TestCompact.java | 1 - .../main/java/com/TestDeleteEntities_v2.java | 3 +- .../src/main/java/com/TestFlush.java | 1 + .../src/main/java/com/TestGetEntityByID.java | 8 +++- .../src/main/java/com/TestIndex_v2.java | 3 +- .../src/main/java/com/TestInsertEntities.java | 3 +- .../main/java/com/TestInsertEntities_v2.java | 3 +- .../src/main/java/com/TestSearchEntities.java | 44 +++++-------------- .../src/main/java/com/Utils.java | 26 ++++++----- 14 files changed, 41 insertions(+), 67 deletions(-) diff --git a/tests/milvus-java-test/src/main/java/com/Constants.java b/tests/milvus-java-test/src/main/java/com/Constants.java index a635b21e87..d74508ab01 100644 --- a/tests/milvus-java-test/src/main/java/com/Constants.java +++ b/tests/milvus-java-test/src/main/java/com/Constants.java @@ -1,9 +1,5 @@ package com; -import com.alibaba.fastjson.JSONObject; - -import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -47,7 +43,7 @@ public final class Constants { public static final List> vectors = Utils.genVectors(nb, dimension, true); - public static final List vectorsBinary = Utils.genBinaryVectors(nb, dimension); + public static final List> vectorsBinary = Utils.genBinaryVectors(nb, dimension); public static final List> defaultFields = Utils.genDefaultFields(dimension,false); diff --git a/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java b/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java index b5bc7ad4eb..de278ec27a 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollectionCount.java @@ -3,6 +3,7 @@ package com; import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.Test; + import java.util.List; public class TestCollectionCount { diff --git a/tests/milvus-java-test/src/main/java/com/TestCollectionCount_v2.java b/tests/milvus-java-test/src/main/java/com/TestCollectionCount_v2.java index 89119445f8..748d3232e0 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollectionCount_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollectionCount_v2.java @@ -3,15 +3,15 @@ package com; import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.Test; + import java.util.List; -import java.nio.ByteBuffer; public class TestCollectionCount_v2 { int segmentRowCount = 5000; int dimension = Constants.dimension; int nb = Constants.nb; List> vectors = Constants.vectors; - List vectorsBinary = Constants.vectorsBinary; + List> vectorsBinary = Constants.vectorsBinary; // case-04 @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) diff --git a/tests/milvus-java-test/src/main/java/com/TestCollectionInfo_v2.java b/tests/milvus-java-test/src/main/java/com/TestCollectionInfo_v2.java index d5bbe500e5..b6c647ce44 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollectionInfo_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollectionInfo_v2.java @@ -6,14 +6,13 @@ import org.testng.annotations.Test; import java.util.Collections; import java.util.List; -import java.nio.ByteBuffer; public class TestCollectionInfo_v2 { int nb = Constants.nb; int dimension = Constants.dimension; int n_list = Constants.n_list; List> vectors = Constants.vectors; - List vectorsBinary = Constants.vectorsBinary; + List> vectorsBinary = Constants.vectorsBinary; String indexType = Constants.indexType; String metricType = Constants.defaultMetricType; String floatFieldName = Constants.floatFieldName; diff --git a/tests/milvus-java-test/src/main/java/com/TestCollection_v2.java b/tests/milvus-java-test/src/main/java/com/TestCollection_v2.java index 84f1398b4c..ba8f840778 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCollection_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestCollection_v2.java @@ -1,12 +1,10 @@ package com; -import com.alibaba.fastjson.JSONObject; import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.*; import java.util.List; -import java.util.Map; public class TestCollection_v2 { int segmentRowCount = 5000; diff --git a/tests/milvus-java-test/src/main/java/com/TestCompact.java b/tests/milvus-java-test/src/main/java/com/TestCompact.java index b997b1fd0c..670d1da6b3 100644 --- a/tests/milvus-java-test/src/main/java/com/TestCompact.java +++ b/tests/milvus-java-test/src/main/java/com/TestCompact.java @@ -1,6 +1,5 @@ package com; -import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import io.milvus.client.*; import org.testng.Assert; diff --git a/tests/milvus-java-test/src/main/java/com/TestDeleteEntities_v2.java b/tests/milvus-java-test/src/main/java/com/TestDeleteEntities_v2.java index c86011ed69..e2a480672b 100644 --- a/tests/milvus-java-test/src/main/java/com/TestDeleteEntities_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestDeleteEntities_v2.java @@ -7,13 +7,12 @@ import org.testng.annotations.Test; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.nio.ByteBuffer; public class TestDeleteEntities_v2 { int dimension = Constants.dimension; int nb = Constants.nb; List> vectors = Constants.vectors; - List vectorsBinary = Constants.vectorsBinary; + List> vectorsBinary = Constants.vectorsBinary; // case-01 @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) diff --git a/tests/milvus-java-test/src/main/java/com/TestFlush.java b/tests/milvus-java-test/src/main/java/com/TestFlush.java index c356fd4753..713b2dfd68 100644 --- a/tests/milvus-java-test/src/main/java/com/TestFlush.java +++ b/tests/milvus-java-test/src/main/java/com/TestFlush.java @@ -6,6 +6,7 @@ import io.milvus.client.*; import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; + import java.util.ArrayList; import java.util.List; diff --git a/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java b/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java index 6956ed175f..ca6c7724a9 100644 --- a/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java +++ b/tests/milvus-java-test/src/main/java/com/TestGetEntityByID.java @@ -58,12 +58,16 @@ public class TestGetEntityByID { // Binary tests @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class) public void testGetEntityByIdValidBinary(MilvusClient client, String collectionName) { + int get_length = 20; InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build(); InsertResponse resInsert = client.insert(insertParam); List ids = resInsert.getEntityIds(); client.flush(collectionName); - GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1)); - assert res.getFieldsMap().get(0).get(Constants.binaryFieldName).equals(Constants.vectorsBinary.get(0).rewind()); + GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, get_length)); + for (int i = 0; i < get_length; i++) { + List> fieldsMap = res.getFieldsMap(); + assert (fieldsMap.get(i).get("binary_vector").equals(Constants.vectorsBinary.get(i))); + } } @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class) diff --git a/tests/milvus-java-test/src/main/java/com/TestIndex_v2.java b/tests/milvus-java-test/src/main/java/com/TestIndex_v2.java index 46e1cd6ef9..ffcbcb88da 100644 --- a/tests/milvus-java-test/src/main/java/com/TestIndex_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestIndex_v2.java @@ -6,7 +6,6 @@ import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.*; -import java.nio.ByteBuffer; import java.util.List; public class TestIndex_v2 { @@ -19,7 +18,7 @@ public class TestIndex_v2 { String defaultBinaryIndexType = Constants.defaultBinaryIndexType; String defaultMetricType = Constants.defaultMetricType; List> vectors = Constants.vectors; - List vectorsBinary = Constants.vectorsBinary; + List> vectorsBinary = Constants.vectorsBinary; // case-01 diff --git a/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java b/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java index 6f6aa17da3..9f0cc8e7eb 100644 --- a/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java +++ b/tests/milvus-java-test/src/main/java/com/TestInsertEntities.java @@ -7,7 +7,6 @@ import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; -import java.nio.ByteBuffer; import java.util.*; import java.util.stream.Collectors; import java.util.stream.LongStream; @@ -188,7 +187,7 @@ public class TestInsertEntities { // case-14 @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class) public void testInsertBinaryEntityWithInvalidDimension(MilvusClient client, String collectionName) { - List vectorsBinary = Utils.genBinaryVectors(nb, dimension-1); + List> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1); List> binaryEntities = Utils.genDefaultBinaryEntities(dimension-1,nb,vectorsBinary); InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(binaryEntities).build(); InsertResponse res = client.insert(insertParam); diff --git a/tests/milvus-java-test/src/main/java/com/TestInsertEntities_v2.java b/tests/milvus-java-test/src/main/java/com/TestInsertEntities_v2.java index f11a6b7c81..6f1b3e7e9f 100644 --- a/tests/milvus-java-test/src/main/java/com/TestInsertEntities_v2.java +++ b/tests/milvus-java-test/src/main/java/com/TestInsertEntities_v2.java @@ -10,7 +10,6 @@ import org.apache.commons.lang3.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.Test; -import java.nio.ByteBuffer; import java.util.List; import java.util.stream.Collectors; import java.util.stream.LongStream; @@ -20,7 +19,7 @@ public class TestInsertEntities_v2 { String tag = "tag"; int nb = Constants.nb; List> vectors = Constants.vectors; - List vectorsBinary = Constants.vectorsBinary; + List> vectorsBinary = Constants.vectorsBinary; // case-01 @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) diff --git a/tests/milvus-java-test/src/main/java/com/TestSearchEntities.java b/tests/milvus-java-test/src/main/java/com/TestSearchEntities.java index 21de271f47..a922594ef6 100644 --- a/tests/milvus-java-test/src/main/java/com/TestSearchEntities.java +++ b/tests/milvus-java-test/src/main/java/com/TestSearchEntities.java @@ -6,7 +6,6 @@ import io.milvus.client.*; import org.testng.Assert; import org.testng.annotations.Test; -import java.nio.ByteBuffer; import java.util.*; public class TestSearchEntities { @@ -16,21 +15,22 @@ public class TestSearchEntities { int nq = Constants.nq; List> queryVectors = Constants.vectors.subList(0, nq); - List queryVectorsBinary = Constants.vectorsBinary.subList(0, nq); + List> queryVectorsBinary = Constants.vectorsBinary.subList(0, nq); - public String dsl = Constants.searchParam; + public String floatDsl = Constants.searchParam; + public String binaryDsl = Constants.binarySearchParam; @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testSearchCollectionNotExisted(MilvusClient client, String collectionName) { String collectionNameNew = Utils.genUniqueStr(collectionName); - SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(dsl).build(); + SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(floatDsl).build(); SearchResponse res_search = client.search(searchParam); assert (!res_search.getResponse().ok()); } @Test(dataProvider = "Collection", dataProviderClass = MainClass.class) public void testSearchCollectionEmpty(MilvusClient client, String collectionName) { - SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build(); SearchResponse res_search = client.search(searchParam); assert (res_search.getResponse().ok()); Assert.assertEquals(res_search.getResultIdsList().size(), 0); @@ -44,7 +44,7 @@ public class TestSearchEntities { assert(res.getResponse().ok()); List ids = res.getEntityIds(); client.flush(collectionName); - SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build(); SearchResponse res_search = client.search(searchParam); Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq); Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq); @@ -60,7 +60,7 @@ public class TestSearchEntities { assert(res.getResponse().ok()); List ids = res.getEntityIds(); client.flush(collectionName); - SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build(); SearchResponse res_search = client.search(searchParam); for (int i = 0; i < Constants.nq; i++) { double distance = res_search.getResultDistancesList().get(i).get(0); @@ -94,7 +94,7 @@ public class TestSearchEntities { InsertResponse res = client.insert(insertParam); assert(res.getResponse().ok()); client.flush(collectionName); - SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).withPartitionTags(queryTags).build(); SearchResponse res_search = client.search(searchParam); Assert.assertEquals(res_search.getResultDistancesList().size(), 0); } @@ -109,7 +109,7 @@ public class TestSearchEntities { InsertResponse res = client.insert(insertParam); assert (res.getResponse().ok()); client.flush(collectionName); - SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).withPartitionTags(queryTags).build(); SearchResponse res_search = client.search(searchParam); Assert.assertEquals(res_search.getResultDistancesList().size(), 0); } @@ -289,18 +289,7 @@ public class TestSearchEntities { @Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class) public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) { String collectionNameNew = Utils.genUniqueStr(collectionName); - String queryKey = "placeholder"; - Map> binaryQueryEntities = new HashMap<>(); - binaryQueryEntities.put(queryKey, queryVectorsBinary); - JSONObject binaryVectorParam = Utils.genBinaryVectorParam(Constants.defaultBinaryMetricType, queryKey, top_k, n_probe); - List leafParams = new ArrayList<>(); - leafParams.add(binaryVectorParam); - String dsl = Utils.genDefaultSearchParam(leafParams); - System.out.println(dsl); - SearchParam searchParam = new SearchParam.Builder(collectionNameNew) - .withBinaryEntities(binaryQueryEntities) - .withDSL(dsl) - .build(); + SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(binaryDsl).build(); SearchResponse resSearch = client.search(searchParam); Assert.assertFalse(resSearch.getResponse().ok()); } @@ -311,18 +300,7 @@ public class TestSearchEntities { InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build(); InsertResponse res = client.insert(insertParam); assert(res.getResponse().ok()); - String queryKey = "placeholder"; - Map> binaryQueryEntities = new HashMap<>(); - binaryQueryEntities.put(queryKey, queryVectorsBinary); - JSONObject binaryVectorParam = Utils.genBinaryVectorParam(Constants.defaultBinaryMetricType, queryKey, top_k, n_probe); - List leafParams = new ArrayList<>(); - leafParams.add(binaryVectorParam); - String dsl = Utils.genDefaultSearchParam(leafParams); - System.out.println(dsl); - SearchParam searchParam = new SearchParam.Builder(collectionName) - .withDSL(dsl) - .withBinaryEntities(binaryQueryEntities) - .build(); + SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(binaryDsl).build(); SearchResponse resSearch = client.search(searchParam); Assert.assertTrue(resSearch.getResponse().ok()); Assert.assertEquals(resSearch.getResultIdsList().size(), nq); diff --git a/tests/milvus-java-test/src/main/java/com/Utils.java b/tests/milvus-java-test/src/main/java/com/Utils.java index bf231286df..d7789163f6 100644 --- a/tests/milvus-java-test/src/main/java/com/Utils.java +++ b/tests/milvus-java-test/src/main/java/com/Utils.java @@ -27,8 +27,8 @@ public class Utils { } public static List> genVectors(int vectorCount, int dimension, boolean norm) { - List> vectors = new ArrayList<>(); Random random = new Random(); + List> vectors = new ArrayList<>(); for (int i = 0; i < vectorCount; ++i) { List vector = new ArrayList<>(); for (int j = 0; j < dimension; ++j) { @@ -42,14 +42,16 @@ public class Utils { return vectors; } - static List genBinaryVectors(long vectorCount, long dimension) { + static List> genBinaryVectors(long vectorCount, long dimension) { Random random = new Random(); - List vectors = new ArrayList<>(); + List> vectors = new ArrayList<>(); final long dimensionInByte = dimension / 8; for (long i = 0; i < vectorCount; ++i) { - ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte); - random.nextBytes(byteBuffer.array()); - vectors.add(byteBuffer); + List vector = new ArrayList<>(); + for (int j = 0; j < dimensionInByte; ++j) { + vector.add((byte) random.nextInt()); + } + vectors.add(vector); } return vectors; } @@ -111,7 +113,7 @@ public class Utils { return fieldsMap; } - public static List> genDefaultBinaryEntities(int dimension, int vectorCount, List vectorsBinary){ + public static List> genDefaultBinaryEntities(int dimension, int vectorCount, List> vectorsBinary){ List> binaryFieldsMap = genDefaultFields(dimension, true); List intValues = new ArrayList<>(vectorCount); List floatValues = new ArrayList<>(vectorCount); @@ -163,7 +165,7 @@ public class Utils { return searchParam; } - static JSONObject genBinaryVectorParam(String metricType, String queryVectors, int topk, int nprobe) { + static JSONObject genBinaryVectorParam(String metricType, List> queryVectors, int topk, int nprobe) { JSONObject searchParam = new JSONObject(); JSONObject fieldParam = new JSONObject(); fieldParam.put("topk", topk); @@ -200,7 +202,7 @@ public class Utils { return JSONObject.toJSONString(boolParam); } - public static String setBinarySearchParam(String metricType, List queryVectors, int topk, int nprobe) { + public static String setBinarySearchParam(String metricType, List> queryVectors, int topk, int nprobe) { JSONObject searchParam = new JSONObject(); JSONObject fieldParam = new JSONObject(); fieldParam.put("topk", topk); @@ -367,7 +369,7 @@ public class Utils { } public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount, - List vectorsBinary) { + List> vectorsBinary) { List intValues = new ArrayList<>(vectorCount); List floatValues = new ArrayList<>(vectorCount); for (int i = 0; i < vectorCount; ++i) { @@ -390,7 +392,7 @@ public class Utils { } public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount, - List vectorsBinary, List entityIds) { + List> vectorsBinary, List entityIds) { List intValues = new ArrayList<>(vectorCount); List floatValues = new ArrayList<>(vectorCount); for (int i = 0; i < vectorCount; ++i) { @@ -414,7 +416,7 @@ public class Utils { } public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount, - List vectorsBinary, String tag) { + List> vectorsBinary, String tag) { List intValues = new ArrayList<>(vectorCount); List floatValues = new ArrayList<>(vectorCount); for (int i = 0; i < vectorCount; ++i) {