diff --git a/mishards/grpc_utils/test_grpc.py b/mishards/grpc_utils/test_grpc.py index d8511c8d6c..314fccfe00 100644 --- a/mishards/grpc_utils/test_grpc.py +++ b/mishards/grpc_utils/test_grpc.py @@ -7,12 +7,12 @@ from milvus.grpc_gen import status_pb2, milvus_pb2 logger = logging.getLogger(__name__) -class TestTracer(opentracing.Tracer): +class FakeTracer(opentracing.Tracer): pass -class TestSpan(opentracing.Span): +class FakeSpan(opentracing.Span): def __init__(self, context, tracer, **kwargs): - super(TestSpan, self).__init__(tracer, context) + super(FakeSpan, self).__init__(tracer, context) self.reset() def set_tag(self, key, value): @@ -26,7 +26,7 @@ class TestSpan(opentracing.Span): self.logs = [] -class TestRpcInfo: +class FakeRpcInfo: def __init__(self, request, response): self.request = request self.response = response @@ -37,32 +37,32 @@ class TestGrpcUtils: request = 'request' OK = status_pb2.Status(error_code=status_pb2.SUCCESS, reason='Success') response = OK - rpc_info = TestRpcInfo(request=request, response=response) - span = TestSpan(context=None, tracer=TestTracer()) + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) assert len(span.logs) == 0 assert len(span.tags) == 0 response = milvus_pb2.BoolReply(status=OK, bool_reply=False) - rpc_info = TestRpcInfo(request=request, response=response) - span = TestSpan(context=None, tracer=TestTracer()) + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) assert len(span.logs) == 0 assert len(span.tags) == 0 response = 1 - rpc_info = TestRpcInfo(request=request, response=response) - span = TestSpan(context=None, tracer=TestTracer()) + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) assert len(span.logs) == 1 assert len(span.tags) == 1 response = 0 - rpc_info = TestRpcInfo(request=request, response=response) - span = TestSpan(context=None, tracer=TestTracer()) + rpc_info = FakeRpcInfo(request=request, response=response) + span = FakeSpan(context=None, tracer=FakeTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) assert len(span.logs) == 0 diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 9d851ecfcb..113ec3ca20 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -237,13 +237,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) + def _add_vectors(self, param, metadata=None): + return self.connection(metadata=metadata).add_vectors(None, None, insert_param=param) + @mark_grpc_method def Insert(self, request, context): logger.info('Insert') # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' - _status, _ids = self.connection(metadata={ - 'resp_class': milvus_pb2.VectorIds - }).add_vectors(None, None, insert_param=request) + _status, _ids = self._add_vectors(metadata={ + 'resp_class': milvus_pb2.VectorIds}, param=request) return milvus_pb2.VectorIds( status=status_pb2.Status(error_code=_status.code, reason=_status.message), vector_id_array=_ids @@ -305,6 +307,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def SearchInFiles(self, request, context): raise NotImplemented() + def _describe_table(self, table_name, metadata=None): + return self.connection(metadata=metadata).describe_table(table_name) + @mark_grpc_method def DescribeTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -319,7 +324,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): } logger.info('DescribeTable {}'.format(_table_name)) - _status, _table = self.connection(metadata=metadata).describe_table(_table_name) + _status, _table = self._describe_table(metadata=metadata, table_name=_table_name) if _status.OK(): return milvus_pb2.TableSchema( @@ -335,6 +340,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): status=status_pb2.Status(error_code=_status.code, reason=_status.message), ) + def _count_table(self, table_name, metadata=None): + return self.connection(metadata=metadata).get_table_row_count(table_name) + @mark_grpc_method def CountTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -351,12 +359,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): metadata = { 'resp_class': milvus_pb2.TableRowCount } - _status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name) + _status, _count = self._count_table(_table_name, metadata=metadata) return milvus_pb2.TableRowCount( status=status_pb2.Status(error_code=_status.code, reason=_status.message), table_row_count=_count if isinstance(_count, int) else -1) + + def _get_server_version(self, metadata=None): + return self.connection(metadata=metadata).server_version() + @mark_grpc_method def Cmd(self, request, context): _status, _cmd = Parser.parse_proto_Command(request) @@ -364,7 +376,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if not _status.OK(): return milvus_pb2.StringReply( - status_pb2.Status(error_code=_status.code, reason=_status.message) + status=status_pb2.Status(error_code=_status.code, reason=_status.message) ) metadata = { @@ -372,7 +384,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): } if _cmd == 'version': - _status, _reply = self.connection(metadata=metadata).server_version() + _status, _reply = self._get_server_version(metadata=metadata) else: _status, _reply = self.connection(metadata=metadata).server_status() @@ -381,19 +393,25 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): string_reply=_reply ) + def _show_tables(self): + return self.connection(metadata=metadata).show_tables() + @mark_grpc_method def ShowTables(self, request, context): logger.info('ShowTables') metadata = { 'resp_class': milvus_pb2.TableName } - _status, _results = self.connection(metadata=metadata).show_tables() + _status, _results = self._show_tables() return milvus_pb2.TableNameList( status=status_pb2.Status(error_code=_status.code, reason=_status.message), table_names=_results ) + def _delete_by_range(self, table_name, start_date, end_date): + return self.connection().delete_vectors_by_range(table_name, start_date, end_date) + @mark_grpc_method def DeleteByRange(self, request, context): _status, unpacks = \ @@ -405,9 +423,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _table_name, _start_date, _end_date = unpacks logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date)) - _status = self.connection().delete_vectors_by_range(_table_name, _start_date, _end_date) + _status = self._delete_by_range(_table_name, _start_date, _end_date) return status_pb2.Status(error_code=_status.code, reason=_status.message) + def _preload_table(self, table_name): + return self.connection().preload_table(table_name) + @mark_grpc_method def PreloadTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -416,9 +437,12 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) logger.info('PreloadTable {}'.format(_table_name)) - _status = self.connection().preload_table(_table_name) + _status = self._preload_table(_table_name) return status_pb2.Status(error_code=_status.code, reason=_status.message) + def _describe_index(self, table_name, metadata=None): + return self.connection(metadata=metadata).describe_index(table_name) + @mark_grpc_method def DescribeIndex(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -433,13 +457,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): } logger.info('DescribeIndex {}'.format(_table_name)) - _status, _index_param = self.connection(metadata=metadata).describe_index(_table_name) + _status, _index_param = self._describe_index(table_name=_table_name, metadata=metadata) _index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist) return milvus_pb2.IndexParam(status=status_pb2.Status(error_code=_status.code, reason=_status.message), table_name=_table_name, index=_index) + def _drop_index(self, table_name): + return self.connection().drop_index(table_name) + @mark_grpc_method def DropIndex(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -448,5 +475,5 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) logger.info('DropIndex {}'.format(_table_name)) - _status = self.connection().drop_index(_table_name) + _status = self._drop_index(_table_name) return status_pb2.Status(error_code=_status.code, reason=_status.message) diff --git a/tracing/factory.py b/tracing/factory.py index 0c14d9d536..61cd75fcd6 100644 --- a/tracing/factory.py +++ b/tracing/factory.py @@ -12,12 +12,14 @@ logger = logging.getLogger(__name__) class TracerFactory: @classmethod def new_tracer(cls, tracer_type, tracer_config, span_decorator=None, **kwargs): + if not tracer_type: + return Tracer() config = tracer_config.TRACING_CONFIG service_name = tracer_config.TRACING_SERVICE_NAME validate=tracer_config.TRACING_VALIDATE - if not tracer_type: - tracer_type = 'jaeger' - config = tracer_config.DEFAULT_TRACING_CONFIG + # if not tracer_type: + # tracer_type = 'jaeger' + # config = tracer_config.DEFAULT_TRACING_CONFIG if tracer_type.lower() == 'jaeger': config = Config(config=config,