update for service discovery

This commit is contained in:
peng.xu 2019-09-18 14:50:36 +08:00
parent 86a893cb04
commit deb4a5fb62
7 changed files with 315 additions and 11 deletions

View File

@ -2,5 +2,13 @@ import settings
from connections import ConnectionMgr
connect_mgr = ConnectionMgr()
from service_founder import ServiceFounder
discover = ServiceFounder(namespace=settings.SD_NAMESPACE,
conn_mgr=connect_mgr,
pod_patt=settings.SD_ROSERVER_POD_PATT,
label_selector=settings.SD_LABEL_SELECTOR,
in_cluster=settings.SD_IN_CLUSTER,
poll_interval=settings.SD_POLL_INTERVAL)
from server import Server
grpc_server = Server(conn_mgr=connect_mgr)

View File

@ -29,7 +29,7 @@ class Connection:
self.conn.connect(uri=self.uri)
except Exception as e:
if not self.error_handlers:
raise exceptions.ConnectionConnectError()
raise exceptions.ConnectionConnectError(e)
for handler in self.error_handlers:
handler(e)
@ -77,6 +77,10 @@ class ConnectionMgr:
self.metas = {}
self.conns = {}
@property
def conn_names(self):
return set(self.metas.keys()) - set(['WOSERVER'])
def conn(self, name, throw=False):
c = self.conns.get(name, None)
if not c:
@ -116,7 +120,8 @@ class ConnectionMgr:
return self.on_diff_meta(name, url)
def on_same_meta(self, name, url):
logger.warn('Register same meta: {}:{}'.format(name, url))
# logger.warn('Register same meta: {}:{}'.format(name, url))
pass
def on_diff_meta(self, name, url):
logger.warn('Received {} with diff url={}'.format(name, url))

View File

@ -3,13 +3,19 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import settings
from mishards import connect_mgr, grpc_server as server
from mishards import (connect_mgr,
discover,
grpc_server as server)
def main():
connect_mgr.register('WOSERVER', settings.WOSERVER)
connect_mgr.register('TEST', 'tcp://127.0.0.1:19530')
server.run(port=settings.SERVER_PORT)
return 0
try:
discover.start()
connect_mgr.register('WOSERVER', settings.WOSERVER if not settings.TESTING else settings.TESTING_WOSERVER)
server.run(port=settings.SERVER_PORT)
return 0
except Exception as e:
logger.error(e)
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@ -43,5 +43,5 @@ class Server:
def stop(self):
logger.info('Server is shuting down ......')
self.exit_flag = True
self.server.stop(0)
self.server_impl.stop(0)
logger.info('Server is closed')

273
mishards/service_founder.py Normal file
View File

@ -0,0 +1,273 @@
import os, sys
if __name__ == '__main__':
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import re
import logging
import time
import copy
import threading
import queue
from functools import wraps
from kubernetes import client, config, watch
from mishards.utils import singleton
logger = logging.getLogger(__name__)
incluster_namespace_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
class K8SMixin:
def __init__(self, namespace, in_cluster=False, **kwargs):
self.namespace = namespace
self.in_cluster = in_cluster
self.kwargs = kwargs
self.v1 = kwargs.get('v1', None)
if not self.namespace:
self.namespace = open(incluster_namespace_path).read()
if not self.v1:
config.load_incluster_config() if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
class K8SServiceDiscover(threading.Thread, K8SMixin):
def __init__(self, message_queue, namespace, label_selector, in_cluster=False, **kwargs):
K8SMixin.__init__(self, namespace=namespace, in_cluster=in_cluster, **kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.label_selector = label_selector
self.poll_interval = kwargs.get('poll_interval', 5)
def run(self):
while not self.terminate:
try:
pods = self.v1.list_namespaced_pod(namespace=self.namespace, label_selector=self.label_selector)
event_message = {
'eType': 'PodHeartBeat',
'events': []
}
for item in pods.items:
pod = self.v1.read_namespaced_pod(name=item.metadata.name, namespace=self.namespace)
name = pod.metadata.name
ip = pod.status.pod_ip
phase = pod.status.phase
reason = pod.status.reason
message = pod.status.message
ready = True if phase == 'Running' else False
pod_event = dict(
pod=name,
ip=ip,
ready=ready,
reason=reason,
message=message
)
event_message['events'].append(pod_event)
self.queue.put(event_message)
except Exception as exc:
logger.error(exc)
time.sleep(self.poll_interval)
def stop(self):
self.terminate = True
class K8SEventListener(threading.Thread, K8SMixin):
def __init__(self, message_queue, namespace, in_cluster=False, **kwargs):
K8SMixin.__init__(self, namespace=namespace, in_cluster=in_cluster, **kwargs)
threading.Thread.__init__(self)
self.queue = message_queue
self.terminate = False
self.at_start_up = True
self._stop_event = threading.Event()
def stop(self):
self.terminate = True
self._stop_event.set()
def run(self):
resource_version = ''
w = watch.Watch()
for event in w.stream(self.v1.list_namespaced_event, namespace=self.namespace,
field_selector='involvedObject.kind=Pod'):
if self.terminate:
break
resource_version = int(event['object'].metadata.resource_version)
info = dict(
eType='WatchEvent',
pod=event['object'].involved_object.name,
reason=event['object'].reason,
message=event['object'].message,
start_up=self.at_start_up,
)
self.at_start_up = False
# logger.info('Received event: {}'.format(info))
self.queue.put(info)
class EventHandler(threading.Thread):
def __init__(self, mgr, message_queue, namespace, pod_patt, **kwargs):
threading.Thread.__init__(self)
self.mgr = mgr
self.queue = message_queue
self.kwargs = kwargs
self.terminate = False
self.pod_patt = re.compile(pod_patt)
self.namespace = namespace
def stop(self):
self.terminate = True
def on_drop(self, event, **kwargs):
pass
def on_pod_started(self, event, **kwargs):
try_cnt = 3
pod = None
while try_cnt > 0:
try_cnt -= 1
try:
pod = self.mgr.v1.read_namespaced_pod(name=event['pod'], namespace=self.namespace)
if not pod.status.pod_ip:
time.sleep(0.5)
continue
break
except client.rest.ApiException as exc:
time.sleep(0.5)
if try_cnt <= 0 and not pod:
if not event['start_up']:
logger.error('Pod {} is started but cannot read pod'.format(event['pod']))
return
elif try_cnt <= 0 and not pod.status.pod_ip:
logger.warn('NoPodIPFoundError')
return
logger.info('Register POD {} with IP {}'.format(pod.metadata.name, pod.status.pod_ip))
self.mgr.add_pod(name=pod.metadata.name, ip=pod.status.pod_ip)
def on_pod_killing(self, event, **kwargs):
logger.info('Unregister POD {}'.format(event['pod']))
self.mgr.delete_pod(name=event['pod'])
def on_pod_heartbeat(self, event, **kwargs):
names = self.mgr.conn_mgr.conn_names
running_names = set()
for each_event in event['events']:
if each_event['ready']:
self.mgr.add_pod(name=each_event['pod'], ip=each_event['ip'])
running_names.add(each_event['pod'])
else:
self.mgr.delete_pod(name=each_event['pod'])
to_delete = names - running_names
for name in to_delete:
self.mgr.delete_pod(name)
logger.info(self.mgr.conn_mgr.conn_names)
def handle_event(self, event):
if event['eType'] == 'PodHeartBeat':
return self.on_pod_heartbeat(event)
if not event or (event['reason'] not in ('Started', 'Killing')):
return self.on_drop(event)
if not re.match(self.pod_patt, event['pod']):
return self.on_drop(event)
logger.info('Handling event: {}'.format(event))
if event['reason'] == 'Started':
return self.on_pod_started(event)
return self.on_pod_killing(event)
def run(self):
while not self.terminate:
try:
event = self.queue.get(timeout=1)
self.handle_event(event)
except queue.Empty:
continue
@singleton
class ServiceFounder(object):
def __init__(self, conn_mgr, namespace, pod_patt, label_selector, in_cluster=False, **kwargs):
self.namespace = namespace
self.kwargs = kwargs
self.queue = queue.Queue()
self.in_cluster = in_cluster
self.conn_mgr = conn_mgr
if not self.namespace:
self.namespace = open(incluster_namespace_path).read()
config.load_incluster_config() if self.in_cluster else config.load_kube_config()
self.v1 = client.CoreV1Api()
self.listener = K8SEventListener(
message_queue=self.queue,
namespace=self.namespace,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs
)
self.pod_heartbeater = K8SServiceDiscover(
message_queue=self.queue,
namespace=namespace,
label_selector=label_selector,
in_cluster=self.in_cluster,
v1=self.v1,
**kwargs
)
self.event_handler = EventHandler(mgr=self,
message_queue=self.queue,
namespace=self.namespace,
pod_patt=pod_patt, **kwargs)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:19530'.format(ip))
def delete_pod(self, name):
self.conn_mgr.unregister(name)
def start(self):
self.listener.daemon = True
self.listener.start()
self.event_handler.start()
while self.listener.at_start_up:
time.sleep(1)
self.pod_heartbeater.start()
def stop(self):
self.listener.stop()
self.pod_heartbeater.stop()
self.event_handler.stop()
if __name__ == '__main__':
from mishards import connect_mgr
logging.basicConfig(level=logging.INFO)
t = ServiceFounder(namespace='xp', conn_mgr=connect_mgr, pod_patt=".*-ro-servers-.*", label_selector='tier=ro-servers', in_cluster=False)
t.start()
cnt = 2
while cnt > 0:
time.sleep(2)
cnt -= 1
t.stop()

View File

@ -11,6 +11,7 @@ from milvus.client import types
import settings
from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
import exceptions
logger = logging.getLogger(__name__)
@ -30,7 +31,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def query_conn(self, name):
conn = self.conn_mgr.conn(name)
conn and conn.on_connect()
if not conn:
raise exceptions.ConnectionNotFoundError(name)
conn.on_connect()
return conn.conn
def _format_date(self, start, end):
@ -51,7 +54,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def _get_routing_file_ids(self, table_id, range_array):
return {
'TEST': {
'milvus-ro-servers-0': {
'table_id': table_id,
'file_ids': [123]
}

View File

@ -7,7 +7,6 @@ env = Env()
env.read_env()
DEBUG = env.bool('DEBUG', False)
TESTING = env.bool('TESTING', False)
METADATA_URI = env.str('METADATA_URI', '')
@ -26,6 +25,16 @@ SEARCH_WORKER_SIZE = env.int('SEARCH_WORKER_SIZE', 10)
SERVER_PORT = env.int('SERVER_PORT', 19530)
WOSERVER = env.str('WOSERVER')
SD_NAMESPACE = env.str('SD_NAMESPACE', '')
SD_IN_CLUSTER = env.bool('SD_IN_CLUSTER', False)
SD_POLL_INTERVAL = env.int('SD_POLL_INTERVAL', 5)
SD_ROSERVER_POD_PATT = env.str('SD_ROSERVER_POD_PATT', '')
SD_LABEL_SELECTOR = env.str('SD_LABEL_SELECTOR', '')
TESTING = env.bool('TESTING', False)
TESTING_WOSERVER = env.str('TESTING_WOSERVER', 'tcp://127.0.0.1:19530')
if __name__ == '__main__':
import logging
logger = logging.getLogger(__name__)