mirror of
https://gitee.com/dromara/MilvusPlus.git
synced 2025-12-07 01:18:23 +08:00
commit
09b3cfb73a
@ -4,11 +4,10 @@ import com.alibaba.fastjson2.JSON;
|
|||||||
import io.milvus.exception.MilvusException;
|
import io.milvus.exception.MilvusException;
|
||||||
import io.milvus.v2.client.MilvusClientV2;
|
import io.milvus.v2.client.MilvusClientV2;
|
||||||
import io.milvus.v2.common.ConsistencyLevel;
|
import io.milvus.v2.common.ConsistencyLevel;
|
||||||
import io.milvus.v2.service.vector.request.GetReq;
|
import io.milvus.v2.service.vector.request.*;
|
||||||
import io.milvus.v2.service.vector.request.QueryReq;
|
|
||||||
import io.milvus.v2.service.vector.request.SearchReq;
|
|
||||||
import io.milvus.v2.service.vector.request.data.BaseVector;
|
import io.milvus.v2.service.vector.request.data.BaseVector;
|
||||||
import io.milvus.v2.service.vector.request.data.FloatVec;
|
import io.milvus.v2.service.vector.request.data.FloatVec;
|
||||||
|
import io.milvus.v2.service.vector.request.ranker.BaseRanker;
|
||||||
import io.milvus.v2.service.vector.response.GetResp;
|
import io.milvus.v2.service.vector.response.GetResp;
|
||||||
import io.milvus.v2.service.vector.response.QueryResp;
|
import io.milvus.v2.service.vector.response.QueryResp;
|
||||||
import io.milvus.v2.service.vector.response.SearchResp;
|
import io.milvus.v2.service.vector.response.SearchResp;
|
||||||
@ -53,6 +52,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
private MilvusClientV2 client;
|
private MilvusClientV2 client;
|
||||||
private Map<String, Object> searchParams = new HashMap<>(16);
|
private Map<String, Object> searchParams = new HashMap<>(16);
|
||||||
|
|
||||||
|
private List<LambdaQueryWrapper<T>> hybridWrapper=new ArrayList<>();
|
||||||
|
|
||||||
|
private BaseRanker ranker;
|
||||||
|
|
||||||
public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) {
|
public LambdaQueryWrapper(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) {
|
||||||
this.collectionName = collectionName;
|
this.collectionName = collectionName;
|
||||||
this.client = client;
|
this.client = client;
|
||||||
@ -63,6 +66,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
public LambdaQueryWrapper() {
|
public LambdaQueryWrapper() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
public LambdaQueryWrapper<T> hybrid(LambdaQueryWrapper<T> wrapper) {
|
||||||
|
this.hybridWrapper.add(wrapper);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 添加集合别名
|
* 添加集合别名
|
||||||
@ -408,6 +415,10 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
this.annsField=annsField.getFieldName(annsField);
|
this.annsField=annsField.getFieldName(annsField);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
public LambdaQueryWrapper<T> ranker(BaseRanker ranker){
|
||||||
|
this.ranker=ranker;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
public LambdaQueryWrapper<T> vector(List<? extends Float> vector) {
|
public LambdaQueryWrapper<T> vector(List<? extends Float> vector) {
|
||||||
BaseVector baseVector = new FloatVec((List<Float>) vector);
|
BaseVector baseVector = new FloatVec((List<Float>) vector);
|
||||||
vectors.add(baseVector);
|
vectors.add(baseVector);
|
||||||
@ -426,6 +437,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public LambdaQueryWrapper<T> vector(BaseVector vector) {
|
public LambdaQueryWrapper<T> vector(BaseVector vector) {
|
||||||
vectors.add(vector);
|
vectors.add(vector);
|
||||||
return this;
|
return this;
|
||||||
@ -444,6 +456,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
this.setLimit(limit);
|
this.setLimit(limit);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
public LambdaQueryWrapper<T> offset(Long offset) {
|
||||||
|
this.setOffset(offset);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public LambdaQueryWrapper<T> topK(Integer topK) {
|
public LambdaQueryWrapper<T> topK(Integer topK) {
|
||||||
this.setTopK(topK);
|
this.setTopK(topK);
|
||||||
return this;
|
return this;
|
||||||
@ -474,6 +491,9 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
if (limit > 0) {
|
if (limit > 0) {
|
||||||
builder.limit(limit);
|
builder.limit(limit);
|
||||||
}
|
}
|
||||||
|
if(offset > 0){
|
||||||
|
builder.offset(offset);
|
||||||
|
}
|
||||||
if (!CollectionUtils.isEmpty(partitionNames)) {
|
if (!CollectionUtils.isEmpty(partitionNames)) {
|
||||||
builder.partitionNames(partitionNames);
|
builder.partitionNames(partitionNames);
|
||||||
}
|
}
|
||||||
@ -493,7 +513,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
return builder.build();
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public QueryReq buildQuery() {
|
private QueryReq buildQuery() {
|
||||||
QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder()
|
QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder()
|
||||||
.collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
|
.collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
|
||||||
String filterStr = buildFilters();
|
String filterStr = buildFilters();
|
||||||
@ -506,6 +526,12 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
if (limit > 0L) {
|
if (limit > 0L) {
|
||||||
builder.limit(limit);
|
builder.limit(limit);
|
||||||
}
|
}
|
||||||
|
if(offset > 0){
|
||||||
|
builder.offset(offset);
|
||||||
|
}
|
||||||
|
if(consistencyLevel!=null){
|
||||||
|
builder.consistencyLevel(consistencyLevel);
|
||||||
|
}
|
||||||
if (!CollectionUtils.isEmpty(partitionNames)) {
|
if (!CollectionUtils.isEmpty(partitionNames)) {
|
||||||
builder.partitionNames(partitionNames);
|
builder.partitionNames(partitionNames);
|
||||||
}
|
}
|
||||||
@ -517,6 +543,42 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
}
|
}
|
||||||
return builder.build();
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
private HybridSearchReq buildHybrid(){
|
||||||
|
//混合查询
|
||||||
|
List<AnnSearchReq> searchRequests = hybridWrapper.stream().filter(v -> StringUtils.isNotEmpty(v.getAnnsField()) && !v.getVectors().isEmpty()).map(
|
||||||
|
v -> {
|
||||||
|
AnnSearchReq.AnnSearchReqBuilder<?, ?> annBuilder = AnnSearchReq.builder()
|
||||||
|
.vectorFieldName(v.getAnnsField())
|
||||||
|
.vectors(v.getVectors());
|
||||||
|
if (v.getTopK() > 0) {
|
||||||
|
annBuilder.topK(v.getTopK());
|
||||||
|
}
|
||||||
|
String expr = v.buildFilters();
|
||||||
|
if (StringUtils.isNotEmpty(expr)) {
|
||||||
|
annBuilder.expr(expr);
|
||||||
|
}
|
||||||
|
Map<String, Object> params = v.searchParams;
|
||||||
|
if (!params.isEmpty()) {
|
||||||
|
annBuilder.params(JSON.toJSONString(params));
|
||||||
|
}
|
||||||
|
return annBuilder.build();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toList());
|
||||||
|
HybridSearchReq.HybridSearchReqBuilder<?, ?> reqBuilder = HybridSearchReq.builder()
|
||||||
|
.collectionName(collectionName)
|
||||||
|
.searchRequests(searchRequests);
|
||||||
|
if(ranker!=null){
|
||||||
|
reqBuilder.ranker(ranker);
|
||||||
|
}
|
||||||
|
if(topK>0){
|
||||||
|
reqBuilder.topK(topK);
|
||||||
|
}
|
||||||
|
if(consistencyLevel!=null){
|
||||||
|
reqBuilder.consistencyLevel(consistencyLevel);
|
||||||
|
}
|
||||||
|
HybridSearchReq hybridSearchReq= reqBuilder.build();
|
||||||
|
return hybridSearchReq;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行搜索
|
* 执行搜索
|
||||||
@ -526,6 +588,11 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
|
|||||||
public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException{
|
public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException{
|
||||||
return executeWithRetry(
|
return executeWithRetry(
|
||||||
() -> {
|
() -> {
|
||||||
|
if(hybridWrapper.size()>0){
|
||||||
|
HybridSearchReq hybridSearchReq = buildHybrid();
|
||||||
|
SearchResp searchResp = client.hybridSearch(hybridSearchReq);
|
||||||
|
return SearchRespConverter.convertSearchRespToMilvusResp(searchResp, entityType);
|
||||||
|
}
|
||||||
if (!vectors.isEmpty()) {
|
if (!vectors.isEmpty()) {
|
||||||
SearchReq searchReq = buildSearch();
|
SearchReq searchReq = buildSearch();
|
||||||
log.info("Build search param--> {}", JSON.toJSONString(searchReq));
|
log.info("Build search param--> {}", JSON.toJSONString(searchReq));
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
package org.dromara.milvus.demo;
|
package org.dromara.milvus.demo;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import io.milvus.v2.service.vector.request.data.FloatVec;
|
||||||
|
import io.milvus.v2.service.vector.request.ranker.RRFRanker;
|
||||||
import io.milvus.v2.service.vector.response.InsertResp;
|
import io.milvus.v2.service.vector.response.InsertResp;
|
||||||
import io.milvus.v2.service.vector.response.UpsertResp;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.dromara.milvus.demo.model.Face;
|
import org.dromara.milvus.demo.model.Face;
|
||||||
import org.dromara.milvus.demo.model.FaceConstants;
|
|
||||||
import org.dromara.milvus.demo.model.FaceMilvusMapper;
|
import org.dromara.milvus.demo.model.FaceMilvusMapper;
|
||||||
import org.dromara.milvus.demo.model.Person;
|
import org.dromara.milvus.demo.model.Person;
|
||||||
|
import org.dromara.milvus.plus.core.conditions.LambdaQueryWrapper;
|
||||||
import org.dromara.milvus.plus.model.vo.MilvusResp;
|
import org.dromara.milvus.plus.model.vo.MilvusResp;
|
||||||
import org.dromara.milvus.plus.model.vo.MilvusResult;
|
import org.dromara.milvus.plus.model.vo.MilvusResult;
|
||||||
import org.springframework.boot.ApplicationArguments;
|
import org.springframework.boot.ApplicationArguments;
|
||||||
@ -32,10 +32,10 @@ public class ApplicationRunnerTest implements ApplicationRunner {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(ApplicationArguments args) {
|
public void run(ApplicationArguments args) {
|
||||||
// insertFace();
|
insertFace();
|
||||||
// getByIdTest();
|
// getByIdTest();
|
||||||
// vectorQuery();
|
vectorQuery();
|
||||||
scalarQuery();
|
// scalarQuery();
|
||||||
// update();
|
// update();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,12 +75,24 @@ public class ApplicationRunnerTest implements ApplicationRunner {
|
|||||||
List<Float> vector = IntStream.range(0, 128)
|
List<Float> vector = IntStream.range(0, 128)
|
||||||
.mapToObj(i -> (float) (Math.random() * 100))
|
.mapToObj(i -> (float) (Math.random() * 100))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
|
List<Float> vector1 = IntStream.range(0, 128)
|
||||||
.vector(Face::getFaceVector, vector)
|
.mapToObj(i -> (float) (Math.random() * 100))
|
||||||
.like(Face::getPersonName, "张三")
|
.collect(Collectors.toList());
|
||||||
.topK(3)
|
// MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
|
||||||
.query();
|
// .vector(Face::getFaceVector, new FloatVec(vector))
|
||||||
log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
|
// .like(Face::getPersonName, "张三")
|
||||||
|
// .topK(3)
|
||||||
|
// .query();
|
||||||
|
// log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
|
||||||
|
|
||||||
|
MilvusResp<List<MilvusResult<Face>>> query = mapper.queryWrapper().
|
||||||
|
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector)).topK(2)).
|
||||||
|
hybrid(new LambdaQueryWrapper<Face>().vector(Face::getFaceVector, new FloatVec(vector1)).topK(4)).
|
||||||
|
ranker(new RRFRanker(20)).
|
||||||
|
topK(2).
|
||||||
|
query();
|
||||||
|
log.info("向量混合查询 query--queryWrapper---{}", JSONObject.toJSONString(query));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void scalarQuery() {
|
public void scalarQuery() {
|
||||||
@ -92,14 +104,14 @@ public class ApplicationRunnerTest implements ApplicationRunner {
|
|||||||
.query();
|
.query();
|
||||||
log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2));
|
log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2));
|
||||||
}
|
}
|
||||||
public void update(){
|
// public void update(){
|
||||||
Face faceTmp = new Face();
|
// Face faceTmp = new Face();
|
||||||
List<Float> vectorTmp = IntStream.range(0, 128)
|
// List<Float> vectorTmp = IntStream.range(0, 128)
|
||||||
.mapToObj(j -> (float) (Math.random() * 100))
|
// .mapToObj(j -> (float) (Math.random() * 100))
|
||||||
.collect(Collectors.toList());
|
// .collect(Collectors.toList());
|
||||||
faceTmp.setFaceVector(vectorTmp);
|
// faceTmp.setFaceVector(vectorTmp);
|
||||||
faceTmp.setPersonName("赵六");
|
// faceTmp.setPersonName("赵六");
|
||||||
MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp);
|
// MilvusResp<UpsertResp> resp = mapper.updateWrapper().eq(FaceConstants.PERSON_NAME,"张三").update(faceTmp);
|
||||||
System.out.printf("===="+ JSON.toJSONString(resp));
|
// System.out.printf("===="+ JSON.toJSONString(resp));
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user