diff --git a/mishards/factories.py b/mishards/factories.py index c4037fe2d7..52c0253b39 100644 --- a/mishards/factories.py +++ b/mishards/factories.py @@ -6,6 +6,7 @@ from factory.alchemy import SQLAlchemyModelFactory from faker import Faker from faker.providers import BaseProvider +from milvus.client.types import MetricType from mishards import db from mishards.models import Tables, TableFiles @@ -27,12 +28,12 @@ class TablesFactory(SQLAlchemyModelFactory): id = factory.Faker('random_number', digits=16, fix_len=True) table_id = factory.Faker('uuid4') - state = factory.Faker('random_element', elements=(0, 1, 2, 3)) + state = factory.Faker('random_element', elements=(0, 1)) dimension = factory.Faker('random_element', elements=(256, 512)) created_on = int(time.time()) index_file_size = 0 engine_type = factory.Faker('random_element', elements=(0, 1, 2, 3)) - metric_type = factory.Faker('random_element', elements=(0, 1)) + metric_type = factory.Faker('random_element', elements=(MetricType.L2, MetricType.IP)) nlist = 16384 diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 113ec3ca20..e04965c12a 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -125,8 +125,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): range_array = [self._range_to_date(r, metadata=metadata) for r in range_array] if range_array else None routing = {} + p_span = None if self.tracer.empty else context.get_active_span().context with self.tracer.start_span('get_routing', - child_of=context.get_active_span().context): + child_of=p_span): routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata) logger.info('Routing: {}'.format(routing)) @@ -145,9 +146,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): conn = self.query_conn(addr, metadata=metadata) start = time.time() span = kwargs.get('span', None) - span = span if span else context.get_active_span().context + span = span if span else (None if self.tracer.empty else context.get_active_span().context) + with self.tracer.start_span('search_{}'.format(addr), - child_of=context.get_active_span().context): + child_of=span): ret = conn.search_vectors_in_files(table_name=query_params['table_id'], file_ids=query_params['file_ids'], query_records=vectors, @@ -160,7 +162,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): all_topk_results.append(ret) with self.tracer.start_span('do_search', - child_of=context.get_active_span().context) as span: + child_of=p_span) as span: with ThreadPoolExecutor(max_workers=workers) as pool: for addr, params in routing.items(): res = pool.submit(search, addr, params, vectors, topk, nprobe, span=span) @@ -171,9 +173,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): reverse = table_meta.metric_type == Types.MetricType.IP with self.tracer.start_span('do_merge', - child_of=context.get_active_span().context): + child_of=p_span): return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata) + def _create_table(self, table_schema): + return self.connection().create_table(table_schema) + @mark_grpc_method def CreateTable(self, request, context): _status, _table_schema = Parser.parse_proto_TableSchema(request) @@ -183,10 +188,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('CreateTable {}'.format(_table_schema['table_name'])) - _status = self.connection().create_table(_table_schema) + _status = self._create_table(_table_schema) return status_pb2.Status(error_code=_status.code, reason=_status.message) + def _has_table(self, table_name, metadata=None): + return self.connection(metadata=metadata).has_table(table_name) + @mark_grpc_method def HasTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -199,15 +207,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('HasTable {}'.format(_table_name)) - _bool = self.connection(metadata={ - 'resp_class': milvus_pb2.BoolReply - }).has_table(_table_name) + _bool = self._has_table(_table_name, metadata={ + 'resp_class': milvus_pb2.BoolReply}) return milvus_pb2.BoolReply( status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"), bool_reply=_bool ) + def _delete_table(self, table_name): + return self.connection().delete_table(table_name) + @mark_grpc_method def DropTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -217,10 +227,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('DropTable {}'.format(_table_name)) - _status = self.connection().delete_table(_table_name) + _status = self._delete_table(_table_name) return status_pb2.Status(error_code=_status.code, reason=_status.message) + def _create_index(self, table_name, index): + return self.connection().create_index(table_name, index) + @mark_grpc_method def CreateIndex(self, request, context): _status, unpacks = Parser.parse_proto_IndexParam(request) @@ -233,7 +246,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('CreateIndex {}'.format(_table_name)) # TODO: interface create_table incompleted - _status = self.connection().create_index(_table_name, _index) + _status = self._create_index(_table_name, _index) return status_pb2.Status(error_code=_status.code, reason=_status.message) @@ -298,7 +311,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('SearchVector takes: {}'.format(now - start)) topk_result_list = milvus_pb2.TopKQueryResultList( - status=status, + status=status_pb2.Status(error_code=status.error_code, reason=status.reason), topk_query_result=results ) return topk_result_list diff --git a/mishards/test_server.py b/mishards/test_server.py new file mode 100644 index 0000000000..e9a7c0d878 --- /dev/null +++ b/mishards/test_server.py @@ -0,0 +1,279 @@ +import logging +import pytest +import mock +import datetime +import random +import faker +import inspect +from milvus import Milvus +from milvus.client.types import Status, IndexType, MetricType +from milvus.client.Abstract import IndexParam, TableSchema +from milvus.grpc_gen import status_pb2, milvus_pb2 +from mishards import db, create_app, settings +from mishards.service_handler import ServiceHandler +from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser +from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables + +logger = logging.getLogger(__name__) + +OK = Status(code=Status.SUCCESS, message='Success') +BAD = Status(code=Status.PERMISSION_DENIED, message='Fail') + + +@pytest.mark.usefixtures('started_app') +class TestServer: + def client(self, port): + m = Milvus() + m.connect(host='localhost', port=port) + return m + + def test_server_start(self, started_app): + assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER + + def test_cmd(self, started_app): + ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK, + '')) + status, _ = self.client(started_app.port).server_version() + assert status.OK() + + Parser.parse_proto_Command = mock.MagicMock(return_value=(BAD, 'cmd')) + status, _ = self.client(started_app.port).server_version() + assert not status.OK() + + def test_drop_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + ServiceHandler._drop_index = mock.MagicMock(return_value=OK) + status = self.client(started_app.port).drop_index(table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client(started_app.port).drop_index(table_name) + assert not status.OK() + + def test_describe_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + index_type = IndexType.FLAT + nlist = 1 + index_param = IndexParam(table_name=table_name, + index_type=index_type, + nlist=nlist) + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._describe_index = mock.MagicMock( + return_value=(OK, index_param)) + status, ret = self.client(started_app.port).describe_index(table_name) + assert status.OK() + assert ret._table_name == index_param._table_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status, _ = self.client(started_app.port).describe_index(table_name) + assert not status.OK() + + def test_preload(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._preload_table = mock.MagicMock(return_value=OK) + status = self.client(started_app.port).preload_table(table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client(started_app.port).preload_table(table_name) + assert not status.OK() + + def test_delete_by_range(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + unpacked = table_name, datetime.datetime.today( + ), datetime.datetime.today() + + Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( + return_value=(OK, unpacked)) + ServiceHandler._delete_by_range = mock.MagicMock(return_value=OK) + status = self.client(started_app.port).delete_vectors_by_range( + *unpacked) + assert status.OK() + + Parser.parse_proto_DeleteByRangeParam = mock.MagicMock( + return_value=(BAD, unpacked)) + status = self.client(started_app.port).delete_vectors_by_range( + *unpacked) + assert not status.OK() + + def test_count_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + count = random.randint(100, 200) + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._count_table = mock.MagicMock(return_value=(OK, count)) + status, ret = self.client( + started_app.port).get_table_row_count(table_name) + assert status.OK() + assert ret == count + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status, _ = self.client( + started_app.port).get_table_row_count(table_name) + assert not status.OK() + + def test_show_tables(self, started_app): + tables = ['t1', 't2'] + ServiceHandler._show_tables = mock.MagicMock(return_value=(OK, tables)) + status, ret = self.client(started_app.port).show_tables() + assert status.OK() + assert ret == tables + + def test_describe_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + dimension = 128 + nlist = 1 + table_schema = TableSchema(table_name=table_name, + index_file_size=100, + metric_type=MetricType.L2, + dimension=dimension) + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_schema.table_name)) + ServiceHandler._describe_table = mock.MagicMock( + return_value=(OK, table_schema)) + status, _ = self.client(started_app.port).describe_table(table_name) + assert status.OK() + + ServiceHandler._describe_table = mock.MagicMock( + return_value=(BAD, table_schema)) + status, _ = self.client(started_app.port).describe_table(table_name) + assert not status.OK() + + Parser.parse_proto_TableName = mock.MagicMock(return_value=(BAD, + 'cmd')) + status, ret = self.client(started_app.port).describe_table(table_name) + assert not status.OK() + + def test_insert(self, started_app): + table_name = inspect.currentframe().f_code.co_name + vectors = [[random.random() for _ in range(16)] for _ in range(10)] + ids = [random.randint(1000000, 20000000) for _ in range(10)] + ServiceHandler._add_vectors = mock.MagicMock(return_value=(OK, ids)) + status, ret = self.client(started_app.port).add_vectors( + table_name=table_name, records=vectors) + assert status.OK() + assert ids == ret + + def test_create_index(self, started_app): + table_name = inspect.currentframe().f_code.co_name + unpacks = table_name, None + Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(OK, + unpacks)) + ServiceHandler._create_index = mock.MagicMock(return_value=OK) + status = self.client( + started_app.port).create_index(table_name=table_name) + assert status.OK() + + Parser.parse_proto_IndexParam = mock.MagicMock(return_value=(BAD, + None)) + status = self.client( + started_app.port).create_index(table_name=table_name) + assert not status.OK() + + def test_drop_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._delete_table = mock.MagicMock(return_value=OK) + status = self.client( + started_app.port).delete_table(table_name=table_name) + assert status.OK() + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + status = self.client( + started_app.port).delete_table(table_name=table_name) + assert not status.OK() + + def test_has_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(OK, table_name)) + ServiceHandler._has_table = mock.MagicMock(return_value=True) + has = self.client(started_app.port).has_table(table_name=table_name) + assert has + + Parser.parse_proto_TableName = mock.MagicMock( + return_value=(BAD, table_name)) + has = self.client(started_app.port).has_table(table_name=table_name) + assert not has + + def test_create_table(self, started_app): + table_name = inspect.currentframe().f_code.co_name + dimension = 128 + table_schema = dict(table_name=table_name, + index_file_size=100, + metric_type=MetricType.L2, + dimension=dimension) + + ServiceHandler._create_table = mock.MagicMock(return_value=OK) + status = self.client(started_app.port).create_table(table_schema) + assert status.OK() + + Parser.parse_proto_TableSchema = mock.MagicMock(return_value=(BAD, + None)) + status = self.client(started_app.port).create_table(table_schema) + assert not status.OK() + + def random_data(self, n, dimension): + return [[random.random() for _ in range(dimension)] for _ in range(n)] + + def test_search(self, started_app): + table_name = inspect.currentframe().f_code.co_name + to_index_cnt = random.randint(10, 20) + table = TablesFactory(table_id=table_name, state=Tables.NORMAL) + to_index_files = TableFilesFactory.create_batch( + to_index_cnt, table=table, file_type=TableFiles.FILE_TYPE_TO_INDEX) + topk = random.randint(5, 10) + nq = random.randint(5, 10) + param = { + 'table_name': table_name, + 'query_records': self.random_data(nq, table.dimension), + 'top_k': topk, + 'nprobe': 2049 + } + + result = [ + milvus_pb2.TopKQueryResult(query_result_arrays=[ + milvus_pb2.QueryResult(id=i, distance=random.random()) + for i in range(topk) + ]) for i in range(nq) + ] + + mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status( + error_code=status_pb2.SUCCESS, reason="Success"), + topk_query_result=result) + + table_schema = TableSchema(table_name=table_name, + index_file_size=table.index_file_size, + metric_type=table.metric_type, + dimension=table.dimension) + + status, _ = self.client(started_app.port).search_vectors(**param) + assert status.code == Status.ILLEGAL_ARGUMENT + + param['nprobe'] = 2048 + Milvus.describe_table = mock.MagicMock(return_value=(BAD, + table_schema)) + status, ret = self.client(started_app.port).search_vectors(**param) + assert status.code == Status.TABLE_NOT_EXISTS + + Milvus.describe_table = mock.MagicMock(return_value=(OK, table_schema)) + Milvus.search_vectors_in_files = mock.MagicMock( + return_value=mock_results) + + status, ret = self.client(started_app.port).search_vectors(**param) + assert status.OK() + assert len(ret) == nq diff --git a/tracing/__init__.py b/tracing/__init__.py index 5014309a52..a1974e2204 100644 --- a/tracing/__init__.py +++ b/tracing/__init__.py @@ -1,6 +1,13 @@ +from contextlib import contextmanager + def empty_server_interceptor_decorator(target_server, interceptor): return target_server +@contextmanager +def EmptySpan(*args, **kwargs): + yield None + return + class Tracer: def __init__(self, tracer=None, @@ -13,11 +20,17 @@ class Tracer: def decorate(self, server): return self.server_decorator(server, self.interceptor) + @property + def empty(self): + return self.tracer is None + def close(self): self.tracer and self.tracer.close() def start_span(self, operation_name=None, child_of=None, references=None, tags=None, start_time=None, ignore_active_span=False): + if self.empty: + return EmptySpan() return self.tracer.start_span(operation_name, child_of, references, tags, start_time, ignore_active_span)