构建MilvusMapper及使用demo

This commit is contained in:
xgc 2024-05-10 11:11:55 +08:00
parent 3a16cd3098
commit 0ad0915ab9
14 changed files with 651 additions and 428 deletions

View File

@ -82,15 +82,25 @@ public class Face {
}
```
```
public static void main(String[] args) {
MilvusWrapper<Face> wrapper=new MilvusWrapper();
@Component
public class FaceMilvusMapper extends MilvusMapper<Face> {
}
@Component
public class ApplicationRunnerTest implements ApplicationRunner {
@Autowired
private FaceMilvusMapper mapper;
@Override
public void run(ApplicationArguments args) throws Exception {
List<Float> vector = Lists.newArrayList(0.1f,0.2f,0.3f);
MilvusResp<Face> resp = wrapper.lambda()
MilvusResp<Face> resp = mapper.lambda()
.eq(Face::getPersonId,1l)
.addVector(vector)
.vector(vector)
.limit(100l)
.query();
}
}
```

View File

@ -82,15 +82,25 @@ public class Face {
}
```
```
public static void main(String[] args) {
MilvusWrapper<Face> wrapper=new MilvusWrapper();
@Component
public class FaceMilvusMapper extends MilvusMapper<Face> {
}
@Component
public class ApplicationRunnerTest implements ApplicationRunner {
@Autowired
private FaceMilvusMapper mapper;
@Override
public void run(ApplicationArguments args) throws Exception {
List<Float> vector = Lists.newArrayList(0.1f,0.2f,0.3f);
MilvusResp<Face> resp = wrapper.lambda()
MilvusResp<Face> resp = mapper.lambda()
.eq(Face::getPersonId,1l)
.addVector(vector)
.vector(vector)
.limit(100l)
.query();
}
}
```

View File

@ -0,0 +1,27 @@
package io.github.javpower.milvus.demo;
import com.google.common.collect.Lists;
import io.github.javpower.milvus.demo.model.Face;
import io.github.javpower.milvus.demo.test.FaceMilvusMapper;
import io.github.javpower.milvus.plus.model.MilvusResp;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
public class ApplicationRunnerTest implements ApplicationRunner {
@Autowired
private FaceMilvusMapper mapper;
@Override
public void run(ApplicationArguments args) throws Exception {
List<Float> vector = Lists.newArrayList(0.1f,0.2f,0.3f);
MilvusResp<Face> resp = mapper.lambda()
.eq(Face::getPersonId,1l)
.vector(vector)
.limit(100l)
.query();
}
}

View File

@ -0,0 +1,10 @@
package io.github.javpower.milvus.demo.test;
import io.github.javpower.milvus.demo.model.Face;
import io.github.javpower.milvus.plus.core.mapper.MilvusMapper;
import org.springframework.stereotype.Component;
@Component
public class FaceMilvusMapper extends MilvusMapper<Face> {
}

View File

@ -1,21 +0,0 @@
//package io.github.javpower.milvus.demo.test;
//
//import com.google.common.collect.Lists;
//import io.github.javpower.milvus.plus.core.conditions.MilvusWrapper;
//import io.github.javpower.milvus.plus.model.MilvusResp;
//
//import java.util.List;
//
//public class TestWrapper {
// public static void main(String[] args) {
// MilvusWrapper<Face> wrapper=new MilvusWrapper();
// List<Float> vector = Lists.newArrayList(0.1f,0.2f,0.3f);
// MilvusResp<Face> resp = wrapper.lambda()
// .eq(Face::getPersonId,1l)
// .addVector(vector)
// .query();
// }
//
//
//
//}

View File

@ -2,4 +2,5 @@ server:
port: 8131
milvus:
uri: localhost:8999
token: sss
token: sss
enable: false

View File

@ -19,7 +19,7 @@ public class MilvusCollectionConfig implements ApplicationRunner {
private final ApplicationContext applicationContext;
private final MilvusCollectionService milvusCollectionService;
@Autowired
@Autowired(required = false)
public MilvusCollectionConfig(ApplicationContext applicationContext, MilvusCollectionService milvusCollectionService) {
this.applicationContext = applicationContext;
this.milvusCollectionService = milvusCollectionService;
@ -33,6 +33,8 @@ public class MilvusCollectionConfig implements ApplicationRunner {
.map(applicationContext::getType)
.toArray(Class<?>[]::new);
// 调用业务处理服务
milvusCollectionService.performBusinessLogic(Arrays.asList(annotatedClasses));
if(milvusCollectionService!=null){
milvusCollectionService.performBusinessLogic(Arrays.asList(annotatedClasses));
}
}
}

View File

@ -2,7 +2,6 @@ package io.github.javpower.milvus.plus.config;
import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
@ -10,10 +9,17 @@ import org.springframework.context.annotation.Configuration;
**/
@Configuration
public class MilvusConfig {
@Autowired
private MilvusProperties properties;
private final MilvusProperties properties;
public MilvusConfig(MilvusProperties properties) {
this.properties = properties;
}
@Bean
public MilvusClientV2 milvusClientV2() {
if(!properties.getEnable()){
return null;
}
ConnectConfig connectConfig = ConnectConfig.builder()
.uri(properties.getUri())
.token(properties.getToken())

View File

@ -10,6 +10,8 @@ import org.springframework.stereotype.Component;
@ConfigurationProperties(prefix = "milvus")
@Component
public class MilvusProperties {
private Boolean enable;
private String uri;
private String token;
}

View File

@ -1,8 +1,171 @@
package io.github.javpower.milvus.plus.core;
/**
* @author xgc
**/
import io.github.javpower.milvus.plus.annotation.MilvusField;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.function.Function;
@FunctionalInterface
public interface FieldFunction<T, R> {
R apply(T entity);
public interface FieldFunction<T,R> extends Function<T,R>, Serializable {
//默认配置
String defaultSplit = "";
Integer defaultToType = 0;
default String getFieldName() {
String methodName = getMethodName();
if (methodName.startsWith("get")) {
methodName = methodName.substring(3);
}
return changeFirstCharCase(methodName,false);
}
/**
* 获取实体类的字段名称(实体声明的字段名称)
*/
default String getFieldNameLine() {
return getFieldName(this, defaultSplit);
}
/**
* 获取实体类的字段名称
*/
default String getFieldName(FieldFunction<T, ?> fn) {
return getFieldName(fn, defaultSplit, defaultToType);
}
/**
* 获取实体类的字段名称
*
* @param split 分隔符多个字母自定义分隔符
*/
default String getFieldName(FieldFunction<T, ?> fn, String split) {
return getFieldName(fn, split, defaultToType);
}
/**
* 获取实体类的字段名称
*
* @param split 分隔符多个字母自定义分隔符
* @param toType 转换方式多个字母以大小写方式返回 0.不做转换 1.大写 2.小写
*/
default String getFieldName(FieldFunction<T, ?> fn, String split, Integer toType) {
SerializedLambda serializedLambda = getSerializedLambdaOne(fn);
// 从lambda信息取出methodfieldclass等
String fieldName = serializedLambda.getImplMethodName().substring("get".length());
fieldName = fieldName.replaceFirst(fieldName.charAt(0) + "", (fieldName.charAt(0) + "").toLowerCase());
Field field;
try {
field = Class.forName(serializedLambda.getImplClass().replace("/", ".")).getDeclaredField(fieldName);
} catch (ClassNotFoundException | NoSuchFieldException e) {
throw new RuntimeException(e);
}
// 从field取出字段名
MilvusField collectionField = field.getAnnotation(MilvusField.class);
if (collectionField != null && StringUtils.isNotBlank(collectionField.name())) {
return collectionField.name();
}else {
//0.不做转换 1.大写 2.小写
switch (toType) {
case 1:
return fieldName.replaceAll("[A-Z]", split + "$0").toUpperCase();
case 2:
return fieldName.replaceAll("[A-Z]", split + "$0").toLowerCase();
default:
return fieldName.replaceAll("[A-Z]", split + "$0");
}
}
}
default String getMethodName() {
return getSerializedLambda().getImplMethodName();
}
default Class<?> getFieldClass() {
return getReturnType();
}
default SerializedLambda getSerializedLambda() {
Method method;
try {
method = getClass().getDeclaredMethod("writeReplace");
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
method.setAccessible(true);
try {
return (SerializedLambda) method.invoke(this);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
default SerializedLambda getSerializedLambdaOne(FieldFunction<T, ?> fn) {
// 从function取出序列化方法
Method writeReplaceMethod;
try {
writeReplaceMethod = fn.getClass().getDeclaredMethod("writeReplace");
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
// 从序列化方法取出序列化的lambda信息
boolean isAccessible = writeReplaceMethod.isAccessible();
writeReplaceMethod.setAccessible(true);
SerializedLambda serializedLambda;
try {
serializedLambda = (SerializedLambda) writeReplaceMethod.invoke(fn);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
writeReplaceMethod.setAccessible(isAccessible);
return serializedLambda;
}
default Class<?> getReturnType() {
SerializedLambda lambda = getSerializedLambda();
Class<?> className;
try {
className = Class.forName(lambda.getImplClass().replace("/", "."));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
Method method;
try {
method = className.getMethod(getMethodName());
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
return method.getReturnType();
}
default String changeFirstCharCase(String str, boolean capitalize) {
if (str == null || str.isEmpty()) {
return str;
}
char baseChar = str.charAt(0);
char updatedChar;
if (capitalize) {
updatedChar = Character.toUpperCase(baseChar);
} else {
updatedChar = Character.toLowerCase(baseChar);
}
if (baseChar == updatedChar) {
return str;
} else {
char[] chars = str.toCharArray();
chars[0] = updatedChar;
return new String(chars);
}
}
}

View File

@ -0,0 +1,357 @@
package io.github.javpower.milvus.plus.core.conditions;
import com.alibaba.fastjson.JSON;
import io.github.javpower.milvus.plus.cache.ConversionCache;
import io.github.javpower.milvus.plus.converter.SearchRespConverter;
import io.github.javpower.milvus.plus.core.FieldFunction;
import io.github.javpower.milvus.plus.model.MilvusResp;
import io.github.javpower.milvus.plus.service.MilvusClient;
import io.milvus.exception.MilvusException;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* 搜索构建器内部类用于构建搜索请求
*/
@Data
@Slf4j
public class LambdaSearchWrapper<T> {
private ConversionCache<?, ?> conversionCache;
private Class<T> entityType;
private String collectionName;
private String annsField;
private int topK;
private List<String> filters = new ArrayList<>();
private List<List<Float>> vectors = new ArrayList<>();
private long offset;
private long limit;
private int roundDecimal;
private String searchParams;
private long guaranteeTimestamp;
private ConsistencyLevel consistencyLevel;
private boolean ignoreGrowing;
private MilvusClient client;
public LambdaSearchWrapper(String collectionName, MilvusClient client,ConversionCache<?, ?> conversionCache,Class<T> entityType) {
this.collectionName = collectionName;
this.client = client;
this.conversionCache=conversionCache;
this.entityType=entityType;
}
public LambdaSearchWrapper() {
}
// addVector
public LambdaSearchWrapper<T> addVector(List<Float> vector) {
vectors.add(vector);
return this;
}
public LambdaSearchWrapper<T> vector(List<Float> vector) {
vectors.add(vector);
return this;
}
// Common comparison operations
public LambdaSearchWrapper<T> eq(String fieldName, Object value) {
return addFilter(fieldName, "==", value);
}
public LambdaSearchWrapper<T> ne(String fieldName, Object value) {
return addFilter(fieldName, "!=", value);
}
public LambdaSearchWrapper<T> gt(String fieldName, Object value) {
return addFilter(fieldName, ">", value);
}
public LambdaSearchWrapper<T> ge(String fieldName, Object value) {
return addFilter(fieldName, ">=", value);
}
public LambdaSearchWrapper<T> lt(String fieldName, Object value) {
return addFilter(fieldName, "<", value);
}
public LambdaSearchWrapper<T> le(String fieldName, Object value) {
return addFilter(fieldName, "<=", value);
}
// Range operation
public LambdaSearchWrapper<T> between(String fieldName, Object start, Object end) {
String filter = String.format("%s >= %s && %s <= %s", fieldName, convertValue(start), fieldName, convertValue(end));
filters.add(filter);
return this;
}
// Null check
public LambdaSearchWrapper<T> isNull(String fieldName) {
filters.add(fieldName + " == null");
return this;
}
public LambdaSearchWrapper<T> isNotNull(String fieldName) {
filters.add(fieldName + " != null");
return this;
}
// In operator
public LambdaSearchWrapper<T> in(String fieldName, List<?> values) {
String valueList = values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
filters.add(fieldName + " in " + valueList);
return this;
}
// Like operator
public LambdaSearchWrapper<T> like(String fieldName, String value) {
filters.add(fieldName + " like '%" + value + "%'");
return this;
}
// JSON array operations
public LambdaSearchWrapper<T> jsonContains(String fieldName, Object value) {
filters.add("JSON_CONTAINS(" + fieldName + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAll(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ALL(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAny(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ANY(" + fieldName + ", " + valueList + ")");
return this;
}
// Array operations
public LambdaSearchWrapper<T> arrayContains(String fieldName, Object value) {
filters.add("ARRAY_CONTAINS(" + fieldName + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAll(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ALL(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAny(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ANY(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayLength(String fieldName, int length) {
filters.add(fieldName + ".length() == " + length);
return this;
}
public LambdaSearchWrapper<T> eq(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "==", value);
}
public LambdaSearchWrapper<T> ne(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "!=", value);
}
public LambdaSearchWrapper<T> gt(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, ">", value);
}
public LambdaSearchWrapper<T> ge(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, ">=", value);
}
public LambdaSearchWrapper<T> lt(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "<", value);
}
public LambdaSearchWrapper<T> le(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "<=", value);
}
// Range operation
public LambdaSearchWrapper<T> between(FieldFunction<T,?> fieldName, Object start, Object end) {
String fn = getFieldName(fieldName);
String filter = String.format("%s >= %s && %s <= %s", fn, convertValue(start), fn, convertValue(end));
filters.add(filter);
return this;
}
// Null check
public LambdaSearchWrapper<T> isNull(FieldFunction<T,?> fieldName) {
String fn = getFieldName(fieldName);
filters.add(fn + " == null");
return this;
}
public LambdaSearchWrapper<T> isNotNull(FieldFunction<T,?> fieldName) {
String fn = getFieldName(fieldName);
filters.add(fn + " != null");
return this;
}
// In operator
public LambdaSearchWrapper<T> in(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
filters.add(fn + " in " + valueList);
return this;
}
// Like operator
public LambdaSearchWrapper<T> like(FieldFunction<T,?> fieldName, String value) {
String fn = getFieldName(fieldName);
filters.add(fn + " like '%" + value + "%'");
return this;
}
// JSON array operations
public LambdaSearchWrapper<T> jsonContains(FieldFunction<T,?> fieldName, Object value) {
String fn = getFieldName(fieldName);
filters.add("JSON_CONTAINS(" + fn + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAll(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ALL(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAny(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ANY(" + fn + ", " + valueList + ")");
return this;
}
// Array operations
public LambdaSearchWrapper<T> arrayContains(FieldFunction<T,?> fieldName, Object value) {
String fn = getFieldName(fieldName);
filters.add("ARRAY_CONTAINS(" + fn + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAll(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ALL(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAny(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ANY(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayLength(FieldFunction<T,?> fieldName, int length) {
String fn = getFieldName(fieldName);
filters.add(fn + ".length() == " + length);
return this;
}
// Logic operations
public LambdaSearchWrapper<T> and(LambdaSearchWrapper<T> other) {
filters.add("(" + String.join(" && ", other.filters) + ")");
return this;
}
public LambdaSearchWrapper<T> or(LambdaSearchWrapper<T> other) {
filters.add("(" + String.join(" || ", other.filters) + ")");
return this;
}
public LambdaSearchWrapper<T> not() {
filters.add("not (" + String.join(" && ", filters) + ")");
return this;
}
public LambdaSearchWrapper<T> limit(Long limit) {
this.setLimit(limit);
return this;
}
public LambdaSearchWrapper<T> topK(Integer topK) {
this.setTopK(topK);
return this;
}
// Helper methods
private String convertValue(Object value) {
if (value instanceof String) {
return "'" + value.toString().replace("'", "\\'") + "'";
}
return value.toString();
}
private String convertValues(List<?> values) {
return values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
}
private LambdaSearchWrapper<T> addFilter(String fieldName, String op, Object value) {
filters.add(fieldName + " " + op + " " + convertValue(value));
return this;
}
private LambdaSearchWrapper<T> addFilter(FieldFunction<T, ?> fieldFunction, String op, Object value) {
String fieldName = getFieldName(fieldFunction);
filters.add(fieldName + " " + op + " " + convertValue(value));
return this;
}
private String getFieldName(FieldFunction<T, ?> fieldFunction) {
return fieldFunction.getFieldName(fieldFunction);
}
/**
* 构建完整的搜索请求
* @return 搜索请求对象
*/
private SearchReq build() {
SearchReq.SearchReqBuilder<?, ?> builder = SearchReq.builder()
.collectionName(collectionName)
.annsField(annsField)
.topK(topK);
if (!vectors.isEmpty()) {
builder.data(vectors);
}
String filterStr = filters.stream().collect(Collectors.joining(" && "));
if (filterStr != null && !filterStr.isEmpty()) {
builder.filter(filterStr);
}
// Set other parameters as needed
return builder.build();
}
/**
* 执行搜索
* @return 搜索响应对象
*/
public MilvusResp<T> query() throws MilvusException {
SearchReq searchReq = build();
log.info("build query param-->{}", JSON.toJSONString(searchReq));
SearchResp search = client.client.search(searchReq);
MilvusResp<T> tMilvusResp = SearchRespConverter.convertSearchRespToMilvusResp(search, entityType);
return tMilvusResp;
}
}

View File

@ -1,380 +0,0 @@
package io.github.javpower.milvus.plus.core.conditions;
import io.github.javpower.milvus.plus.annotation.MilvusCollection;
import io.github.javpower.milvus.plus.cache.ConversionCache;
import io.github.javpower.milvus.plus.cache.FieldFunctionCache;
import io.github.javpower.milvus.plus.cache.MilvusCache;
import io.github.javpower.milvus.plus.cache.PropertyCache;
import io.github.javpower.milvus.plus.converter.SearchRespConverter;
import io.github.javpower.milvus.plus.core.FieldFunction;
import io.github.javpower.milvus.plus.model.MilvusResp;
import io.github.javpower.milvus.plus.service.MilvusClient;
import io.github.javpower.milvus.plus.util.SpringUtils;
import io.milvus.exception.MilvusException;
import io.milvus.v2.common.ConsistencyLevel;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.Data;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* @author xgc
**/
public class MilvusWrapper<T> {
/**
* 创建搜索构建器实例
* @param collectionName 集合名称
* @return 返回搜索构建器
*/
public LambdaSearchWrapper<T> lambda() {
// 获取实例化的类的类型参数T
Type type = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
Class<T> entityType = (Class<T>) type;
// 从实体类上获取@MilvusCollection注解
MilvusCollection collectionAnnotation = entityType.getAnnotation(MilvusCollection.class);
if (collectionAnnotation == null) {
throw new IllegalStateException("Entity type " + entityType.getName() + " is not annotated with @MilvusCollection.");
}
ConversionCache<?, ?> conversionCache = MilvusCache.milvusCache.get(entityType);
String collectionName = conversionCache.getCollectionName();
// 使用SpringUtil获取MilvusClient实例
MilvusClient client = SpringUtils.getBean(MilvusClient.class);
// 使用注解中的集合名称创建LambdaSearchWrapper实例
return new LambdaSearchWrapper<>(collectionName, client,conversionCache,entityType);
}
/**
* 搜索构建器内部类用于构建搜索请求
*/
@Data
public static class LambdaSearchWrapper<T> {
private ConversionCache<?, ?> conversionCache;
private Class<T> entityType;
private String collectionName;
private String annsField;
private int topK;
private List<String> filters = new ArrayList<>();
private List<List<Float>> vectors = new ArrayList<>();
private long offset;
private long limit;
private int roundDecimal;
private String searchParams;
private long guaranteeTimestamp;
private ConsistencyLevel consistencyLevel;
private boolean ignoreGrowing;
private MilvusClient client;
public LambdaSearchWrapper(String collectionName, MilvusClient client,ConversionCache<?, ?> conversionCache,Class<T> entityType) {
this.collectionName = collectionName;
this.client = client;
this.conversionCache=conversionCache;
this.entityType=entityType;
}
public LambdaSearchWrapper() {
}
// addVector
public LambdaSearchWrapper<T> addVector(List<Float> vector) {
vectors.add(vector);
return this;
}
// Common comparison operations
public LambdaSearchWrapper<T> eq(String fieldName, Object value) {
return addFilter(fieldName, "==", value);
}
public LambdaSearchWrapper<T> ne(String fieldName, Object value) {
return addFilter(fieldName, "!=", value);
}
public LambdaSearchWrapper<T> gt(String fieldName, Object value) {
return addFilter(fieldName, ">", value);
}
public LambdaSearchWrapper<T> ge(String fieldName, Object value) {
return addFilter(fieldName, ">=", value);
}
public LambdaSearchWrapper<T> lt(String fieldName, Object value) {
return addFilter(fieldName, "<", value);
}
public LambdaSearchWrapper<T> le(String fieldName, Object value) {
return addFilter(fieldName, "<=", value);
}
// Range operation
public LambdaSearchWrapper<T> between(String fieldName, Object start, Object end) {
String filter = String.format("%s >= %s && %s <= %s", fieldName, convertValue(start), fieldName, convertValue(end));
filters.add(filter);
return this;
}
// Null check
public LambdaSearchWrapper<T> isNull(String fieldName) {
filters.add(fieldName + " == null");
return this;
}
public LambdaSearchWrapper<T> isNotNull(String fieldName) {
filters.add(fieldName + " != null");
return this;
}
// In operator
public LambdaSearchWrapper<T> in(String fieldName, List<?> values) {
String valueList = values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
filters.add(fieldName + " in " + valueList);
return this;
}
// Like operator
public LambdaSearchWrapper<T> like(String fieldName, String value) {
filters.add(fieldName + " like '%" + value + "%'");
return this;
}
// JSON array operations
public LambdaSearchWrapper<T> jsonContains(String fieldName, Object value) {
filters.add("JSON_CONTAINS(" + fieldName + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAll(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ALL(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAny(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ANY(" + fieldName + ", " + valueList + ")");
return this;
}
// Array operations
public LambdaSearchWrapper<T> arrayContains(String fieldName, Object value) {
filters.add("ARRAY_CONTAINS(" + fieldName + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAll(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ALL(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAny(String fieldName, List<?> values) {
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ANY(" + fieldName + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayLength(String fieldName, int length) {
filters.add(fieldName + ".length() == " + length);
return this;
}
public LambdaSearchWrapper<T> eq(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "==", value);
}
public LambdaSearchWrapper<T> ne(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "!=", value);
}
public LambdaSearchWrapper<T> gt(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, ">", value);
}
public LambdaSearchWrapper<T> ge(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, ">=", value);
}
public LambdaSearchWrapper<T> lt(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "<", value);
}
public LambdaSearchWrapper<T> le(FieldFunction<T,?> fieldName, Object value) {
return addFilter(fieldName, "<=", value);
}
// Range operation
public LambdaSearchWrapper<T> between(FieldFunction<T,?> fieldName, Object start, Object end) {
String fn = getFieldName(fieldName);
String filter = String.format("%s >= %s && %s <= %s", fn, convertValue(start), fn, convertValue(end));
filters.add(filter);
return this;
}
// Null check
public LambdaSearchWrapper<T> isNull(FieldFunction<T,?> fieldName) {
String fn = getFieldName(fieldName);
filters.add(fn + " == null");
return this;
}
public LambdaSearchWrapper<T> isNotNull(FieldFunction<T,?> fieldName) {
String fn = getFieldName(fieldName);
filters.add(fn + " != null");
return this;
}
// In operator
public LambdaSearchWrapper<T> in(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
filters.add(fn + " in " + valueList);
return this;
}
// Like operator
public LambdaSearchWrapper<T> like(FieldFunction<T,?> fieldName, String value) {
String fn = getFieldName(fieldName);
filters.add(fn + " like '%" + value + "%'");
return this;
}
// JSON array operations
public LambdaSearchWrapper<T> jsonContains(FieldFunction<T,?> fieldName, Object value) {
String fn = getFieldName(fieldName);
filters.add("JSON_CONTAINS(" + fn + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAll(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ALL(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> jsonContainsAny(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("JSON_CONTAINS_ANY(" + fn + ", " + valueList + ")");
return this;
}
// Array operations
public LambdaSearchWrapper<T> arrayContains(FieldFunction<T,?> fieldName, Object value) {
String fn = getFieldName(fieldName);
filters.add("ARRAY_CONTAINS(" + fn + ", " + convertValue(value) + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAll(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ALL(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayContainsAny(FieldFunction<T,?> fieldName, List<?> values) {
String fn = getFieldName(fieldName);
String valueList = convertValues(values);
filters.add("ARRAY_CONTAINS_ANY(" + fn + ", " + valueList + ")");
return this;
}
public LambdaSearchWrapper<T> arrayLength(FieldFunction<T,?> fieldName, int length) {
String fn = getFieldName(fieldName);
filters.add(fn + ".length() == " + length);
return this;
}
// Logic operations
public LambdaSearchWrapper<T> and(LambdaSearchWrapper<T> other) {
filters.add("(" + String.join(" && ", other.filters) + ")");
return this;
}
public LambdaSearchWrapper<T> or(LambdaSearchWrapper<T> other) {
filters.add("(" + String.join(" || ", other.filters) + ")");
return this;
}
public LambdaSearchWrapper<T> not() {
filters.add("not (" + String.join(" && ", filters) + ")");
return this;
}
// Helper methods
private String convertValue(Object value) {
if (value instanceof String) {
return "'" + value.toString().replace("'", "\\'") + "'";
}
return value.toString();
}
private String convertValues(List<?> values) {
return values.stream()
.map(this::convertValue)
.collect(Collectors.joining(", ", "[", "]"));
}
private LambdaSearchWrapper<T> addFilter(String fieldName, String op, Object value) {
filters.add(fieldName + " " + op + " " + convertValue(value));
return this;
}
private LambdaSearchWrapper<T> addFilter(FieldFunction<T, ?> fieldFunction, String op, Object value) {
String fieldName = getFieldName(fieldFunction);
filters.add(fieldName + " " + op + " " + convertValue(value));
return this;
}
private String getFieldName(FieldFunction<T, ?> fieldFunction) {
FieldFunctionCache<?, ?> fieldFunctionCache = conversionCache.getFieldFunctionCache();
String fn = fieldFunctionCache.getFieldName(fieldFunction);
PropertyCache propertyCache = conversionCache.getPropertyCache();
String fieldName = propertyCache.functionToPropertyMap.get(fn);
return fieldName;
}
/**
* 构建完整的搜索请求
* @return 搜索请求对象
*/
private SearchReq build() {
SearchReq.SearchReqBuilder<?, ?> builder = SearchReq.builder()
.collectionName(collectionName)
.annsField(annsField)
.topK(topK);
if (!vectors.isEmpty()) {
builder.data(vectors);
}
String filterStr = filters.stream().collect(Collectors.joining(" && "));
if (filterStr != null && !filterStr.isEmpty()) {
builder.filter(filterStr);
}
// Set other parameters as needed
return builder.build();
}
/**
* 执行搜索
* @return 搜索响应对象
* @throws MilvusException 如果搜索执行失败
*/
public MilvusResp<T> query() throws MilvusException {
SearchReq searchReq = build();
SearchResp search = client.client.search(searchReq);
MilvusResp<T> tMilvusResp = SearchRespConverter.convertSearchRespToMilvusResp(search, entityType);
return tMilvusResp;
}
}
}

View File

@ -0,0 +1,39 @@
package io.github.javpower.milvus.plus.core.mapper;
import io.github.javpower.milvus.plus.annotation.MilvusCollection;
import io.github.javpower.milvus.plus.cache.ConversionCache;
import io.github.javpower.milvus.plus.cache.MilvusCache;
import io.github.javpower.milvus.plus.core.conditions.LambdaSearchWrapper;
import io.github.javpower.milvus.plus.service.MilvusClient;
import io.github.javpower.milvus.plus.util.SpringUtils;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
/**
* @author xgc
**/
@Slf4j
public class MilvusMapper<T> {
/**
* 创建搜索构建器实例
* @return 返回搜索构建器
*/
public LambdaSearchWrapper<T> lambda() {
// 获取实例化的类的类型参数T
Type type = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0];
Class<T> entityType = (Class<T>) type;
// 从实体类上获取@MilvusCollection注解
MilvusCollection collectionAnnotation = entityType.getAnnotation(MilvusCollection.class);
if (collectionAnnotation == null) {
throw new IllegalStateException("Entity type " + entityType.getName() + " is not annotated with @MilvusCollection.");
}
ConversionCache<?, ?> conversionCache = MilvusCache.milvusCache.get(entityType);
String collectionName = conversionCache==null?null:conversionCache.getCollectionName();
// 使用SpringUtil获取MilvusClient实例
MilvusClient client = SpringUtils.getBean(MilvusClient.class);
// 使用注解中的集合名称创建LambdaSearchWrapper实例
return new LambdaSearchWrapper<>(collectionName,client,conversionCache,entityType);
}
}

View File

@ -2,18 +2,15 @@ package io.github.javpower.milvus.plus.service;
import io.milvus.v2.client.MilvusClientV2;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* @author xgc
**/
@Service
public class MilvusClient implements AutoCloseable {
public final MilvusClientV2 client;
public MilvusClient(MilvusClientV2 client) {
this.client = client;
}
@Autowired(required = false)
public MilvusClientV2 client;
@Override
public void close() throws InterruptedException {