!10 优化初始化model对象,路径扫描判空逻辑

Merge pull request !10 from xgc/dev
This commit is contained in:
xgc 2024-06-03 17:05:52 +00:00 committed by Gitee
commit 6db4e35fb9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 151 additions and 98 deletions

View File

@ -14,10 +14,7 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/** /**
* 构建器内部类用于构建insert请求 * 构建器内部类用于构建insert请求
@ -87,10 +84,11 @@ public class LambdaInsertWrapper<T> extends AbstractChainWrapper<T> implements
return resp; return resp;
} }
public MilvusResp<InsertResp> insert(T ...t) throws MilvusException { public MilvusResp<InsertResp> insert(Iterator<T> iterator) throws MilvusException {
PropertyCache propertyCache = conversionCache.getPropertyCache(); PropertyCache propertyCache = conversionCache.getPropertyCache();
List<JSONObject> jsonObjects=new ArrayList<>(); List<JSONObject> jsonObjects=new ArrayList<>();
for (T t1 : t) { while (iterator.hasNext()) {
T t1 = iterator.next();
Map<String, Object> propertiesMap = getPropertiesMap(t1); Map<String, Object> propertiesMap = getPropertiesMap(t1);
JSONObject jsonObject=new JSONObject(); JSONObject jsonObject=new JSONObject();
for (Map.Entry<String, Object> entry : propertiesMap.entrySet()) { for (Map.Entry<String, Object> entry : propertiesMap.entrySet()) {
@ -100,7 +98,6 @@ public class LambdaInsertWrapper<T> extends AbstractChainWrapper<T> implements
jsonObject.put(tk,value); jsonObject.put(tk,value);
} }
jsonObjects.add(jsonObject); jsonObjects.add(jsonObject);
} }
return insert(jsonObjects); return insert(jsonObjects);
} }

View File

@ -18,6 +18,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -404,26 +405,29 @@ public class LambdaUpdateWrapper<T> extends AbstractChainWrapper<T> implements
return resp; return resp;
} }
public MilvusResp<UpsertResp> updateById(T ...t) throws MilvusException { public MilvusResp<UpsertResp> updateById(Iterator<T> iterator) throws MilvusException {
PropertyCache propertyCache = conversionCache.getPropertyCache(); PropertyCache propertyCache = conversionCache.getPropertyCache();
String pk = CollectionToPrimaryCache.collectionToPrimary.get(collectionName); String pk = CollectionToPrimaryCache.collectionToPrimary.get(collectionName);
List<JSONObject> jsonObjects=new ArrayList<>(); List<JSONObject> jsonObjects = new ArrayList<>();
for (T t1 : t) { // 使用迭代器遍历可变参数
while (iterator.hasNext()) {
T t1 = iterator.next();
Map<String, Object> propertiesMap = getPropertiesMap(t1); Map<String, Object> propertiesMap = getPropertiesMap(t1);
JSONObject jsonObject=new JSONObject(); JSONObject jsonObject = new JSONObject();
for (Map.Entry<String, Object> entry : propertiesMap.entrySet()) { for (Map.Entry<String, Object> entry : propertiesMap.entrySet()) {
String key = entry.getKey(); String key = entry.getKey();
Object value = entry.getValue(); Object value = entry.getValue();
// 根据PropertyCache转换属性名
String tk = propertyCache.functionToPropertyMap.get(key); String tk = propertyCache.functionToPropertyMap.get(key);
jsonObject.put(tk,value); jsonObject.put(tk, value);
} }
// 检查是否包含主键
if (!jsonObject.containsKey(pk)) { if (!jsonObject.containsKey(pk)) {
throw new MilvusException("not find primary key",400); throw new MilvusException("not find primary key", 400);
} }
jsonObjects.add(jsonObject); jsonObjects.add(jsonObject);
} }
return update(jsonObjects); return update(jsonObjects);
} }
@Override @Override

View File

@ -15,7 +15,10 @@ import org.dromara.milvus.plus.model.vo.MilvusResult;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.ParameterizedType; import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.util.Collection;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException;
/** /**
* @author xgc * @author xgc
@ -58,6 +61,25 @@ public abstract class BaseMilvusMapper<T>{
public LambdaInsertWrapper<T> insertWrapper() { public LambdaInsertWrapper<T> insertWrapper() {
return lambda(new LambdaInsertWrapper<>()); return lambda(new LambdaInsertWrapper<>());
} }
private static class ArrayIterator<T> implements Iterator<T> {
private final T[] array;
private int index = 0;
public ArrayIterator(T[] array) {
this.array = array;
}
@Override
public boolean hasNext() {
return index < array.length;
}
@Override
public T next() {
if (!hasNext()) throw new NoSuchElementException();
return array[index++];
}
}
public MilvusResp<List<MilvusResult<T>>> getById(Serializable ... ids) { public MilvusResp<List<MilvusResult<T>>> getById(Serializable ... ids) {
@ -68,15 +90,25 @@ public abstract class BaseMilvusMapper<T>{
LambdaDeleteWrapper<T> lambda = deleteWrapper(); LambdaDeleteWrapper<T> lambda = deleteWrapper();
return lambda.removeById(ids); return lambda.removeById(ids);
} }
public MilvusResp<UpsertResp> updateById(T ... entity){
LambdaUpdateWrapper<T> lambda = updateWrapper();
return lambda.updateById(entity);
}
public MilvusResp<InsertResp> insert(T ... entity){ public MilvusResp<InsertResp> insert(T ... entity){
LambdaInsertWrapper<T> lambda = insertWrapper(); LambdaInsertWrapper<T> lambda = insertWrapper();
return lambda.insert(entity); Iterator<T> iterator = new ArrayIterator<>(entity);
return lambda.insert(iterator);
}
public MilvusResp<InsertResp> insert(Collection<T> entity){
LambdaInsertWrapper<T> lambda = insertWrapper();
return lambda.insert(entity.iterator());
} }
public MilvusResp<UpsertResp> updateById(T... entity) {
LambdaUpdateWrapper<T> lambda = updateWrapper();
Iterator<T> iterator = new ArrayIterator<>(entity);
return lambda.updateById(iterator);
}
public MilvusResp<UpsertResp> updateById(Collection<T> entity) {
LambdaUpdateWrapper<T> lambda = updateWrapper();
return lambda.updateById(entity.iterator());
}
/** /**
* 创建通用构建器实例 * 创建通用构建器实例
@ -99,7 +131,4 @@ public abstract class BaseMilvusMapper<T>{
wrapper.init(collectionName,client, conversionCache, entityType); wrapper.init(collectionName,client, conversionCache, entityType);
return wrapper.wrapper(); return wrapper.wrapper();
} }
} }

View File

@ -1,12 +1,13 @@
package org.dromara.milvus.plus.service; package org.dromara.milvus.plus.service;
import org.dromara.milvus.plus.annotation.MilvusCollection;
import org.dromara.milvus.plus.cache.CollectionToPrimaryCache;
import org.dromara.milvus.plus.model.MilvusProperties;
import io.milvus.v2.client.ConnectConfig; import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.service.collection.request.ReleaseCollectionReq; import io.milvus.v2.service.collection.request.ReleaseCollectionReq;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.dromara.milvus.plus.annotation.MilvusCollection;
import org.dromara.milvus.plus.cache.CollectionToPrimaryCache;
import org.dromara.milvus.plus.model.MilvusProperties;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver; import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternResolver;
@ -15,23 +16,19 @@ import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory; import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import java.util.ArrayList; import java.util.*;
import java.util.List; import java.util.stream.Collectors;
import java.util.Set;
@Slf4j @Slf4j
public abstract class AbstractMilvusClientBuilder implements MilvusClientBuilder, ICMService { public abstract class AbstractMilvusClientBuilder implements MilvusClientBuilder, ICMService {
@Setter
protected MilvusProperties properties; protected MilvusProperties properties;
protected MilvusClientV2 client; protected MilvusClientV2 client;
private final static String CLASS="*.class"; private final static String CLASS = "*.class";
public void setProperties(MilvusProperties properties) {
this.properties = properties;
}
@Override @Override
public void initialize() { public void initialize() {
if (properties.isEnable()) { if (properties.isEnable()) {
@ -53,7 +50,7 @@ public abstract class AbstractMilvusClientBuilder implements MilvusClientBuilder
if (client != null) { if (client != null) {
//释放集合+释放client //释放集合+释放client
Set<String> co = CollectionToPrimaryCache.collectionToPrimary.keySet(); Set<String> co = CollectionToPrimaryCache.collectionToPrimary.keySet();
if(co.size()>0){ if (!co.isEmpty()) {
for (String name : co) { for (String name : co) {
ReleaseCollectionReq releaseCollectionReq = ReleaseCollectionReq.builder() ReleaseCollectionReq releaseCollectionReq = ReleaseCollectionReq.builder()
.collectionName(name) .collectionName(name)
@ -67,11 +64,12 @@ public abstract class AbstractMilvusClientBuilder implements MilvusClientBuilder
} }
public void handler(){ public void handler() {
if(client!=null){ if (Objects.isNull(client)) {
List<Class<?>> classes = getClass(properties.getPackages()); log.warn("initialize handler over!");
performBusinessLogic(classes);
} }
List<Class<?>> classes = getClass(properties.getPackages());
performBusinessLogic(classes);
} }
@Override @Override
@ -80,31 +78,36 @@ public abstract class AbstractMilvusClientBuilder implements MilvusClientBuilder
} }
//获取指定包下实体类 //获取指定包下实体类
private List<Class<?>> getClass(List<String> packages){ private List<Class<?>> getClass(List<String> packages) {
List<Class<?>> res=new ArrayList<>();
ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver(); ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
for (String pg : packages) { return Optional.ofNullable(packages)
String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + .orElseThrow(() -> new RuntimeException("model package is null, please configure the [packages] parameter"))
ClassUtils.convertClassNameToResourcePath(pg+".") + CLASS; .stream()
try { .map(pg -> {
Resource[] resources = resourcePatternResolver.getResources(pattern); List<Class<?>> res = new ArrayList<>();
MetadataReaderFactory readerFactory = new CachingMetadataReaderFactory(resourcePatternResolver); String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
for (Resource resource : resources) { + ClassUtils.convertClassNameToResourcePath(pg + ".") + CLASS;
MetadataReader reader = readerFactory.getMetadataReader(resource); try {
String classname = reader.getClassMetadata().getClassName(); Resource[] resources = resourcePatternResolver.getResources(pattern);
Class<?> clazz = Class.forName(classname); MetadataReaderFactory readerFactory = new CachingMetadataReaderFactory(resourcePatternResolver);
MilvusCollection annotation = clazz.getAnnotation(MilvusCollection.class); for (Resource resource : resources) {
if(annotation!=null){ MetadataReader reader = readerFactory.getMetadataReader(resource);
res.add(clazz); String classname = reader.getClassMetadata().getClassName();
Class<?> clazz = Class.forName(classname);
MilvusCollection annotation = clazz.getAnnotation(MilvusCollection.class);
if (annotation != null) {
res.add(clazz);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
} }
} return res;
}catch (Exception e){ }).flatMap(Collection::stream)
e.printStackTrace(); .collect(Collectors.toList());
}
}
return res;
} }
//缓存+是否构建集合
//缓存 + 是否构建集合
public void performBusinessLogic(List<Class<?>> annotatedClasses) { public void performBusinessLogic(List<Class<?>> annotatedClasses) {
for (Class<?> milvusClass : annotatedClasses) { for (Class<?> milvusClass : annotatedClasses) {
createCollection(milvusClass); createCollection(milvusClass);

View File

@ -3,7 +3,20 @@ package org.dromara.milvus.plus.service;
import io.milvus.v2.client.MilvusClientV2; import io.milvus.v2.client.MilvusClientV2;
public interface MilvusClientBuilder { public interface MilvusClientBuilder {
/**
* 初始化
*/
void initialize(); void initialize();
/**
* 关闭客户端
*/
void close() throws InterruptedException; void close() throws InterruptedException;
/**
* 获取milvus客户端
*
* @return MilvusClientV2
*/
MilvusClientV2 getClient(); MilvusClientV2 getClient();
} }

View File

@ -1,18 +1,20 @@
package org.dromara.milvus.demo; package org.dromara.milvus.demo;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import org.dromara.milvus.demo.model.Face;
import org.dromara.milvus.demo.test.FaceMilvusMapper;
import org.dromara.milvus.plus.model.vo.MilvusResp;
import org.dromara.milvus.plus.model.vo.MilvusResult;
import io.milvus.v2.service.vector.response.InsertResp; import io.milvus.v2.service.vector.response.InsertResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.dromara.milvus.demo.mapper.FaceMilvusMapper;
import org.dromara.milvus.demo.model.Face;
import org.dromara.milvus.plus.model.vo.MilvusResp;
import org.dromara.milvus.plus.model.vo.MilvusResult;
import org.springframework.boot.ApplicationArguments; import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner; import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
@Component @Component
@Slf4j @Slf4j
@ -24,53 +26,58 @@ public class ApplicationRunnerTest implements ApplicationRunner {
} }
@Override @Override
public void run(ApplicationArguments args){ public void run(ApplicationArguments args) {
face(); insertFace();
getByIdTest();
vectorQuery();
scalarQuery();
} }
private void face(){
Face face=new Face(); private void insertFace() {
List<Float> vector = new ArrayList<>(); List<Face> faces = LongStream.range(1, 10)
for (int i = 0; i < 128; i++) { .mapToObj(i -> {
vector.add((float) (Math.random() * 100)); // 这里仅作为示例使用随机数 Face faceTmp = new Face();
} faceTmp.setPersonId(i);
face.setPersonId(1l); List<Float> vectorTmp = IntStream.range(0, 128)
face.setFaceVector(vector); .mapToObj(j -> (float) (Math.random() * 100))
.collect(Collectors.toList());
faceTmp.setFaceVector(vectorTmp);
faceTmp.setPersonName(i % 2 == 0 ? "张三" + i : "李四" + i);
return faceTmp;
})
.collect(Collectors.toList());
//新增 //新增
List<Face> faces=new ArrayList<>(); MilvusResp<InsertResp> insert = mapper.insertWrapper()
for (int i = 1; i < 10 ;i++){ .partition("face_001")
Face face1=new Face(); .insert(faces.iterator());
face1.setPersonId(Long.valueOf(i));
List<Float> vector1 = new ArrayList<>();
for (int j = 0; j < 128; j++) {
vector1.add((float) (Math.random() * 100)); // 这里仅作为示例使用随机数
}
face1.setFaceVector(vector1);
if(i%2==0){
face1.setPersonName("张三"+i);
}else {
face1.setPersonName("李四"+i);
}
faces.add(face1);
}
MilvusResp<InsertResp> insert = mapper.insertWrapper().partition("face_001").insert(faces.toArray(new Face[0]));
log.info("insert--{}", JSONObject.toJSONString(insert)); log.info("insert--{}", JSONObject.toJSONString(insert));
//MilvusResp<InsertResp> insert = mapper.insert(faces.toArray(faces.toArray(new Face[0]))); log.info("insert--{}", JSONObject.toJSONString(insert)); }
public void getByIdTest() {
//id查询 //id查询
MilvusResp<List<MilvusResult<Face>>> query = mapper.getById(9l); MilvusResp<List<MilvusResult<Face>>> query = mapper.getById(9L);
log.info("query--getById---{}", JSONObject.toJSONString(query)); log.info("query--getById---{}", JSONObject.toJSONString(query));
}
public void vectorQuery() {
//向量查询 //向量查询
List<Float> vector = IntStream.range(0, 128)
.mapToObj(i -> (float) (Math.random() * 100))
.collect(Collectors.toList());
MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper() MilvusResp<List<MilvusResult<Face>>> query1 = mapper.queryWrapper()
.vector(Face::getFaceVector, vector) .vector(Face::getFaceVector, vector)
.like(Face::getPersonName, "张三") .like(Face::getPersonName, "张三")
.topK(3) .topK(3)
.query(); .query();
log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1)); log.info("向量查询 query--queryWrapper---{}", JSONObject.toJSONString(query1));
}
public void scalarQuery() {
//标量查询 //标量查询
MilvusResp<List<MilvusResult<Face>>> query2 = mapper.queryWrapper() MilvusResp<List<MilvusResult<Face>>> query2 = mapper.queryWrapper()
.eq(Face::getPersonId, 2L) .eq(Face::getPersonId, 2L)
.limit(3l) .limit(3L)
.query(); .query();
log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2)); log.info("标量查询 query--queryWrapper---{}", JSONObject.toJSONString(query2));
} }
} }

View File

@ -1,4 +1,4 @@
package org.dromara.milvus.demo.test; package org.dromara.milvus.demo.mapper;
import org.dromara.milvus.demo.model.Face; import org.dromara.milvus.demo.model.Face;
import org.dromara.milvus.plus.mapper.MilvusMapper; import org.dromara.milvus.plus.mapper.MilvusMapper;