diff --git a/pyengine/engine/controller/scheduler.py b/pyengine/engine/controller/scheduler.py index 3eab1b49b3..275284da6b 100644 --- a/pyengine/engine/controller/scheduler.py +++ b/pyengine/engine/controller/scheduler.py @@ -18,7 +18,6 @@ class Scheduler(metaclass=Singleton): assert k != 0 query_vectors = serialize.to_array(vectors) - return self.__scheduler(index_file_key, query_vectors, k) @@ -33,11 +32,12 @@ class Scheduler(metaclass=Singleton): searcher = search_index.FaissSearch(index) result_list.append(searcher.search_by_vectors(vectors, k)) - index_data_list = index_data_key['index'] - for key in index_data_list: - index = GetIndexData(key) - searcher = search_index.FaissSearch(index) - result_list.append(searcher.search_by_vectors(vectors, k)) + if 'index' in index_data_key: + index_data_list = index_data_key['index'] + for key in index_data_list: + index = GetIndexData(key) + searcher = search_index.FaissSearch(index) + result_list.append(searcher.search_by_vectors(vectors, k)) if len(result_list) == 1: return result_list[0].vectors diff --git a/pyengine/engine/controller/tests/test_vector_engine.py b/pyengine/engine/controller/tests/test_vector_engine.py index 25eb5a9ed6..5bd957ecb5 100644 --- a/pyengine/engine/controller/tests/test_vector_engine.py +++ b/pyengine/engine/controller/tests/test_vector_engine.py @@ -12,8 +12,6 @@ logger = logging.getLogger(__name__) class TestVectorEngine: def setup_class(self): self.__vector = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8] - self.__vector_2 = [1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8] - self.__query_vector = [[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8],[1.2, 2.2, 3.3, 4.5, 5.5, 6.6, 7.8, 8.8]] self.__limit = 1 @@ -50,23 +48,23 @@ class TestVectorEngine: assert code == VectorEngine.GROUP_NOT_EXIST # Add vector for exist group - code = VectorEngine.AddVector('test_group', self.__vector_2) + code = VectorEngine.AddVector('test_group', self.__vector) assert code == VectorEngine.SUCCESS_CODE # Add vector for exist group - code = VectorEngine.AddVector('test_group', self.__vector_2) + code = VectorEngine.AddVector('test_group', self.__vector) assert code == VectorEngine.SUCCESS_CODE # Add vector for exist group - code = VectorEngine.AddVector('test_group', self.__vector_2) + code = VectorEngine.AddVector('test_group', self.__vector) assert code == VectorEngine.SUCCESS_CODE # Add vector for exist group - code = VectorEngine.AddVector('test_group', self.__vector_2) + code = VectorEngine.AddVector('test_group', self.__vector) assert code == VectorEngine.SUCCESS_CODE # Check search vector interface - code, vector_id = VectorEngine.SearchVector('test_group', self.__query_vector, self.__limit) + code, vector_id = VectorEngine.SearchVector('test_group', self.__vector, self.__limit) assert code == VectorEngine.SUCCESS_CODE assert vector_id == 0 diff --git a/pyengine/engine/controller/tests/test_views.py b/pyengine/engine/controller/tests/test_views.py index eb66454676..1f28cb1b94 100644 --- a/pyengine/engine/controller/tests/test_views.py +++ b/pyengine/engine/controller/tests/test_views.py @@ -43,8 +43,8 @@ class TestViews: def test_vector(self, test_client): - dimension = {"dimension": 10} - resp = test_client.post('/vector/group/6', data=json.dumps(dimension)) + dimension = {"dimension": 8} + resp = test_client.post('/vector/group/6', data=json.dumps(dimension), headers = TestViews.HEADERS) assert resp.status_code == 200 assert self.loads(resp)['code'] == 0 diff --git a/pyengine/engine/controller/vector_engine.py b/pyengine/engine/controller/vector_engine.py index 9a44b2c02b..50a7e98046 100644 --- a/pyengine/engine/controller/vector_engine.py +++ b/pyengine/engine/controller/vector_engine.py @@ -148,7 +148,9 @@ class VectorEngine(object): index_map['dimension'] = group.dimension scheduler_instance = Scheduler() - result = scheduler_instance.Search(index_map, vector, limit) + vectors = [] + vectors.append(vector) + result = scheduler_instance.Search(index_map, vectors, limit) vector_id = 0 diff --git a/pyengine/engine/controller/views.py b/pyengine/engine/controller/views.py index 3d596fb6d5..425b337d3c 100644 --- a/pyengine/engine/controller/views.py +++ b/pyengine/engine/controller/views.py @@ -26,7 +26,7 @@ class VectorSearch(Resource): def __init__(self): self.__parser = reqparse.RequestParser() self.__parser.add_argument('vector', type=float, action='append', location=['json']) - self.__parser.add_argument('limit', type=int, action='append', location=['json']) + self.__parser.add_argument('limit', type=int, location=['json']) def get(self, group_id): args = self.__parser.parse_args() @@ -51,7 +51,7 @@ class Group(Resource): def __init__(self): self.__parser = reqparse.RequestParser() self.__parser.add_argument('group_id', type=str) - self.__parser.add_argument('dimension', type=int, action='append', location=['json']) + self.__parser.add_argument('dimension', type=int, location=['json']) def post(self, group_id): args = self.__parser.parse_args() diff --git a/pyengine/engine/retrieval/tests/scheduler_test.py b/pyengine/engine/retrieval/tests/scheduler_test.py deleted file mode 100644 index 2644cbec02..0000000000 --- a/pyengine/engine/retrieval/tests/scheduler_test.py +++ /dev/null @@ -1,3 +0,0 @@ -from engine.controller import scheduler - -# scheduler.Scheduler.Search() \ No newline at end of file