From deb4a5fb62ff540eb06003d9b2940d09b8aeeb16 Mon Sep 17 00:00:00 2001 From: "peng.xu" Date: Wed, 18 Sep 2019 14:50:36 +0800 Subject: [PATCH] update for service discovery --- mishards/__init__.py | 8 ++ mishards/connections.py | 9 +- mishards/main.py | 16 ++- mishards/server.py | 2 +- mishards/service_founder.py | 273 ++++++++++++++++++++++++++++++++++++ mishards/service_handler.py | 7 +- mishards/settings.py | 11 +- 7 files changed, 315 insertions(+), 11 deletions(-) create mode 100644 mishards/service_founder.py diff --git a/mishards/__init__.py b/mishards/__init__.py index 700dd4238c..b3a14cf7e3 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -2,5 +2,13 @@ import settings from connections import ConnectionMgr connect_mgr = ConnectionMgr() +from service_founder import ServiceFounder +discover = ServiceFounder(namespace=settings.SD_NAMESPACE, + conn_mgr=connect_mgr, + pod_patt=settings.SD_ROSERVER_POD_PATT, + label_selector=settings.SD_LABEL_SELECTOR, + in_cluster=settings.SD_IN_CLUSTER, + poll_interval=settings.SD_POLL_INTERVAL) + from server import Server grpc_server = Server(conn_mgr=connect_mgr) diff --git a/mishards/connections.py b/mishards/connections.py index 06d5f3ff16..82dd082eac 100644 --- a/mishards/connections.py +++ b/mishards/connections.py @@ -29,7 +29,7 @@ class Connection: self.conn.connect(uri=self.uri) except Exception as e: if not self.error_handlers: - raise exceptions.ConnectionConnectError() + raise exceptions.ConnectionConnectError(e) for handler in self.error_handlers: handler(e) @@ -77,6 +77,10 @@ class ConnectionMgr: self.metas = {} self.conns = {} + @property + def conn_names(self): + return set(self.metas.keys()) - set(['WOSERVER']) + def conn(self, name, throw=False): c = self.conns.get(name, None) if not c: @@ -116,7 +120,8 @@ class ConnectionMgr: return self.on_diff_meta(name, url) def on_same_meta(self, name, url): - logger.warn('Register same meta: {}:{}'.format(name, url)) + # logger.warn('Register same meta: {}:{}'.format(name, url)) + pass def on_diff_meta(self, name, url): logger.warn('Received {} with diff url={}'.format(name, url)) diff --git a/mishards/main.py b/mishards/main.py index 2ba3f14697..0526f87ff8 100644 --- a/mishards/main.py +++ b/mishards/main.py @@ -3,13 +3,19 @@ import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import settings -from mishards import connect_mgr, grpc_server as server +from mishards import (connect_mgr, + discover, + grpc_server as server) def main(): - connect_mgr.register('WOSERVER', settings.WOSERVER) - connect_mgr.register('TEST', 'tcp://127.0.0.1:19530') - server.run(port=settings.SERVER_PORT) - return 0 + try: + discover.start() + connect_mgr.register('WOSERVER', settings.WOSERVER if not settings.TESTING else settings.TESTING_WOSERVER) + server.run(port=settings.SERVER_PORT) + return 0 + except Exception as e: + logger.error(e) + return 1 if __name__ == '__main__': sys.exit(main()) diff --git a/mishards/server.py b/mishards/server.py index 59ea7db46b..d2f88cf592 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -43,5 +43,5 @@ class Server: def stop(self): logger.info('Server is shuting down ......') self.exit_flag = True - self.server.stop(0) + self.server_impl.stop(0) logger.info('Server is closed') diff --git a/mishards/service_founder.py b/mishards/service_founder.py new file mode 100644 index 0000000000..7fc47639e7 --- /dev/null +++ b/mishards/service_founder.py @@ -0,0 +1,273 @@ +import os, sys +if __name__ == '__main__': + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import re +import logging +import time +import copy +import threading +import queue +from functools import wraps +from kubernetes import client, config, watch + +from mishards.utils import singleton + +logger = logging.getLogger(__name__) + +incluster_namespace_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' + + +class K8SMixin: + def __init__(self, namespace, in_cluster=False, **kwargs): + self.namespace = namespace + self.in_cluster = in_cluster + self.kwargs = kwargs + self.v1 = kwargs.get('v1', None) + if not self.namespace: + self.namespace = open(incluster_namespace_path).read() + + if not self.v1: + config.load_incluster_config() if self.in_cluster else config.load_kube_config() + self.v1 = client.CoreV1Api() + + +class K8SServiceDiscover(threading.Thread, K8SMixin): + def __init__(self, message_queue, namespace, label_selector, in_cluster=False, **kwargs): + K8SMixin.__init__(self, namespace=namespace, in_cluster=in_cluster, **kwargs) + threading.Thread.__init__(self) + self.queue = message_queue + self.terminate = False + self.label_selector = label_selector + self.poll_interval = kwargs.get('poll_interval', 5) + + def run(self): + while not self.terminate: + try: + pods = self.v1.list_namespaced_pod(namespace=self.namespace, label_selector=self.label_selector) + event_message = { + 'eType': 'PodHeartBeat', + 'events': [] + } + for item in pods.items: + pod = self.v1.read_namespaced_pod(name=item.metadata.name, namespace=self.namespace) + name = pod.metadata.name + ip = pod.status.pod_ip + phase = pod.status.phase + reason = pod.status.reason + message = pod.status.message + ready = True if phase == 'Running' else False + + pod_event = dict( + pod=name, + ip=ip, + ready=ready, + reason=reason, + message=message + ) + + event_message['events'].append(pod_event) + + self.queue.put(event_message) + + + except Exception as exc: + logger.error(exc) + + time.sleep(self.poll_interval) + + def stop(self): + self.terminate = True + + +class K8SEventListener(threading.Thread, K8SMixin): + def __init__(self, message_queue, namespace, in_cluster=False, **kwargs): + K8SMixin.__init__(self, namespace=namespace, in_cluster=in_cluster, **kwargs) + threading.Thread.__init__(self) + self.queue = message_queue + self.terminate = False + self.at_start_up = True + self._stop_event = threading.Event() + + def stop(self): + self.terminate = True + self._stop_event.set() + + def run(self): + resource_version = '' + w = watch.Watch() + for event in w.stream(self.v1.list_namespaced_event, namespace=self.namespace, + field_selector='involvedObject.kind=Pod'): + if self.terminate: + break + + resource_version = int(event['object'].metadata.resource_version) + + info = dict( + eType='WatchEvent', + pod=event['object'].involved_object.name, + reason=event['object'].reason, + message=event['object'].message, + start_up=self.at_start_up, + ) + self.at_start_up = False + # logger.info('Received event: {}'.format(info)) + self.queue.put(info) + + +class EventHandler(threading.Thread): + def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs): + threading.Thread.__init__(self) + self.mgr = mgr + self.queue = message_queue + self.kwargs = kwargs + self.terminate = False + self.pod_patt = re.compile(pod_patt) + self.namespace = namespace + + def stop(self): + self.terminate = True + + def on_drop(self, event, **kwargs): + pass + + def on_pod_started(self, event, **kwargs): + try_cnt = 3 + pod = None + while try_cnt > 0: + try_cnt -= 1 + try: + pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], namespace=self.namespace) + if not pod.status.pod_ip: + time.sleep(0.5) + continue + break + except client.rest.ApiException as exc: + time.sleep(0.5) + + if try_cnt <= 0 and not pod: + if not event['start_up']: + logger.error('Pod {} is started but cannot read pod'.format(event['pod'])) + return + elif try_cnt <= 0 and not pod.status.pod_ip: + logger.warn('NoPodIPFoundError') + return + + logger.info('Register POD {} with IP {}'.format(pod.metadata.name, pod.status.pod_ip)) + self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip) + + def on_pod_killing(self, event, **kwargs): + logger.info('Unregister POD {}'.format(event['pod'])) + self.mgr.delete_pod(name=event['pod']) + + def on_pod_heartbeat(self, event, **kwargs): + names = self.mgr.conn_mgr.conn_names + + running_names = set() + for each_event in event['events']: + if each_event['ready']: + self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip']) + running_names.add(each_event['pod']) + else: + self.mgr.delete_pod(name=each_event['pod']) + + to_delete = names - running_names + for name in to_delete: + self.mgr.delete_pod(name) + + logger.info(self.mgr.conn_mgr.conn_names) + + def handle_event(self, event): + if event['eType'] == 'PodHeartBeat': + return self.on_pod_heartbeat(event) + + if not event or (event['reason'] not in ('Started', 'Killing')): + return self.on_drop(event) + + if not re.match(self.pod_patt, event['pod']): + return self.on_drop(event) + + logger.info('Handling event: {}'.format(event)) + + if event['reason'] == 'Started': + return self.on_pod_started(event) + + return self.on_pod_killing(event) + + def run(self): + while not self.terminate: + try: + event = self.queue.get(timeout=1) + self.handle_event(event) + except queue.Empty: + continue + +@singleton +class ServiceFounder(object): + def __init__(self, conn_mgr, namespace, pod_patt, label_selector, in_cluster=False, **kwargs): + self.namespace = namespace + self.kwargs = kwargs + self.queue = queue.Queue() + self.in_cluster = in_cluster + + self.conn_mgr = conn_mgr + + if not self.namespace: + self.namespace = open(incluster_namespace_path).read() + + config.load_incluster_config() if self.in_cluster else config.load_kube_config() + self.v1 = client.CoreV1Api() + + self.listener = K8SEventListener( + message_queue=self.queue, + namespace=self.namespace, + in_cluster=self.in_cluster, + v1=self.v1, + **kwargs + ) + + self.pod_heartbeater = K8SServiceDiscover( + message_queue=self.queue, + namespace=namespace, + label_selector=label_selector, + in_cluster=self.in_cluster, + v1=self.v1, + **kwargs + ) + + self.event_handler = EventHandler(mgr=self, + message_queue=self.queue, + namespace=self.namespace, + pod_patt=pod_patt, **kwargs) + + def add_pod(self, name, ip): + self.conn_mgr.register(name, 'tcp://{}:19530'.format(ip)) + + def delete_pod(self, name): + self.conn_mgr.unregister(name) + + def start(self): + self.listener.daemon = True + self.listener.start() + self.event_handler.start() + while self.listener.at_start_up: + time.sleep(1) + + self.pod_heartbeater.start() + + def stop(self): + self.listener.stop() + self.pod_heartbeater.stop() + self.event_handler.stop() + + +if __name__ == '__main__': + from mishards import connect_mgr + logging.basicConfig(level=logging.INFO) + t = ServiceFounder(namespace='xp', conn_mgr=connect_mgr, pod_patt=".*-ro-servers-.*", label_selector='tier=ro-servers', in_cluster=False) + t.start() + cnt = 2 + while cnt > 0: + time.sleep(2) + cnt -= 1 + t.stop() diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 89ae2cd36c..516359f27c 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -11,6 +11,7 @@ from milvus.client import types import settings from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser +import exceptions logger = logging.getLogger(__name__) @@ -30,7 +31,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def query_conn(self, name): conn = self.conn_mgr.conn(name) - conn and conn.on_connect() + if not conn: + raise exceptions.ConnectionNotFoundError(name) + conn.on_connect() return conn.conn def _format_date(self, start, end): @@ -51,7 +54,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def _get_routing_file_ids(self, table_id, range_array): return { - 'TEST': { + 'milvus-ro-servers-0': { 'table_id': table_id, 'file_ids': [123] } diff --git a/mishards/settings.py b/mishards/settings.py index 4d87e69fe3..c4466da6ec 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -7,7 +7,6 @@ env = Env() env.read_env() DEBUG = env.bool('DEBUG', False) -TESTING = env.bool('TESTING', False) METADATA_URI = env.str('METADATA_URI', '') @@ -26,6 +25,16 @@ SEARCH_WORKER_SIZE = env.int('SEARCH_WORKER_SIZE', 10) SERVER_PORT = env.int('SERVER_PORT', 19530) WOSERVER = env.str('WOSERVER') +SD_NAMESPACE = env.str('SD_NAMESPACE', '') +SD_IN_CLUSTER = env.bool('SD_IN_CLUSTER', False) +SD_POLL_INTERVAL = env.int('SD_POLL_INTERVAL', 5) +SD_ROSERVER_POD_PATT = env.str('SD_ROSERVER_POD_PATT', '') +SD_LABEL_SELECTOR = env.str('SD_LABEL_SELECTOR', '') + +TESTING = env.bool('TESTING', False) +TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530') + + if __name__ == '__main__': import logging logger = logging.getLogger(__name__)