!28 dev to main

Merge pull request !28 from xgc/dev
This commit is contained in:
xgc 2024-07-25 03:21:04 +00:00 committed by Gitee
commit 09b3cfb73a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 105 additions and 26 deletions

View File

@ -4,11 +4,10 @@ import com.alibaba.fastjson2.JSON;
import io.milvus.exception.MilvusException;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.vector.request.GetReq;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.request.*;
import io.milvus.v2.service.vector.request.data.BaseVector;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
import io.milvus.v2.service.vector.response.GetResp;
import io.milvus.v2.service.vector.response.QueryResp;
import io.milvus.v2.service.vector.response.SearchResp;
@ -53,6 +52,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
private MilvusClientV2 client;
private Map<String, Object> searchParams = new HashMap<>(16);
private List<LambdaQueryWrapper<T>> hybridWrapper=new ArrayList<>();
private BaseRanker ranker;
public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) {
this.collectionName = collectionName;
this.client = client;
@ -63,6 +66,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
public LambdaQueryWrapper() {
}
public LambdaQueryWrapper<T> hybrid(LambdaQueryWrapper<T> wrapper) {
this.hybridWrapper.add(wrapper);
return this;
}
/**
* 添加集合别名
@ -408,6 +415,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
this.annsField=annsField.getFieldName(annsField);
return this;
}
public LambdaQueryWrapper<T> ranker(BaseRanker ranker){
this.ranker=ranker;
return this;
}
public LambdaQueryWrapper<T> vector(List<? extends Float> vector) {
BaseVector baseVector = new FloatVec((List<Float>) vector);
vectors.add(baseVector);
@ -426,6 +437,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
return this;
}
public LambdaQueryWrapper<T> vector(BaseVector vector) {
vectors.add(vector);
return this;
@ -444,6 +456,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
this.setLimit(limit);
return this;
}
public LambdaQueryWrapper<T> offset(Long offset) {
this.setOffset(offset);
return this;
}
public LambdaQueryWrapper<T> topK(Integer topK) {
this.setTopK(topK);
return this;
@ -474,6 +491,9 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
if (limit > 0) {
builder.limit(limit);
}
if(offset > 0){
builder.offset(offset);
}
if (!CollectionUtils.isEmpty(partitionNames)) {
builder.partitionNames(partitionNames);
}
@ -493,7 +513,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
return builder.build();
}
public QueryReq buildQuery() {
private QueryReq buildQuery() {
QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder()
.collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
String filterStr = buildFilters();
@ -506,6 +526,12 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
if (limit > 0L) {
builder.limit(limit);
}
if(offset > 0){
builder.offset(offset);
}
if(consistencyLevel!=null){
builder.consistencyLevel(consistencyLevel);
}
if (!CollectionUtils.isEmpty(partitionNames)) {
builder.partitionNames(partitionNames);
}
@ -517,6 +543,42 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
}
return builder.build();
}
private HybridSearchReq buildHybrid(){
//混合查询
List<AnnSearchReq> searchRequests = hybridWrapper.stream().filter(v -> StringUtils.isNotEmpty(v.getAnnsField()) && !v.getVectors().isEmpty()).map(
v -> {
AnnSearchReq.AnnSearchReqBuilder<?, ?> annBuilder = AnnSearchReq.builder()
.vectorFieldName(v.getAnnsField())
.vectors(v.getVectors());
if (v.getTopK() > 0) {
annBuilder.topK(v.getTopK());
}
String expr = v.buildFilters();
if (StringUtils.isNotEmpty(expr)) {
annBuilder.expr(expr);
}
Map<String, Object> params = v.searchParams;
if (!params.isEmpty()) {
annBuilder.params(JSON.toJSONString(params));
}
return annBuilder.build();
}
).collect(Collectors.toList());
HybridSearchReq.HybridSearchReqBuilder<?, ?> reqBuilder = HybridSearchReq.builder()
.collectionName(collectionName)
.searchRequests(searchRequests);
if(ranker!=null){
reqBuilder.ranker(ranker);
}
if(topK>0){
reqBuilder.topK(topK);
}
if(consistencyLevel!=null){
reqBuilder.consistencyLevel(consistencyLevel);
}
HybridSearchReq hybridSearchReq= reqBuilder.build();
return hybridSearchReq;
}
/**
* 执行搜索
@ -526,6 +588,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException{
return executeWithRetry(
() -> {
if(hybridWrapper.size()>0){
HybridSearchReq hybridSearchReq = buildHybrid();
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
return SearchRespConverter.convertSearchRespToMilvusResp(searchResp, entityType);
}
if (!vectors.isEmpty()) {
SearchReq searchReq = buildSearch();
log.info("Build search param--> {}", JSON.toJSONString(searchReq));

View File

@ -1,15 +1,15 @@
package org.dromara.milvus.demo;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.UpsertResp;
import lombok.extern.slf4j.Slf4j;
import org.dromara.milvus.demo.model.Face;
import org.dromara.milvus.demo.model.FaceConstants;
import org.dromara.milvus.demo.model.FaceMilvusMapper;
import org.dromara.milvus.demo.model.Person;
import org.dromara.milvus.plus.core.conditions.LambdaQueryWrapper;
import org.dromara.milvus.plus.model.vo.MilvusResp;
import org.dromara.milvus.plus.model.vo.MilvusResult;
import org.springframework.boot.ApplicationArguments;
@ -32,10 +32,10 @@ public class ApplicationRunnerTest implements ApplicationRunner {
@Override
public void run(ApplicationArguments args) {
// insertFace();
insertFace();
// getByIdTest();
// vectorQuery();
scalarQuery();
vectorQuery();
// scalarQuery();
// update();
}
@ -75,12 +75,24 @@ public class ApplicationRunnerTest implements ApplicationRunner {
List<Float> vector = IntStream.range(0, 128)
.mapToObj(i -> (float) (Math.random() * 100))
.collect(Collectors.toList());
MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
.vector(Face::getFaceVector, vector)
.like(Face::getPersonName, "张三")
.topK(3)
.query();
log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
List<Float> vector1 = IntStream.range(0, 128)
.mapToObj(i -> (float) (Math.random() * 100))
.collect(Collectors.toList());
// MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
// .vector(Face::getFaceVector, new FloatVec(vector))
// .like(Face::getPersonName, "张三")
// .topK(3)
// .query();
// log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
MilvusResp<List<MilvusResult<Face>>> query = mapper.queryWrapper().
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector)).topK(2)).
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector1)).topK(4)).
ranker(new RRFRanker(20)).
topK(2).
query();
log.info("向量混合查询 query--queryWrapper---{}", JSONObject.toJSONString(query));
}
public void scalarQuery() {
@ -92,14 +104,14 @@ public class ApplicationRunnerTest implements ApplicationRunner {
.query();
log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2));
}
public void update(){
Face faceTmp = new Face();
List<Float> vectorTmp = IntStream.range(0, 128)
.mapToObj(j -> (float) (Math.random() * 100))
.collect(Collectors.toList());
faceTmp.setFaceVector(vectorTmp);
faceTmp.setPersonName("赵六");
MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp);
System.out.printf("===="+ JSON.toJSONString(resp));
}
// public void update(){
// Face faceTmp = new Face();
// List<Float> vectorTmp = IntStream.range(0, 128)
// .mapToObj(j -> (float) (Math.random() * 100))
// .collect(Collectors.toList());
// faceTmp.setFaceVector(vectorTmp);
// faceTmp.setPersonName("赵六");
// MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp);
// System.out.printf("===="+ JSON.toJSONString(resp));
// }
}