diff --git a/mishards/__init__.py b/mishards/__init__.py index c1cea84861..76f3168b51 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -2,22 +2,29 @@ from mishards import settings from mishards.db_base import DB db = DB() -db.init_db(uri=settings.SQLALCHEMY_DATABASE_URI, echo=settings.SQL_ECHO) - -from mishards.connections import ConnectionMgr -connect_mgr = ConnectionMgr() - -from sd import ProviderManager - -sd_proiver_class = ProviderManager.get_provider(settings.SD_PROVIDER) -discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr) - -from tracing.factory import TracerFactory -from grpc_utils import GrpcSpanDecorator -tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig, - span_decorator=GrpcSpanDecorator()) from mishards.server import Server -grpc_server = Server(conn_mgr=connect_mgr, tracer=tracer) +grpc_server = Server() -from mishards import exception_handlers +def create_app(testing_config=None): + config = testing_config if testing_config else settings.DefaultConfig + db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO) + + from mishards.connections import ConnectionMgr + connect_mgr = ConnectionMgr() + + from sd import ProviderManager + + sd_proiver_class = ProviderManager.get_provider(settings.SD_PROVIDER) + discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr) + + from tracing.factory import TracerFactory + from grpc_utils import GrpcSpanDecorator + tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig, + span_decorator=GrpcSpanDecorator()) + + grpc_server.init_app(conn_mgr=connect_mgr, tracer=tracer, discover=discover) + + from mishards import exception_handlers + + return grpc_server diff --git a/mishards/main.py b/mishards/main.py index 7fac55dfa2..9197fbf598 100644 --- a/mishards/main.py +++ b/mishards/main.py @@ -2,10 +2,10 @@ import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from mishards import ( - settings, - grpc_server as server) + settings, create_app) def main(): + server = create_app() server.run(port=settings.SERVER_PORT) return 0 diff --git a/mishards/server.py b/mishards/server.py index 1f72a8812d..0ca4a8f866 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -12,20 +12,23 @@ from jaeger_client import Config from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server from mishards.grpc_utils import is_grpc_method from mishards.service_handler import ServiceHandler -from mishards import settings, discover +from mishards import settings logger = logging.getLogger(__name__) class Server: - def __init__(self, conn_mgr, tracer, port=19530, max_workers=10, **kwargs): + def __init__(self): self.pre_run_handlers = set() self.grpc_methods = set() self.error_handlers = {} self.exit_flag = False + + def init_app(self, conn_mgr, tracer, discover, port=19530, max_workers=10, **kwargs): self.port = int(port) self.conn_mgr = conn_mgr self.tracer = tracer + self.discover = discover self.server_impl = grpc.server( thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), @@ -73,7 +76,7 @@ class Server: def on_pre_run(self): for handler in self.pre_run_handlers: handler() - discover.start() + self.discover.start() def start(self, port=None): handler_class = self.decorate_handler(ServiceHandler) diff --git a/mishards/settings.py b/mishards/settings.py index 4a70d44561..b42cb791f6 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -16,9 +16,6 @@ TIMEZONE = env.str('TIMEZONE', 'UTC') from utils.logger_helper import config config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) -SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') -SQL_ECHO = env.bool('SQL_ECHO', False) - TIMEOUT = env.int('TIMEOUT', 60) MAX_RETRY = env.int('MAX_RETRY', 3) SEARCH_WORKER_SIZE = env.int('SEARCH_WORKER_SIZE', 10) @@ -63,6 +60,15 @@ class TracingConfig: 'logging': env.bool('TRACING_LOGGING', True) } +class DefaultConfig: + SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI') + SQL_ECHO = env.bool('SQL_ECHO', False) + +# class TestingConfig(DefaultConfig): +# SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_TEST_URI') +# SQL_ECHO = env.bool('SQL_TEST_ECHO', False) + + if __name__ == '__main__': import logging logger = logging.getLogger(__name__)