feat: 支持别名查询

1、查询添加别名配置
2、格式化相关类的代码
This commit is contained in:
code2tan 2024-06-06 23:27:00 +08:00
parent 9e154449e5
commit dfcd93c8a3
3 changed files with 33 additions and 14 deletions

View File

@ -1,10 +1,12 @@
package org.dromara.milvus.plus.core.conditions; package org.dromara.milvus.plus.core.conditions;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.Iterator; import java.util.Iterator;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@EqualsAndHashCode(callSuper = true)
@Data @Data
public abstract class AbstractChainWrapper<T> extends ConditionBuilder<T>{ public abstract class AbstractChainWrapper<T> extends ConditionBuilder<T>{
protected static class ArrayIterator<T> implements Iterator<T> { protected static class ArrayIterator<T> implements Iterator<T> {

View File

@ -13,6 +13,7 @@ import io.milvus.v2.service.vector.response.SearchResp;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.dromara.milvus.plus.cache.ConversionCache; import org.dromara.milvus.plus.cache.ConversionCache;
import org.dromara.milvus.plus.converter.SearchRespConverter; import org.dromara.milvus.plus.converter.SearchRespConverter;
import org.dromara.milvus.plus.core.FieldFunction; import org.dromara.milvus.plus.core.FieldFunction;
@ -35,6 +36,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
private List<String> outputFields; private List<String> outputFields;
private Class<T> entityType; private Class<T> entityType;
private String collectionName; private String collectionName;
private String collectionAlias;
private List<String> partitionNames = new ArrayList<>(); private List<String> partitionNames = new ArrayList<>();
private String annsField; private String annsField;
@ -60,6 +62,17 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
} }
/**
* 添加集合别名
*
* @param collectionAlias 别名
* @return this
*/
public LambdaQueryWrapper<T> alias(String collectionAlias) {
this.collectionAlias = collectionAlias;
return this;
}
public LambdaQueryWrapper<T> partition(String... partitionName) { public LambdaQueryWrapper<T> partition(String... partitionName) {
this.partitionNames.addAll(Arrays.asList(partitionName)); this.partitionNames.addAll(Arrays.asList(partitionName));
return this; return this;
@ -417,7 +430,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
*/ */
private SearchReq buildSearch() { private SearchReq buildSearch() {
SearchReq.SearchReqBuilder<?, ?> builder = SearchReq.builder() SearchReq.SearchReqBuilder<?, ?> builder = SearchReq.builder()
.collectionName(collectionName); .collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
if (annsField != null && !annsField.isEmpty()) { if (annsField != null && !annsField.isEmpty()) {
builder.annsField(annsField); builder.annsField(annsField);
@ -456,15 +469,15 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
public QueryReq buildQuery() { public QueryReq buildQuery() {
QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder() QueryReq.QueryReqBuilder<?, ?> builder = QueryReq.builder()
.collectionName(collectionName); .collectionName(StringUtils.isNotBlank(collectionAlias) ? collectionAlias : collectionName);
String filterStr = buildFilters(); String filterStr = buildFilters();
if (filterStr != null && !filterStr.isEmpty()) { if (StringUtils.isNotBlank(filterStr)) {
builder.filter(filterStr); builder.filter(filterStr);
} }
if (topK > 0) { if (topK > 0) {
builder.limit(topK); builder.limit(topK);
} }
if (limit > 0) { if (limit > 0L) {
builder.limit(limit); builder.limit(limit);
} }
if (!CollectionUtils.isEmpty(partitionNames)) { if (!CollectionUtils.isEmpty(partitionNames)) {
@ -481,6 +494,7 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
/** /**
* 执行搜索 * 执行搜索
*
* @return 搜索响应对象 * @return 搜索响应对象
*/ */
public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException { public MilvusResp<List<MilvusResult<T>>> query() throws MilvusException {
@ -496,31 +510,34 @@ public class LambdaQueryWrapper<T> extends AbstractChainWrapper<T> implements Wr
return SearchRespConverter.convertGetRespToMilvusResp(query, entityType); return SearchRespConverter.convertGetRespToMilvusResp(query, entityType);
} }
} }
public MilvusResp<List<MilvusResult<T>>> query(FieldFunction<T,?> ... outputFields) throws MilvusException{
List<String> otf=new ArrayList<>(); public MilvusResp<List<MilvusResult<T>>> query(FieldFunction<T, ?>... outputFields) throws MilvusException {
List<String> otf = new ArrayList<>();
for (FieldFunction<T, ?> outputField : outputFields) { for (FieldFunction<T, ?> outputField : outputFields) {
otf.add(outputField.getFieldName(outputField)); otf.add(outputField.getFieldName(outputField));
} }
this.outputFields=otf; this.outputFields = otf;
return query(); return query();
} }
public MilvusResp<List<MilvusResult<T>>> query(String ... outputFields) throws MilvusException{
this.outputFields=Arrays.stream(outputFields).collect(Collectors.toList()); public MilvusResp<List<MilvusResult<T>>> query(String... outputFields) throws MilvusException {
this.outputFields = Arrays.stream(outputFields).collect(Collectors.toList());
return query(); return query();
} }
public MilvusResp<List<MilvusResult<T>>> getById(Serializable ... ids){
public MilvusResp<List<MilvusResult<T>>> getById(Serializable... ids) {
GetReq.GetReqBuilder<?, ?> builder = GetReq.builder() GetReq.GetReqBuilder<?, ?> builder = GetReq.builder()
.collectionName(collectionName) .collectionName(collectionName)
.ids(Arrays.asList(ids)); .ids(Arrays.asList(ids));
if(!CollectionUtils.isEmpty(partitionNames)){ if (!CollectionUtils.isEmpty(partitionNames)) {
builder.partitionName(partitionNames.get(0)); builder.partitionName(partitionNames.get(0));
} }
GetReq getReq = builder GetReq getReq = builder.build();
.build();
GetResp getResp = client.get(getReq); GetResp getResp = client.get(getReq);
return SearchRespConverter.convertGetRespToMilvusResp(getResp, entityType); return SearchRespConverter.convertGetRespToMilvusResp(getResp, entityType);
} }
@Override @Override
public void init(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) { public void init(String collectionName, MilvusClientV2 client, ConversionCache conversionCache, Class<T> entityType) {
setClient(client); setClient(client);

View File

@ -64,7 +64,7 @@ public class ApplicationRunnerTest implements ApplicationRunner {
List<Float> vector = IntStream.range(0, 128) List<Float> vector = IntStream.range(0, 128)
.mapToObj(i -> (float) (Math.random() * 100)) .mapToObj(i -> (float) (Math.random() * 100))
.collect(Collectors.toList()); .collect(Collectors.toList());
MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper() MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper().alias("alias_face")
.vector(Face::getFaceVector, vector) .vector(Face::getFaceVector, vector)
.like(Face::getPersonName, "张三") .like(Face::getPersonName, "张三")
.topK(3) .topK(3)