unittest pass

This commit is contained in:
xj.lin 2019-03-25 19:20:35 +08:00
parent 0de0921ad8
commit 06ca60aa8d
6 changed files with 18 additions and 21 deletions

View File

@ -18,7 +18,6 @@ class Scheduler(metaclass=Singleton):
assert k != 0
query_vectors = serialize.to_array(vectors)
return self.__scheduler(index_file_key, query_vectors, k)
@ -33,11 +32,12 @@ class Scheduler(metaclass=Singleton):
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
index_data_list = index_data_key['index']
for key in index_data_list:
index = GetIndexData(key)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
if 'index' in index_data_key:
index_data_list = index_data_key['index']
for key in index_data_list:
index = GetIndexData(key)
searcher = search_index.FaissSearch(index)
result_list.append(searcher.search_by_vectors(vectors, k))
if len(result_list) == 1:
return result_list[0].vectors

View File

@ -12,8 +12,6 @@ logger = logging.getLogger(__name__)
class TestVectorEngine:
def setup_class(self):
self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8]
self.__vector_2 = [1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]
self.__query_vector = [[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8],[1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]]
self.__limit = 1
@ -50,23 +48,23 @@ class TestVectorEngine:
assert code == VectorEngine.GROUP_NOT_EXIST
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
code = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
code = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
code = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
# Add vector for exist group
code = VectorEngine.AddVector('test_group', self.__vector_2)
code = VectorEngine.AddVector('test_group', self.__vector)
assert code == VectorEngine.SUCCESS_CODE
# Check search vector interface
code, vector_id = VectorEngine.SearchVector('test_group', self.__query_vector, self.__limit)
code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit)
assert code == VectorEngine.SUCCESS_CODE
assert vector_id == 0

View File

@ -43,8 +43,8 @@ class TestViews:
def test_vector(self, test_client):
dimension = {"dimension": 10}
resp = test_client.post('/vector/group/6', data=json.dumps(dimension))
dimension = {"dimension": 8}
resp = test_client.post('/vector/group/6', data=json.dumps(dimension), headers = TestViews.HEADERS)
assert resp.status_code == 200
assert self.loads(resp)['code'] == 0

View File

@ -148,7 +148,9 @@ class VectorEngine(object):
index_map['dimension'] = group.dimension
scheduler_instance = Scheduler()
result = scheduler_instance.Search(index_map, vector, limit)
vectors = []
vectors.append(vector)
result = scheduler_instance.Search(index_map, vectors, limit)
vector_id = 0

View File

@ -26,7 +26,7 @@ class VectorSearch(Resource):
def __init__(self):
self.__parser = reqparse.RequestParser()
self.__parser.add_argument('vector', type=float, action='append', location=['json'])
self.__parser.add_argument('limit', type=int, action='append', location=['json'])
self.__parser.add_argument('limit', type=int, location=['json'])
def get(self, group_id):
args = self.__parser.parse_args()
@ -51,7 +51,7 @@ class Group(Resource):
def __init__(self):
self.__parser = reqparse.RequestParser()
self.__parser.add_argument('group_id', type=str)
self.__parser.add_argument('dimension', type=int, action='append', location=['json'])
self.__parser.add_argument('dimension', type=int, location=['json'])
def post(self, group_id):
args = self.__parser.parse_args()

View File

@ -1,3 +0,0 @@
from engine.controller import scheduler
# scheduler.Scheduler.Search()