[skip ci] use List<List<Byte>> for binary searchParam (#3622)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
Cai Yudong 2020-09-05 19:08:29 +08:00 committed by GitHub
parent 4b004737c2
commit ed50cb303f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 41 additions and 67 deletions

View File

@ -1,9 +1,5 @@
package com;
import com.alibaba.fastjson.JSONObject;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -47,7 +43,7 @@ public final class Constants {
public static final List<List<Float>> vectors = Utils.genVectors(nb, dimension, true);
public static final List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
public static final List<List<Byte>> vectorsBinary = Utils.genBinaryVectors(nb, dimension);
public static final List<Map<String,Object>> defaultFields = Utils.genDefaultFields(dimension,false);

View File

@ -3,6 +3,7 @@ package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
public class TestCollectionCount {

View File

@ -3,15 +3,15 @@ package com;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.List;
import java.nio.ByteBuffer;
public class TestCollectionCount_v2 {
int segmentRowCount = 5000;
int dimension = Constants.dimension;
int nb = Constants.nb;
List<List<Float>> vectors = Constants.vectors;
List<ByteBuffer> vectorsBinary = Constants.vectorsBinary;
List<List<Byte>> vectorsBinary = Constants.vectorsBinary;
// case-04
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)

View File

@ -6,14 +6,13 @@ import org.testng.annotations.Test;
import java.util.Collections;
import java.util.List;
import java.nio.ByteBuffer;
public class TestCollectionInfo_v2 {
int nb = Constants.nb;
int dimension = Constants.dimension;
int n_list = Constants.n_list;
List<List<Float>> vectors = Constants.vectors;
List<ByteBuffer> vectorsBinary = Constants.vectorsBinary;
List<List<Byte>> vectorsBinary = Constants.vectorsBinary;
String indexType = Constants.indexType;
String metricType = Constants.defaultMetricType;
String floatFieldName = Constants.floatFieldName;

View File

@ -1,12 +1,10 @@
package com;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.*;
import java.util.List;
import java.util.Map;
public class TestCollection_v2 {
int segmentRowCount = 5000;

View File

@ -1,6 +1,5 @@
package com;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import io.milvus.client.*;
import org.testng.Assert;

View File

@ -7,13 +7,12 @@ import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.nio.ByteBuffer;
public class TestDeleteEntities_v2 {
int dimension = Constants.dimension;
int nb = Constants.nb;
List<List<Float>> vectors = Constants.vectors;
List<ByteBuffer> vectorsBinary = Constants.vectorsBinary;
List<List<Byte>> vectorsBinary = Constants.vectorsBinary;
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)

View File

@ -6,6 +6,7 @@ import io.milvus.client.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.List;

View File

@ -58,12 +58,16 @@ public class TestGetEntityByID {
// Binary tests
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testGetEntityByIdValidBinary(MilvusClient client, String collectionName) {
int get_length = 20;
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse resInsert = client.insert(insertParam);
List<Long> ids = resInsert.getEntityIds();
client.flush(collectionName);
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, 1));
assert res.getFieldsMap().get(0).get(Constants.binaryFieldName).equals(Constants.vectorsBinary.get(0).rewind());
GetEntityByIDResponse res = client.getEntityByID(collectionName, ids.subList(0, get_length));
for (int i = 0; i < get_length; i++) {
List<Map<String,Object>> fieldsMap = res.getFieldsMap();
assert (fieldsMap.get(i).get("binary_vector").equals(Constants.vectorsBinary.get(i)));
}
}
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)

View File

@ -6,7 +6,6 @@ import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.*;
import java.nio.ByteBuffer;
import java.util.List;
public class TestIndex_v2 {
@ -19,7 +18,7 @@ public class TestIndex_v2 {
String defaultBinaryIndexType = Constants.defaultBinaryIndexType;
String defaultMetricType = Constants.defaultMetricType;
List<List<Float>> vectors = Constants.vectors;
List<ByteBuffer> vectorsBinary = Constants.vectorsBinary;
List<List<Byte>> vectorsBinary = Constants.vectorsBinary;
// case-01

View File

@ -7,7 +7,6 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
@ -188,7 +187,7 @@ public class TestInsertEntities {
// case-14
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testInsertBinaryEntityWithInvalidDimension(MilvusClient client, String collectionName) {
List<ByteBuffer> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1);
List<List<Byte>> vectorsBinary = Utils.genBinaryVectors(nb, dimension-1);
List<Map<String,Object>> binaryEntities = Utils.genDefaultBinaryEntities(dimension-1,nb,vectorsBinary);
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(binaryEntities).build();
InsertResponse res = client.insert(insertParam);

View File

@ -10,7 +10,6 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
@ -20,7 +19,7 @@ public class TestInsertEntities_v2 {
String tag = "tag";
int nb = Constants.nb;
List<List<Float>> vectors = Constants.vectors;
List<ByteBuffer> vectorsBinary = Constants.vectorsBinary;
List<List<Byte>> vectorsBinary = Constants.vectorsBinary;
// case-01
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)

View File

@ -6,7 +6,6 @@ import io.milvus.client.*;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.*;
public class TestSearchEntities {
@ -16,21 +15,22 @@ public class TestSearchEntities {
int nq = Constants.nq;
List<List<Float>> queryVectors = Constants.vectors.subList(0, nq);
List<ByteBuffer> queryVectorsBinary = Constants.vectorsBinary.subList(0, nq);
List<List<Byte>> queryVectorsBinary = Constants.vectorsBinary.subList(0, nq);
public String dsl = Constants.searchParam;
public String floatDsl = Constants.searchParam;
public String binaryDsl = Constants.binarySearchParam;
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchCollectionNotExisted(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(dsl).build();
SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(floatDsl).build();
SearchResponse res_search = client.search(searchParam);
assert (!res_search.getResponse().ok());
}
@Test(dataProvider = "Collection", dataProviderClass = MainClass.class)
public void testSearchCollectionEmpty(MilvusClient client, String collectionName) {
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build();
SearchResponse res_search = client.search(searchParam);
assert (res_search.getResponse().ok());
Assert.assertEquals(res_search.getResultIdsList().size(), 0);
@ -44,7 +44,7 @@ public class TestSearchEntities {
assert(res.getResponse().ok());
List<Long> ids = res.getEntityIds();
client.flush(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build();
SearchResponse res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultIdsList().size(), Constants.nq);
Assert.assertEquals(res_search.getResultDistancesList().size(), Constants.nq);
@ -60,7 +60,7 @@ public class TestSearchEntities {
assert(res.getResponse().ok());
List<Long> ids = res.getEntityIds();
client.flush(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).build();
SearchResponse res_search = client.search(searchParam);
for (int i = 0; i < Constants.nq; i++) {
double distance = res_search.getResultDistancesList().get(i).get(0);
@ -94,7 +94,7 @@ public class TestSearchEntities {
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
client.flush(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).withPartitionTags(queryTags).build();
SearchResponse res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
}
@ -109,7 +109,7 @@ public class TestSearchEntities {
InsertResponse res = client.insert(insertParam);
assert (res.getResponse().ok());
client.flush(collectionName);
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(dsl).withPartitionTags(queryTags).build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(floatDsl).withPartitionTags(queryTags).build();
SearchResponse res_search = client.search(searchParam);
Assert.assertEquals(res_search.getResultDistancesList().size(), 0);
}
@ -289,18 +289,7 @@ public class TestSearchEntities {
@Test(dataProvider = "BinaryCollection", dataProviderClass = MainClass.class)
public void testSearchCollectionNotExistedBinary(MilvusClient client, String collectionName) {
String collectionNameNew = Utils.genUniqueStr(collectionName);
String queryKey = "placeholder";
Map<String, List<ByteBuffer>> binaryQueryEntities = new HashMap<>();
binaryQueryEntities.put(queryKey, queryVectorsBinary);
JSONObject binaryVectorParam = Utils.genBinaryVectorParam(Constants.defaultBinaryMetricType, queryKey, top_k, n_probe);
List<JSONObject> leafParams = new ArrayList<>();
leafParams.add(binaryVectorParam);
String dsl = Utils.genDefaultSearchParam(leafParams);
System.out.println(dsl);
SearchParam searchParam = new SearchParam.Builder(collectionNameNew)
.withBinaryEntities(binaryQueryEntities)
.withDSL(dsl)
.build();
SearchParam searchParam = new SearchParam.Builder(collectionNameNew).withDSL(binaryDsl).build();
SearchResponse resSearch = client.search(searchParam);
Assert.assertFalse(resSearch.getResponse().ok());
}
@ -311,18 +300,7 @@ public class TestSearchEntities {
InsertParam insertParam = new InsertParam.Builder(collectionName).withFields(Constants.defaultBinaryEntities).build();
InsertResponse res = client.insert(insertParam);
assert(res.getResponse().ok());
String queryKey = "placeholder";
Map<String, List<ByteBuffer>> binaryQueryEntities = new HashMap<>();
binaryQueryEntities.put(queryKey, queryVectorsBinary);
JSONObject binaryVectorParam = Utils.genBinaryVectorParam(Constants.defaultBinaryMetricType, queryKey, top_k, n_probe);
List<JSONObject> leafParams = new ArrayList<>();
leafParams.add(binaryVectorParam);
String dsl = Utils.genDefaultSearchParam(leafParams);
System.out.println(dsl);
SearchParam searchParam = new SearchParam.Builder(collectionName)
.withDSL(dsl)
.withBinaryEntities(binaryQueryEntities)
.build();
SearchParam searchParam = new SearchParam.Builder(collectionName).withDSL(binaryDsl).build();
SearchResponse resSearch = client.search(searchParam);
Assert.assertTrue(resSearch.getResponse().ok());
Assert.assertEquals(resSearch.getResultIdsList().size(), nq);

View File

@ -27,8 +27,8 @@ public class Utils {
}
public static List<List<Float>> genVectors(int vectorCount, int dimension, boolean norm) {
List<List<Float>> vectors = new ArrayList<>();
Random random = new Random();
List<List<Float>> vectors = new ArrayList<>();
for (int i = 0; i < vectorCount; ++i) {
List<Float> vector = new ArrayList<>();
for (int j = 0; j < dimension; ++j) {
@ -42,14 +42,16 @@ public class Utils {
return vectors;
}
static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
static List<List<Byte>> genBinaryVectors(long vectorCount, long dimension) {
Random random = new Random();
List<ByteBuffer> vectors = new ArrayList<>();
List<List<Byte>> vectors = new ArrayList<>();
final long dimensionInByte = dimension / 8;
for (long i = 0; i < vectorCount; ++i) {
ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
random.nextBytes(byteBuffer.array());
vectors.add(byteBuffer);
List<Byte> vector = new ArrayList<>();
for (int j = 0; j < dimensionInByte; ++j) {
vector.add((byte) random.nextInt());
}
vectors.add(vector);
}
return vectors;
}
@ -111,7 +113,7 @@ public class Utils {
return fieldsMap;
}
public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<ByteBuffer> vectorsBinary){
public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<List<Byte>> vectorsBinary){
List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
@ -163,7 +165,7 @@ public class Utils {
return searchParam;
}
static JSONObject genBinaryVectorParam(String metricType, String queryVectors, int topk, int nprobe) {
static JSONObject genBinaryVectorParam(String metricType, List<List<Byte>> queryVectors, int topk, int nprobe) {
JSONObject searchParam = new JSONObject();
JSONObject fieldParam = new JSONObject();
fieldParam.put("topk", topk);
@ -200,7 +202,7 @@ public class Utils {
return JSONObject.toJSONString(boolParam);
}
public static String setBinarySearchParam(String metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
public static String setBinarySearchParam(String metricType, List<List<Byte>> queryVectors, int topk, int nprobe) {
JSONObject searchParam = new JSONObject();
JSONObject fieldParam = new JSONObject();
fieldParam.put("topk", topk);
@ -367,7 +369,7 @@ public class Utils {
}
public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
List<ByteBuffer> vectorsBinary) {
List<List<Byte>> vectorsBinary) {
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
for (int i = 0; i < vectorCount; ++i) {
@ -390,7 +392,7 @@ public class Utils {
}
public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
List<ByteBuffer> vectorsBinary, List<Long> entityIds) {
List<List<Byte>> vectorsBinary, List<Long> entityIds) {
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
for (int i = 0; i < vectorCount; ++i) {
@ -414,7 +416,7 @@ public class Utils {
}
public static InsertParam genDefaultBinaryInsertParam(String collectionName, int dimension, int vectorCount,
List<ByteBuffer> vectorsBinary, String tag) {
List<List<Byte>> vectorsBinary, String tag) {
List<Long> intValues = new ArrayList<>(vectorCount);
List<Float> floatValues = new ArrayList<>(vectorCount);
for (int i = 0; i < vectorCount; ++i) {