diff --git a/conftest.py b/conftest.py index d6c9f3acc7..1aba5b32cf 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import logging import pytest +import grpc from mishards import settings, db, create_app logger = logging.getLogger(__name__) @@ -14,3 +15,12 @@ def app(request): yield app db.drop_all() + +@pytest.fixture +def started_app(app): + app.on_pre_run() + app.start(app.port) + + yield app + + app.stop() diff --git a/mishards/__init__.py b/mishards/__init__.py index 47d8adb6e3..4bd77d8c60 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -24,7 +24,7 @@ def create_app(testing_config=None): from tracing.factory import TracerFactory from mishards.grpc_utils import GrpcSpanDecorator - tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig, + tracer = TracerFactory.new_tracer(config.TRACING_TYPE, settings.TracingConfig, span_decorator=GrpcSpanDecorator()) grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover) diff --git a/mishards/connections.py b/mishards/connections.py index 22263e9e7e..7db271381c 100644 --- a/mishards/connections.py +++ b/mishards/connections.py @@ -18,7 +18,7 @@ class Connection: self.conn = Milvus() self.error_handlers = [] if not error_handlers else error_handlers self.on_retry_func = kwargs.get('on_retry_func', None) - self._connect() + # self._connect() def __str__(self): return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) diff --git a/mishards/grpc_utils/test_grpc.py b/mishards/grpc_utils/test_grpc.py index 068ee391e7..d8511c8d6c 100644 --- a/mishards/grpc_utils/test_grpc.py +++ b/mishards/grpc_utils/test_grpc.py @@ -57,7 +57,6 @@ class TestGrpcUtils: span = TestSpan(context=None, tracer=TestTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) - logger.error(span.logs) assert len(span.logs) == 1 assert len(span.tags) == 1 @@ -66,7 +65,6 @@ class TestGrpcUtils: span = TestSpan(context=None, tracer=TestTracer()) span_deco = GrpcSpanDecorator() span_deco(span, rpc_info) - logger.error(span.logs) assert len(span.logs) == 0 assert len(span.tags) == 0 diff --git a/mishards/main.py b/mishards/main.py index 3f69484ee4..c0d142607b 100644 --- a/mishards/main.py +++ b/mishards/main.py @@ -6,8 +6,7 @@ from mishards import (settings, create_app) def main(): - server = create_app( - settings.TestingConfig if settings.TESTING else settings.DefaultConfig) + server = create_app(settings.DefaultConfig) server.run(port=settings.SERVER_PORT) return 0 diff --git a/mishards/server.py b/mishards/server.py index feb2176e86..dcaacd0fbc 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -39,7 +39,7 @@ class Server: self.register_pre_run_handler(self.pre_run_handler) def pre_run_handler(self): - woserver = settings.WOSERVER if not settings.TESTING else settings.TESTING_WOSERVER + woserver = settings.WOSERVER url = urlparse(woserver) ip = socket.gethostbyname(url.hostname) socket.inet_pton(socket.AF_INET, ip) diff --git a/mishards/settings.py b/mishards/settings.py index 1982a508e7..c9b62717d4 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -43,10 +43,7 @@ elif SD_PROVIDER == 'Static': SD_PROVIDER_SETTINGS = StaticProviderSettings( hosts=env.list('SD_STATIC_HOSTS', [])) -TESTING = env.bool('TESTING', False) -TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530') - -TRACING_TYPE = env.str('TRACING_TYPE', '') +# TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530') class TracingConfig: @@ -64,19 +61,24 @@ class TracingConfig: }, 'logging': env.bool('TRACING_LOGGING', True) } + DEFAULT_TRACING_CONFIG = { + 'sampler': { + 'type': env.str('TRACING_SAMPLER_TYPE', 'const'), + 'param': env.str('TRACING_SAMPLER_PARAM', "0"), + } + } class DefaultConfig: SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') SQL_ECHO = env.bool('SQL_ECHO', False) + TRACING_TYPE = env.str('TRACING_TYPE', '') -TESTING = env.bool('TESTING', False) -if TESTING: - - class TestingConfig(DefaultConfig): - SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI') - SQL_ECHO = env.bool('SQL_TEST_ECHO', False) +class TestingConfig(DefaultConfig): + SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI') + SQL_ECHO = env.bool('SQL_TEST_ECHO', False) + TRACING_TYPE = env.str('TRACING_TEST_TYPE', '') if __name__ == '__main__': diff --git a/mishards/test_connections.py b/mishards/test_connections.py index f1c54f0c61..819d2e03da 100644 --- a/mishards/test_connections.py +++ b/mishards/test_connections.py @@ -1,6 +1,8 @@ import logging import pytest +import mock +from milvus import Milvus from mishards.connections import (ConnectionMgr, Connection) from mishards import exceptions @@ -27,6 +29,12 @@ class TestConnection: mgr.register('WOSERVER', 'xxxx') assert len(mgr.conn_names) == 0 + assert not mgr.conn('XXXX', None) + with pytest.raises(exceptions.ConnectionNotFoundError): + mgr.conn('XXXX', None, True) + + mgr.conn('WOSERVER', None) + def test_connection(self): class Conn: def __init__(self, state): @@ -37,6 +45,7 @@ class TestConnection: def connected(self): return self.state + FAIL_CONN = Conn(False) PASS_CONN = Conn(True) @@ -58,7 +67,9 @@ class TestConnection: max_retry = 3 RetryObj = Retry() - c = Connection('client', uri='', + + c = Connection('client', + uri='xx', max_retry=max_retry, on_retry_func=RetryObj) c.conn = FAIL_CONN @@ -75,3 +86,16 @@ class TestConnection: this_connect() assert ff.executed assert RetryObj.times == 0 + + this_connect = c.connect(func=None) + with pytest.raises(TypeError): + this_connect() + + errors = [] + + def error_handler(err): + errors.append(err) + + this_connect = c.connect(func=None, exception_handler=error_handler) + this_connect() + assert len(errors) == 1 diff --git a/requirements.txt b/requirements.txt index ea338d0723..133cfac8ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ SQLAlchemy==1.3.5 urllib3==1.25.3 jaeger-client>=3.4.0 grpcio-opentracing>=1.0 +mock==2.0.0 diff --git a/tracing/factory.py b/tracing/factory.py index 648dfa291e..0c14d9d536 100644 --- a/tracing/factory.py +++ b/tracing/factory.py @@ -12,13 +12,17 @@ logger = logging.getLogger(__name__) class TracerFactory: @classmethod def new_tracer(cls, tracer_type, tracer_config, span_decorator=None, **kwargs): + config = tracer_config.TRACING_CONFIG + service_name = tracer_config.TRACING_SERVICE_NAME + validate=tracer_config.TRACING_VALIDATE if not tracer_type: - return Tracer() + tracer_type = 'jaeger' + config = tracer_config.DEFAULT_TRACING_CONFIG if tracer_type.lower() == 'jaeger': - config = Config(config=tracer_config.TRACING_CONFIG, - service_name=tracer_config.TRACING_SERVICE_NAME, - validate=tracer_config.TRACING_VALIDATE + config = Config(config=config, + service_name=service_name, + validate=validate ) tracer = config.initialize_tracer()