!28 dev to main

Merge pull request !28 from xgc/dev
This commit is contained in:
xgc 2024-07-25 03:21:04 +00:00 committed by Gitee
commit 09b3cfb73a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 105 additions and 26 deletions

View File

@ -4,11 +4,10 @@ import com.alibaba.fastjson2.JSON;
import io.milvus.exception.MilvusException; import io.milvus.exception.MilvusException;
import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.ConsistencyLevel; import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.vector.request.GetReq; import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.request.data.BaseVector; import io.milvus.v2.service.vector.request.data.BaseVector;
import io.milvus.v2.service.vector.request.data.FloatVec; import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
import io.milvus.v2.service.vector.response.GetResp; import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.QueryResp; import io.milvus.v2.service.vector.response.QueryResp;
import io.milvus.v2.service.vector.response.SearchResp; import io.milvus.v2.service.vector.response.SearchResp;
@ -53,6 +52,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
private MilvusClientV2 client; private MilvusClientV2 client;
private Map<String, Object> searchParams = new HashMap<>(16); private Map<String, Object> searchParams = new HashMap<>(16);
private List<LambdaQueryWrapper<T>> hybridWrapper=new ArrayList<>();
private BaseRanker ranker;
public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) { public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) {
this.collectionName = collectionName; this.collectionName = collectionName;
this.client = client; this.client = client;
@ -63,6 +66,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
public LambdaQueryWrapper() { public LambdaQueryWrapper() {
} }
public LambdaQueryWrapper<T> hybrid(LambdaQueryWrapper<T> wrapper) {
this.hybridWrapper.add(wrapper);
return this;
}
/** /**
* 添加集合别名 * 添加集合别名
@ -408,6 +415,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
this.annsField=annsField.getFieldName(annsField); this.annsField=annsField.getFieldName(annsField);
return this; return this;
} }
public LambdaQueryWrapper<T> ranker(BaseRanker ranker){
this.ranker=ranker;
return this;
}
public LambdaQueryWrapper<T> vector(List<? extends Float> vector) { public LambdaQueryWrapper<T> vector(List<? extends Float> vector) {
BaseVector baseVector = new FloatVec((List<Float>) vector); BaseVector baseVector = new FloatVec((List<Float>) vector);
vectors.add(baseVector); vectors.add(baseVector);
@ -426,6 +437,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
return this; return this;
} }
public LambdaQueryWrapper<T> vector(BaseVector vector) { public LambdaQueryWrapper<T> vector(BaseVector vector) {
vectors.add(vector); vectors.add(vector);
return this; return this;
@ -444,6 +456,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
this.setLimit(limit); this.setLimit(limit);
return this; return this;
} }
public LambdaQueryWrapper<T> offset(Long offset) {
this.setOffset(offset);
return this;
}
public LambdaQueryWrapper<T> topK(Integer topK) { public LambdaQueryWrapper<T> topK(Integer topK) {
this.setTopK(topK); this.setTopK(topK);
return this; return this;
@ -474,6 +491,9 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
if (limit > 0) { if (limit > 0) {
builder.limit(limit); builder.limit(limit);
} }
if(offset > 0){
builder.offset(offset);
}
if (!CollectionUtils.isEmpty(partitionNames)) { if (!CollectionUtils.isEmpty(partitionNames)) {
builder.partitionNames(partitionNames); builder.partitionNames(partitionNames);
} }
@ -493,7 +513,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
return builder.build(); return builder.build();
} }
public QueryReq buildQuery() { private QueryReq buildQuery() {
QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder() QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder()
.collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName); .collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
String filterStr = buildFilters(); String filterStr = buildFilters();
@ -506,6 +526,12 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
if (limit > 0L) { if (limit > 0L) {
builder.limit(limit); builder.limit(limit);
} }
if(offset > 0){
builder.offset(offset);
}
if(consistencyLevel!=null){
builder.consistencyLevel(consistencyLevel);
}
if (!CollectionUtils.isEmpty(partitionNames)) { if (!CollectionUtils.isEmpty(partitionNames)) {
builder.partitionNames(partitionNames); builder.partitionNames(partitionNames);
} }
@ -517,6 +543,42 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
} }
return builder.build(); return builder.build();
} }
private HybridSearchReq buildHybrid(){
//混合查询
List<AnnSearchReq> searchRequests = hybridWrapper.stream().filter(v -> StringUtils.isNotEmpty(v.getAnnsField()) && !v.getVectors().isEmpty()).map(
v -> {
AnnSearchReq.AnnSearchReqBuilder<?, ?> annBuilder = AnnSearchReq.builder()
.vectorFieldName(v.getAnnsField())
.vectors(v.getVectors());
if (v.getTopK() > 0) {
annBuilder.topK(v.getTopK());
}
String expr = v.buildFilters();
if (StringUtils.isNotEmpty(expr)) {
annBuilder.expr(expr);
}
Map<String, Object> params = v.searchParams;
if (!params.isEmpty()) {
annBuilder.params(JSON.toJSONString(params));
}
return annBuilder.build();
}
).collect(Collectors.toList());
HybridSearchReq.HybridSearchReqBuilder<?, ?> reqBuilder = HybridSearchReq.builder()
.collectionName(collectionName)
.searchRequests(searchRequests);
if(ranker!=null){
reqBuilder.ranker(ranker);
}
if(topK>0){
reqBuilder.topK(topK);
}
if(consistencyLevel!=null){
reqBuilder.consistencyLevel(consistencyLevel);
}
HybridSearchReq hybridSearchReq= reqBuilder.build();
return hybridSearchReq;
}
/** /**
* 执行搜索 * 执行搜索
@ -526,6 +588,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException{ public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException{
return executeWithRetry( return executeWithRetry(
() -> { () -> {
if(hybridWrapper.size()>0){
HybridSearchReq hybridSearchReq = buildHybrid();
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
return SearchRespConverter.convertSearchRespToMilvusResp(searchResp, entityType);
}
if (!vectors.isEmpty()) { if (!vectors.isEmpty()) {
SearchReq searchReq = buildSearch(); SearchReq searchReq = buildSearch();
log.info("Build search param--> {}", JSON.toJSONString(searchReq)); log.info("Build search param--> {}", JSON.toJSONString(searchReq));

View File

@ -1,15 +1,15 @@
package org.dromara.milvus.demo; package org.dromara.milvus.demo;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
import io.milvus.v2.service.vector.response.InsertResp; import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.UpsertResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.dromara.milvus.demo.model.Face; import org.dromara.milvus.demo.model.Face;
import org.dromara.milvus.demo.model.FaceConstants;
import org.dromara.milvus.demo.model.FaceMilvusMapper; import org.dromara.milvus.demo.model.FaceMilvusMapper;
import org.dromara.milvus.demo.model.Person; import org.dromara.milvus.demo.model.Person;
import org.dromara.milvus.plus.core.conditions.LambdaQueryWrapper;
import org.dromara.milvus.plus.model.vo.MilvusResp; import org.dromara.milvus.plus.model.vo.MilvusResp;
import org.dromara.milvus.plus.model.vo.MilvusResult; import org.dromara.milvus.plus.model.vo.MilvusResult;
import org.springframework.boot.ApplicationArguments; import org.springframework.boot.ApplicationArguments;
@ -32,10 +32,10 @@ public class ApplicationRunnerTest implements ApplicationRunner {
@Override @Override
public void run(ApplicationArguments args) { public void run(ApplicationArguments args) {
// insertFace(); insertFace();
// getByIdTest(); // getByIdTest();
// vectorQuery(); vectorQuery();
scalarQuery(); // scalarQuery();
// update(); // update();
} }
@ -75,12 +75,24 @@ public class ApplicationRunnerTest implements ApplicationRunner {
List<Float> vector = IntStream.range(0, 128) List<Float> vector = IntStream.range(0, 128)
.mapToObj(i -> (float) (Math.random() * 100)) .mapToObj(i -> (float) (Math.random() * 100))
.collect(Collectors.toList()); .collect(Collectors.toList());
MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper() List<Float> vector1 = IntStream.range(0, 128)
.vector(Face::getFaceVector, vector) .mapToObj(i -> (float) (Math.random() * 100))
.like(Face::getPersonName, "张三") .collect(Collectors.toList());
.topK(3) // MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
.query(); // .vector(Face::getFaceVector, new FloatVec(vector))
log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1)); // .like(Face::getPersonName, "张三")
// .topK(3)
// .query();
// log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
MilvusResp<List<MilvusResult<Face>>> query = mapper.queryWrapper().
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector)).topK(2)).
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector1)).topK(4)).
ranker(new RRFRanker(20)).
topK(2).
query();
log.info("向量混合查询 query--queryWrapper---{}", JSONObject.toJSONString(query));
} }
public void scalarQuery() { public void scalarQuery() {
@ -92,14 +104,14 @@ public class ApplicationRunnerTest implements ApplicationRunner {
.query(); .query();
log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2)); log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2));
} }
public void update(){ // public void update(){
Face faceTmp = new Face(); // Face faceTmp = new Face();
List<Float> vectorTmp = IntStream.range(0, 128) // List<Float> vectorTmp = IntStream.range(0, 128)
.mapToObj(j -> (float) (Math.random() * 100)) // .mapToObj(j -> (float) (Math.random() * 100))
.collect(Collectors.toList()); // .collect(Collectors.toList());
faceTmp.setFaceVector(vectorTmp); // faceTmp.setFaceVector(vectorTmp);
faceTmp.setPersonName("赵六"); // faceTmp.setPersonName("赵六");
MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp); // MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp);
System.out.printf("===="+ JSON.toJSONString(resp)); // System.out.printf("===="+ JSON.toJSONString(resp));
} // }
} }