diff --git a/milvus-plus-core/src/main/java/org/dromara/milvus/plus/core/conditions/LambdaQueryWrapper.java b/milvus-plus-core/src/main/java/org/dromara/milvus/plus/core/conditions/LambdaQueryWrapper.java index 6722b20..6063ba4 100644 --- a/milvus-plus-core/src/main/java/org/dromara/milvus/plus/core/conditions/LambdaQueryWrapper.java +++ b/milvus-plus-core/src/main/java/org/dromara/milvus/plus/core/conditions/LambdaQueryWrapper.java @@ -4,11 +4,10 @@ import com.alibaba.fastjson2.JSON; import io.milvus.exception.MilvusException; import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.common.ConsistencyLevel; -import io.milvus.v2.service.vector.request.GetReq; -import io.milvus.v2.service.vector.request.QueryReq; -import io.milvus.v2.service.vector.request.SearchReq; +import io.milvus.v2.service.vector.request.*; import io.milvus.v2.service.vector.request.data.BaseVector; import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.request.ranker.BaseRanker; import io.milvus.v2.service.vector.response.GetResp; import io.milvus.v2.service.vector.response.QueryResp; import io.milvus.v2.service.vector.response.SearchResp; @@ -53,6 +52,10 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr private MilvusClientV2 client; private Map searchParams = new HashMap<>(16); + private List> hybridWrapper=new ArrayList<>(); + + private BaseRanker ranker; + public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class entityType) { this.collectionName = collectionName; this.client = client; @@ -63,6 +66,10 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr public LambdaQueryWrapper() { } + public LambdaQueryWrapper hybrid(LambdaQueryWrapper wrapper) { + this.hybridWrapper.add(wrapper); + return this; + } /** * 添加集合别名 @@ -408,6 +415,10 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr this.annsField=annsField.getFieldName(annsField); return this; } + public LambdaQueryWrapper ranker(BaseRanker ranker){ + this.ranker=ranker; + return this; + } public LambdaQueryWrapper vector(List vector) { BaseVector baseVector = new FloatVec((List) vector); vectors.add(baseVector); @@ -426,6 +437,7 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr return this; } + public LambdaQueryWrapper vector(BaseVector vector) { vectors.add(vector); return this; @@ -444,6 +456,11 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr this.setLimit(limit); return this; } + public LambdaQueryWrapper offset(Long offset) { + this.setOffset(offset); + return this; + } + public LambdaQueryWrapper topK(Integer topK) { this.setTopK(topK); return this; @@ -474,6 +491,9 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr if (limit > 0) { builder.limit(limit); } + if(offset > 0){ + builder.offset(offset); + } if (!CollectionUtils.isEmpty(partitionNames)) { builder.partitionNames(partitionNames); } @@ -493,7 +513,7 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr return builder.build(); } - public QueryReq buildQuery() { + private QueryReq buildQuery() { QueryReq.QueryReqBuilder builder = QueryReq.builder() .collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName); String filterStr = buildFilters(); @@ -506,6 +526,12 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr if (limit > 0L) { builder.limit(limit); } + if(offset > 0){ + builder.offset(offset); + } + if(consistencyLevel!=null){ + builder.consistencyLevel(consistencyLevel); + } if (!CollectionUtils.isEmpty(partitionNames)) { builder.partitionNames(partitionNames); } @@ -517,6 +543,42 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr } return builder.build(); } + private HybridSearchReq buildHybrid(){ + //混合查询 + List searchRequests = hybridWrapper.stream().filter(v -> StringUtils.isNotEmpty(v.getAnnsField()) && !v.getVectors().isEmpty()).map( + v -> { + AnnSearchReq.AnnSearchReqBuilder annBuilder = AnnSearchReq.builder() + .vectorFieldName(v.getAnnsField()) + .vectors(v.getVectors()); + if (v.getTopK() > 0) { + annBuilder.topK(v.getTopK()); + } + String expr = v.buildFilters(); + if (StringUtils.isNotEmpty(expr)) { + annBuilder.expr(expr); + } + Map params = v.searchParams; + if (!params.isEmpty()) { + annBuilder.params(JSON.toJSONString(params)); + } + return annBuilder.build(); + } + ).collect(Collectors.toList()); + HybridSearchReq.HybridSearchReqBuilder reqBuilder = HybridSearchReq.builder() + .collectionName(collectionName) + .searchRequests(searchRequests); + if(ranker!=null){ + reqBuilder.ranker(ranker); + } + if(topK>0){ + reqBuilder.topK(topK); + } + if(consistencyLevel!=null){ + reqBuilder.consistencyLevel(consistencyLevel); + } + HybridSearchReq hybridSearchReq= reqBuilder.build(); + return hybridSearchReq; + } /** * 执行搜索 @@ -526,6 +588,11 @@ public class LambdaQueryWrapper extends AbstractChainWrapper implements Wr public MilvusResp>> query() throws MilvusException{ return executeWithRetry( () -> { + if(hybridWrapper.size()>0){ + HybridSearchReq hybridSearchReq = buildHybrid(); + SearchResp searchResp = client.hybridSearch(hybridSearchReq); + return SearchRespConverter.convertSearchRespToMilvusResp(searchResp, entityType); + } if (!vectors.isEmpty()) { SearchReq searchReq = buildSearch(); log.info("Build search param--> {}", JSON.toJSONString(searchReq)); diff --git a/milvus-spring-demo/src/main/java/org/dromara/milvus/demo/ApplicationRunnerTest.java b/milvus-spring-demo/src/main/java/org/dromara/milvus/demo/ApplicationRunnerTest.java index f7d630a..1c5e6f8 100644 --- a/milvus-spring-demo/src/main/java/org/dromara/milvus/demo/ApplicationRunnerTest.java +++ b/milvus-spring-demo/src/main/java/org/dromara/milvus/demo/ApplicationRunnerTest.java @@ -1,15 +1,15 @@ package org.dromara.milvus.demo; -import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; +import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.request.ranker.RRFRanker; import io.milvus.v2.service.vector.response.InsertResp; -import io.milvus.v2.service.vector.response.UpsertResp; import lombok.extern.slf4j.Slf4j; import org.dromara.milvus.demo.model.Face; -import org.dromara.milvus.demo.model.FaceConstants; import org.dromara.milvus.demo.model.FaceMilvusMapper; import org.dromara.milvus.demo.model.Person; +import org.dromara.milvus.plus.core.conditions.LambdaQueryWrapper; import org.dromara.milvus.plus.model.vo.MilvusResp; import org.dromara.milvus.plus.model.vo.MilvusResult; import org.springframework.boot.ApplicationArguments; @@ -32,10 +32,10 @@ public class ApplicationRunnerTest implements ApplicationRunner { @Override public void run(ApplicationArguments args) { - // insertFace(); + insertFace(); // getByIdTest(); -// vectorQuery(); - scalarQuery(); + vectorQuery(); + // scalarQuery(); // update(); } @@ -75,12 +75,24 @@ public class ApplicationRunnerTest implements ApplicationRunner { List vector = IntStream.range(0, 128) .mapToObj(i -> (float) (Math.random() * 100)) .collect(Collectors.toList()); - MilvusResp>> query1 = mapper.queryWrapper() - .vector(Face::getFaceVector, vector) - .like(Face::getPersonName, "张三") - .topK(3) - .query(); - log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1)); + List vector1 = IntStream.range(0, 128) + .mapToObj(i -> (float) (Math.random() * 100)) + .collect(Collectors.toList()); +// MilvusResp>> query1 = mapper.queryWrapper() +// .vector(Face::getFaceVector, new FloatVec(vector)) +// .like(Face::getPersonName, "张三") +// .topK(3) +// .query(); +// log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1)); + + MilvusResp>> query = mapper.queryWrapper(). + hybrid(new LambdaQueryWrapper().vector(Face::getFaceVector, new FloatVec(vector)).topK(2)). + hybrid(new LambdaQueryWrapper().vector(Face::getFaceVector, new FloatVec(vector1)).topK(4)). + ranker(new RRFRanker(20)). + topK(2). + query(); + log.info("向量混合查询 query--queryWrapper---{}", JSONObject.toJSONString(query)); + } public void scalarQuery() { @@ -92,14 +104,14 @@ public class ApplicationRunnerTest implements ApplicationRunner { .query(); log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2)); } - public void update(){ - Face faceTmp = new Face(); - List vectorTmp = IntStream.range(0, 128) - .mapToObj(j -> (float) (Math.random() * 100)) - .collect(Collectors.toList()); - faceTmp.setFaceVector(vectorTmp); - faceTmp.setPersonName("赵六"); - MilvusResp resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp); - System.out.printf("===="+ JSON.toJSONString(resp)); - } +// public void update(){ +// Face faceTmp = new Face(); +// List vectorTmp = IntStream.range(0, 128) +// .mapToObj(j -> (float) (Math.random() * 100)) +// .collect(Collectors.toList()); +// faceTmp.setFaceVector(vectorTmp); +// faceTmp.setPersonName("赵六"); +// MilvusResp resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp); +// System.out.printf("===="+ JSON.toJSONString(resp)); +// } } \ No newline at end of file