mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
unittest pass
This commit is contained in:
parent
0de0921ad8
commit
06ca60aa8d
@ -18,7 +18,6 @@ class Scheduler(metaclass=Singleton):
|
||||
assert k != 0
|
||||
|
||||
query_vectors = serialize.to_array(vectors)
|
||||
|
||||
return self.__scheduler(index_file_key, query_vectors, k)
|
||||
|
||||
|
||||
@ -33,11 +32,12 @@ class Scheduler(metaclass=Singleton):
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
index_data_list = index_data_key['index']
|
||||
for key in index_data_list:
|
||||
index = GetIndexData(key)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
if 'index' in index_data_key:
|
||||
index_data_list = index_data_key['index']
|
||||
for key in index_data_list:
|
||||
index = GetIndexData(key)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
if len(result_list) == 1:
|
||||
return result_list[0].vectors
|
||||
|
||||
@ -12,8 +12,6 @@ logger = logging.getLogger(__name__)
|
||||
class TestVectorEngine:
|
||||
def setup_class(self):
|
||||
self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]
|
||||
self.__vector_2 = [1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]
|
||||
self.__query_vector = [[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8],[1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]]
|
||||
self.__limit = 1
|
||||
|
||||
|
||||
@ -50,23 +48,23 @@ class TestVectorEngine:
|
||||
assert code == VectorEngine.GROUP_NOT_EXIST
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector_2)
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector_2)
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector_2)
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Add vector for exist group
|
||||
code = VectorEngine.AddVector('test_group', self.__vector_2)
|
||||
code = VectorEngine.AddVector('test_group', self.__vector)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
|
||||
# Check search vector interface
|
||||
code, vector_id = VectorEngine.SearchVector('test_group', self.__query_vector, self.__limit)
|
||||
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
|
||||
assert code == VectorEngine.SUCCESS_CODE
|
||||
assert vector_id == 0
|
||||
|
||||
|
||||
@ -43,8 +43,8 @@ class TestViews:
|
||||
|
||||
|
||||
def test_vector(self, test_client):
|
||||
dimension = {"dimension": 10}
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(dimension))
|
||||
dimension = {"dimension": 8}
|
||||
resp = test_client.post('/vector/group/6', data=json.dumps(dimension), headers = TestViews.HEADERS)
|
||||
assert resp.status_code == 200
|
||||
assert self.loads(resp)['code'] == 0
|
||||
|
||||
|
||||
@ -148,7 +148,9 @@ class VectorEngine(object):
|
||||
index_map['dimension'] = group.dimension
|
||||
|
||||
scheduler_instance = Scheduler()
|
||||
result = scheduler_instance.Search(index_map, vector, limit)
|
||||
vectors = []
|
||||
vectors.append(vector)
|
||||
result = scheduler_instance.Search(index_map, vectors, limit)
|
||||
|
||||
vector_id = 0
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class VectorSearch(Resource):
|
||||
def __init__(self):
|
||||
self.__parser = reqparse.RequestParser()
|
||||
self.__parser.add_argument('vector', type=float, action='append', location=['json'])
|
||||
self.__parser.add_argument('limit', type=int, action='append', location=['json'])
|
||||
self.__parser.add_argument('limit', type=int, location=['json'])
|
||||
|
||||
def get(self, group_id):
|
||||
args = self.__parser.parse_args()
|
||||
@ -51,7 +51,7 @@ class Group(Resource):
|
||||
def __init__(self):
|
||||
self.__parser = reqparse.RequestParser()
|
||||
self.__parser.add_argument('group_id', type=str)
|
||||
self.__parser.add_argument('dimension', type=int, action='append', location=['json'])
|
||||
self.__parser.add_argument('dimension', type=int, location=['json'])
|
||||
|
||||
def post(self, group_id):
|
||||
args = self.__parser.parse_args()
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from engine.controller import scheduler
|
||||
|
||||
# scheduler.Scheduler.Search()
|
||||
Loading…
x
Reference in New Issue
Block a user