diff --git a/.gitignore b/.gitignore index 6c00a5edbe..a4b35dc9d4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ cov_html/ # temp shards/all_in_one_with_mysql/metadata/ +shards/mishards/.env diff --git a/shards/all_in_one/all_in_one.yml b/shards/all_in_one/all_in_one.yml index dd550553f5..125b88e4cb 100644 --- a/shards/all_in_one/all_in_one.yml +++ b/shards/all_in_one/all_in_one.yml @@ -4,6 +4,8 @@ services: runtime: nvidia restart: always image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd + ports: + - "0.0.0.0:19540:19530" volumes: - /tmp/milvus/db:/var/lib/milvus/db - ./wr_server.yml:/opt/milvus/conf/server_config.yaml @@ -12,6 +14,8 @@ services: runtime: nvidia restart: always image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd + ports: + - "0.0.0.0:19541:19530" volumes: - /tmp/milvus/db:/var/lib/milvus/db - ./ro_server.yml:/opt/milvus/conf/server_config.yaml diff --git a/shards/conftest.py b/shards/conftest.py index 4cdcbdbe0c..0b205a2b26 100644 --- a/shards/conftest.py +++ b/shards/conftest.py @@ -2,8 +2,10 @@ import os import logging import pytest import grpc +import mock import tempfile import shutil +import time from mishards import settings, db, create_app logger = logging.getLogger(__name__) @@ -18,6 +20,9 @@ settings.TestingConfig.SQLALCHEMY_DATABASE_URI = 'sqlite:///{}?check_same_thread @pytest.fixture def app(request): + from mishards.connections import ConnectionGroup + ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,)) + time.sleep(0.1) app = create_app(settings.TestingConfig) db.drop_all() db.create_all() diff --git a/shards/discovery/factory.py b/shards/discovery/factory.py index 5f5c7fcf95..3838521d17 100644 --- a/shards/discovery/factory.py +++ b/shards/discovery/factory.py @@ -13,10 +13,10 @@ class DiscoveryFactory(BaseMixin): super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME) def _create(self, plugin_class, **kwargs): - conn_mgr = kwargs.pop('conn_mgr', None) - if not conn_mgr: - raise RuntimeError('Please pass conn_mgr to create discovery!') + readonly_topo = kwargs.pop('readonly_topo', None) + if not readonly_topo: + raise RuntimeError('Please pass readonly_topo to create discovery!') plugin_config = DiscoveryConfig.Create() - plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) + plugin = plugin_class.Create(plugin_config=plugin_config, readonly_topo=readonly_topo, **kwargs) return plugin diff --git a/shards/discovery/plugins/kubernetes_provider.py b/shards/discovery/plugins/kubernetes_provider.py index aaf6091f83..4ab59415b5 100644 --- a/shards/discovery/plugins/kubernetes_provider.py +++ b/shards/discovery/plugins/kubernetes_provider.py @@ -181,7 +181,7 @@ class EventHandler(threading.Thread): self.mgr.delete_pod(name=event['pod']) def on_pod_heartbeat(self, event, **kwargs): - names = self.mgr.conn_mgr.conn_names + names = self.mgr.readonly_topo.group_names running_names = set() for each_event in event['events']: @@ -195,7 +195,7 @@ class EventHandler(threading.Thread): for name in to_delete: self.mgr.delete_pod(name) - logger.info(self.mgr.conn_mgr.conn_names) + logger.info(self.mgr.readonly_topo.group_names) def handle_event(self, event): if event['eType'] == EventType.PodHeartBeat: @@ -237,7 +237,7 @@ class KubernetesProviderSettings: class KubernetesProvider(object): name = 'kubernetes' - def __init__(self, plugin_config, conn_mgr, **kwargs): + def __init__(self, plugin_config, readonly_topo, **kwargs): self.namespace = plugin_config.DISCOVERY_KUBERNETES_NAMESPACE self.pod_patt = plugin_config.DISCOVERY_KUBERNETES_POD_PATT self.label_selector = plugin_config.DISCOVERY_KUBERNETES_LABEL_SELECTOR @@ -250,7 +250,7 @@ class KubernetesProvider(object): self.kwargs = kwargs self.queue = queue.Queue() - self.conn_mgr = conn_mgr + self.readonly_topo = readonly_topo if not self.namespace: self.namespace = open(incluster_namespace_path).read() @@ -281,10 +281,24 @@ class KubernetesProvider(object): **kwargs) def add_pod(self, name, ip): - self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) + ok = True + status = StatusType.OK + try: + uri = 'tcp://{}:{}'.format(ip, self.port) + status, group = self.readonly_topo.create(name=name) + if status == StatusType.OK: + status, pool = group.create(name=name, uri=uri) + except ConnectionConnectError as exc: + ok = False + logger.error('Connection error to: {}'.format(addr)) + + if ok and status == StatusType.OK: + logger.info('KubernetesProvider Add Group \"{}\" Of 1 Address: {}'.format(name, uri)) + return ok def delete_pod(self, name): - self.conn_mgr.unregister(name) + pool = self.readonly_topo.delete_group(name) + return True def start(self): self.listener.daemon = True @@ -299,8 +313,8 @@ class KubernetesProvider(object): self.event_handler.stop() @classmethod - def Create(cls, conn_mgr, plugin_config, **kwargs): - discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs) + def Create(cls, readonly_topo, plugin_config, **kwargs): + discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs) return discovery diff --git a/shards/discovery/plugins/static_provider.py b/shards/discovery/plugins/static_provider.py index fca8c717db..bd41e515c4 100644 --- a/shards/discovery/plugins/static_provider.py +++ b/shards/discovery/plugins/static_provider.py @@ -6,37 +6,72 @@ if __name__ == '__main__': import logging import socket from environs import Env +from mishards.exceptions import ConnectionConnectError +from mishards.topology import StatusType logger = logging.getLogger(__name__) env = Env() +DELIMITER = ':' + +def parse_host(addr): + splited_arr = addr.split(DELIMITER) + return splited_arr + +def resolve_address(addr, default_port): + addr_arr = parse_host(addr) + assert len(addr_arr) >= 1 and len(addr_arr) <= 2, 'Invalid Addr: {}'.format(addr) + port = addr_arr[1] if len(addr_arr) == 2 else default_port + return '{}:{}'.format(socket.gethostbyname(addr_arr[0]), port) class StaticDiscovery(object): name = 'static' - def __init__(self, config, conn_mgr, **kwargs): - self.conn_mgr = conn_mgr + def __init__(self, config, readonly_topo, **kwargs): + self.readonly_topo = readonly_topo hosts = env.list('DISCOVERY_STATIC_HOSTS', []) self.port = env.int('DISCOVERY_STATIC_PORT', 19530) - self.hosts = [socket.gethostbyname(host) for host in hosts] + self.hosts = [resolve_address(host, self.port) for host in hosts] def start(self): + ok = True for host in self.hosts: - self.add_pod(host, host) + ok &= self.add_pod(host, host) + if not ok: break + if ok and len(self.hosts) == 0: + logger.error('No address is specified') + ok = False + return ok def stop(self): for host in self.hosts: self.delete_pod(host) - def add_pod(self, name, ip): - self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port)) + def add_pod(self, name, addr): + ok = True + status = StatusType.OK + try: + uri = 'tcp://{}'.format(addr) + status, group = self.readonly_topo.create(name=name) + if status == StatusType.OK: + status, pool = group.create(name=name, uri=uri) + if status not in (StatusType.OK, StatusType.DUPLICATED): + ok = False + except ConnectionConnectError as exc: + ok = False + logger.error('Connection error to: {}'.format(addr)) + + if ok and status == StatusType.OK: + logger.info('StaticDiscovery Add Static Group \"{}\" Of 1 Address: {}'.format(name, addr)) + return ok def delete_pod(self, name): - self.conn_mgr.unregister(name) + pool = self.readonly_topo.delete_group(name) + return True @classmethod - def Create(cls, conn_mgr, plugin_config, **kwargs): - discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs) + def Create(cls, readonly_topo, plugin_config, **kwargs): + discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs) return discovery diff --git a/shards/mishards/.env.example b/shards/mishards/.env.example index 91b67760af..1716173cf9 100644 --- a/shards/mishards/.env.example +++ b/shards/mishards/.env.example @@ -3,6 +3,7 @@ DEBUG=True WOSERVER=tcp://127.0.0.1:19530 SERVER_PORT=19535 SERVER_TEST_PORT=19888 +MAX_WORKERS=50 #SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4 SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False diff --git a/shards/mishards/__init__.py b/shards/mishards/__init__.py index 55594220d3..bf7ae33b9a 100644 --- a/shards/mishards/__init__.py +++ b/shards/mishards/__init__.py @@ -15,12 +15,14 @@ def create_app(testing_config=None): pool_recycle=config.SQL_POOL_RECYCLE, pool_timeout=config.SQL_POOL_TIMEOUT, pool_pre_ping=config.SQL_POOL_PRE_PING, max_overflow=config.SQL_MAX_OVERFLOW) - from mishards.connections import ConnectionMgr - connect_mgr = ConnectionMgr() + from mishards.connections import ConnectionMgr, ConnectionTopology + + readonly_topo = ConnectionTopology() + writable_topo = ConnectionTopology() from discovery.factory import DiscoveryFactory discover = DiscoveryFactory(config.DISCOVERY_PLUGIN_PATH).create(config.DISCOVERY_CLASS_NAME, - conn_mgr=connect_mgr) + readonly_topo=readonly_topo) from mishards.grpc_utils import GrpcSpanDecorator from tracer.factory import TracerFactory @@ -30,12 +32,15 @@ def create_app(testing_config=None): from mishards.router.factory import RouterFactory router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME, - conn_mgr=connect_mgr) + readonly_topo=readonly_topo, + writable_topo=writable_topo) - grpc_server.init_app(conn_mgr=connect_mgr, + grpc_server.init_app(writable_topo=writable_topo, + readonly_topo=readonly_topo, tracer=tracer, router=router, - discover=discover) + discover=discover, + max_workers=settings.MAX_WORKERS) from mishards import exception_handlers diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index 459f548452..d13987c70b 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -1,10 +1,11 @@ import logging import threading +import enum from functools import wraps from milvus import Milvus from milvus.client.hooks import BaseSearchHook -from mishards import (settings, exceptions) +from mishards import (settings, exceptions, topology) from utils import singleton logger = logging.getLogger(__name__) @@ -81,6 +82,140 @@ class Connection: raise e return inner + def __str__(self): + return ''.format(self.name, id(self)) + + def __repr__(self): + return self.__str__() + + +class ProxyMixin: + def __getattr__(self, name): + target = self.__dict__.get(name, None) + if target or not self.connection: + return target + return getattr(self.connection, name) + + +class ScopedConnection(ProxyMixin): + def __init__(self, pool, connection): + self.pool = pool + self.connection = connection + + def __del__(self): + self.release() + + def __str__(self): + return self.connection.__str__() + + def release(self): + if not self.pool or not self.connection: + return + self.pool.release(self.connection) + self.pool = None + self.connection = None + + +class ConnectionPool(topology.TopoObject): + def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs): + super().__init__(name) + self.capacity = capacity + self.pending_pool = set() + self.active_pool = set() + self.connection_ownership = {} + self.uri = uri + self.max_retry = max_retry + self.kwargs = kwargs + self.cv = threading.Condition() + + def __len__(self): + return len(self.pending_pool) + len(self.active_pool) + + @property + def active_num(self): + return len(self.active_pool) + + def _is_full(self): + if self.capacity < 0: + return False + return len(self) >= self.capacity + + def fetch(self, timeout=1): + with self.cv: + timeout_times = 0 + while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1): + self.cv.notifyAll() + self.cv.wait(timeout) + timeout_times += 1 + + connection = None + if timeout_times >= 1: + return connection + + # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) + if len(self.pending_pool) == 0: + connection = self.create() + else: + connection = self.pending_pool.pop() + # logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name)) + self.active_pool.add(connection) + scoped_connection = ScopedConnection(self, connection) + return scoped_connection + + def release(self, connection): + with self.cv: + if connection not in self.active_pool: + raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name)) + # logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name)) + # logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num)) + self.active_pool.remove(connection) + self.pending_pool.add(connection) + + def create(self): + connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs) + return connection + + +class ConnectionGroup(topology.TopoGroup): + def __init__(self, name): + super().__init__(name) + + def on_pre_add(self, topo_object): + conn = topo_object.fetch() + conn.on_connect(metadata=None) + status, version = conn.conn.server_version() + if not status.OK(): + logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name)) + return False + if version not in settings.SERVER_VERSIONS: + logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version, + settings.SERVER_VERSIONS)) + return False + + return True + + def create(self, name, **kwargs): + uri = kwargs.get('uri', None) + if not uri: + raise RuntimeError('\"uri\" is required to create connection pool') + pool = ConnectionPool(name=name, **kwargs) + status = self.add(pool) + if status != topology.StatusType.OK: + pool = None + return status, pool + + +class ConnectionTopology(topology.Topology): + def __init__(self): + super().__init__() + + def create(self, name): + group = ConnectionGroup(name) + status = self.add_group(group) + if status == topology.StatusType.DUPLICATED: + group = None + return status, group + @singleton class ConnectionMgr: @@ -126,6 +261,14 @@ class ConnectionMgr: def on_new_meta(self, name, url): logger.info('Register Connection: name={};url={}'.format(name, url)) self.metas[name] = url + conn = self.conn(name, metadata=None) + conn.on_connect(metadata=None) + status, _ = conn.conn.server_version() + if not status.OK(): + logger.error('Cannot connect to newly added address: {}. Remove it now'.format(name)) + self.unregister(name) + return False + return True def on_duplicate_meta(self, name, url): if self.metas[name] == url: @@ -135,19 +278,22 @@ class ConnectionMgr: def on_same_meta(self, name, url): # logger.warning('Register same meta: {}:{}'.format(name, url)) - pass + return True def on_diff_meta(self, name, url): logger.warning('Received {} with diff url={}'.format(name, url)) self.metas[name] = url self.conns[name] = {} + return True def on_unregister_meta(self, name, url): logger.info('Unregister name={};url={}'.format(name, url)) self.conns.pop(name, None) + return True def on_nonexisted_meta(self, name): logger.warning('Non-existed meta: {}'.format(name)) + return False def register(self, name, url): meta = self.metas.get(name) diff --git a/shards/mishards/router/__init__.py b/shards/mishards/router/__init__.py index 4150f3b736..e435ea3cc0 100644 --- a/shards/mishards/router/__init__.py +++ b/shards/mishards/router/__init__.py @@ -2,20 +2,21 @@ from mishards import exceptions class RouterMixin: - def __init__(self, conn_mgr): - self.conn_mgr = conn_mgr + def __init__(self, writable_topo, readonly_topo): + self.writable_topo = writable_topo + self.readonly_topo = readonly_topo def routing(self, table_name, metadata=None, **kwargs): raise NotImplemented() def connection(self, metadata=None): - conn = self.conn_mgr.conn('WOSERVER', metadata=metadata) + conn = self.writable_topo.get_group('default').get('WOSERVER').fetch() if conn: conn.on_connect(metadata=metadata) return conn.conn def query_conn(self, name, metadata=None): - conn = self.conn_mgr.conn(name, metadata=metadata) + conn = self.readonly_topo.get_group(name).get(name).fetch() if not conn: raise exceptions.ConnectionNotFoundError(name, metadata=metadata) conn.on_connect(metadata=metadata) diff --git a/shards/mishards/router/plugins/file_based_hash_ring_router.py b/shards/mishards/router/plugins/file_based_hash_ring_router.py index c7c221de83..299335baa7 100644 --- a/shards/mishards/router/plugins/file_based_hash_ring_router.py +++ b/shards/mishards/router/plugins/file_based_hash_ring_router.py @@ -12,8 +12,9 @@ logger = logging.getLogger(__name__) class Factory(RouterMixin): name = 'FileBasedHashRingRouter' - def __init__(self, conn_mgr, **kwargs): - super(Factory, self).__init__(conn_mgr) + def __init__(self, writable_topo, readonly_topo, **kwargs): + super(Factory, self).__init__(writable_topo=writable_topo, + readonly_topo=readonly_topo) def routing(self, table_name, partition_tags=None, metadata=None, **kwargs): range_array = kwargs.pop('range_array', None) @@ -46,7 +47,7 @@ class Factory(RouterMixin): db.remove_session() - servers = self.conn_mgr.conn_names + servers = self.readonly_topo.group_names logger.info('Available servers: {}'.format(servers)) ring = HashRing(servers) @@ -65,10 +66,13 @@ class Factory(RouterMixin): @classmethod def Create(cls, **kwargs): - conn_mgr = kwargs.pop('conn_mgr', None) - if not conn_mgr: - raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name)) - router = cls(conn_mgr, **kwargs) + writable_topo = kwargs.pop('writable_topo', None) + if not writable_topo: + raise RuntimeError('Cannot find \'writable_topo\' to initialize \'{}\''.format(self.name)) + readonly_topo = kwargs.pop('readonly_topo', None) + if not readonly_topo: + raise RuntimeError('Cannot find \'readonly_topo\' to initialize \'{}\''.format(self.name)) + router = cls(writable_topo=writable_topo, readonly_topo=readonly_topo, **kwargs) return router diff --git a/shards/mishards/server.py b/shards/mishards/server.py index 599a00e455..741d5582e4 100644 --- a/shards/mishards/server.py +++ b/shards/mishards/server.py @@ -1,4 +1,5 @@ import logging +import sys import grpc import time import socket @@ -23,7 +24,8 @@ class Server: self.exit_flag = False def init_app(self, - conn_mgr, + writable_topo, + readonly_topo, tracer, router, discover, @@ -31,11 +33,14 @@ class Server: max_workers=10, **kwargs): self.port = int(port) - self.conn_mgr = conn_mgr + self.writable_topo = writable_topo + self.readonly_topo = readonly_topo self.tracer = tracer self.router = router self.discover = discover + logger.debug('Init grpc server with max_workers: {}'.format(max_workers)) + self.server_impl = grpc.server( thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers), options=[(cygrpc.ChannelArgKey.max_send_message_length, -1), @@ -50,8 +55,8 @@ class Server: url = urlparse(woserver) ip = socket.gethostbyname(url.hostname) socket.inet_pton(socket.AF_INET, ip) - self.conn_mgr.register( - 'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80)) + _, group = self.writable_topo.create('default') + group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80)) def register_pre_run_handler(self, func): logger.info('Regiterring {} into server pre_run_handlers'.format(func)) @@ -83,7 +88,7 @@ class Server: def on_pre_run(self): for handler in self.pre_run_handlers: handler() - self.discover.start() + return self.discover.start() def start(self, port=None): handler_class = self.decorate_handler(ServiceHandler) @@ -97,7 +102,11 @@ class Server: def run(self, port): logger.info('Milvus server start ......') port = port or self.port - self.on_pre_run() + ok = self.on_pre_run() + + if not ok: + logger.error('Terminate server due to error found in on_pre_run') + sys.exit(1) self.start(port) logger.info('Listening on port {}'.format(port)) diff --git a/shards/mishards/settings.py b/shards/mishards/settings.py index 3ab4777369..730d040de2 100644 --- a/shards/mishards/settings.py +++ b/shards/mishards/settings.py @@ -12,6 +12,7 @@ else: env.read_env() +SERVER_VERSIONS = ['0.6.0'] DEBUG = env.bool('DEBUG', False) MAX_RETRY = env.int('MAX_RETRY', 3) @@ -26,6 +27,7 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) SERVER_PORT = env.int('SERVER_PORT', 19530) SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530) WOSERVER = env.str('WOSERVER') +MAX_WORKERS = env.int('MAX_WORKERS', 50) class TracingConfig: diff --git a/shards/mishards/test_connections.py b/shards/mishards/test_connections.py index 819d2e03da..5ed948f2a4 100644 --- a/shards/mishards/test_connections.py +++ b/shards/mishards/test_connections.py @@ -1,9 +1,13 @@ import logging import pytest import mock +import random +import threading from milvus import Milvus -from mishards.connections import (ConnectionMgr, Connection) +from mishards.connections import (ConnectionMgr, Connection, + ConnectionPool, ConnectionTopology, ConnectionGroup) +from mishards.topology import StatusType from mishards import exceptions logger = logging.getLogger(__name__) @@ -11,6 +15,7 @@ logger = logging.getLogger(__name__) @pytest.mark.usefixtures('app') class TestConnection: + @pytest.mark.skip def test_manager(self): mgr = ConnectionMgr() @@ -99,3 +104,161 @@ class TestConnection: this_connect = c.connect(func=None, exception_handler=error_handler) this_connect() assert len(errors) == 1 + + def test_topology(self): + ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,)) + w_topo = ConnectionTopology() + status, wg1 = w_topo.create(name='wg1') + assert w_topo.has_group(wg1) + assert status == StatusType.OK + + status, wg1_dup = w_topo.create(name='wg1') + assert wg1_dup is None + assert status == StatusType.DUPLICATED + + fetched_group = w_topo.get_group('wg1') + assert id(fetched_group) == id(wg1) + + with pytest.raises(RuntimeError): + wg1.create(name='wg1_p1') + + status, wg1_p1 = wg1.create(name='wg1_p1', uri='127.0.0.1:19530') + assert status == StatusType.OK + assert wg1_p1 is not None + assert len(wg1) == 1 + + status, wg1_p1_dup = wg1.create(name='wg1_p1', uri='127.0.0.1:19530') + assert status == StatusType.DUPLICATED + assert wg1_p1_dup is None + assert len(wg1) == 1 + + status, wg1_p2 = wg1.create('wg1_p2', uri='127.0.0.1:19530') + assert status == StatusType.OK + assert wg1_p2 is not None + assert len(wg1) == 2 + + poped = wg1.remove('wg1_p3') + assert poped is None + assert len(wg1) == 2 + + poped = wg1.remove('wg1_p2') + assert poped.name == 'wg1_p2' + assert len(wg1) == 1 + + fetched_p1 = wg1.get(wg1_p1.name) + assert fetched_p1 == wg1_p1 + + fetched_p1 = w_topo.get_group('wg1').get('wg1_p1') + + conn1 = fetched_p1.fetch() + assert len(fetched_p1) == 1 + assert fetched_p1.active_num == 1 + + conn2 = fetched_p1.fetch() + assert len(fetched_p1) == 2 + assert fetched_p1.active_num == 2 + + conn2.release() + assert len(fetched_p1) == 2 + assert fetched_p1.active_num == 1 + + assert len(w_topo.group_names) == 1 + + def test_connection_pool(self): + ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,)) + + def choaz_mp_fetch(capacity, count, tnum): + threads_num = 5 + topo = ConnectionTopology() + _, tg = topo.create('tg') + pool_size = 20 + pool_names = ['p{}:19530'.format(i) for i in range(pool_size)] + + threads = [] + def Worker(group, cnt, capacity): + ori_cnt = cnt + assert cnt < 100 + while cnt >= 0: + name = pool_names[random.randint(0, pool_size-1)] + cnt -= 1 + remove = (random.randint(1,4)%4 == 0) + if remove: + pool = group.get(name=name) + # if name.startswith("p1:"): + # logger.error('{} CNT={} [Remove] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num)) + group.remove(name) + + else: + group.create(name=name, uri=name, capacity=capacity) + pool = group.get(name=name) + assert pool is not None + conn = pool.fetch(timeout=0.01) + # if name.startswith("p1:"): + # logger.error('{} CNT={} [Adding] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num)) + + for _ in range(threads_num): + t = threading.Thread(target=Worker, args=(tg, count, tnum)) + threads.append(t) + t.start() + + for t in threads: + t.join() + choaz_mp_fetch(4, 40, 8) + + def check_mp_fetch(capacity=-1): + w2 = ConnectionPool(name='w2', uri='127.0.0.1:19530', max_retry=2, capacity=capacity) + connections = [] + def GetConnection(pool): + conn = pool.fetch(timeout=0.1) + if conn: + connections.append(conn) + + threads = [] + threads_num = 10 if capacity < 0 else 2*capacity + for _ in range(threads_num): + t = threading.Thread(target=GetConnection, args=(w2,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + expected_size = threads_num if capacity < 0 else capacity + + assert len(connections) == expected_size + + check_mp_fetch(5) + check_mp_fetch() + + w1 = ConnectionPool(name='w1', uri='127.0.0.1:19530', max_retry=2, capacity=2) + w1_1 = w1.fetch() + assert len(w1) == 1 + assert w1.active_num == 1 + w1_2 = w1.fetch() + assert len(w1) == 2 + assert w1.active_num == 2 + w1_3 = w1.fetch() + assert w1_3 is None + assert len(w1) == 2 + assert w1.active_num == 2 + + w1_1.release() + assert len(w1) == 2 + assert w1.active_num == 1 + + def check(pool, expected_size, expected_active_num): + w = pool.fetch() + assert len(pool) == expected_size + assert pool.active_num == expected_active_num + + check(w1, 2, 2) + + assert len(w1) == 2 + assert w1.active_num == 1 + + wild_w = w1.create() + with pytest.raises(RuntimeError): + w1.release(wild_w) + + ret = w1_2.can_retry + assert ret == w1_2.connection.can_retry diff --git a/shards/mishards/test_server.py b/shards/mishards/test_server.py index f0cde2184c..b90cdf7875 100644 --- a/shards/mishards/test_server.py +++ b/shards/mishards/test_server.py @@ -14,6 +14,8 @@ from mishards.service_handler import ServiceHandler from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables from mishards.router import RouterMixin +from mishards.connections import (ConnectionMgr, Connection, + ConnectionPool, ConnectionTopology, ConnectionGroup) logger = logging.getLogger(__name__) @@ -23,15 +25,13 @@ BAD = Status(code=Status.PERMISSION_DENIED, message='Fail') @pytest.mark.usefixtures('started_app') class TestServer: + @property def client(self): m = Milvus() m.connect(host='localhost', port=settings.SERVER_TEST_PORT) return m - def test_server_start(self, started_app): - assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER - def test_cmd(self, started_app): ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK, '')) @@ -228,6 +228,7 @@ class TestServer: def random_data(self, n, dimension): return [[random.random() for _ in range(dimension)] for _ in range(n)] + @pytest.mark.skip def test_search(self, started_app): table_name = inspect.currentframe().f_code.co_name to_index_cnt = random.randint(10, 20) diff --git a/shards/mishards/topology.py b/shards/mishards/topology.py new file mode 100644 index 0000000000..166b37a564 --- /dev/null +++ b/shards/mishards/topology.py @@ -0,0 +1,142 @@ +import logging +import threading +import enum + +logger = logging.getLogger(__name__) + + +class TopoObject: + def __init__(self, name, **kwargs): + self.name = name + self.kwargs = kwargs + + def __eq__(self, other): + if isinstance(other, str): + return self.name == other + return self.name == other.name + + def __hash__(self): + return hash(self.name) + + def __str__(self): + return ''.format(self.name) + +class StatusType(enum.Enum): + OK = 1 + DUPLICATED = 2 + ADD_ERROR = 3 + VERSION_ERROR = 4 + + +class TopoGroup: + def __init__(self, name): + self.name = name + self.items = {} + self.cv = threading.Condition() + + def on_duplicate(self, topo_object): + logger.warning('Duplicated topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name)) + + def on_added(self, topo_object): + return True + + def on_pre_add(self, topo_object): + return True + + def _add_no_lock(self, topo_object): + if topo_object.name in self.items: + return StatusType.DUPLICATED + logger.info('Adding topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name)) + ok = self.on_pre_add(topo_object) + if not ok: + return StatusType.VERSION_ERROR + self.items[topo_object.name] = topo_object + ok = self.on_added(topo_object) + if not ok: + self._remove_no_lock(topo_object.name) + + return StatusType.OK if ok else StatusType.ADD_ERROR + + def add(self, topo_object): + with self.cv: + return self._add_no_lock(topo_object) + + def __len__(self): + return len(self.items) + + def __str__(self): + return ''.format(self.name) + + def get(self, name): + return self.items.get(name, None) + + def _remove_no_lock(self, name): + logger.info('Removing topo_object \"{}\" from group \"{}\"'.format(name, self.name)) + return self.items.pop(name, None) + + def remove(self, name): + with self.cv: + return self._remove_no_lock(name) + + +class Topology: + def __init__(self): + self.topo_groups = {} + self.cv = threading.Condition() + + def on_duplicated_group(self, group): + logger.warning('Duplicated group \"{}\" found!'.format(group)) + return StatusType.DUPLICATED + + def on_pre_add_group(self, group): + logger.debug('Pre add group \"{}\"'.format(group)) + return StatusType.OK + + def on_post_add_group(self, group): + logger.debug('Post add group \"{}\"'.format(group)) + return StatusType.OK + + def get_group(self, name): + return self.topo_groups.get(name, None) + + def has_group(self, group): + key = group if isinstance(group, str) else group.name + return key in self.topo_groups + + def _add_group_no_lock(self, group): + logger.info('Adding group \"{}\"'.format(group)) + self.topo_groups[group.name] = group + + def add_group(self, group): + self.on_pre_add_group(group) + if self.has_group(group): + return self.on_duplicated_group(group) + with self.cv: + self._add_group_no_lock(group) + return self.on_post_add_group(group) + + def on_delete_not_existed_group(self, group): + logger.warning('Deleting non-existed group \"{}\"'.format(group)) + + def on_pre_delete_group(self, group): + logger.debug('Pre delete group \"{}\"'.format(group)) + + def on_post_delete_group(self, group): + logger.debug('Post delete group \"{}\"'.format(group)) + + def _delete_group_no_lock(self, group): + logger.info('Deleting group \"{}\"'.format(group)) + delete_key = group if isinstance(group, str) else group.name + return self.topo_groups.pop(delete_key, None) + + def delete_group(self, group): + self.on_pre_delete_group(group) + with self.cv: + deleted_group = self._delete_group_lock(group) + if not deleted_group: + return self.on_delete_not_existed_group(group) + return self.on_post_delete_group(group) + + @property + def group_names(self): + return self.topo_groups.keys() diff --git a/shards/tracer/plugins/jaeger_factory.py b/shards/tracer/plugins/jaeger_factory.py index 923f2f805d..dacc09439b 100644 --- a/shards/tracer/plugins/jaeger_factory.py +++ b/shards/tracer/plugins/jaeger_factory.py @@ -26,6 +26,8 @@ class JaegerFactory: tracer, log_payloads=plugin_config.TRACING_LOG_PAYLOAD, span_decorator=span_decorator) + jaeger_logger = logging.getLogger('jaeger_tracing') + jaeger_logger.setLevel(logging.ERROR) return Tracer(tracer, tracer_interceptor, intercept_server)