diff --git a/mishards/__init__.py b/mishards/__init__.py index 8682b6eba6..b351986cba 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -1,4 +1,6 @@ +import logging from mishards import settings +logger = logging.getLogger() from mishards.db_base import DB db = DB() @@ -7,9 +9,6 @@ from mishards.server import Server grpc_server = Server() def create_app(testing_config=None): - import logging - logger = logging.getLogger() - config = testing_config if testing_config else settings.DefaultConfig db.init_db(uri=config.SQLALCHEMY_DATABASE_URI, echo=config.SQL_ECHO) logger.info(db) @@ -23,7 +22,7 @@ def create_app(testing_config=None): discover = sd_proiver_class(settings=settings.SD_PROVIDER_SETTINGS, conn_mgr=connect_mgr) from tracing.factory import TracerFactory - from grpc_utils import GrpcSpanDecorator + from mishards.grpc_utils import GrpcSpanDecorator tracer = TracerFactory.new_tracer(settings.TRACING_TYPE, settings.TracingConfig, span_decorator=GrpcSpanDecorator()) diff --git a/mishards/db_base.py b/mishards/db_base.py index 1006f21f55..b1492aa8f5 100644 --- a/mishards/db_base.py +++ b/mishards/db_base.py @@ -3,14 +3,23 @@ from sqlalchemy import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy.orm.session import Session as SessionBase logger = logging.getLogger(__name__) + +class LocalSession(SessionBase): + def __init__(self, db, autocommit=False, autoflush=True, **options): + self.db = db + bind = options.pop('bind', None) or db.engine + SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, **options) + class DB: Model = declarative_base() def __init__(self, uri=None, echo=False): self.echo = echo uri and self.init_db(uri, echo) + self.session_factory = scoped_session(sessionmaker(class_=LocalSession, db=self)) def init_db(self, uri, echo=False): url = make_url(uri) @@ -22,8 +31,6 @@ class DB: echo=echo, max_overflow=0) self.uri = uri - self.session = sessionmaker() - self.session.configure(bind=self.engine) self.url = url def __str__(self): @@ -31,7 +38,7 @@ class DB: @property def Session(self): - return self.session() + return self.session_factory() def drop_all(self): self.Model.metadata.drop_all(self.engine) diff --git a/mishards/factories.py b/mishards/factories.py index 5bd059654a..26e9ab2619 100644 --- a/mishards/factories.py +++ b/mishards/factories.py @@ -19,7 +19,7 @@ factory.Faker.add_provider(FakerProvider) class TablesFactory(SQLAlchemyModelFactory): class Meta: model = Tables - sqlalchemy_session = db.Session + sqlalchemy_session = db.session_factory sqlalchemy_session_persistence = 'commit' id = factory.Faker('random_number', digits=16, fix_len=True) @@ -35,7 +35,7 @@ class TablesFactory(SQLAlchemyModelFactory): class TableFilesFactory(SQLAlchemyModelFactory): class Meta: model = TableFiles - sqlalchemy_session = db.Session + sqlalchemy_session = db.session_factory sqlalchemy_session_persistence = 'commit' id = factory.Faker('random_number', digits=16, fix_len=True)