diff --git a/README.cn.md b/README.cn.md index 1863ada..5c67388 100644 --- a/README.cn.md +++ b/README.cn.md @@ -82,15 +82,25 @@ public class Face { } ``` ``` -public static void main(String[] args) { - MilvusWrapper wrapper=new MilvusWrapper(); +@Component +public class FaceMilvusMapper extends MilvusMapper { + +} + +@Component +public class ApplicationRunnerTest implements ApplicationRunner { + @Autowired + private FaceMilvusMapper mapper; + @Override + public void run(ApplicationArguments args) throws Exception { List vector = Lists.newArrayList(0.1f,0.2f,0.3f); - MilvusResp resp = wrapper.lambda() + MilvusResp resp = mapper.lambda() .eq(Face::getPersonId,1l) - .addVector(vector) + .vector(vector) + .limit(100l) .query(); } - +} ``` diff --git a/README.md b/README.md index 5571796..c2ad08f 100644 --- a/README.md +++ b/README.md @@ -82,15 +82,25 @@ public class Face { } ``` ``` -public static void main(String[] args) { - MilvusWrapper wrapper=new MilvusWrapper(); +@Component +public class FaceMilvusMapper extends MilvusMapper { + +} + +@Component +public class ApplicationRunnerTest implements ApplicationRunner { + @Autowired + private FaceMilvusMapper mapper; + @Override + public void run(ApplicationArguments args) throws Exception { List vector = Lists.newArrayList(0.1f,0.2f,0.3f); - MilvusResp resp = wrapper.lambda() + MilvusResp resp = mapper.lambda() .eq(Face::getPersonId,1l) - .addVector(vector) + .vector(vector) + .limit(100l) .query(); } - +} ``` diff --git a/milvus-demo/src/main/java/io/github/javpower/milvus/demo/ApplicationRunnerTest.java b/milvus-demo/src/main/java/io/github/javpower/milvus/demo/ApplicationRunnerTest.java new file mode 100644 index 0000000..6366922 --- /dev/null +++ b/milvus-demo/src/main/java/io/github/javpower/milvus/demo/ApplicationRunnerTest.java @@ -0,0 +1,27 @@ +package io.github.javpower.milvus.demo; + +import com.google.common.collect.Lists; +import io.github.javpower.milvus.demo.model.Face; +import io.github.javpower.milvus.demo.test.FaceMilvusMapper; +import io.github.javpower.milvus.plus.model.MilvusResp; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationRunner; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Component +public class ApplicationRunnerTest implements ApplicationRunner { + @Autowired + private FaceMilvusMapper mapper; + @Override + public void run(ApplicationArguments args) throws Exception { + List vector = Lists.newArrayList(0.1f,0.2f,0.3f); + MilvusResp resp = mapper.lambda() + .eq(Face::getPersonId,1l) + .vector(vector) + .limit(100l) + .query(); + } +} \ No newline at end of file diff --git a/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/FaceMilvusMapper.java b/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/FaceMilvusMapper.java new file mode 100644 index 0000000..570ac0e --- /dev/null +++ b/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/FaceMilvusMapper.java @@ -0,0 +1,10 @@ +package io.github.javpower.milvus.demo.test; + +import io.github.javpower.milvus.demo.model.Face; +import io.github.javpower.milvus.plus.core.mapper.MilvusMapper; +import org.springframework.stereotype.Component; + +@Component +public class FaceMilvusMapper extends MilvusMapper { + +} diff --git a/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/TestWrapper.java b/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/TestWrapper.java deleted file mode 100644 index 6c931ef..0000000 --- a/milvus-demo/src/main/java/io/github/javpower/milvus/demo/test/TestWrapper.java +++ /dev/null @@ -1,21 +0,0 @@ -//package io.github.javpower.milvus.demo.test; -// -//import com.google.common.collect.Lists; -//import io.github.javpower.milvus.plus.core.conditions.MilvusWrapper; -//import io.github.javpower.milvus.plus.model.MilvusResp; -// -//import java.util.List; -// -//public class TestWrapper { -// public static void main(String[] args) { -// MilvusWrapper wrapper=new MilvusWrapper(); -// List vector = Lists.newArrayList(0.1f,0.2f,0.3f); -// MilvusResp resp = wrapper.lambda() -// .eq(Face::getPersonId,1l) -// .addVector(vector) -// .query(); -// } -// -// -// -//} diff --git a/milvus-demo/src/main/resources/application.yml b/milvus-demo/src/main/resources/application.yml index 3d53b32..6520f6b 100644 --- a/milvus-demo/src/main/resources/application.yml +++ b/milvus-demo/src/main/resources/application.yml @@ -2,4 +2,5 @@ server: port: 8131 milvus: uri: localhost:8999 - token: sss \ No newline at end of file + token: sss + enable: false \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusCollectionConfig.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusCollectionConfig.java index c06ede3..4a7646d 100644 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusCollectionConfig.java +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusCollectionConfig.java @@ -19,7 +19,7 @@ public class MilvusCollectionConfig implements ApplicationRunner { private final ApplicationContext applicationContext; private final MilvusCollectionService milvusCollectionService; - @Autowired + @Autowired(required = false) public MilvusCollectionConfig(ApplicationContext applicationContext, MilvusCollectionService milvusCollectionService) { this.applicationContext = applicationContext; this.milvusCollectionService = milvusCollectionService; @@ -33,6 +33,8 @@ public class MilvusCollectionConfig implements ApplicationRunner { .map(applicationContext::getType) .toArray(Class[]::new); // 调用业务处理服务 - milvusCollectionService.performBusinessLogic(Arrays.asList(annotatedClasses)); + if(milvusCollectionService!=null){ + milvusCollectionService.performBusinessLogic(Arrays.asList(annotatedClasses)); + } } } \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusConfig.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusConfig.java index e5671d4..fab88be 100644 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusConfig.java +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusConfig.java @@ -2,7 +2,6 @@ package io.github.javpower.milvus.plus.config; import io.milvus.v2.client.ConnectConfig; import io.milvus.v2.client.MilvusClientV2; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; /** @@ -10,10 +9,17 @@ import org.springframework.context.annotation.Configuration; **/ @Configuration public class MilvusConfig { - @Autowired - private MilvusProperties properties; + private final MilvusProperties properties; + + public MilvusConfig(MilvusProperties properties) { + this.properties = properties; + } + @Bean public MilvusClientV2 milvusClientV2() { + if(!properties.getEnable()){ + return null; + } ConnectConfig connectConfig = ConnectConfig.builder() .uri(properties.getUri()) .token(properties.getToken()) diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusProperties.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusProperties.java index 991a75b..c690096 100644 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusProperties.java +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/config/MilvusProperties.java @@ -10,6 +10,8 @@ import org.springframework.stereotype.Component; @ConfigurationProperties(prefix = "milvus") @Component public class MilvusProperties { + + private Boolean enable; private String uri; private String token; } \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/FieldFunction.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/FieldFunction.java index a2f64c9..8b8777f 100644 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/FieldFunction.java +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/FieldFunction.java @@ -1,8 +1,171 @@ package io.github.javpower.milvus.plus.core; -/** - * @author xgc - **/ + +import io.github.javpower.milvus.plus.annotation.MilvusField; +import org.apache.commons.lang3.StringUtils; + +import java.io.Serializable; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.function.Function; + @FunctionalInterface -public interface FieldFunction { - R apply(T entity); +public interface FieldFunction extends Function, Serializable { + + //默认配置 + String defaultSplit = ""; + Integer defaultToType = 0; + + default String getFieldName() { + String methodName = getMethodName(); + if (methodName.startsWith("get")) { + methodName = methodName.substring(3); + } + return changeFirstCharCase(methodName,false); + } + + /** + * 获取实体类的字段名称(实体声明的字段名称) + */ + default String getFieldNameLine() { + return getFieldName(this, defaultSplit); + } + + /** + * 获取实体类的字段名称 + */ + + default String getFieldName(FieldFunction fn) { + return getFieldName(fn, defaultSplit, defaultToType); + } + /** + * 获取实体类的字段名称 + * + * @param split 分隔符,多个字母自定义分隔符 + */ + default String getFieldName(FieldFunction fn, String split) { + return getFieldName(fn, split, defaultToType); + } + + /** + * 获取实体类的字段名称 + * + * @param split 分隔符,多个字母自定义分隔符 + * @param toType 转换方式,多个字母以大小写方式返回 0.不做转换 1.大写 2.小写 + */ + default String getFieldName(FieldFunction fn, String split, Integer toType) { + SerializedLambda serializedLambda = getSerializedLambdaOne(fn); + + // 从lambda信息取出method、field、class等 + String fieldName = serializedLambda.getImplMethodName().substring("get".length()); + fieldName = fieldName.replaceFirst(fieldName.charAt(0) + "", (fieldName.charAt(0) + "").toLowerCase()); + Field field; + try { + field = Class.forName(serializedLambda.getImplClass().replace("/", ".")).getDeclaredField(fieldName); + } catch (ClassNotFoundException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + // 从field取出字段名 + MilvusField collectionField = field.getAnnotation(MilvusField.class); + if (collectionField != null && StringUtils.isNotBlank(collectionField.name())) { + return collectionField.name(); + }else { + //0.不做转换 1.大写 2.小写 + switch (toType) { + case 1: + return fieldName.replaceAll("[A-Z]", split + "$0").toUpperCase(); + case 2: + return fieldName.replaceAll("[A-Z]", split + "$0").toLowerCase(); + default: + return fieldName.replaceAll("[A-Z]", split + "$0"); + } + + } + + } + + + default String getMethodName() { + return getSerializedLambda().getImplMethodName(); + } + + default Class getFieldClass() { + return getReturnType(); + } + + default SerializedLambda getSerializedLambda() { + Method method; + try { + method = getClass().getDeclaredMethod("writeReplace"); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + method.setAccessible(true); + try { + return (SerializedLambda) method.invoke(this); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + default SerializedLambda getSerializedLambdaOne(FieldFunction fn) { + // 从function取出序列化方法 + Method writeReplaceMethod; + try { + writeReplaceMethod = fn.getClass().getDeclaredMethod("writeReplace"); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + + // 从序列化方法取出序列化的lambda信息 + boolean isAccessible = writeReplaceMethod.isAccessible(); + writeReplaceMethod.setAccessible(true); + SerializedLambda serializedLambda; + try { + serializedLambda = (SerializedLambda) writeReplaceMethod.invoke(fn); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + writeReplaceMethod.setAccessible(isAccessible); + return serializedLambda; + } + + default Class getReturnType() { + SerializedLambda lambda = getSerializedLambda(); + Class className; + try { + className = Class.forName(lambda.getImplClass().replace("/", ".")); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + Method method; + try { + method = className.getMethod(getMethodName()); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + return method.getReturnType(); + } + default String changeFirstCharCase(String str, boolean capitalize) { + if (str == null || str.isEmpty()) { + return str; + } + char baseChar = str.charAt(0); + char updatedChar; + if (capitalize) { + updatedChar = Character.toUpperCase(baseChar); + } else { + updatedChar = Character.toLowerCase(baseChar); + } + + if (baseChar == updatedChar) { + return str; + } else { + char[] chars = str.toCharArray(); + chars[0] = updatedChar; + return new String(chars); + } + } + } \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/LambdaSearchWrapper.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/LambdaSearchWrapper.java new file mode 100644 index 0000000..2587553 --- /dev/null +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/LambdaSearchWrapper.java @@ -0,0 +1,357 @@ +package io.github.javpower.milvus.plus.core.conditions; + +import com.alibaba.fastjson.JSON; +import io.github.javpower.milvus.plus.cache.ConversionCache; +import io.github.javpower.milvus.plus.converter.SearchRespConverter; +import io.github.javpower.milvus.plus.core.FieldFunction; +import io.github.javpower.milvus.plus.model.MilvusResp; +import io.github.javpower.milvus.plus.service.MilvusClient; +import io.milvus.exception.MilvusException; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.service.vector.request.SearchReq; +import io.milvus.v2.service.vector.response.SearchResp; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * 搜索构建器内部类,用于构建搜索请求 + */ +@Data +@Slf4j +public class LambdaSearchWrapper { + private ConversionCache conversionCache; + private Class entityType; + private String collectionName; + private String annsField; + private int topK; + private List filters = new ArrayList<>(); + private List> vectors = new ArrayList<>(); + private long offset; + private long limit; + private int roundDecimal; + private String searchParams; + private long guaranteeTimestamp; + private ConsistencyLevel consistencyLevel; + private boolean ignoreGrowing; + private MilvusClient client; + + public LambdaSearchWrapper(String collectionName, MilvusClient client,ConversionCache conversionCache,Class entityType) { + this.collectionName = collectionName; + this.client = client; + this.conversionCache=conversionCache; + this.entityType=entityType; + } + + public LambdaSearchWrapper() { + + } + + // addVector + public LambdaSearchWrapper addVector(List vector) { + vectors.add(vector); + return this; + } + public LambdaSearchWrapper vector(List vector) { + vectors.add(vector); + return this; + } + + // Common comparison operations + public LambdaSearchWrapper eq(String fieldName, Object value) { + return addFilter(fieldName, "==", value); + } + + public LambdaSearchWrapper ne(String fieldName, Object value) { + return addFilter(fieldName, "!=", value); + } + + public LambdaSearchWrapper gt(String fieldName, Object value) { + return addFilter(fieldName, ">", value); + } + + public LambdaSearchWrapper ge(String fieldName, Object value) { + return addFilter(fieldName, ">=", value); + } + + public LambdaSearchWrapper lt(String fieldName, Object value) { + return addFilter(fieldName, "<", value); + } + + public LambdaSearchWrapper le(String fieldName, Object value) { + return addFilter(fieldName, "<=", value); + } + + // Range operation + public LambdaSearchWrapper between(String fieldName, Object start, Object end) { + String filter = String.format("%s >= %s && %s <= %s", fieldName, convertValue(start), fieldName, convertValue(end)); + filters.add(filter); + return this; + } + + // Null check + public LambdaSearchWrapper isNull(String fieldName) { + filters.add(fieldName + " == null"); + return this; + } + + public LambdaSearchWrapper isNotNull(String fieldName) { + filters.add(fieldName + " != null"); + return this; + } + + // In operator + public LambdaSearchWrapper in(String fieldName, List values) { + String valueList = values.stream() + .map(this::convertValue) + .collect(Collectors.joining(", ", "[", "]")); + filters.add(fieldName + " in " + valueList); + return this; + } + + // Like operator + public LambdaSearchWrapper like(String fieldName, String value) { + filters.add(fieldName + " like '%" + value + "%'"); + return this; + } + + // JSON array operations + public LambdaSearchWrapper jsonContains(String fieldName, Object value) { + filters.add("JSON_CONTAINS(" + fieldName + ", " + convertValue(value) + ")"); + return this; + } + + public LambdaSearchWrapper jsonContainsAll(String fieldName, List values) { + String valueList = convertValues(values); + filters.add("JSON_CONTAINS_ALL(" + fieldName + ", " + valueList + ")"); + return this; + } + + public LambdaSearchWrapper jsonContainsAny(String fieldName, List values) { + String valueList = convertValues(values); + filters.add("JSON_CONTAINS_ANY(" + fieldName + ", " + valueList + ")"); + return this; + } + + // Array operations + public LambdaSearchWrapper arrayContains(String fieldName, Object value) { + filters.add("ARRAY_CONTAINS(" + fieldName + ", " + convertValue(value) + ")"); + return this; + } + + public LambdaSearchWrapper arrayContainsAll(String fieldName, List values) { + String valueList = convertValues(values); + filters.add("ARRAY_CONTAINS_ALL(" + fieldName + ", " + valueList + ")"); + return this; + } + + public LambdaSearchWrapper arrayContainsAny(String fieldName, List values) { + String valueList = convertValues(values); + filters.add("ARRAY_CONTAINS_ANY(" + fieldName + ", " + valueList + ")"); + return this; + } + + public LambdaSearchWrapper arrayLength(String fieldName, int length) { + filters.add(fieldName + ".length() == " + length); + return this; + } + + + + public LambdaSearchWrapper eq(FieldFunction fieldName, Object value) { + return addFilter(fieldName, "==", value); + } + + public LambdaSearchWrapper ne(FieldFunction fieldName, Object value) { + return addFilter(fieldName, "!=", value); + } + + public LambdaSearchWrapper gt(FieldFunction fieldName, Object value) { + return addFilter(fieldName, ">", value); + } + + public LambdaSearchWrapper ge(FieldFunction fieldName, Object value) { + return addFilter(fieldName, ">=", value); + } + + public LambdaSearchWrapper lt(FieldFunction fieldName, Object value) { + return addFilter(fieldName, "<", value); + } + + public LambdaSearchWrapper le(FieldFunction fieldName, Object value) { + return addFilter(fieldName, "<=", value); + } + + // Range operation + public LambdaSearchWrapper between(FieldFunction fieldName, Object start, Object end) { + String fn = getFieldName(fieldName); + String filter = String.format("%s >= %s && %s <= %s", fn, convertValue(start), fn, convertValue(end)); + filters.add(filter); + return this; + } + + // Null check + public LambdaSearchWrapper isNull(FieldFunction fieldName) { + String fn = getFieldName(fieldName); + filters.add(fn + " == null"); + return this; + } + + public LambdaSearchWrapper isNotNull(FieldFunction fieldName) { + String fn = getFieldName(fieldName); + filters.add(fn + " != null"); + return this; + } + + // In operator + public LambdaSearchWrapper in(FieldFunction fieldName, List values) { + String fn = getFieldName(fieldName); + String valueList = values.stream() + .map(this::convertValue) + .collect(Collectors.joining(", ", "[", "]")); + filters.add(fn + " in " + valueList); + return this; + } + + // Like operator + public LambdaSearchWrapper like(FieldFunction fieldName, String value) { + String fn = getFieldName(fieldName); + filters.add(fn + " like '%" + value + "%'"); + return this; + } + + // JSON array operations + public LambdaSearchWrapper jsonContains(FieldFunction fieldName, Object value) { + String fn = getFieldName(fieldName); + filters.add("JSON_CONTAINS(" + fn + ", " + convertValue(value) + ")"); + return this; + } + + public LambdaSearchWrapper jsonContainsAll(FieldFunction fieldName, List values) { + String fn = getFieldName(fieldName); + String valueList = convertValues(values); + filters.add("JSON_CONTAINS_ALL(" + fn + ", " + valueList + ")"); + return this; + } + + public LambdaSearchWrapper jsonContainsAny(FieldFunction fieldName, List values) { + String fn = getFieldName(fieldName); + String valueList = convertValues(values); + filters.add("JSON_CONTAINS_ANY(" + fn + ", " + valueList + ")"); + return this; + } + + // Array operations + public LambdaSearchWrapper arrayContains(FieldFunction fieldName, Object value) { + String fn = getFieldName(fieldName); + filters.add("ARRAY_CONTAINS(" + fn + ", " + convertValue(value) + ")"); + return this; + } + + public LambdaSearchWrapper arrayContainsAll(FieldFunction fieldName, List values) { + String fn = getFieldName(fieldName); + String valueList = convertValues(values); + filters.add("ARRAY_CONTAINS_ALL(" + fn + ", " + valueList + ")"); + return this; + } + public LambdaSearchWrapper arrayContainsAny(FieldFunction fieldName, List values) { + String fn = getFieldName(fieldName); + String valueList = convertValues(values); + filters.add("ARRAY_CONTAINS_ANY(" + fn + ", " + valueList + ")"); + return this; + } + + public LambdaSearchWrapper arrayLength(FieldFunction fieldName, int length) { + String fn = getFieldName(fieldName); + filters.add(fn + ".length() == " + length); + return this; + } + + // Logic operations + public LambdaSearchWrapper and(LambdaSearchWrapper other) { + filters.add("(" + String.join(" && ", other.filters) + ")"); + return this; + } + + public LambdaSearchWrapper or(LambdaSearchWrapper other) { + filters.add("(" + String.join(" || ", other.filters) + ")"); + return this; + } + + public LambdaSearchWrapper not() { + filters.add("not (" + String.join(" && ", filters) + ")"); + return this; + } + + public LambdaSearchWrapper limit(Long limit) { + this.setLimit(limit); + return this; + } + public LambdaSearchWrapper topK(Integer topK) { + this.setTopK(topK); + return this; + } + + + // Helper methods + private String convertValue(Object value) { + if (value instanceof String) { + return "'" + value.toString().replace("'", "\\'") + "'"; + } + return value.toString(); + } + + private String convertValues(List values) { + return values.stream() + .map(this::convertValue) + .collect(Collectors.joining(", ", "[", "]")); + } + + private LambdaSearchWrapper addFilter(String fieldName, String op, Object value) { + filters.add(fieldName + " " + op + " " + convertValue(value)); + return this; + } + private LambdaSearchWrapper addFilter(FieldFunction fieldFunction, String op, Object value) { + String fieldName = getFieldName(fieldFunction); + filters.add(fieldName + " " + op + " " + convertValue(value)); + return this; + } + private String getFieldName(FieldFunction fieldFunction) { + return fieldFunction.getFieldName(fieldFunction); + } + + /** + * 构建完整的搜索请求 + * @return 搜索请求对象 + */ + private SearchReq build() { + SearchReq.SearchReqBuilder builder = SearchReq.builder() + .collectionName(collectionName) + .annsField(annsField) + .topK(topK); + if (!vectors.isEmpty()) { + builder.data(vectors); + } + String filterStr = filters.stream().collect(Collectors.joining(" && ")); + if (filterStr != null && !filterStr.isEmpty()) { + builder.filter(filterStr); + } + // Set other parameters as needed + return builder.build(); + } + + /** + * 执行搜索 + * @return 搜索响应对象 + */ + public MilvusResp query() throws MilvusException { + SearchReq searchReq = build(); + log.info("build query param-->{}", JSON.toJSONString(searchReq)); + SearchResp search = client.client.search(searchReq); + MilvusResp tMilvusResp = SearchRespConverter.convertSearchRespToMilvusResp(search, entityType); + return tMilvusResp; + } +} \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/MilvusWrapper.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/MilvusWrapper.java deleted file mode 100644 index 179717e..0000000 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/conditions/MilvusWrapper.java +++ /dev/null @@ -1,380 +0,0 @@ -package io.github.javpower.milvus.plus.core.conditions; - -import io.github.javpower.milvus.plus.annotation.MilvusCollection; -import io.github.javpower.milvus.plus.cache.ConversionCache; -import io.github.javpower.milvus.plus.cache.FieldFunctionCache; -import io.github.javpower.milvus.plus.cache.MilvusCache; -import io.github.javpower.milvus.plus.cache.PropertyCache; -import io.github.javpower.milvus.plus.converter.SearchRespConverter; -import io.github.javpower.milvus.plus.core.FieldFunction; -import io.github.javpower.milvus.plus.model.MilvusResp; -import io.github.javpower.milvus.plus.service.MilvusClient; -import io.github.javpower.milvus.plus.util.SpringUtils; -import io.milvus.exception.MilvusException; -import io.milvus.v2.common.ConsistencyLevel; -import io.milvus.v2.service.vector.request.SearchReq; -import io.milvus.v2.service.vector.response.SearchResp; -import lombok.Data; - -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -/** - * @author xgc - **/ -public class MilvusWrapper { - - - /** - * 创建搜索构建器实例 - * @param collectionName 集合名称 - * @return 返回搜索构建器 - */ - public LambdaSearchWrapper lambda() { - // 获取实例化的类的类型参数T - Type type = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0]; - Class entityType = (Class) type; - // 从实体类上获取@MilvusCollection注解 - MilvusCollection collectionAnnotation = entityType.getAnnotation(MilvusCollection.class); - if (collectionAnnotation == null) { - throw new IllegalStateException("Entity type " + entityType.getName() + " is not annotated with @MilvusCollection."); - } - ConversionCache conversionCache = MilvusCache.milvusCache.get(entityType); - String collectionName = conversionCache.getCollectionName(); - // 使用SpringUtil获取MilvusClient实例 - MilvusClient client = SpringUtils.getBean(MilvusClient.class); - // 使用注解中的集合名称创建LambdaSearchWrapper实例 - return new LambdaSearchWrapper<>(collectionName, client,conversionCache,entityType); - } - - /** - * 搜索构建器内部类,用于构建搜索请求 - */ - @Data - public static class LambdaSearchWrapper { - private ConversionCache conversionCache; - private Class entityType; - private String collectionName; - private String annsField; - private int topK; - private List filters = new ArrayList<>(); - private List> vectors = new ArrayList<>(); - private long offset; - private long limit; - private int roundDecimal; - private String searchParams; - private long guaranteeTimestamp; - private ConsistencyLevel consistencyLevel; - private boolean ignoreGrowing; - private MilvusClient client; - - public LambdaSearchWrapper(String collectionName, MilvusClient client,ConversionCache conversionCache,Class entityType) { - this.collectionName = collectionName; - this.client = client; - this.conversionCache=conversionCache; - this.entityType=entityType; - } - - public LambdaSearchWrapper() { - - } - - // addVector - - public LambdaSearchWrapper addVector(List vector) { - vectors.add(vector); - return this; - } - - // Common comparison operations - public LambdaSearchWrapper eq(String fieldName, Object value) { - return addFilter(fieldName, "==", value); - } - - public LambdaSearchWrapper ne(String fieldName, Object value) { - return addFilter(fieldName, "!=", value); - } - - public LambdaSearchWrapper gt(String fieldName, Object value) { - return addFilter(fieldName, ">", value); - } - - public LambdaSearchWrapper ge(String fieldName, Object value) { - return addFilter(fieldName, ">=", value); - } - - public LambdaSearchWrapper lt(String fieldName, Object value) { - return addFilter(fieldName, "<", value); - } - - public LambdaSearchWrapper le(String fieldName, Object value) { - return addFilter(fieldName, "<=", value); - } - - // Range operation - public LambdaSearchWrapper between(String fieldName, Object start, Object end) { - String filter = String.format("%s >= %s && %s <= %s", fieldName, convertValue(start), fieldName, convertValue(end)); - filters.add(filter); - return this; - } - - // Null check - public LambdaSearchWrapper isNull(String fieldName) { - filters.add(fieldName + " == null"); - return this; - } - - public LambdaSearchWrapper isNotNull(String fieldName) { - filters.add(fieldName + " != null"); - return this; - } - - // In operator - public LambdaSearchWrapper in(String fieldName, List values) { - String valueList = values.stream() - .map(this::convertValue) - .collect(Collectors.joining(", ", "[", "]")); - filters.add(fieldName + " in " + valueList); - return this; - } - - // Like operator - public LambdaSearchWrapper like(String fieldName, String value) { - filters.add(fieldName + " like '%" + value + "%'"); - return this; - } - - // JSON array operations - public LambdaSearchWrapper jsonContains(String fieldName, Object value) { - filters.add("JSON_CONTAINS(" + fieldName + ", " + convertValue(value) + ")"); - return this; - } - - public LambdaSearchWrapper jsonContainsAll(String fieldName, List values) { - String valueList = convertValues(values); - filters.add("JSON_CONTAINS_ALL(" + fieldName + ", " + valueList + ")"); - return this; - } - - public LambdaSearchWrapper jsonContainsAny(String fieldName, List values) { - String valueList = convertValues(values); - filters.add("JSON_CONTAINS_ANY(" + fieldName + ", " + valueList + ")"); - return this; - } - - // Array operations - public LambdaSearchWrapper arrayContains(String fieldName, Object value) { - filters.add("ARRAY_CONTAINS(" + fieldName + ", " + convertValue(value) + ")"); - return this; - } - - public LambdaSearchWrapper arrayContainsAll(String fieldName, List values) { - String valueList = convertValues(values); - filters.add("ARRAY_CONTAINS_ALL(" + fieldName + ", " + valueList + ")"); - return this; - } - - public LambdaSearchWrapper arrayContainsAny(String fieldName, List values) { - String valueList = convertValues(values); - filters.add("ARRAY_CONTAINS_ANY(" + fieldName + ", " + valueList + ")"); - return this; - } - - public LambdaSearchWrapper arrayLength(String fieldName, int length) { - filters.add(fieldName + ".length() == " + length); - return this; - } - - - - public LambdaSearchWrapper eq(FieldFunction fieldName, Object value) { - return addFilter(fieldName, "==", value); - } - - public LambdaSearchWrapper ne(FieldFunction fieldName, Object value) { - return addFilter(fieldName, "!=", value); - } - - public LambdaSearchWrapper gt(FieldFunction fieldName, Object value) { - return addFilter(fieldName, ">", value); - } - - public LambdaSearchWrapper ge(FieldFunction fieldName, Object value) { - return addFilter(fieldName, ">=", value); - } - - public LambdaSearchWrapper lt(FieldFunction fieldName, Object value) { - return addFilter(fieldName, "<", value); - } - - public LambdaSearchWrapper le(FieldFunction fieldName, Object value) { - return addFilter(fieldName, "<=", value); - } - - // Range operation - public LambdaSearchWrapper between(FieldFunction fieldName, Object start, Object end) { - String fn = getFieldName(fieldName); - String filter = String.format("%s >= %s && %s <= %s", fn, convertValue(start), fn, convertValue(end)); - filters.add(filter); - return this; - } - - // Null check - public LambdaSearchWrapper isNull(FieldFunction fieldName) { - String fn = getFieldName(fieldName); - filters.add(fn + " == null"); - return this; - } - - public LambdaSearchWrapper isNotNull(FieldFunction fieldName) { - String fn = getFieldName(fieldName); - filters.add(fn + " != null"); - return this; - } - - // In operator - public LambdaSearchWrapper in(FieldFunction fieldName, List values) { - String fn = getFieldName(fieldName); - String valueList = values.stream() - .map(this::convertValue) - .collect(Collectors.joining(", ", "[", "]")); - filters.add(fn + " in " + valueList); - return this; - } - - // Like operator - public LambdaSearchWrapper like(FieldFunction fieldName, String value) { - String fn = getFieldName(fieldName); - filters.add(fn + " like '%" + value + "%'"); - return this; - } - - // JSON array operations - public LambdaSearchWrapper jsonContains(FieldFunction fieldName, Object value) { - String fn = getFieldName(fieldName); - filters.add("JSON_CONTAINS(" + fn + ", " + convertValue(value) + ")"); - return this; - } - - public LambdaSearchWrapper jsonContainsAll(FieldFunction fieldName, List values) { - String fn = getFieldName(fieldName); - String valueList = convertValues(values); - filters.add("JSON_CONTAINS_ALL(" + fn + ", " + valueList + ")"); - return this; - } - - public LambdaSearchWrapper jsonContainsAny(FieldFunction fieldName, List values) { - String fn = getFieldName(fieldName); - String valueList = convertValues(values); - filters.add("JSON_CONTAINS_ANY(" + fn + ", " + valueList + ")"); - return this; - } - - // Array operations - public LambdaSearchWrapper arrayContains(FieldFunction fieldName, Object value) { - String fn = getFieldName(fieldName); - filters.add("ARRAY_CONTAINS(" + fn + ", " + convertValue(value) + ")"); - return this; - } - - public LambdaSearchWrapper arrayContainsAll(FieldFunction fieldName, List values) { - String fn = getFieldName(fieldName); - String valueList = convertValues(values); - filters.add("ARRAY_CONTAINS_ALL(" + fn + ", " + valueList + ")"); - return this; - } - public LambdaSearchWrapper arrayContainsAny(FieldFunction fieldName, List values) { - String fn = getFieldName(fieldName); - String valueList = convertValues(values); - filters.add("ARRAY_CONTAINS_ANY(" + fn + ", " + valueList + ")"); - return this; - } - - public LambdaSearchWrapper arrayLength(FieldFunction fieldName, int length) { - String fn = getFieldName(fieldName); - filters.add(fn + ".length() == " + length); - return this; - } - - // Logic operations - public LambdaSearchWrapper and(LambdaSearchWrapper other) { - filters.add("(" + String.join(" && ", other.filters) + ")"); - return this; - } - - public LambdaSearchWrapper or(LambdaSearchWrapper other) { - filters.add("(" + String.join(" || ", other.filters) + ")"); - return this; - } - - public LambdaSearchWrapper not() { - filters.add("not (" + String.join(" && ", filters) + ")"); - return this; - } - - // Helper methods - private String convertValue(Object value) { - if (value instanceof String) { - return "'" + value.toString().replace("'", "\\'") + "'"; - } - return value.toString(); - } - - private String convertValues(List values) { - return values.stream() - .map(this::convertValue) - .collect(Collectors.joining(", ", "[", "]")); - } - - private LambdaSearchWrapper addFilter(String fieldName, String op, Object value) { - filters.add(fieldName + " " + op + " " + convertValue(value)); - return this; - } - private LambdaSearchWrapper addFilter(FieldFunction fieldFunction, String op, Object value) { - String fieldName = getFieldName(fieldFunction); - filters.add(fieldName + " " + op + " " + convertValue(value)); - return this; - } - private String getFieldName(FieldFunction fieldFunction) { - FieldFunctionCache fieldFunctionCache = conversionCache.getFieldFunctionCache(); - String fn = fieldFunctionCache.getFieldName(fieldFunction); - PropertyCache propertyCache = conversionCache.getPropertyCache(); - String fieldName = propertyCache.functionToPropertyMap.get(fn); - return fieldName; - } - - /** - * 构建完整的搜索请求 - * @return 搜索请求对象 - */ - private SearchReq build() { - SearchReq.SearchReqBuilder builder = SearchReq.builder() - .collectionName(collectionName) - .annsField(annsField) - .topK(topK); - if (!vectors.isEmpty()) { - builder.data(vectors); - } - String filterStr = filters.stream().collect(Collectors.joining(" && ")); - if (filterStr != null && !filterStr.isEmpty()) { - builder.filter(filterStr); - } - // Set other parameters as needed - return builder.build(); - } - - /** - * 执行搜索 - * @return 搜索响应对象 - * @throws MilvusException 如果搜索执行失败 - */ - public MilvusResp query() throws MilvusException { - SearchReq searchReq = build(); - SearchResp search = client.client.search(searchReq); - MilvusResp tMilvusResp = SearchRespConverter.convertSearchRespToMilvusResp(search, entityType); - return tMilvusResp; - } - } -} \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/mapper/MilvusMapper.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/mapper/MilvusMapper.java new file mode 100644 index 0000000..6406834 --- /dev/null +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/core/mapper/MilvusMapper.java @@ -0,0 +1,39 @@ +package io.github.javpower.milvus.plus.core.mapper; + +import io.github.javpower.milvus.plus.annotation.MilvusCollection; +import io.github.javpower.milvus.plus.cache.ConversionCache; +import io.github.javpower.milvus.plus.cache.MilvusCache; +import io.github.javpower.milvus.plus.core.conditions.LambdaSearchWrapper; +import io.github.javpower.milvus.plus.service.MilvusClient; +import io.github.javpower.milvus.plus.util.SpringUtils; +import lombok.extern.slf4j.Slf4j; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * @author xgc + **/ +@Slf4j +public class MilvusMapper { + /** + * 创建搜索构建器实例 + * @return 返回搜索构建器 + */ + public LambdaSearchWrapper lambda() { + // 获取实例化的类的类型参数T + Type type = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0]; + Class entityType = (Class) type; + // 从实体类上获取@MilvusCollection注解 + MilvusCollection collectionAnnotation = entityType.getAnnotation(MilvusCollection.class); + if (collectionAnnotation == null) { + throw new IllegalStateException("Entity type " + entityType.getName() + " is not annotated with @MilvusCollection."); + } + ConversionCache conversionCache = MilvusCache.milvusCache.get(entityType); + String collectionName = conversionCache==null?null:conversionCache.getCollectionName(); + // 使用SpringUtil获取MilvusClient实例 + MilvusClient client = SpringUtils.getBean(MilvusClient.class); + // 使用注解中的集合名称创建LambdaSearchWrapper实例 + return new LambdaSearchWrapper<>(collectionName,client,conversionCache,entityType); + } +} \ No newline at end of file diff --git a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/service/MilvusClient.java b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/service/MilvusClient.java index a326738..478860a 100644 --- a/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/service/MilvusClient.java +++ b/milvus-plus-boot-starter/src/main/java/io/github/javpower/milvus/plus/service/MilvusClient.java @@ -2,18 +2,15 @@ package io.github.javpower.milvus.plus.service; import io.milvus.v2.client.MilvusClientV2; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; /** * @author xgc **/ @Service public class MilvusClient implements AutoCloseable { - - public final MilvusClientV2 client; - - public MilvusClient(MilvusClientV2 client) { - this.client = client; - } + @Autowired(required = false) + public MilvusClientV2 client; @Override public void close() throws InterruptedException {